Skip to main content

inauguration/
module_resolver.rs

1use std::collections::HashSet;
2use std::path::PathBuf;
3
4use crate::core_ir::UnifiedModule;
5use crate::in_lang_parse;
6
7const MAX_DEPTH: usize = 16;
8
9#[derive(Debug, Clone)]
10pub struct ModuleResolver {
11    search_paths: Vec<PathBuf>,
12}
13
14impl Default for ModuleResolver {
15    fn default() -> Self {
16        Self::new()
17    }
18}
19
20impl ModuleResolver {
21    pub fn new() -> Self {
22        Self {
23            search_paths: vec![PathBuf::from(".")],
24        }
25    }
26
27    pub fn add_search_path(&mut self, path: PathBuf) {
28        self.search_paths.push(path);
29    }
30
31    /// Resolve `use` imports declared in `source`, returning parsed modules (non-fatal).
32    pub fn resolve_imports(&self, source: &str) -> Result<Vec<UnifiedModule>, String> {
33        let surface = in_lang_parse::parse_in_surface_info(source)
34            .map_err(|e| format!("surface info: {e}"))?;
35        let mut modules = Vec::new();
36        let mut seen = HashSet::new();
37        for name in &surface.semantic_imports {
38            self.resolve_recursive(name, &mut modules, &mut seen, 0);
39        }
40        Ok(modules)
41    }
42
43    fn resolve_recursive(
44        &self,
45        name: &str,
46        out: &mut Vec<UnifiedModule>,
47        seen: &mut HashSet<PathBuf>,
48        depth: usize,
49    ) {
50        if depth >= MAX_DEPTH {
51            eprintln!("[import] warning: max depth ({MAX_DEPTH}) reached for `{name}`");
52            return;
53        }
54        let Some(path) = self.find_module(name) else {
55            eprintln!("[import] warning: module not found: `{name}`");
56            return;
57        };
58        let key = path.canonicalize().unwrap_or_else(|_| path.clone());
59        if !seen.insert(key) {
60            return;
61        }
62        let imported = match in_lang_parse::parse_in_library_file(&path) {
63            Ok(m) => m,
64            Err(e) => {
65                eprintln!("[import] warning: `{name}` ({}): {e}", path.display());
66                return;
67            }
68        };
69        // Resolve the imported module's own `use` imports
70        let nested_source = match std::fs::read_to_string(&path) {
71            Ok(s) => s,
72            Err(e) => {
73                eprintln!(
74                    "[import] warning: cannot read `{name}` ({}): {e}",
75                    path.display()
76                );
77                return;
78            }
79        };
80        let nested_surface = match in_lang_parse::parse_in_surface_info(&nested_source) {
81            Ok(s) => s,
82            Err(e) => {
83                eprintln!(
84                    "[import] warning: surface info for `{name}` ({}): {e}",
85                    path.display()
86                );
87                return;
88            }
89        };
90        for nested_name in &nested_surface.semantic_imports {
91            self.resolve_recursive(nested_name, out, seen, depth + 1);
92        }
93        out.push(imported);
94    }
95
96    fn find_module(&self, name: &str) -> Option<PathBuf> {
97        let dotted = name.replace('.', "/");
98        for dir in &self.search_paths {
99            let candidates = [
100                dir.join(&dotted).with_extension("in"),
101                dir.join(&dotted),
102                dir.join(format!("{name}.in")),
103                dir.join(name),
104            ];
105            for c in &candidates {
106                if c.is_file() {
107                    return Some(c.clone());
108                }
109            }
110        }
111        None
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118    use crate::core_ir::Decl;
119    use std::fs;
120    use std::time::{SystemTime, UNIX_EPOCH};
121
122    fn temp_dir(label: &str) -> PathBuf {
123        std::env::temp_dir().join(format!(
124            "inauguration-module-resolver-{}-{}-{label}",
125            std::process::id(),
126            SystemTime::now()
127                .duration_since(UNIX_EPOCH)
128                .unwrap()
129                .as_nanos()
130        ))
131    }
132
133    #[test]
134    fn resolve_simple_import() {
135        let dir = temp_dir("simple");
136        fs::create_dir_all(&dir).unwrap();
137        fs::write(dir.join("lib.in"), "fn helper() -> Int { return 42; }\n").unwrap();
138        fs::write(
139            dir.join("main.in"),
140            "use lib;\nfn main() -> void { helper(); return; }\n",
141        )
142        .unwrap();
143
144        let source = fs::read_to_string(dir.join("main.in")).unwrap();
145        let mut resolver = ModuleResolver::new();
146        resolver.add_search_path(dir.clone());
147        let modules = resolver.resolve_imports(&source).unwrap();
148        fs::remove_dir_all(&dir).unwrap();
149
150        assert_eq!(modules.len(), 1, "expected one imported module");
151        assert!(
152            modules[0]
153                .decls
154                .iter()
155                .any(|d| matches!(d, Decl::Function { name, .. } if name == "helper")),
156            "expected helper in imported module"
157        );
158    }
159
160    #[test]
161    fn import_not_found_returns_empty_with_warning() {
162        let dir = temp_dir("missing");
163        fs::create_dir_all(&dir).unwrap();
164        fs::write(
165            dir.join("main.in"),
166            "use nonexistent;\nfn main() -> void { return; }\n",
167        )
168        .unwrap();
169
170        let source = fs::read_to_string(dir.join("main.in")).unwrap();
171        let resolver = ModuleResolver::new();
172        let modules = resolver.resolve_imports(&source).unwrap();
173        fs::remove_dir_all(&dir).unwrap();
174
175        assert!(
176            modules.is_empty(),
177            "missing import should produce empty result"
178        );
179    }
180
181    #[test]
182    fn resolve_recursive_import() {
183        let dir = temp_dir("recursive");
184        fs::create_dir_all(&dir).unwrap();
185        fs::write(dir.join("base.in"), "fn base_fn() -> Int { return 1; }\n").unwrap();
186        fs::write(
187            dir.join("lib.in"),
188            "use base;\nfn lib_fn() -> Int { return base_fn(); }\n",
189        )
190        .unwrap();
191        fs::write(
192            dir.join("main.in"),
193            "use lib;\nfn main() -> void { lib_fn(); return; }\n",
194        )
195        .unwrap();
196
197        let source = fs::read_to_string(dir.join("main.in")).unwrap();
198        let mut resolver = ModuleResolver::new();
199        resolver.add_search_path(dir.clone());
200        let modules = resolver.resolve_imports(&source).unwrap();
201        fs::remove_dir_all(&dir).unwrap();
202
203        let names: Vec<&str> = modules
204            .iter()
205            .flat_map(|m| m.decls.iter())
206            .filter_map(|d| match d {
207                Decl::Function { name, .. } => Some(name.as_str()),
208                _ => None,
209            })
210            .collect();
211        assert!(
212            names.contains(&"base_fn"),
213            "expected base_fn, got {names:?}"
214        );
215        assert!(names.contains(&"lib_fn"), "expected lib_fn, got {names:?}");
216    }
217
218    #[test]
219    fn resolve_dotted_import_path() {
220        let dir = temp_dir("dotted");
221        fs::create_dir_all(&dir).unwrap();
222        let sub = dir.join("data");
223        fs::create_dir_all(&sub).unwrap();
224        fs::write(sub.join("models.in"), "fn make() -> Int { return 0; }\n").unwrap();
225        fs::write(
226            dir.join("main.in"),
227            "use data.models;\nfn main() -> void { make(); return; }\n",
228        )
229        .unwrap();
230
231        let source = fs::read_to_string(dir.join("main.in")).unwrap();
232        let mut resolver = ModuleResolver::new();
233        resolver.add_search_path(dir.clone());
234        let modules = resolver.resolve_imports(&source).unwrap();
235        fs::remove_dir_all(&dir).unwrap();
236
237        assert_eq!(modules.len(), 1);
238        assert!(
239            modules[0]
240                .decls
241                .iter()
242                .any(|d| matches!(d, Decl::Function { name, .. } if name == "make")),
243            "expected make in data.models module"
244        );
245    }
246}