cerebro 1.1.3

A high-performance semantic memory engine for AI Agents, now featuring SwarmForge for built-in multi-agent orchestration.
Documentation
pub mod consolidation;

#[cfg(feature = "graph")]
pub mod graph;

use std::sync::Arc;
use crate::models::{Document, Node};
use crate::traits::{Chunker, Embedder, Result, VectorStore};

/// The unified Cerebro Memory Engine.
/// Ties together chunking, embedding, and storage into a single API.
pub struct MemoryEngine {
    pub chunker: Arc<dyn Chunker>,
    pub embedder: Arc<dyn Embedder>,
    pub store: Arc<dyn VectorStore>,
}

impl MemoryEngine {
    pub fn new(
        chunker: Arc<dyn Chunker>,
        embedder: Arc<dyn Embedder>,
        store: Arc<dyn VectorStore>,
    ) -> Self {
        Self { chunker, embedder, store }
    }

    /// High-level function to ingest a Document into long term memory.
    pub async fn ingest_document(&self, doc: Document) -> Result<()> {
        let chunks = self.chunker.chunk(&doc)?;

        if chunks.is_empty() {
            return Ok(());
        }

        let texts: Vec<&str> = chunks.iter().map(|c| c.text.as_str()).collect();
        let embeddings = self.embedder.embed(&texts).await?;

        let mut nodes = Vec::with_capacity(chunks.len());
        for (chunk, embedding) in chunks.into_iter().zip(embeddings.into_iter()) {
            nodes.push(Node::new(chunk, embedding));
        }

        self.store.upsert(nodes).await?;
        Ok(())
    }

    /// Query the long-term memory for semantically relevant chunks.
    pub async fn query(&self, text: &str, top_k: usize) -> Result<Vec<(Node, f32)>> {
        let embedding = self.embedder.embed_query(text).await?;
        self.store.search(text, &embedding, top_k).await
    }

    /// Starts the background memory pruning and consolidation worker.
    pub fn start_consolidation_loop(&self, interval_seconds: u64) {
        let worker = consolidation::ConsolidationWorker::new(self.store.clone(), interval_seconds);
        worker.start();
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::compute::mock::MockEmbedder;
    use crate::chunker::RecursiveCharacterChunker;
    use crate::storage::memory::MemoryVectorStore;

    fn build_engine() -> MemoryEngine {
        let chunker = Arc::new(RecursiveCharacterChunker::new(50, 10));
        let embedder = Arc::new(MockEmbedder::new(8));
        let store = Arc::new(MemoryVectorStore::new());
        MemoryEngine::new(chunker, embedder, store)
    }

    #[tokio::test]
    async fn test_ingest_and_query_roundtrip() {
        let engine = build_engine();
        let doc = Document::new("Rust is a systems programming language focused on safety and performance.");
        engine.ingest_document(doc).await.unwrap();
        let results = engine.query("programming language", 5).await.unwrap();
        assert!(!results.is_empty());
    }

    #[tokio::test]
    async fn test_ingest_empty_document() {
        let engine = build_engine();
        engine.ingest_document(Document::new("")).await.unwrap();
        let results = engine.query("anything", 5).await.unwrap();
        assert!(results.is_empty());
    }

    #[tokio::test]
    async fn test_query_empty_store() {
        let engine = build_engine();
        let results = engine.query("no documents", 5).await.unwrap();
        assert!(results.is_empty());
    }

    #[tokio::test]
    async fn test_ingest_multiple_documents() {
        let engine = build_engine();
        engine.ingest_document(Document::new("First document about Rust")).await.unwrap();
        engine.ingest_document(Document::new("Second document about Python")).await.unwrap();
        engine.ingest_document(Document::new("Third document about JavaScript")).await.unwrap();
        let results = engine.query("programming", 10).await.unwrap();
        assert!(results.len() >= 3);
    }

    #[tokio::test]
    async fn test_ingest_large_document_chunks_correctly() {
        let engine = build_engine();
        let doc = Document::new("a".repeat(200));
        engine.ingest_document(doc).await.unwrap();
        let results = engine.query("aaa", 100).await.unwrap();
        assert!(results.len() > 1);
    }

    #[tokio::test]
    async fn test_query_top_k_limiting() {
        let engine = build_engine();
        for i in 0..20 {
            engine.ingest_document(Document::new(format!("Document number {} with sufficient length.", i))).await.unwrap();
        }
        let results = engine.query("document", 3).await.unwrap();
        assert!(results.len() <= 3);
    }
}