1use std::collections::BTreeSet;
2use std::path::{Path, PathBuf};
3
4use crate::sandbox::{LuaError, LuaRuntime};
5
6#[derive(Debug, Clone)]
8pub struct LuaExtension {
9 pub name: String,
10 pub path: PathBuf,
11}
12
13pub fn discover_extensions(
15 user_config_dir: &Path,
16 project_dir: Option<&Path>,
17) -> Vec<LuaExtension> {
18 let mut extensions = Vec::new();
19
20 let mut dirs = vec![user_config_dir.join("lua")];
21 if let Some(project) = project_dir {
22 dirs.push(project.join(".imp").join("lua"));
23 }
24
25 let mut seen_names = BTreeSet::new();
26 for dir in &dirs {
27 if let Ok(entries) = std::fs::read_dir(dir) {
28 for entry in entries.flatten() {
29 let path = entry.path();
30
31 if path.extension().is_some_and(|e| e == "lua") {
33 let name = path
34 .file_stem()
35 .map(|s| s.to_string_lossy().to_string())
36 .unwrap_or_default();
37 if seen_names.insert(name.clone()) {
38 extensions.push(LuaExtension { name, path });
39 }
40 continue;
41 }
42
43 if path.is_dir() {
45 let init = path.join("init.lua");
46 if init.exists() {
47 let name = path
48 .file_name()
49 .map(|s| s.to_string_lossy().to_string())
50 .unwrap_or_default();
51 if seen_names.insert(name.clone()) {
52 extensions.push(LuaExtension { name, path: init });
53 }
54 }
55 }
56 }
57 }
58 }
59
60 extensions
61}
62
63pub fn load_extensions(
65 runtime: &LuaRuntime,
66 extensions: &[LuaExtension],
67) -> Vec<(String, Result<(), LuaError>)> {
68 extensions
69 .iter()
70 .map(|ext| {
71 let result = runtime.exec_file(&ext.path);
72 (ext.name.clone(), result)
73 })
74 .collect()
75}
76
77pub fn reload(
79 user_config_dir: &Path,
80 project_dir: Option<&Path>,
81 policy: &imp_core::config::LuaCapabilityPolicy,
82) -> Result<(LuaRuntime, Vec<LuaExtension>), LuaError> {
83 let extensions = discover_extensions(user_config_dir, project_dir);
84 let runtime = LuaRuntime::new()?;
85 crate::bridge::setup_host_api(&runtime)?;
86 runtime.apply_capability_policy(policy);
87 load_extensions(&runtime, &extensions);
88 Ok((runtime, extensions))
89}
90
91#[cfg(test)]
92mod tests {
93 use super::*;
94
95 #[test]
96 fn discover_extensions_deduplicates_global_and_project_names() {
97 let temp = tempfile::tempdir().unwrap();
98 let user_config = temp.path().join("user");
99 let project = temp.path().join("project");
100 std::fs::create_dir_all(user_config.join("lua")).unwrap();
101 std::fs::create_dir_all(project.join(".imp").join("lua")).unwrap();
102 std::fs::write(user_config.join("lua").join("imp-update.lua"), "").unwrap();
103 std::fs::write(project.join(".imp").join("lua").join("imp-update.lua"), "").unwrap();
104
105 let extensions = discover_extensions(&user_config, Some(&project));
106
107 assert_eq!(extensions.len(), 1);
108 assert_eq!(extensions[0].name, "imp-update");
109 assert_eq!(
110 extensions[0].path,
111 user_config.join("lua").join("imp-update.lua")
112 );
113 }
114}