cerebro 1.1.0

A high-performance semantic memory engine + multi-agent swarm orchestrator for AI, written in pure Rust.
Documentation
//! # Cerebro Memory Bus
//!
//! The bridge between Cerebro's three memory tiers and the swarm orchestrator.
//! Provides agents with access to:
//! - **Working Memory** (`KVStore`) — fast key-value state for active execution
//! - **Semantic Memory** (`MemoryEngine`) — vector search over past agent outputs
//! - **Episodic Memory** — per-agent conversation histories within a run

use std::collections::HashMap;
use std::sync::{Arc, RwLock};

use crate::engine::MemoryEngine;
use crate::models::{Document, Node};
use crate::traits::{KVStore, Result};
use super::agent::ChatMessage;

/// The CerebroMemoryBus provides all three memory tiers to the swarm.
///
/// This is the **key integration layer** — it wraps existing Cerebro primitives
/// (MemoryEngine, KVStore) and adds episodic memory for agent conversations.
pub struct CerebroMemoryBus {
    /// Semantic memory — vector search over past agent outputs.
    /// Uses the existing MemoryEngine with its Chunker + Embedder + VectorStore pipeline.
    pub semantic: Arc<MemoryEngine>,

    /// Working memory — fast KV state for active execution.
    /// Uses the existing KVStore trait (MemoryKVStore or any impl).
    pub working: Arc<dyn KVStore>,

    /// Episodic memory — per-agent conversation histories.
    /// Keyed by agent_id, stores the ordered list of ChatMessages.
    episodic: Arc<RwLock<HashMap<String, Vec<ChatMessage>>>>,
}

impl CerebroMemoryBus {
    /// Create a new memory bus from existing Cerebro components.
    pub fn new(semantic: Arc<MemoryEngine>, working: Arc<dyn KVStore>) -> Self {
        Self {
            semantic,
            working,
            episodic: Arc::new(RwLock::new(HashMap::new())),
        }
    }

    // ─── Semantic Memory (Long-Term Knowledge) ───────────────────────

    /// Store an agent's output as a searchable memory in the vector store.
    /// The output is chunked, embedded, and indexed — becoming part of the
    /// swarm's collective knowledge base.
    pub async fn commit_to_semantic(
        &self,
        agent_id: &str,
        run_id: &str,
        output: &str,
    ) -> Result<String> {
        let doc = Document::new(output)
            .with_metadata("agent_id", agent_id)
            .with_metadata("run_id", run_id)
            .with_metadata("source", "swarm_agent")
            .with_metadata("timestamp", &chrono::Utc::now().to_rfc3339());

        let doc_id = doc.id.clone();
        self.semantic.ingest_document(doc).await?;
        Ok(doc_id)
    }

    /// Retrieve semantically relevant context from past swarm runs.
    /// Agents use this to recall prior knowledge before making decisions.
    pub async fn recall_semantic(&self, query: &str, top_k: usize) -> Vec<(Node, f32)> {
        self.semantic.query(query, top_k).await.unwrap_or_default()
    }

    /// Format semantic recall results as a context string for LLM prompts.
    pub async fn recall_as_context(&self, query: &str, top_k: usize) -> String {
        let results = self.recall_semantic(query, top_k).await;
        if results.is_empty() {
            return String::new();
        }

        let mut context = String::from("=== Relevant Prior Knowledge ===\n");
        for (i, (node, score)) in results.iter().enumerate() {
            let _agent = node.chunk.document_id.clone(); // doc_id could carry agent info
            context.push_str(&format!(
                "\n[{}] (relevance: {:.3})\n{}\n",
                i + 1,
                score,
                node.chunk.text
            ));
        }
        context.push_str("\n=== End Prior Knowledge ===\n");
        context
    }

    // ─── Working Memory (Fast State) ─────────────────────────────────

    /// Set a namespaced key-value pair for an agent's working state.
    /// Keys are automatically prefixed: `agent:{agent_id}:{key}`.
    pub async fn set_state(&self, agent_id: &str, key: &str, value: &str) -> Result<()> {
        let full_key = format!("agent:{}:{}", agent_id, key);
        self.working.set(&full_key, value).await
    }

    /// Get an agent's working state value.
    pub async fn get_state(&self, agent_id: &str, key: &str) -> Result<Option<String>> {
        let full_key = format!("agent:{}:{}", agent_id, key);
        self.working.get(&full_key).await
    }

    /// Set a global (non-agent-specific) working memory value.
    /// Used for swarm-level state like `current_step`, `final_output`, etc.
    pub async fn set_global(&self, key: &str, value: &str) -> Result<()> {
        let full_key = format!("swarm:{}", key);
        self.working.set(&full_key, value).await
    }

    /// Get a global working memory value.
    pub async fn get_global(&self, key: &str) -> Result<Option<String>> {
        let full_key = format!("swarm:{}", key);
        self.working.get(&full_key).await
    }

    // ─── Episodic Memory (Conversation History) ──────────────────────

    /// Append a message to an agent's episodic conversation log.
    pub fn push_message(&self, agent_id: &str, message: ChatMessage) {
        if let Ok(mut store) = self.episodic.write() {
            store
                .entry(agent_id.to_string())
                .or_insert_with(Vec::new)
                .push(message);
        }
    }

    /// Get the full conversation history for an agent.
    pub fn get_history(&self, agent_id: &str) -> Vec<ChatMessage> {
        self.episodic
            .read()
            .map(|store| store.get(agent_id).cloned().unwrap_or_default())
            .unwrap_or_default()
    }

    /// Get the last N messages from an agent's history.
    pub fn get_recent_history(&self, agent_id: &str, n: usize) -> Vec<ChatMessage> {
        let history = self.get_history(agent_id);
        if history.len() <= n {
            history
        } else {
            history[history.len() - n..].to_vec()
        }
    }

    /// Clear all episodic memory (typically at the start of a new run).
    pub fn clear_episodic(&self) {
        if let Ok(mut store) = self.episodic.write() {
            store.clear();
        }
    }

    /// Get all agent IDs that have episodic history.
    pub fn active_agents(&self) -> Vec<String> {
        self.episodic
            .read()
            .map(|store| store.keys().cloned().collect())
            .unwrap_or_default()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::chunker::RecursiveCharacterChunker;
    use crate::compute::mock::MockEmbedder;
    use crate::storage::kv::MemoryKVStore;
    use crate::storage::memory::MemoryVectorStore;
    use crate::swarm::agent::Role;

    fn build_memory_bus() -> CerebroMemoryBus {
        let chunker = Arc::new(RecursiveCharacterChunker::new(512, 50));
        let embedder = Arc::new(MockEmbedder::new(8));
        let store = Arc::new(MemoryVectorStore::new());
        let engine = Arc::new(MemoryEngine::new(chunker, embedder, store));
        let kv = Arc::new(MemoryKVStore::new());
        CerebroMemoryBus::new(engine, kv)
    }

    #[tokio::test]
    async fn test_working_memory_set_get() {
        let bus = build_memory_bus();
        bus.set_state("agent-1", "current_task", "Review security")
            .await
            .unwrap();
        let val = bus.get_state("agent-1", "current_task").await.unwrap();
        assert_eq!(val, Some("Review security".to_string()));
    }

    #[tokio::test]
    async fn test_working_memory_isolation() {
        let bus = build_memory_bus();
        bus.set_state("agent-1", "step", "3").await.unwrap();
        bus.set_state("agent-2", "step", "7").await.unwrap();

        let v1 = bus.get_state("agent-1", "step").await.unwrap();
        let v2 = bus.get_state("agent-2", "step").await.unwrap();
        assert_eq!(v1, Some("3".to_string()));
        assert_eq!(v2, Some("7".to_string()));
    }

    #[tokio::test]
    async fn test_global_state() {
        let bus = build_memory_bus();
        bus.set_global("run_status", "running").await.unwrap();
        let val = bus.get_global("run_status").await.unwrap();
        assert_eq!(val, Some("running".to_string()));
    }

    #[tokio::test]
    async fn test_semantic_commit_and_recall() {
        let bus = build_memory_bus();
        let doc_id = bus
            .commit_to_semantic("agent-1", "run-001", "Rust ensures memory safety without a garbage collector through its ownership system.")
            .await
            .unwrap();
        assert!(!doc_id.is_empty());

        let results = bus.recall_semantic("memory safety", 5).await;
        // With MockEmbedder the vector search is approximate, but we should get results
        assert!(!results.is_empty());
    }

    #[tokio::test]
    async fn test_recall_as_context_empty() {
        let bus = build_memory_bus();
        let ctx = bus.recall_as_context("anything", 5).await;
        assert!(ctx.is_empty());
    }

    #[test]
    fn test_episodic_push_and_get() {
        let bus = build_memory_bus();
        bus.push_message("agent-1", ChatMessage::new(Role::User, "Hello"));
        bus.push_message("agent-1", ChatMessage::new(Role::Assistant, "Hi there!"));
        bus.push_message("agent-2", ChatMessage::new(Role::User, "Different agent"));

        let h1 = bus.get_history("agent-1");
        assert_eq!(h1.len(), 2);
        assert_eq!(h1[0].content, "Hello");
        assert_eq!(h1[1].content, "Hi there!");

        let h2 = bus.get_history("agent-2");
        assert_eq!(h2.len(), 1);
    }

    #[test]
    fn test_episodic_recent_history() {
        let bus = build_memory_bus();
        for i in 0..10 {
            bus.push_message("agent-1", ChatMessage::new(Role::User, format!("msg-{}", i)));
        }
        let recent = bus.get_recent_history("agent-1", 3);
        assert_eq!(recent.len(), 3);
        assert_eq!(recent[0].content, "msg-7");
        assert_eq!(recent[2].content, "msg-9");
    }

    #[test]
    fn test_episodic_clear() {
        let bus = build_memory_bus();
        bus.push_message("agent-1", ChatMessage::new(Role::User, "Hello"));
        assert_eq!(bus.get_history("agent-1").len(), 1);
        bus.clear_episodic();
        assert_eq!(bus.get_history("agent-1").len(), 0);
    }

    #[test]
    fn test_active_agents() {
        let bus = build_memory_bus();
        bus.push_message("agent-a", ChatMessage::new(Role::User, "hi"));
        bus.push_message("agent-b", ChatMessage::new(Role::User, "hey"));
        let agents = bus.active_agents();
        assert_eq!(agents.len(), 2);
        assert!(agents.contains(&"agent-a".to_string()));
        assert!(agents.contains(&"agent-b".to_string()));
    }
}