use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
use swarm_engine_core::config::PathResolver;
use super::builtin;
use super::types::{EvalScenario, ScenarioId};
use crate::error::{EvalError, Result};
#[derive(Debug)]
pub struct ScenarioRegistry {
builtin: HashMap<ScenarioId, EvalScenario>,
user: HashMap<ScenarioId, EvalScenario>,
search_paths: Vec<PathBuf>,
}
impl ScenarioRegistry {
pub fn new() -> Self {
Self {
builtin: HashMap::new(),
user: HashMap::new(),
search_paths: Vec::new(),
}
}
pub fn with_builtin() -> Self {
let mut registry = Self::new();
registry.load_builtin();
registry
}
pub fn with_default_paths() -> Self {
let mut registry = Self::with_builtin();
for path in PathResolver::eval_scenario_search_paths() {
registry.add_search_path(path);
}
registry
}
pub fn discover() -> Result<Self> {
let mut registry = Self::with_default_paths();
registry.scan_all_paths()?;
Ok(registry)
}
fn load_builtin(&mut self) {
for scenario in builtin::builtin_scenarios() {
self.builtin.insert(scenario.meta.id.clone(), scenario);
}
}
pub fn add_search_path(&mut self, path: impl Into<PathBuf>) {
self.search_paths.push(path.into());
}
pub fn get(&self, id: &ScenarioId) -> Option<&EvalScenario> {
self.user.get(id).or_else(|| self.builtin.get(id))
}
pub fn contains(&self, id: &ScenarioId) -> bool {
self.user.contains_key(id) || self.builtin.contains_key(id)
}
pub fn list(&self) -> Vec<&EvalScenario> {
let mut scenarios: Vec<_> = self.builtin.values().collect();
scenarios.extend(self.user.values());
scenarios.sort_by(|a, b| a.meta.name.cmp(&b.meta.name));
scenarios
}
pub fn list_builtin(&self) -> Vec<&EvalScenario> {
let mut scenarios: Vec<_> = self.builtin.values().collect();
scenarios.sort_by(|a, b| a.meta.name.cmp(&b.meta.name));
scenarios
}
pub fn list_user(&self) -> Vec<&EvalScenario> {
let mut scenarios: Vec<_> = self.user.values().collect();
scenarios.sort_by(|a, b| a.meta.name.cmp(&b.meta.name));
scenarios
}
pub fn filter_by_tags(&self, tags: &[String]) -> Vec<&EvalScenario> {
self.list()
.into_iter()
.filter(|s| tags.iter().any(|tag| s.meta.tags.contains(tag)))
.collect()
}
pub fn search_by_name(&self, query: &str) -> Vec<&EvalScenario> {
let query_lower = query.to_lowercase();
self.list()
.into_iter()
.filter(|s| s.meta.name.to_lowercase().contains(&query_lower))
.collect()
}
pub fn load_from_file(&mut self, path: &Path) -> Result<ScenarioId> {
let content = fs::read_to_string(path).map_err(|e| {
EvalError::Config(format!("Failed to read scenario file {:?}: {}", path, e))
})?;
let scenario: EvalScenario = match path.extension().and_then(|e| e.to_str()) {
Some("json") => serde_json::from_str(&content).map_err(|e| {
EvalError::Config(format!(
"Failed to parse JSON scenario file {:?}: {}",
path, e
))
})?,
_ => toml::from_str(&content).map_err(|e| {
EvalError::Config(format!(
"Failed to parse TOML scenario file {:?}: {}",
path, e
))
})?,
};
let id = scenario.meta.id.clone();
self.user.insert(id.clone(), scenario);
Ok(id)
}
pub fn scan_directory(&mut self, dir: &Path) -> Result<Vec<ScenarioId>> {
if !dir.exists() {
return Ok(Vec::new());
}
let mut loaded = Vec::new();
for entry in fs::read_dir(dir)
.map_err(|e| EvalError::Config(format!("Failed to read directory {:?}: {}", dir, e)))?
{
let entry = entry
.map_err(|e| EvalError::Config(format!("Failed to read directory entry: {}", e)))?;
let path = entry.path();
if path
.extension()
.is_some_and(|ext| ext == "toml" || ext == "json")
{
match self.load_from_file(&path) {
Ok(id) => loaded.push(id),
Err(e) => {
tracing::warn!("Failed to load scenario from {:?}: {}", path, e);
}
}
}
}
Ok(loaded)
}
pub fn scan_all_paths(&mut self) -> Result<Vec<ScenarioId>> {
let paths = self.search_paths.clone();
let mut loaded = Vec::new();
for path in paths {
match self.scan_directory(&path) {
Ok(ids) => loaded.extend(ids),
Err(e) => {
tracing::warn!("Failed to scan directory {:?}: {}", path, e);
}
}
}
Ok(loaded)
}
pub fn save(&self, id: &ScenarioId, path: &Path) -> Result<()> {
let scenario = self
.get(id)
.ok_or_else(|| EvalError::Config(format!("Scenario not found: {}", id)))?;
let content = match path.extension().and_then(|e| e.to_str()) {
Some("json") => serde_json::to_string_pretty(scenario).map_err(|e| {
EvalError::Config(format!("Failed to serialize scenario to JSON: {}", e))
})?,
_ => toml::to_string_pretty(scenario).map_err(|e| {
EvalError::Config(format!("Failed to serialize scenario to TOML: {}", e))
})?,
};
fs::write(path, content).map_err(|e| {
EvalError::Config(format!("Failed to write scenario file {:?}: {}", path, e))
})?;
Ok(())
}
pub fn register(&mut self, scenario: EvalScenario) -> ScenarioId {
let id = scenario.meta.id.clone();
self.user.insert(id.clone(), scenario);
id
}
pub fn remove(&mut self, id: &ScenarioId) -> Option<EvalScenario> {
self.user.remove(id)
}
pub fn len(&self) -> usize {
self.builtin.len() + self.user.len()
}
pub fn is_empty(&self) -> bool {
self.builtin.is_empty() && self.user.is_empty()
}
}
impl Default for ScenarioRegistry {
fn default() -> Self {
Self::with_builtin()
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn create_test_scenario(name: &str, id: &str) -> EvalScenario {
use super::super::conditions::EvalConditions;
use super::super::types::*;
EvalScenario {
meta: ScenarioMeta {
name: name.to_string(),
version: "1.0.0".to_string(),
id: ScenarioId::new(id),
description: "Test scenario".to_string(),
tags: vec!["test".to_string()],
},
task: TaskConfig::default(),
llm: LlmConfig::default(),
manager: ManagerConfig::default(),
batch_processor: BatchProcessorConfig::default(),
dependency_graph: None,
actions: ScenarioActions::default(),
app_config: AppConfigTemplate::default(),
environment: EnvironmentConfig {
env_type: "test".to_string(),
params: serde_json::Value::Null,
initial_state: None,
},
agents: AgentsConfig::default(),
conditions: EvalConditions::default(),
milestones: vec![],
variants: Vec::new(),
}
}
#[test]
fn test_registry_new() {
let registry = ScenarioRegistry::new();
assert!(registry.is_empty());
}
#[test]
fn test_registry_with_builtin() {
let registry = ScenarioRegistry::with_builtin();
assert!(!registry.list_builtin().is_empty());
}
#[test]
fn test_registry_register_and_get() {
let mut registry = ScenarioRegistry::new();
let scenario = create_test_scenario("Test Scenario", "test:scenario:v1");
let id = registry.register(scenario);
assert!(registry.contains(&id));
let retrieved = registry.get(&id).unwrap();
assert_eq!(retrieved.meta.name, "Test Scenario");
}
#[test]
fn test_registry_user_overrides_builtin() {
let mut registry = ScenarioRegistry::with_builtin();
if let Some(builtin) = registry.list_builtin().first() {
let builtin_id = builtin.meta.id.clone();
let user_scenario = create_test_scenario("User Override", builtin_id.as_str());
registry.register(user_scenario);
let retrieved = registry.get(&builtin_id).unwrap();
assert_eq!(retrieved.meta.name, "User Override");
}
}
#[test]
fn test_registry_filter_by_tags() {
let mut registry = ScenarioRegistry::new();
let mut s1 = create_test_scenario("Scenario 1", "s1");
s1.meta.tags = vec!["coordination".to_string()];
let mut s2 = create_test_scenario("Scenario 2", "s2");
s2.meta.tags = vec!["basic".to_string()];
registry.register(s1);
registry.register(s2);
let filtered = registry.filter_by_tags(&["coordination".to_string()]);
assert_eq!(filtered.len(), 1);
assert_eq!(filtered[0].meta.name, "Scenario 1");
}
#[test]
fn test_registry_save_and_load() {
let temp_dir = TempDir::new().unwrap();
let file_path = temp_dir.path().join("test_scenario.json");
let mut registry = ScenarioRegistry::new();
let scenario = create_test_scenario("Saved Scenario", "saved:scenario:v1");
let id = registry.register(scenario);
registry.save(&id, &file_path).unwrap();
let mut new_registry = ScenarioRegistry::new();
let loaded_id = new_registry.load_from_file(&file_path).unwrap();
assert_eq!(loaded_id.as_str(), "saved:scenario:v1");
let loaded = new_registry.get(&loaded_id).unwrap();
assert_eq!(loaded.meta.name, "Saved Scenario");
}
#[test]
fn test_registry_scan_directory() {
let temp_dir = TempDir::new().unwrap();
let mut registry = ScenarioRegistry::new();
for i in 0..3 {
let scenario =
create_test_scenario(&format!("Scenario {}", i), &format!("scan:scenario:{}", i));
let id = registry.register(scenario);
let file_path = temp_dir.path().join(format!("scenario_{}.json", i));
registry.save(&id, &file_path).unwrap();
}
let mut new_registry = ScenarioRegistry::new();
let loaded = new_registry.scan_directory(temp_dir.path()).unwrap();
assert_eq!(loaded.len(), 3);
assert_eq!(new_registry.list_user().len(), 3);
}
#[test]
fn test_with_default_paths() {
let registry = ScenarioRegistry::with_default_paths();
assert!(!registry.list_builtin().is_empty());
assert!(!registry.search_paths.is_empty());
}
#[test]
fn test_discover_creates_registry() {
let result = ScenarioRegistry::discover();
assert!(result.is_ok());
let registry = result.unwrap();
assert!(!registry.list_builtin().is_empty());
}
}