cerebro 1.1.5

A blazing-fast AI memory layer that enables teams of specialized agents to collaborate through a shared cognitive architecture.
Documentation
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use crate::models::Node;
use crate::traits::{CerebroError, Result, VectorStore};

/// A simple in-memory vector store for testing and local development.
#[derive(Clone, Default)]
pub struct MemoryVectorStore {
    nodes: Arc<RwLock<HashMap<String, Node>>>,
}

impl MemoryVectorStore {
    pub fn new() -> Self {
        Self { nodes: Arc::new(RwLock::new(HashMap::new())) }
    }

    /// Naive cosine similarity for testing.
    fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
        if a.len() != b.len() || a.is_empty() { return 0.0; }
        let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
        let norm_b: f32 = b.iter().map(|y| y * y).sum::<f32>().sqrt();
        if norm_a == 0.0 || norm_b == 0.0 { 0.0 } else { dot / (norm_a * norm_b) }
    }
}

#[async_trait]
impl VectorStore for MemoryVectorStore {
    async fn upsert(&self, nodes: Vec<Node>) -> Result<()> {
        let mut store = self.nodes.write().map_err(|_| CerebroError::StorageError("Lock poisoned".into()))?;
        for n in nodes { store.insert(n.id.clone(), n); }
        Ok(())
    }

    async fn get(&self, node_ids: &[&str]) -> Result<Vec<Node>> {
        let store = self.nodes.read().map_err(|_| CerebroError::StorageError("Lock poisoned".into()))?;
        Ok(node_ids.iter().filter_map(|id| store.get(*id).cloned()).collect())
    }

    async fn search(&self, _text_query: &str, embedding: &[f32], top_k: usize) -> Result<Vec<(Node, f32)>> {
        let store = self.nodes.read().map_err(|_| CerebroError::StorageError("Lock poisoned".into()))?;
        let mut scored: Vec<(Node, f32)> = store.values()
            .map(|node| {
                let score = Self::cosine_similarity(&node.embedding, embedding);
                (node.clone(), score)
            })
            .collect();
        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
        scored.truncate(top_k);
        Ok(scored)
    }

    async fn delete_document(&self, doc_id: &str) -> Result<()> {
        let mut store = self.nodes.write().map_err(|_| CerebroError::StorageError("Lock poisoned".into()))?;
        store.retain(|_, node| node.chunk.document_id != doc_id);
        Ok(())
    }

    async fn get_all_nodes(&self) -> Result<Vec<Node>> {
        let store = self.nodes.read().map_err(|_| CerebroError::StorageError("Lock poisoned".into()))?;
        Ok(store.values().cloned().collect())
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::models::Chunk;

    fn make_node(id: &str, doc_id: &str, text: &str, embedding: Vec<f32>) -> Node {
        Node { id: id.into(), chunk: Chunk { document_id: doc_id.into(), index: 0, text: text.into() }, embedding, edges: vec![] }
    }

    #[tokio::test]
    async fn test_upsert_and_get() {
        let store = MemoryVectorStore::new();
        store.upsert(vec![make_node("n1", "d1", "hello", vec![1.0, 0.0, 0.0])]).await.unwrap();
        let results = store.get(&["n1"]).await.unwrap();
        assert_eq!(results.len(), 1);
        assert_eq!(results[0].chunk.text, "hello");
    }

    #[tokio::test]
    async fn test_get_nonexistent_returns_empty() {
        let store = MemoryVectorStore::new();
        assert!(store.get(&["nonexistent"]).await.unwrap().is_empty());
    }

    #[tokio::test]
    async fn test_upsert_overwrites_existing() {
        let store = MemoryVectorStore::new();
        store.upsert(vec![make_node("n1", "d1", "original", vec![1.0, 0.0])]).await.unwrap();
        store.upsert(vec![make_node("n1", "d1", "updated", vec![0.0, 1.0])]).await.unwrap();
        let results = store.get(&["n1"]).await.unwrap();
        assert_eq!(results[0].chunk.text, "updated");
    }

    #[tokio::test]
    async fn test_search_returns_ranked_results() {
        let store = MemoryVectorStore::new();
        store.upsert(vec![
            make_node("n1", "d1", "rust", vec![1.0, 0.0, 0.0]),
            make_node("n2", "d1", "python", vec![0.0, 1.0, 0.0]),
            make_node("n3", "d1", "rust mem", vec![0.9, 0.1, 0.0]),
        ]).await.unwrap();
        let results = store.search("", &[1.0, 0.0, 0.0], 2).await.unwrap();
        assert_eq!(results.len(), 2);
        assert_eq!(results[0].0.id, "n1");
        assert!((results[0].1 - 1.0).abs() < 0.001);
    }

    #[tokio::test]
    async fn test_search_top_k_limits() {
        let store = MemoryVectorStore::new();
        for i in 0..10 {
            store.upsert(vec![make_node(&format!("n{}", i), "d1", "t", vec![i as f32, 0.0])]).await.unwrap();
        }
        assert_eq!(store.search("", &[1.0, 0.0], 3).await.unwrap().len(), 3);
    }

    #[tokio::test]
    async fn test_search_empty_store() {
        assert!(MemoryVectorStore::new().search("", &[1.0], 5).await.unwrap().is_empty());
    }

    #[tokio::test]
    async fn test_delete_document() {
        let store = MemoryVectorStore::new();
        store.upsert(vec![
            make_node("n1", "doc-A", "c1", vec![1.0]),
            make_node("n2", "doc-A", "c2", vec![0.5]),
            make_node("n3", "doc-B", "c3", vec![0.3]),
        ]).await.unwrap();
        store.delete_document("doc-A").await.unwrap();
        assert!(store.get(&["n1", "n2"]).await.unwrap().is_empty());
        assert_eq!(store.get(&["n3"]).await.unwrap().len(), 1);
    }

    #[tokio::test]
    async fn test_cosine_identical() { assert!((MemoryVectorStore::cosine_similarity(&[1.0, 2.0], &[1.0, 2.0]) - 1.0).abs() < 0.001); }

    #[tokio::test]
    async fn test_cosine_orthogonal() { assert!(MemoryVectorStore::cosine_similarity(&[1.0, 0.0], &[0.0, 1.0]).abs() < 0.001); }

    #[tokio::test]
    async fn test_cosine_empty() { assert_eq!(MemoryVectorStore::cosine_similarity(&[], &[]), 0.0); }

    #[tokio::test]
    async fn test_cosine_mismatched() { assert_eq!(MemoryVectorStore::cosine_similarity(&[1.0, 2.0], &[1.0]), 0.0); }
}