use crate::ast::Program;
use crate::module::{module_path_to_string, LoadedModule, ModuleError, ModulePath};
use crate::parser::parse_program;
use std::collections::{HashMap, HashSet};
use std::fs;
use std::path::{Path, PathBuf};
pub struct ModuleResolver {
search_paths: Vec<PathBuf>,
loaded: HashMap<String, LoadedModule>,
loading: Vec<ModulePath>,
}
impl ModuleResolver {
pub fn new(search_paths: Vec<PathBuf>) -> Self {
Self {
search_paths,
loaded: HashMap::new(),
loading: Vec::new(),
}
}
pub fn find_module_file(&self, base_dir: &Path, module_path: &[String]) -> Option<PathBuf> {
let relative_path = format!("{}.xlog", module_path.join("/"));
let candidate = base_dir.join(&relative_path);
if candidate.exists() {
return Some(candidate);
}
for search_path in &self.search_paths {
let candidate = search_path.join(&relative_path);
if candidate.exists() {
return Some(candidate);
}
}
None
}
fn searched_paths(&self, base_dir: &Path, module_path: &[String]) -> Vec<PathBuf> {
let relative_path = format!("{}.xlog", module_path.join("/"));
let mut searched = vec![base_dir.join(&relative_path)];
for sp in &self.search_paths {
searched.push(sp.join(&relative_path));
}
searched
}
fn check_cycle(&self, module_path: &[String]) -> Option<Vec<ModulePath>> {
let path_str = module_path_to_string(module_path);
for (i, loading_path) in self.loading.iter().enumerate() {
if module_path_to_string(loading_path) == path_str {
let mut cycle: Vec<ModulePath> = self.loading[i..].to_vec();
cycle.push(module_path.to_vec());
return Some(cycle);
}
}
None
}
pub fn extract_exports(program: &Program) -> (HashSet<String>, HashSet<String>) {
let mut pred_exports = HashSet::new();
let mut func_exports = HashSet::new();
for pred in &program.predicates {
if !pred.is_private {
pred_exports.insert(pred.name.clone());
}
}
for rule in &program.rules {
let is_private = program
.predicates
.iter()
.any(|p| p.name == rule.head.predicate && p.is_private);
if !is_private {
pred_exports.insert(rule.head.predicate.clone());
}
}
for func in &program.functions {
if !func.is_private {
func_exports.insert(func.name.clone());
}
}
(pred_exports, func_exports)
}
pub fn load_module(
&mut self,
base_dir: &Path,
module_path: &[String],
) -> Result<&LoadedModule, ModuleError> {
let path_key = module_path_to_string(module_path);
if self.loaded.contains_key(&path_key) {
return Ok(self.loaded.get(&path_key).unwrap());
}
if let Some(cycle) = self.check_cycle(module_path) {
return Err(ModuleError::CircularImport { cycle });
}
let source_file = self
.find_module_file(base_dir, module_path)
.ok_or_else(|| ModuleError::NotFound {
path: module_path.to_vec(),
searched: self.searched_paths(base_dir, module_path),
})?;
self.loading.push(module_path.to_vec());
let source = fs::read_to_string(&source_file).map_err(|e| ModuleError::ParseError {
path: source_file.clone(),
message: e.to_string(),
})?;
let program = parse_program(&source).map_err(|e| ModuleError::ParseError {
path: source_file.clone(),
message: e.to_string(),
})?;
let (exports, function_exports) = Self::extract_exports(&program);
let module_dir = source_file.parent().unwrap_or(base_dir);
for import in &program.imports {
self.load_module(module_dir, &import.module_path)?;
}
self.loading.pop();
let module = LoadedModule {
path: module_path.to_vec(),
source_file,
exports,
function_exports,
program,
};
self.loaded.insert(path_key.clone(), module);
Ok(self.loaded.get(&path_key).unwrap())
}
pub fn check_import(&self, module_path: &[String], predicate: &str) -> Result<(), ModuleError> {
let path_key = module_path_to_string(module_path);
let module = self
.loaded
.get(&path_key)
.ok_or_else(|| ModuleError::NotFound {
path: module_path.to_vec(),
searched: vec![],
})?;
if !module.exports.contains(predicate) {
return Err(ModuleError::PredicateNotFound {
name: predicate.to_string(),
module: module_path.to_vec(),
});
}
Ok(())
}
#[allow(clippy::type_complexity)]
pub fn validate_imports(
&self,
program: &Program,
) -> Result<(HashMap<String, ModulePath>, HashMap<String, ModulePath>), ModuleError> {
let mut imported_predicates: HashMap<String, ModulePath> = HashMap::new();
let mut imported_functions: HashMap<String, ModulePath> = HashMap::new();
for use_decl in &program.imports {
let module = self
.loaded
.get(&module_path_to_string(&use_decl.module_path))
.expect("module should be loaded");
let all_exports: HashSet<String> = module
.exports
.iter()
.chain(module.function_exports.iter())
.cloned()
.collect();
let names_to_import: Vec<String> = match &use_decl.imports {
Some(specific) => specific.clone(),
None => all_exports.iter().cloned().collect(),
};
for name in names_to_import {
let is_predicate = module.exports.contains(&name);
let is_function = module.function_exports.contains(&name);
if !is_predicate && !is_function {
return Err(ModuleError::PredicateNotFound {
name: name.clone(),
module: use_decl.module_path.clone(),
});
}
if is_predicate {
if let Some(prev_module) = imported_predicates.get(&name) {
if prev_module != &use_decl.module_path {
return Err(ModuleError::ImportConflict {
name,
module1: prev_module.clone(),
module2: use_decl.module_path.clone(),
});
}
}
imported_predicates.insert(name.clone(), use_decl.module_path.clone());
}
if is_function {
if let Some(prev_module) = imported_functions.get(&name) {
if prev_module != &use_decl.module_path {
return Err(ModuleError::ImportConflict {
name,
module1: prev_module.clone(),
module2: use_decl.module_path.clone(),
});
}
}
imported_functions.insert(name.clone(), use_decl.module_path.clone());
}
}
}
Ok((imported_predicates, imported_functions))
}
pub fn get_module(&self, module_path: &[String]) -> Option<&LoadedModule> {
self.loaded.get(&module_path_to_string(module_path))
}
pub fn is_loaded(&self, module_path: &str) -> bool {
self.loaded.contains_key(module_path)
}
pub fn loaded_modules(&self) -> Vec<&str> {
self.loaded.keys().map(|s| s.as_str()).collect()
}
pub fn merge_imports(&self, mut program: Program) -> Result<Program, ModuleError> {
for use_decl in &program.imports.clone() {
let path_key = module_path_to_string(&use_decl.module_path);
let loaded_module =
self.loaded
.get(&path_key)
.ok_or_else(|| ModuleError::NotFound {
path: use_decl.module_path.clone(),
searched: vec![],
})?;
let imported_items = match &use_decl.imports {
Some(items) if !items.is_empty() => {
Some(items.iter().cloned().collect())
}
_ => {
None
}
};
program.merge_from(&loaded_module.program, imported_items.as_ref());
}
Ok(program)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::TempDir;
fn create_test_module(dir: &Path, name: &str, content: &str) -> PathBuf {
let path = dir.join(format!("{}.xlog", name));
let mut file = fs::File::create(&path).unwrap();
file.write_all(content.as_bytes()).unwrap();
path
}
#[test]
fn test_find_module_file() {
let tmp = TempDir::new().unwrap();
create_test_module(tmp.path(), "graph", "edge(1, 2).");
let resolver = ModuleResolver::new(vec![]);
let found = resolver.find_module_file(tmp.path(), &["graph".into()]);
assert!(found.is_some());
}
#[test]
fn test_module_not_found() {
let tmp = TempDir::new().unwrap();
let mut resolver = ModuleResolver::new(vec![]);
let result = resolver.load_module(tmp.path(), &["nonexistent".into()]);
assert!(matches!(result, Err(ModuleError::NotFound { .. })));
}
#[test]
fn test_circular_import() {
let tmp = TempDir::new().unwrap();
create_test_module(tmp.path(), "a", "use b.");
create_test_module(tmp.path(), "b", "use a.");
let mut resolver = ModuleResolver::new(vec![]);
let result = resolver.load_module(tmp.path(), &["a".into()]);
assert!(matches!(result, Err(ModuleError::CircularImport { .. })));
}
#[test]
fn test_load_simple_module() {
let tmp = TempDir::new().unwrap();
create_test_module(
tmp.path(),
"math",
r#"
pred add(u32, u32, u32).
add(1, 2, 3).
"#,
);
let mut resolver = ModuleResolver::new(vec![]);
let result = resolver.load_module(tmp.path(), &["math".into()]);
assert!(result.is_ok());
let module = result.unwrap();
assert!(module.exports.contains("add"));
}
#[test]
fn test_private_not_exported() {
let tmp = TempDir::new().unwrap();
create_test_module(
tmp.path(),
"graph",
r#"
pred edge(u32, u32).
private pred helper(u32).
edge(1, 2).
helper(1).
"#,
);
let mut resolver = ModuleResolver::new(vec![]);
let result = resolver.load_module(tmp.path(), &["graph".into()]);
assert!(result.is_ok());
let module = result.unwrap();
assert!(module.exports.contains("edge"));
assert!(!module.exports.contains("helper"));
}
#[test]
fn test_search_paths() {
let tmp = TempDir::new().unwrap();
let lib_dir = tmp.path().join("lib");
fs::create_dir(&lib_dir).unwrap();
create_test_module(&lib_dir, "stdlib", "helper(1).");
let resolver = ModuleResolver::new(vec![lib_dir.clone()]);
let found = resolver.find_module_file(tmp.path(), &["stdlib".into()]);
assert!(found.is_some());
assert!(found.unwrap().starts_with(&lib_dir));
}
#[test]
fn test_function_exports() {
let tmp = TempDir::new().unwrap();
create_test_module(
tmp.path(),
"mathfuncs",
r#"
func square(X) = X * X.
func cube(X) = X * X * X.
private func helper(X) = X.
"#,
);
let mut resolver = ModuleResolver::new(vec![]);
let result = resolver.load_module(tmp.path(), &["mathfuncs".into()]);
assert!(result.is_ok());
let module = result.unwrap();
assert!(module.function_exports.contains("square"));
assert!(module.function_exports.contains("cube"));
assert!(!module.function_exports.contains("helper"));
}
#[test]
fn test_mixed_exports() {
let tmp = TempDir::new().unwrap();
create_test_module(
tmp.path(),
"mixed",
r#"
pred value(i64).
value(42).
func double(X) = X * 2.
"#,
);
let mut resolver = ModuleResolver::new(vec![]);
let result = resolver.load_module(tmp.path(), &["mixed".into()]);
assert!(result.is_ok());
let module = result.unwrap();
assert!(module.exports.contains("value"));
assert!(module.function_exports.contains("double"));
}
}