use std::collections::{HashMap, HashSet};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use async_trait::async_trait;
use crate::errors::{ErrorCode, ModuleError};
use crate::module::Module;
use crate::registry::dependencies::resolve_dependencies;
use crate::registry::metadata::{load_id_map, load_metadata, parse_dependencies};
use crate::registry::registry::{
DiscoveredModule, Discoverer, ModuleDescriptor, MAX_MODULE_ID_LENGTH, RESERVED_WORDS,
};
use crate::registry::scanner::scan_extensions;
use crate::registry::types::{DepInfo, DiscoveredFile};
pub type ModuleFactory = Arc<
dyn Fn(&DiscoveredFile, &str) -> Result<Option<Arc<dyn Module>>, ModuleError> + Send + Sync,
>;
pub struct DefaultDiscoverer {
id_map_path: Option<PathBuf>,
extensions: Vec<String>,
max_depth: u32,
follow_symlinks: bool,
factory: ModuleFactory,
}
impl std::fmt::Debug for DefaultDiscoverer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DefaultDiscoverer")
.field("id_map_path", &self.id_map_path)
.field("extensions", &self.extensions)
.field("max_depth", &self.max_depth)
.field("follow_symlinks", &self.follow_symlinks)
.field("factory", &"<ModuleFactory>")
.finish()
}
}
impl DefaultDiscoverer {
#[must_use]
pub fn new() -> Self {
Self {
id_map_path: None,
extensions: vec![".rs".to_string()],
max_depth: 8,
follow_symlinks: false,
factory: Arc::new(|_file, _entry| Ok(None)),
}
}
#[must_use]
pub fn with_id_map(mut self, path: Option<impl AsRef<Path>>) -> Self {
self.id_map_path = path.map(|p| p.as_ref().to_path_buf());
self
}
#[must_use]
pub fn with_extensions(mut self, exts: &[&str]) -> Self {
self.extensions = exts.iter().map(|s| (*s).to_string()).collect();
self
}
#[must_use]
pub fn with_max_depth(mut self, depth: u32) -> Self {
self.max_depth = depth;
self
}
#[must_use]
pub fn with_follow_symlinks(mut self, follow: bool) -> Self {
self.follow_symlinks = follow;
self
}
#[must_use]
pub fn with_factory(mut self, factory: ModuleFactory) -> Self {
self.factory = factory;
self
}
}
impl Default for DefaultDiscoverer {
fn default() -> Self {
Self::new()
}
}
struct Pending {
file: DiscoveredFile,
module: Arc<dyn Module>,
descriptor: ModuleDescriptor,
deps: Vec<DepInfo>,
}
#[async_trait]
impl Discoverer for DefaultDiscoverer {
#[allow(clippy::too_many_lines)] async fn discover(&self, roots: &[String]) -> Result<Vec<DiscoveredModule>, ModuleError> {
let mut discovered_files: Vec<DiscoveredFile> = Vec::new();
let ext_refs: Vec<&str> = self.extensions.iter().map(String::as_str).collect();
for root in roots {
let path = Path::new(root);
let mut files =
scan_extensions(path, self.max_depth, self.follow_symlinks, Some(&ext_refs))?;
discovered_files.append(&mut files);
}
let id_overrides: HashMap<String, HashMap<String, serde_json::Value>> =
match &self.id_map_path {
Some(path) => load_id_map(path)?,
None => HashMap::new(),
};
for file in &mut discovered_files {
if let Some(override_entry) =
id_overrides.get(file.file_path.to_string_lossy().as_ref())
{
if let Some(new_id) = override_entry.get("id").and_then(|v| v.as_str()) {
file.canonical_id = new_id.to_string();
}
}
}
let mut metadata_per_file: HashMap<PathBuf, HashMap<String, serde_json::Value>> =
HashMap::new();
for file in &discovered_files {
if let Some(meta_path) = &file.meta_path {
let meta = load_metadata(meta_path)?;
metadata_per_file.insert(file.file_path.clone(), meta);
}
}
let mut pending: Vec<Pending> = Vec::new();
for file in discovered_files {
let first_seg = file.canonical_id.split('.').next().unwrap_or("");
if RESERVED_WORDS.contains(&first_seg) {
tracing::warn!(
canonical_id = %file.canonical_id,
"Skipping discovered file: first segment is a reserved word"
);
continue;
}
if file.canonical_id.len() > MAX_MODULE_ID_LENGTH {
tracing::warn!(
canonical_id = %file.canonical_id,
"Skipping discovered file: module_id exceeds {MAX_MODULE_ID_LENGTH} chars"
);
continue;
}
let meta = metadata_per_file.get(&file.file_path);
let entry_point_name =
crate::registry::entry_point::resolve_entry_point_name(&file.file_path, meta)?
.unwrap_or_else(|| {
crate::registry::entry_point::infer_struct_name(&file.file_path)
});
let Some(module) = (self.factory)(&file, &entry_point_name)? else {
tracing::debug!(
canonical_id = %file.canonical_id,
entry_point = %entry_point_name,
"DefaultDiscoverer factory returned None — skipping"
);
continue;
};
let descriptor = build_descriptor(&file, module.as_ref(), meta);
let deps = if let Some(m) = meta {
m.get("dependencies")
.and_then(|v| v.as_array())
.map(|arr| parse_dependencies(arr))
.unwrap_or_default()
} else {
Vec::new()
};
pending.push(Pending {
file,
module,
descriptor,
deps,
});
}
let mut seen_ids_lower: HashSet<String> = HashSet::new();
for p in &pending {
let lower = p.file.canonical_id.to_lowercase();
if !seen_ids_lower.insert(lower.clone()) {
return Err(ModuleError::new(
ErrorCode::ModuleIdConflict,
format!(
"Duplicate module ID '{}' (case-insensitive) discovered in roots {:?}",
p.file.canonical_id, roots,
),
));
}
}
let modules_with_deps: Vec<(String, Vec<DepInfo>)> = pending
.iter()
.map(|p| (p.file.canonical_id.clone(), p.deps.clone()))
.collect();
let module_versions: HashMap<String, String> = pending
.iter()
.map(|p| (p.file.canonical_id.clone(), p.descriptor.version.clone()))
.collect();
let load_order = resolve_dependencies(&modules_with_deps, None, Some(&module_versions))?;
let by_id: HashMap<String, Pending> = pending
.into_iter()
.map(|p| (p.file.canonical_id.clone(), p))
.collect();
let mut result = Vec::with_capacity(load_order.len());
for id in load_order {
if let Some(p) = by_id.get(&id) {
result.push(DiscoveredModule {
name: p.file.canonical_id.clone(),
source: p.file.file_path.to_string_lossy().into_owned(),
descriptor: p.descriptor.clone(),
module: Arc::clone(&p.module),
});
}
}
Ok(result)
}
}
fn build_descriptor(
file: &DiscoveredFile,
module: &dyn Module,
meta: Option<&HashMap<String, serde_json::Value>>,
) -> ModuleDescriptor {
let yaml = meta.cloned().unwrap_or_default();
let description = yaml
.get("description")
.and_then(|v| v.as_str())
.filter(|s| !s.is_empty())
.map_or_else(|| module.description().to_string(), str::to_string);
let version = yaml
.get("version")
.and_then(|v| v.as_str())
.unwrap_or("1.0.0")
.to_string();
let tags: Vec<String> = yaml
.get("tags")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_default();
let documentation = yaml
.get("documentation")
.and_then(|v| v.as_str())
.map(String::from);
ModuleDescriptor {
module_id: file.canonical_id.clone(),
name: yaml.get("name").and_then(|v| v.as_str()).map(String::from),
description,
documentation,
input_schema: module.input_schema(),
output_schema: module.output_schema(),
version,
tags,
annotations: None,
examples: vec![],
metadata: yaml,
display: None,
sunset_date: None,
dependencies: vec![],
enabled: true,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::context::Context;
use serde_json::json;
use std::sync::atomic::{AtomicUsize, Ordering};
use tempfile::tempdir;
#[derive(Debug)]
struct TestModule {
desc: String,
}
#[async_trait]
impl Module for TestModule {
fn description(&self) -> &str {
&self.desc
}
fn input_schema(&self) -> serde_json::Value {
json!({ "type": "object" })
}
fn output_schema(&self) -> serde_json::Value {
json!({ "type": "object" })
}
async fn execute(
&self,
_inputs: serde_json::Value,
_ctx: &Context<serde_json::Value>,
) -> Result<serde_json::Value, ModuleError> {
Ok(json!({}))
}
}
#[tokio::test]
async fn missing_root_yields_config_not_found() {
let discoverer = DefaultDiscoverer::new();
let result = discoverer
.discover(&["/this/path/does/not/exist".to_string()])
.await;
let err = result.expect_err("missing root should error");
assert_eq!(err.code, ErrorCode::ConfigNotFound);
}
#[tokio::test]
async fn empty_root_yields_empty_discovery() {
let tmp = tempdir().unwrap();
let discoverer = DefaultDiscoverer::new();
let result = discoverer
.discover(&[tmp.path().to_string_lossy().into_owned()])
.await
.unwrap();
assert!(result.is_empty(), "no .rs files → no discovered modules");
}
#[tokio::test]
async fn factory_invoked_for_discovered_file() {
let tmp = tempdir().unwrap();
std::fs::write(tmp.path().join("hello.rs"), "// stub").unwrap();
let calls = Arc::new(AtomicUsize::new(0));
let calls_clone = Arc::clone(&calls);
let factory: ModuleFactory = Arc::new(move |_file, entry_point| {
calls_clone.fetch_add(1, Ordering::SeqCst);
assert_eq!(entry_point, "Hello");
Ok(Some(Arc::new(TestModule {
desc: "test".to_string(),
}) as Arc<dyn Module>))
});
let discoverer = DefaultDiscoverer::new().with_factory(factory);
let result = discoverer
.discover(&[tmp.path().to_string_lossy().into_owned()])
.await
.unwrap();
assert_eq!(calls.load(Ordering::SeqCst), 1);
assert_eq!(result.len(), 1);
assert_eq!(result[0].name, "hello");
}
#[tokio::test]
async fn circular_dependency_yields_dependency_error() {
let tmp = tempdir().unwrap();
std::fs::write(tmp.path().join("a.rs"), "// a").unwrap();
std::fs::write(
tmp.path().join("a_meta.yaml"),
"dependencies:\n - module_id: b\n",
)
.unwrap();
std::fs::write(tmp.path().join("b.rs"), "// b").unwrap();
std::fs::write(
tmp.path().join("b_meta.yaml"),
"dependencies:\n - module_id: a\n",
)
.unwrap();
let factory: ModuleFactory = Arc::new(|_file, _entry| {
Ok(Some(Arc::new(TestModule {
desc: "circular".to_string(),
}) as Arc<dyn Module>))
});
let discoverer = DefaultDiscoverer::new().with_factory(factory);
let err = discoverer
.discover(&[tmp.path().to_string_lossy().into_owned()])
.await
.expect_err("cycle should error");
assert!(
matches!(
err.code,
ErrorCode::CircularDependency | ErrorCode::DependencyNotFound
),
"expected CircularDependency or DependencyNotFound, got {:?}",
err.code
);
}
}