cerebro 1.1.8

A blazing-fast AI memory layer that enables teams of specialized agents to collaborate through a shared cognitive architecture.
Documentation
pub mod consolidation;

#[cfg(feature = "graph")]
pub mod graph;

use crate::models::{Document, Node};
use crate::traits::{Chunker, Embedder, Result, VectorStore};
use std::sync::Arc;

/// The unified Cerebro Memory Engine.
/// Ties together chunking, embedding, and storage into a single API.
pub struct MemoryEngine {
    pub chunker: Arc<dyn Chunker>,
    pub embedder: Arc<dyn Embedder>,
    pub store: Arc<dyn VectorStore>,
    #[cfg(feature = "graph")]
    pub graph_layer: Option<Arc<graph::GraphMemoryLayer>>,
}

impl MemoryEngine {
    pub fn new(
        chunker: Arc<dyn Chunker>,
        embedder: Arc<dyn Embedder>,
        store: Arc<dyn VectorStore>,
    ) -> Self {
        Self {
            chunker,
            embedder,
            store,
            #[cfg(feature = "graph")]
            graph_layer: None,
        }
    }

    #[cfg(feature = "graph")]
    pub fn with_graph_layer(mut self, layer: Arc<graph::GraphMemoryLayer>) -> Self {
        self.graph_layer = Some(layer);
        self
    }

    /// High-level function to ingest a Document into long term memory.
    pub async fn ingest_document(&self, doc: Document) -> Result<()> {
        let chunks = self.chunker.chunk(&doc)?;

        if chunks.is_empty() {
            return Ok(());
        }

        let texts: Vec<&str> = chunks.iter().map(|c| c.text.as_str()).collect();
        let embeddings = self.embedder.embed(&texts).await?;

        let mut nodes = Vec::with_capacity(chunks.len());
        for (chunk, embedding) in chunks.iter().zip(embeddings) {
            nodes.push(Node::new(chunk.clone(), embedding));
        }

        self.store.upsert(nodes).await?;

        #[cfg(feature = "graph")]
        {
            if let Some(graph) = &self.graph_layer {
                for chunk in &chunks {
                    if let Ok(triplets) = graph.extract_knowledge(chunk).await {
                        let _ = graph.upsert_triplets(&triplets, &doc.id).await;
                    }
                }
            }
        }

        Ok(())
    }

    /// Query the long-term memory for semantically relevant chunks.
    pub async fn query(&self, text: &str, top_k: usize) -> Result<Vec<(Node, f32)>> {
        let embedding = self.embedder.embed_query(text).await?;
        self.store.search(text, &embedding, top_k).await
    }

    /// Starts the background memory pruning and consolidation worker.
    pub fn start_consolidation_loop(&self, interval_seconds: u64) {
        let worker = consolidation::ConsolidationWorker::new(self.store.clone(), interval_seconds);
        worker.start();
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::chunker::RecursiveCharacterChunker;
    use crate::compute::mock::MockEmbedder;
    use crate::storage::memory::MemoryVectorStore;

    fn build_engine() -> MemoryEngine {
        let chunker = Arc::new(RecursiveCharacterChunker::new(50, 10));
        let embedder = Arc::new(MockEmbedder::new(8));
        let store = Arc::new(MemoryVectorStore::new());
        MemoryEngine::new(chunker, embedder, store)
    }

    #[tokio::test]
    async fn test_ingest_and_query_roundtrip() {
        let engine = build_engine();
        let doc = Document::new(
            "Rust is a systems programming language focused on safety and performance.",
        );
        engine.ingest_document(doc).await.unwrap();
        let results = engine.query("programming language", 5).await.unwrap();
        assert!(!results.is_empty());
    }

    #[tokio::test]
    async fn test_ingest_empty_document() {
        let engine = build_engine();
        engine.ingest_document(Document::new("")).await.unwrap();
        let results = engine.query("anything", 5).await.unwrap();
        assert!(results.is_empty());
    }

    #[tokio::test]
    async fn test_query_empty_store() {
        let engine = build_engine();
        let results = engine.query("no documents", 5).await.unwrap();
        assert!(results.is_empty());
    }

    #[tokio::test]
    async fn test_ingest_multiple_documents() {
        let engine = build_engine();
        engine
            .ingest_document(Document::new("First document about Rust"))
            .await
            .unwrap();
        engine
            .ingest_document(Document::new("Second document about Python"))
            .await
            .unwrap();
        engine
            .ingest_document(Document::new("Third document about JavaScript"))
            .await
            .unwrap();
        let results = engine.query("programming", 10).await.unwrap();
        assert!(results.len() >= 3);
    }

    #[tokio::test]
    async fn test_ingest_large_document_chunks_correctly() {
        let engine = build_engine();
        let doc = Document::new("a".repeat(200));
        engine.ingest_document(doc).await.unwrap();
        let results = engine.query("aaa", 100).await.unwrap();
        assert!(results.len() > 1);
    }

    #[tokio::test]
    async fn test_query_top_k_limiting() {
        let engine = build_engine();
        for i in 0..20 {
            engine
                .ingest_document(Document::new(format!(
                    "Document number {} with sufficient length.",
                    i
                )))
                .await
                .unwrap();
        }
        let results = engine.query("document", 3).await.unwrap();
        assert!(results.len() <= 3);
    }
}