use std::collections::HashMap;
use std::sync::Arc;
use crate::traits::Result;
use super::agent::AgentConfig;
use super::llm::LlmClient;
use super::memory_bus::CerebroMemoryBus;
use super::trace::{ExecutionTracer, RunStatus};
use super::patterns::{SwarmPatternExecutor, SwarmResult};
use super::patterns::sequential::SequentialPattern;
use super::patterns::parallel::ParallelPattern;
use super::patterns::hierarchical::HierarchicalPattern;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
#[serde(tag = "pattern", rename_all = "snake_case")]
pub enum SwarmPattern {
Sequential {
agent_order: Vec<String>,
},
Parallel {
agents: Vec<String>,
merger: String,
},
Hierarchical {
supervisor: String,
workers: Vec<String>,
},
}
pub struct SwarmOrchestrator {
agents: HashMap<String, AgentConfig>,
memory: Arc<CerebroMemoryBus>,
llm: LlmClient,
}
impl SwarmOrchestrator {
pub fn new(memory: Arc<CerebroMemoryBus>) -> Self {
Self {
agents: HashMap::new(),
memory,
llm: LlmClient::new(),
}
}
pub fn register_agent(&mut self, config: AgentConfig) {
self.agents.insert(config.id.clone(), config);
}
pub fn register_agents(&mut self, configs: Vec<AgentConfig>) {
for config in configs {
self.register_agent(config);
}
}
pub fn get_agent(&self, id: &str) -> Option<&AgentConfig> {
self.agents.get(id)
}
pub fn list_agents(&self) -> Vec<(&str, &str)> {
self.agents
.values()
.map(|a| (a.id.as_str(), a.name.as_str()))
.collect()
}
pub async fn execute(
&self,
pattern: SwarmPattern,
input: &str,
) -> Result<SwarmResult> {
let run_id = uuid::Uuid::new_v4().to_string();
let tracer = ExecutionTracer::new(&run_id);
println!("\n╔══════════════════════════════════════════════════╗");
println!("║ 🧠 CEREBRO SWARM ENGINE ║");
println!("╠══════════════════════════════════════════════════╣");
println!("║ Run ID: {} ║", &run_id[..36]);
println!("╚══════════════════════════════════════════════════╝\n");
self.memory.clear_episodic();
self.memory.set_global("run_id", &run_id).await.ok();
self.memory.set_global("status", "running").await.ok();
self.memory.set_global("input", input).await.ok();
let result = match pattern {
SwarmPattern::Sequential { agent_order } => {
println!(" Pattern: Sequential Pipeline ({} agents)\n", agent_order.len());
let executor = SequentialPattern::new(agent_order);
executor
.execute(&self.agents, &self.memory, &tracer, &self.llm, input)
.await
}
SwarmPattern::Parallel { agents, merger } => {
println!(
" Pattern: Parallel Fan-Out ({} agents + merger)\n",
agents.len()
);
let executor = ParallelPattern::new(agents, merger);
executor
.execute(&self.agents, &self.memory, &tracer, &self.llm, input)
.await
}
SwarmPattern::Hierarchical {
supervisor,
workers,
} => {
println!(
" Pattern: Hierarchical Supervisor ({} workers)\n",
workers.len()
);
let executor = HierarchicalPattern::new(supervisor, workers);
executor
.execute(&self.agents, &self.memory, &tracer, &self.llm, input)
.await
}
};
match &result {
Ok(r) => {
self.memory.set_global("status", "completed").await.ok();
println!("\n╔══════════════════════════════════════════════════╗");
println!("║ ✅ Swarm Complete ║");
println!("║ Total Tokens: {:<34}║", r.total_tokens);
println!("║ Duration: {}ms{:>31}║", r.total_duration_ms, "");
println!("║ Trace Steps: {:<35}║", r.trace.steps.len());
println!("╚══════════════════════════════════════════════════╝\n");
}
Err(e) => {
self.memory.set_global("status", "failed").await.ok();
let error_tracer = ExecutionTracer::new(&run_id);
let _trace = error_tracer.finalize(RunStatus::Failed);
eprintln!("\n ❌ Swarm execution failed: {}\n", e);
}
}
result
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::swarm::agent::LlmProvider;
#[test]
fn test_register_and_list_agents() {
let memory = Arc::new(CerebroMemoryBus::new(
Arc::new(crate::engine::MemoryEngine::new(
Arc::new(crate::chunker::RecursiveCharacterChunker::new(512, 50)),
Arc::new(crate::compute::mock::MockEmbedder::new(8)),
Arc::new(crate::storage::memory::MemoryVectorStore::new()),
)),
Arc::new(crate::storage::kv::MemoryKVStore::new()),
));
let mut orch = SwarmOrchestrator::new(memory);
orch.register_agent(AgentConfig {
id: "sec".into(),
name: "Security Agent".into(),
system_prompt: "You review code security.".into(),
model: LlmProvider::Ollama {
model: "llama3".into(),
base_url: "http://localhost:11434".into(),
},
tools: vec![],
handoff_targets: vec![],
max_steps: 5,
});
orch.register_agent(AgentConfig {
id: "perf".into(),
name: "Performance Agent".into(),
system_prompt: "You review code performance.".into(),
model: LlmProvider::Ollama {
model: "llama3".into(),
base_url: "http://localhost:11434".into(),
},
tools: vec![],
handoff_targets: vec![],
max_steps: 5,
});
let agents = orch.list_agents();
assert_eq!(agents.len(), 2);
assert!(orch.get_agent("sec").is_some());
assert!(orch.get_agent("nonexistent").is_none());
}
#[test]
fn test_swarm_pattern_serialization() {
let pattern = SwarmPattern::Sequential {
agent_order: vec!["a".into(), "b".into(), "c".into()],
};
let json = serde_json::to_string_pretty(&pattern).unwrap();
let deser: SwarmPattern = serde_json::from_str(&json).unwrap();
match deser {
SwarmPattern::Sequential { agent_order } => {
assert_eq!(agent_order.len(), 3);
}
_ => panic!("Expected Sequential"),
}
}
}