cerebro 1.1.8

A blazing-fast AI memory layer that enables teams of specialized agents to collaborate through a shared cognitive architecture.
Documentation
//! # Swarm Orchestrator
//!
//! The top-level coordinator for multi-agent swarm execution.
//! Manages agent registries, selects patterns, and drives execution
//! through the CerebroMemoryBus.

use std::collections::HashMap;
use std::sync::Arc;

use super::agent::AgentConfig;
use super::immunology::ImmunologySupervisor;
use super::llm::LlmClient;
use super::memory_bus::CerebroMemoryBus;
use super::patterns::hierarchical::HierarchicalPattern;
use super::patterns::parallel::ParallelPattern;
use super::patterns::sequential::SequentialPattern;
use super::patterns::{SwarmPatternExecutor, SwarmResult};
use super::trace::{ExecutionTracer, RunStatus};
use crate::traits::Result;

/// Defines how agents should be orchestrated.
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
#[serde(tag = "pattern", rename_all = "snake_case")]
pub enum SwarmPattern {
    /// Agents execute one after another.
    Sequential {
        /// Ordered list of agent IDs.
        agent_order: Vec<String>,
    },
    /// Multiple agents run in parallel, results merged.
    Parallel {
        /// Agent IDs for parallel execution.
        agents: Vec<String>,
        /// Agent ID responsible for merging results.
        merger: String,
    },
    /// Supervisor decomposes, delegates, and synthesizes.
    Hierarchical {
        /// The supervisor agent ID.
        supervisor: String,
        /// Worker agent IDs.
        workers: Vec<String>,
    },
}

/// The SwarmOrchestrator coordinates agent execution across patterns.
pub struct SwarmOrchestrator {
    /// Registered agent configurations, keyed by agent ID.
    agents: HashMap<String, AgentConfig>,
    /// Global registry of available tools.
    tools: HashMap<String, Arc<dyn super::tools::AgentTool>>,
    /// The Cerebro-powered shared memory bus.
    memory: Arc<CerebroMemoryBus>,
    /// Unified LLM client for all providers.
    llm: LlmClient,
    /// Optional immunology supervisor for healing agents on failure.
    immunology: Option<ImmunologySupervisor>,
}

impl SwarmOrchestrator {
    /// Create a new orchestrator with just the memory bus. Add agents separately.
    pub fn new(memory: Arc<CerebroMemoryBus>) -> Self {
        Self {
            agents: HashMap::new(),
            tools: HashMap::new(),
            memory,
            llm: LlmClient::new(),
            immunology: None,
        }
    }

    /// Enable the Immunology Supervisor by providing a "Doctor" LLM provider.
    pub fn with_immunology(&mut self, doctor: super::agent::LlmProvider) {
        self.immunology = Some(ImmunologySupervisor::new(Some(doctor)));
    }

    /// Register an executable tool to the global registry.
    pub fn register_tool(&mut self, tool: Arc<dyn super::tools::AgentTool>) {
        self.tools.insert(tool.definition().name.clone(), tool);
    }

    /// Register an agent configuration.
    pub fn register_agent(&mut self, config: AgentConfig) {
        self.agents.insert(config.id.clone(), config);
    }

    /// Register multiple agents at once.
    pub fn register_agents(&mut self, configs: Vec<AgentConfig>) {
        for config in configs {
            self.register_agent(config);
        }
    }

    /// Get a reference to a registered agent's config.
    pub fn get_agent(&self, id: &str) -> Option<&AgentConfig> {
        self.agents.get(id)
    }

    /// List all registered agent IDs and names.
    pub fn list_agents(&self) -> Vec<(&str, &str)> {
        self.agents
            .values()
            .map(|a| (a.id.as_str(), a.name.as_str()))
            .collect()
    }

    /// Execute a swarm with the specified pattern and input.
    ///
    /// This is the main entry point for running a swarm.
    /// It creates a fresh execution trace, selects the pattern executor,
    /// and drives the run through the CerebroMemoryBus.
    pub async fn execute(&self, pattern: SwarmPattern, input: &str) -> Result<SwarmResult> {
        let run_id = uuid::Uuid::new_v4().to_string();
        let tracer = ExecutionTracer::new(&run_id);

        println!("\n╔══════════════════════════════════════════════════╗");
        println!("║           🧠 CEREBRO SWARM ENGINE                ║");
        println!("╠══════════════════════════════════════════════════╣");
        println!("║  Run ID: {}", &run_id[..36]);
        println!("╚══════════════════════════════════════════════════╝\n");

        // Clear episodic memory for a fresh run
        self.memory.clear_episodic();

        // Store run metadata in working memory
        self.memory.set_global("run_id", &run_id).await.ok();
        self.memory.set_global("status", "running").await.ok();
        self.memory.set_global("input", input).await.ok();

        let result = match pattern {
            SwarmPattern::Sequential { agent_order } => {
                println!(
                    "  Pattern: Sequential Pipeline ({} agents)\n",
                    agent_order.len()
                );
                let executor = SequentialPattern::new(agent_order);
                executor
                    .execute(
                        &self.agents,
                        &self.tools,
                        &self.memory,
                        &tracer,
                        &self.llm,
                        input,
                    )
                    .await
            }
            SwarmPattern::Parallel { agents, merger } => {
                println!(
                    "  Pattern: Parallel Fan-Out ({} agents + merger)\n",
                    agents.len()
                );
                let executor = ParallelPattern::new(agents, merger);
                executor
                    .execute(
                        &self.agents,
                        &self.tools,
                        &self.memory,
                        &tracer,
                        &self.llm,
                        input,
                    )
                    .await
            }
            SwarmPattern::Hierarchical {
                supervisor,
                workers,
            } => {
                println!(
                    "  Pattern: Hierarchical Supervisor ({} workers)\n",
                    workers.len()
                );
                let executor = HierarchicalPattern::new(supervisor, workers);
                executor
                    .execute(
                        &self.agents,
                        &self.tools,
                        &self.memory,
                        &tracer,
                        &self.llm,
                        input,
                    )
                    .await
            }
        };

        match &result {
            Ok(r) => {
                self.memory.set_global("status", "completed").await.ok();
                println!("\n╔══════════════════════════════════════════════════╗");
                println!("║  ✅ Swarm Complete                               ║");
                println!("║  Total Tokens: {:<34}║", r.total_tokens);
                println!("║  Duration: {}ms{:>31}║", r.total_duration_ms, "");
                println!("║  Trace Steps: {:<35}║", r.trace.steps.len());
                println!("╚══════════════════════════════════════════════════╝\n");
            }
            Err(e) => {
                self.memory.set_global("status", "failed").await.ok();
                let error_tracer = ExecutionTracer::new(&run_id);
                let _trace = error_tracer.finalize(RunStatus::Failed);
                eprintln!("\n  ❌ Swarm execution failed: {}\n", e);

                // Try to heal agents if immunology is enabled
                if let Some(immu) = &self.immunology {
                    if let Some(agent_id) = self.agents.keys().next().cloned() {
                        if let Some(agent) = self.agents.get(&agent_id).cloned() {
                            if let Ok(healed) =
                                immu.heal_agent(&self.llm, agent, &e.to_string()).await
                            {
                                println!("💉 [Immunology] Agent {} healed. To auto-apply, SwarmOrchestrator must use RwLock. New prompt:\n{}", healed.id, healed.system_prompt);
                            }
                        }
                    }
                }
            }
        }

        result
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::swarm::agent::LlmProvider;

    #[test]
    fn test_register_and_list_agents() {
        let memory = Arc::new(CerebroMemoryBus::new(
            Arc::new(crate::engine::MemoryEngine::new(
                Arc::new(crate::chunker::RecursiveCharacterChunker::new(512, 50)),
                Arc::new(crate::compute::mock::MockEmbedder::new(8)),
                Arc::new(crate::storage::memory::MemoryVectorStore::new()),
            )),
            Arc::new(crate::storage::kv::MemoryKVStore::new()),
        ));

        let mut orch = SwarmOrchestrator::new(memory);

        orch.register_agent(AgentConfig {
            id: "sec".into(),
            name: "Security Agent".into(),
            system_prompt: "You review code security.".into(),
            model: LlmProvider::Ollama {
                model: "llama3".into(),
                base_url: "http://localhost:11434".into(),
            },
            tools: vec![],
            handoff_targets: vec![],
            max_steps: 5,
        });

        orch.register_agent(AgentConfig {
            id: "perf".into(),
            name: "Performance Agent".into(),
            system_prompt: "You review code performance.".into(),
            model: LlmProvider::Ollama {
                model: "llama3".into(),
                base_url: "http://localhost:11434".into(),
            },
            tools: vec![],
            handoff_targets: vec![],
            max_steps: 5,
        });

        let agents = orch.list_agents();
        assert_eq!(agents.len(), 2);
        assert!(orch.get_agent("sec").is_some());
        assert!(orch.get_agent("nonexistent").is_none());
    }

    #[test]
    fn test_swarm_pattern_serialization() {
        let pattern = SwarmPattern::Sequential {
            agent_order: vec!["a".into(), "b".into(), "c".into()],
        };
        let json = serde_json::to_string_pretty(&pattern).unwrap();
        let deser: SwarmPattern = serde_json::from_str(&json).unwrap();
        match deser {
            SwarmPattern::Sequential { agent_order } => {
                assert_eq!(agent_order.len(), 3);
            }
            _ => panic!("Expected Sequential"),
        }
    }
}