Skip to main content

xlog_logic/
resolver.rs

1//! Module resolution for XLOG programs.
2
3use crate::ast::Program;
4use crate::module::{module_path_to_string, LoadedModule, ModuleError, ModulePath};
5use crate::parser::parse_program;
6use std::collections::{HashMap, HashSet};
7use std::fs;
8use std::path::{Path, PathBuf};
9
10/// Resolves and loads modules
11pub struct ModuleResolver {
12    /// Directories to search for modules
13    search_paths: Vec<PathBuf>,
14    /// Already loaded modules (path string -> module)
15    loaded: HashMap<String, LoadedModule>,
16    /// Currently loading (for cycle detection)
17    loading: Vec<ModulePath>,
18}
19
20impl ModuleResolver {
21    /// Create a new resolver with given search paths
22    pub fn new(search_paths: Vec<PathBuf>) -> Self {
23        Self {
24            search_paths,
25            loaded: HashMap::new(),
26            loading: Vec::new(),
27        }
28    }
29
30    /// Find the file for a module path
31    pub fn find_module_file(&self, base_dir: &Path, module_path: &[String]) -> Option<PathBuf> {
32        let relative_path = format!("{}.xlog", module_path.join("/"));
33
34        // Try relative to base_dir first
35        let candidate = base_dir.join(&relative_path);
36        if candidate.exists() {
37            return Some(candidate);
38        }
39
40        // Try search paths
41        for search_path in &self.search_paths {
42            let candidate = search_path.join(&relative_path);
43            if candidate.exists() {
44                return Some(candidate);
45            }
46        }
47
48        None
49    }
50
51    /// Get the list of searched paths for error reporting
52    fn searched_paths(&self, base_dir: &Path, module_path: &[String]) -> Vec<PathBuf> {
53        let relative_path = format!("{}.xlog", module_path.join("/"));
54        let mut searched = vec![base_dir.join(&relative_path)];
55        for sp in &self.search_paths {
56            searched.push(sp.join(&relative_path));
57        }
58        searched
59    }
60
61    /// Check if we're in a circular import
62    fn check_cycle(&self, module_path: &[String]) -> Option<Vec<ModulePath>> {
63        let path_str = module_path_to_string(module_path);
64        for (i, loading_path) in self.loading.iter().enumerate() {
65            if module_path_to_string(loading_path) == path_str {
66                // Found cycle - return the cycle path
67                let mut cycle: Vec<ModulePath> = self.loading[i..].to_vec();
68                cycle.push(module_path.to_vec());
69                return Some(cycle);
70            }
71        }
72        None
73    }
74
75    /// Extract exports from a parsed program
76    /// Returns (predicate exports, function exports)
77    pub fn extract_exports(program: &Program) -> (HashSet<String>, HashSet<String>) {
78        let mut pred_exports = HashSet::new();
79        let mut func_exports = HashSet::new();
80
81        // Add declared predicates that aren't private
82        for pred in &program.predicates {
83            if !pred.is_private {
84                pred_exports.insert(pred.name.clone());
85            }
86        }
87
88        // Add rule heads (all rules define public predicates unless declared private)
89        for rule in &program.rules {
90            // Check if this predicate was declared as private
91            let is_private = program
92                .predicates
93                .iter()
94                .any(|p| p.name == rule.head.predicate && p.is_private);
95            if !is_private {
96                pred_exports.insert(rule.head.predicate.clone());
97            }
98        }
99
100        // Add functions that aren't private
101        for func in &program.functions {
102            if !func.is_private {
103                func_exports.insert(func.name.clone());
104            }
105        }
106
107        (pred_exports, func_exports)
108    }
109
110    /// Load a module from a path
111    pub fn load_module(
112        &mut self,
113        base_dir: &Path,
114        module_path: &[String],
115    ) -> Result<&LoadedModule, ModuleError> {
116        let path_key = module_path_to_string(module_path);
117
118        // Already loaded?
119        if self.loaded.contains_key(&path_key) {
120            return Ok(self.loaded.get(&path_key).unwrap());
121        }
122
123        // Check for cycle
124        if let Some(cycle) = self.check_cycle(module_path) {
125            return Err(ModuleError::CircularImport { cycle });
126        }
127
128        // Find the file
129        let source_file = self
130            .find_module_file(base_dir, module_path)
131            .ok_or_else(|| ModuleError::NotFound {
132                path: module_path.to_vec(),
133                searched: self.searched_paths(base_dir, module_path),
134            })?;
135
136        // Mark as loading
137        self.loading.push(module_path.to_vec());
138
139        // Read and parse
140        let source = fs::read_to_string(&source_file).map_err(|e| ModuleError::ParseError {
141            path: source_file.clone(),
142            message: e.to_string(),
143        })?;
144
145        let program = parse_program(&source).map_err(|e| ModuleError::ParseError {
146            path: source_file.clone(),
147            message: e.to_string(),
148        })?;
149
150        // Extract exports
151        let (exports, function_exports) = Self::extract_exports(&program);
152
153        // Recursively load imports
154        let module_dir = source_file.parent().unwrap_or(base_dir);
155        for import in &program.imports {
156            self.load_module(module_dir, &import.module_path)?;
157        }
158
159        // Remove from loading
160        self.loading.pop();
161
162        // Store loaded module
163        let module = LoadedModule {
164            path: module_path.to_vec(),
165            source_file,
166            exports,
167            function_exports,
168            program,
169        };
170
171        self.loaded.insert(path_key.clone(), module);
172        Ok(self.loaded.get(&path_key).unwrap())
173    }
174
175    /// Check if a predicate can be imported from a module
176    pub fn check_import(&self, module_path: &[String], predicate: &str) -> Result<(), ModuleError> {
177        let path_key = module_path_to_string(module_path);
178        let module = self
179            .loaded
180            .get(&path_key)
181            .ok_or_else(|| ModuleError::NotFound {
182                path: module_path.to_vec(),
183                searched: vec![],
184            })?;
185
186        if !module.exports.contains(predicate) {
187            return Err(ModuleError::PredicateNotFound {
188                name: predicate.to_string(),
189                module: module_path.to_vec(),
190            });
191        }
192
193        Ok(())
194    }
195
196    /// Validate all imports in a program
197    /// Returns (predicate imports, function imports) mapped to their source modules
198    #[allow(clippy::type_complexity)]
199    pub fn validate_imports(
200        &self,
201        program: &Program,
202    ) -> Result<(HashMap<String, ModulePath>, HashMap<String, ModulePath>), ModuleError> {
203        let mut imported_predicates: HashMap<String, ModulePath> = HashMap::new();
204        let mut imported_functions: HashMap<String, ModulePath> = HashMap::new();
205
206        for use_decl in &program.imports {
207            let module = self
208                .loaded
209                .get(&module_path_to_string(&use_decl.module_path))
210                .expect("module should be loaded");
211
212            // Combine all available exports for wildcard imports
213            let all_exports: HashSet<String> = module
214                .exports
215                .iter()
216                .chain(module.function_exports.iter())
217                .cloned()
218                .collect();
219
220            let names_to_import: Vec<String> = match &use_decl.imports {
221                Some(specific) => specific.clone(),
222                None => all_exports.iter().cloned().collect(),
223            };
224
225            for name in names_to_import {
226                // Check if name exists as predicate or function
227                let is_predicate = module.exports.contains(&name);
228                let is_function = module.function_exports.contains(&name);
229
230                if !is_predicate && !is_function {
231                    return Err(ModuleError::PredicateNotFound {
232                        name: name.clone(),
233                        module: use_decl.module_path.clone(),
234                    });
235                }
236
237                // Check for conflicts with predicates
238                if is_predicate {
239                    if let Some(prev_module) = imported_predicates.get(&name) {
240                        if prev_module != &use_decl.module_path {
241                            return Err(ModuleError::ImportConflict {
242                                name,
243                                module1: prev_module.clone(),
244                                module2: use_decl.module_path.clone(),
245                            });
246                        }
247                    }
248                    imported_predicates.insert(name.clone(), use_decl.module_path.clone());
249                }
250
251                // Check for conflicts with functions
252                if is_function {
253                    if let Some(prev_module) = imported_functions.get(&name) {
254                        if prev_module != &use_decl.module_path {
255                            return Err(ModuleError::ImportConflict {
256                                name,
257                                module1: prev_module.clone(),
258                                module2: use_decl.module_path.clone(),
259                            });
260                        }
261                    }
262                    imported_functions.insert(name.clone(), use_decl.module_path.clone());
263                }
264            }
265        }
266
267        Ok((imported_predicates, imported_functions))
268    }
269
270    /// Get a loaded module by path
271    pub fn get_module(&self, module_path: &[String]) -> Option<&LoadedModule> {
272        self.loaded.get(&module_path_to_string(module_path))
273    }
274
275    /// Check if a module is loaded
276    pub fn is_loaded(&self, module_path: &str) -> bool {
277        self.loaded.contains_key(module_path)
278    }
279
280    /// Get all loaded module paths (for testing)
281    pub fn loaded_modules(&self) -> Vec<&str> {
282        self.loaded.keys().map(|s| s.as_str()).collect()
283    }
284
285    /// Merge all imported modules into a program.
286    /// Returns a new program with all imports resolved and merged.
287    ///
288    /// # Arguments
289    /// * `program` - The main program with imports to resolve
290    ///
291    /// # Returns
292    /// The program with all imports merged in
293    pub fn merge_imports(&self, mut program: Program) -> Result<Program, ModuleError> {
294        for use_decl in &program.imports.clone() {
295            let path_key = module_path_to_string(&use_decl.module_path);
296            let loaded_module =
297                self.loaded
298                    .get(&path_key)
299                    .ok_or_else(|| ModuleError::NotFound {
300                        path: use_decl.module_path.clone(),
301                        searched: vec![],
302                    })?;
303
304            // Determine which items to import
305            let imported_items = match &use_decl.imports {
306                Some(items) if !items.is_empty() => {
307                    // Import specific items
308                    Some(items.iter().cloned().collect())
309                }
310                _ => {
311                    // Import all public items
312                    None
313                }
314            };
315
316            // Merge the module into the program
317            program.merge_from(&loaded_module.program, imported_items.as_ref());
318        }
319
320        Ok(program)
321    }
322}
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327    use std::io::Write;
328    use tempfile::TempDir;
329
330    fn create_test_module(dir: &Path, name: &str, content: &str) -> PathBuf {
331        let path = dir.join(format!("{}.xlog", name));
332        let mut file = fs::File::create(&path).unwrap();
333        file.write_all(content.as_bytes()).unwrap();
334        path
335    }
336
337    #[test]
338    fn test_find_module_file() {
339        let tmp = TempDir::new().unwrap();
340        create_test_module(tmp.path(), "graph", "edge(1, 2).");
341
342        let resolver = ModuleResolver::new(vec![]);
343        let found = resolver.find_module_file(tmp.path(), &["graph".into()]);
344        assert!(found.is_some());
345    }
346
347    #[test]
348    fn test_module_not_found() {
349        let tmp = TempDir::new().unwrap();
350        let mut resolver = ModuleResolver::new(vec![]);
351
352        let result = resolver.load_module(tmp.path(), &["nonexistent".into()]);
353        assert!(matches!(result, Err(ModuleError::NotFound { .. })));
354    }
355
356    #[test]
357    fn test_circular_import() {
358        let tmp = TempDir::new().unwrap();
359        create_test_module(tmp.path(), "a", "use b.");
360        create_test_module(tmp.path(), "b", "use a.");
361
362        let mut resolver = ModuleResolver::new(vec![]);
363        let result = resolver.load_module(tmp.path(), &["a".into()]);
364        assert!(matches!(result, Err(ModuleError::CircularImport { .. })));
365    }
366
367    #[test]
368    fn test_load_simple_module() {
369        let tmp = TempDir::new().unwrap();
370        create_test_module(
371            tmp.path(),
372            "math",
373            r#"
374            pred add(u32, u32, u32).
375            add(1, 2, 3).
376        "#,
377        );
378
379        let mut resolver = ModuleResolver::new(vec![]);
380        let result = resolver.load_module(tmp.path(), &["math".into()]);
381        assert!(result.is_ok());
382        let module = result.unwrap();
383        assert!(module.exports.contains("add"));
384    }
385
386    #[test]
387    fn test_private_not_exported() {
388        let tmp = TempDir::new().unwrap();
389        create_test_module(
390            tmp.path(),
391            "graph",
392            r#"
393            pred edge(u32, u32).
394            private pred helper(u32).
395            edge(1, 2).
396            helper(1).
397        "#,
398        );
399
400        let mut resolver = ModuleResolver::new(vec![]);
401        let result = resolver.load_module(tmp.path(), &["graph".into()]);
402        assert!(result.is_ok());
403        let module = result.unwrap();
404        assert!(module.exports.contains("edge"));
405        assert!(!module.exports.contains("helper"));
406    }
407
408    #[test]
409    fn test_search_paths() {
410        let tmp = TempDir::new().unwrap();
411        let lib_dir = tmp.path().join("lib");
412        fs::create_dir(&lib_dir).unwrap();
413        create_test_module(&lib_dir, "stdlib", "helper(1).");
414
415        let resolver = ModuleResolver::new(vec![lib_dir.clone()]);
416        let found = resolver.find_module_file(tmp.path(), &["stdlib".into()]);
417        assert!(found.is_some());
418        assert!(found.unwrap().starts_with(&lib_dir));
419    }
420
421    #[test]
422    fn test_function_exports() {
423        let tmp = TempDir::new().unwrap();
424        create_test_module(
425            tmp.path(),
426            "mathfuncs",
427            r#"
428            func square(X) = X * X.
429            func cube(X) = X * X * X.
430            private func helper(X) = X.
431        "#,
432        );
433
434        let mut resolver = ModuleResolver::new(vec![]);
435        let result = resolver.load_module(tmp.path(), &["mathfuncs".into()]);
436        assert!(result.is_ok());
437        let module = result.unwrap();
438
439        // Public functions should be exported
440        assert!(module.function_exports.contains("square"));
441        assert!(module.function_exports.contains("cube"));
442
443        // Private function should not be exported
444        assert!(!module.function_exports.contains("helper"));
445    }
446
447    #[test]
448    fn test_mixed_exports() {
449        let tmp = TempDir::new().unwrap();
450        create_test_module(
451            tmp.path(),
452            "mixed",
453            r#"
454            pred value(i64).
455            value(42).
456            func double(X) = X * 2.
457        "#,
458        );
459
460        let mut resolver = ModuleResolver::new(vec![]);
461        let result = resolver.load_module(tmp.path(), &["mixed".into()]);
462        assert!(result.is_ok());
463        let module = result.unwrap();
464
465        // Both predicate and function exports should be present
466        assert!(module.exports.contains("value"));
467        assert!(module.function_exports.contains("double"));
468    }
469}