cerebro 1.1.8

A blazing-fast AI memory layer that enables teams of specialized agents to collaborate through a shared cognitive architecture.
Documentation
use crate::models::Node;
use crate::traits::{CerebroError, Result, VectorStore};
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};

/// A simple in-memory vector store for testing and local development.
#[derive(Clone, Default)]
pub struct MemoryVectorStore {
    nodes: Arc<RwLock<HashMap<String, Node>>>,
}

impl MemoryVectorStore {
    pub fn new() -> Self {
        Self {
            nodes: Arc::new(RwLock::new(HashMap::new())),
        }
    }

    /// Naive cosine similarity for testing.
    fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
        if a.len() != b.len() || a.is_empty() {
            return 0.0;
        }
        let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
        let norm_b: f32 = b.iter().map(|y| y * y).sum::<f32>().sqrt();
        if norm_a == 0.0 || norm_b == 0.0 {
            0.0
        } else {
            dot / (norm_a * norm_b)
        }
    }
}

#[async_trait]
impl VectorStore for MemoryVectorStore {
    async fn upsert(&self, nodes: Vec<Node>) -> Result<()> {
        let mut store = self
            .nodes
            .write()
            .map_err(|_| CerebroError::StorageError("Lock poisoned".into()))?;
        for n in nodes {
            store.insert(n.id.clone(), n);
        }
        Ok(())
    }

    async fn get(&self, node_ids: &[&str]) -> Result<Vec<Node>> {
        let store = self
            .nodes
            .read()
            .map_err(|_| CerebroError::StorageError("Lock poisoned".into()))?;
        Ok(node_ids
            .iter()
            .filter_map(|id| store.get(*id).cloned())
            .collect())
    }

    async fn search(
        &self,
        _text_query: &str,
        embedding: &[f32],
        top_k: usize,
    ) -> Result<Vec<(Node, f32)>> {
        let store = self
            .nodes
            .read()
            .map_err(|_| CerebroError::StorageError("Lock poisoned".into()))?;
        let mut scored: Vec<(Node, f32)> = store
            .values()
            .map(|node| {
                let score = Self::cosine_similarity(&node.embedding, embedding);
                (node.clone(), score)
            })
            .collect();
        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
        scored.truncate(top_k);
        Ok(scored)
    }

    async fn delete_document(&self, doc_id: &str) -> Result<()> {
        let mut store = self
            .nodes
            .write()
            .map_err(|_| CerebroError::StorageError("Lock poisoned".into()))?;
        store.retain(|_, node| node.chunk.document_id != doc_id);
        Ok(())
    }

    async fn get_all_nodes(&self) -> Result<Vec<Node>> {
        let store = self
            .nodes
            .read()
            .map_err(|_| CerebroError::StorageError("Lock poisoned".into()))?;
        Ok(store.values().cloned().collect())
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::models::Chunk;

    fn make_node(id: &str, doc_id: &str, text: &str, embedding: Vec<f32>) -> Node {
        Node {
            id: id.into(),
            chunk: Chunk {
                document_id: doc_id.into(),
                index: 0,
                text: text.into(),
            },
            embedding,
            edges: vec![],
        }
    }

    #[tokio::test]
    async fn test_upsert_and_get() {
        let store = MemoryVectorStore::new();
        store
            .upsert(vec![make_node("n1", "d1", "hello", vec![1.0, 0.0, 0.0])])
            .await
            .unwrap();
        let results = store.get(&["n1"]).await.unwrap();
        assert_eq!(results.len(), 1);
        assert_eq!(results[0].chunk.text, "hello");
    }

    #[tokio::test]
    async fn test_get_nonexistent_returns_empty() {
        let store = MemoryVectorStore::new();
        assert!(store.get(&["nonexistent"]).await.unwrap().is_empty());
    }

    #[tokio::test]
    async fn test_upsert_overwrites_existing() {
        let store = MemoryVectorStore::new();
        store
            .upsert(vec![make_node("n1", "d1", "original", vec![1.0, 0.0])])
            .await
            .unwrap();
        store
            .upsert(vec![make_node("n1", "d1", "updated", vec![0.0, 1.0])])
            .await
            .unwrap();
        let results = store.get(&["n1"]).await.unwrap();
        assert_eq!(results[0].chunk.text, "updated");
    }

    #[tokio::test]
    async fn test_search_returns_ranked_results() {
        let store = MemoryVectorStore::new();
        store
            .upsert(vec![
                make_node("n1", "d1", "rust", vec![1.0, 0.0, 0.0]),
                make_node("n2", "d1", "python", vec![0.0, 1.0, 0.0]),
                make_node("n3", "d1", "rust mem", vec![0.9, 0.1, 0.0]),
            ])
            .await
            .unwrap();
        let results = store.search("", &[1.0, 0.0, 0.0], 2).await.unwrap();
        assert_eq!(results.len(), 2);
        assert_eq!(results[0].0.id, "n1");
        assert!((results[0].1 - 1.0).abs() < 0.001);
    }

    #[tokio::test]
    async fn test_search_top_k_limits() {
        let store = MemoryVectorStore::new();
        for i in 0..10 {
            store
                .upsert(vec![make_node(
                    &format!("n{}", i),
                    "d1",
                    "t",
                    vec![i as f32, 0.0],
                )])
                .await
                .unwrap();
        }
        assert_eq!(store.search("", &[1.0, 0.0], 3).await.unwrap().len(), 3);
    }

    #[tokio::test]
    async fn test_search_empty_store() {
        assert!(MemoryVectorStore::new()
            .search("", &[1.0], 5)
            .await
            .unwrap()
            .is_empty());
    }

    #[tokio::test]
    async fn test_delete_document() {
        let store = MemoryVectorStore::new();
        store
            .upsert(vec![
                make_node("n1", "doc-A", "c1", vec![1.0]),
                make_node("n2", "doc-A", "c2", vec![0.5]),
                make_node("n3", "doc-B", "c3", vec![0.3]),
            ])
            .await
            .unwrap();
        store.delete_document("doc-A").await.unwrap();
        assert!(store.get(&["n1", "n2"]).await.unwrap().is_empty());
        assert_eq!(store.get(&["n3"]).await.unwrap().len(), 1);
    }

    #[tokio::test]
    async fn test_cosine_identical() {
        assert!(
            (MemoryVectorStore::cosine_similarity(&[1.0, 2.0], &[1.0, 2.0]) - 1.0).abs() < 0.001
        );
    }

    #[tokio::test]
    async fn test_cosine_orthogonal() {
        assert!(MemoryVectorStore::cosine_similarity(&[1.0, 0.0], &[0.0, 1.0]).abs() < 0.001);
    }

    #[tokio::test]
    async fn test_cosine_empty() {
        assert_eq!(MemoryVectorStore::cosine_similarity(&[], &[]), 0.0);
    }

    #[tokio::test]
    async fn test_cosine_mismatched() {
        assert_eq!(
            MemoryVectorStore::cosine_similarity(&[1.0, 2.0], &[1.0]),
            0.0
        );
    }
}