Skip to main content

imp_lua/
loader.rs

1use std::collections::BTreeSet;
2use std::path::{Path, PathBuf};
3
4use crate::sandbox::{LuaError, LuaRuntime};
5
6/// Discovered Lua extension.
7#[derive(Debug, Clone)]
8pub struct LuaExtension {
9    pub name: String,
10    pub path: PathBuf,
11}
12
13/// Discover Lua extensions from user and project directories.
14pub fn discover_extensions(
15    user_config_dir: &Path,
16    project_dir: Option<&Path>,
17) -> Vec<LuaExtension> {
18    let mut extensions = Vec::new();
19
20    let mut dirs = vec![user_config_dir.join("lua")];
21    if let Some(project) = project_dir {
22        dirs.push(project.join(".imp").join("lua"));
23    }
24
25    let mut seen_names = BTreeSet::new();
26    for dir in &dirs {
27        if let Ok(entries) = std::fs::read_dir(dir) {
28            for entry in entries.flatten() {
29                let path = entry.path();
30
31                // Direct .lua file
32                if path.extension().is_some_and(|e| e == "lua") {
33                    let name = path
34                        .file_stem()
35                        .map(|s| s.to_string_lossy().to_string())
36                        .unwrap_or_default();
37                    if seen_names.insert(name.clone()) {
38                        extensions.push(LuaExtension { name, path });
39                    }
40                    continue;
41                }
42
43                // Directory with init.lua
44                if path.is_dir() {
45                    let init = path.join("init.lua");
46                    if init.exists() {
47                        let name = path
48                            .file_name()
49                            .map(|s| s.to_string_lossy().to_string())
50                            .unwrap_or_default();
51                        if seen_names.insert(name.clone()) {
52                            extensions.push(LuaExtension { name, path: init });
53                        }
54                    }
55                }
56            }
57        }
58    }
59
60    extensions
61}
62
63/// Load all discovered extensions into a Lua runtime.
64pub fn load_extensions(
65    runtime: &LuaRuntime,
66    extensions: &[LuaExtension],
67) -> Vec<(String, Result<(), LuaError>)> {
68    extensions
69        .iter()
70        .map(|ext| {
71            let result = runtime.exec_file(&ext.path);
72            (ext.name.clone(), result)
73        })
74        .collect()
75}
76
77/// Hot reload: drop old state, create new runtime, re-load extensions.
78pub fn reload(
79    user_config_dir: &Path,
80    project_dir: Option<&Path>,
81    policy: &imp_core::config::LuaCapabilityPolicy,
82) -> Result<(LuaRuntime, Vec<LuaExtension>), LuaError> {
83    let extensions = discover_extensions(user_config_dir, project_dir);
84    let runtime = LuaRuntime::new()?;
85    crate::bridge::setup_host_api(&runtime)?;
86    runtime.apply_capability_policy(policy);
87    load_extensions(&runtime, &extensions);
88    Ok((runtime, extensions))
89}
90
91#[cfg(test)]
92mod tests {
93    use super::*;
94
95    #[test]
96    fn discover_extensions_deduplicates_global_and_project_names() {
97        let temp = tempfile::tempdir().unwrap();
98        let user_config = temp.path().join("user");
99        let project = temp.path().join("project");
100        std::fs::create_dir_all(user_config.join("lua")).unwrap();
101        std::fs::create_dir_all(project.join(".imp").join("lua")).unwrap();
102        std::fs::write(user_config.join("lua").join("imp-update.lua"), "").unwrap();
103        std::fs::write(project.join(".imp").join("lua").join("imp-update.lua"), "").unwrap();
104
105        let extensions = discover_extensions(&user_config, Some(&project));
106
107        assert_eq!(extensions.len(), 1);
108        assert_eq!(extensions[0].name, "imp-update");
109        assert_eq!(
110            extensions[0].path,
111            user_config.join("lua").join("imp-update.lua")
112        );
113    }
114}