use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use crate::models::Node;
use crate::traits::{CerebroError, Result, VectorStore};
#[derive(Clone, Default)]
pub struct MemoryVectorStore {
nodes: Arc<RwLock<HashMap<String, Node>>>,
}
impl MemoryVectorStore {
pub fn new() -> Self {
Self { nodes: Arc::new(RwLock::new(HashMap::new())) }
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() { return 0.0; }
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|y| y * y).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 { 0.0 } else { dot / (norm_a * norm_b) }
}
}
#[async_trait]
impl VectorStore for MemoryVectorStore {
async fn upsert(&self, nodes: Vec<Node>) -> Result<()> {
let mut store = self.nodes.write().map_err(|_| CerebroError::StorageError("Lock poisoned".into()))?;
for n in nodes { store.insert(n.id.clone(), n); }
Ok(())
}
async fn get(&self, node_ids: &[&str]) -> Result<Vec<Node>> {
let store = self.nodes.read().map_err(|_| CerebroError::StorageError("Lock poisoned".into()))?;
Ok(node_ids.iter().filter_map(|id| store.get(*id).cloned()).collect())
}
async fn search(&self, _text_query: &str, embedding: &[f32], top_k: usize) -> Result<Vec<(Node, f32)>> {
let store = self.nodes.read().map_err(|_| CerebroError::StorageError("Lock poisoned".into()))?;
let mut scored: Vec<(Node, f32)> = store.values()
.map(|node| {
let score = Self::cosine_similarity(&node.embedding, embedding);
(node.clone(), score)
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(top_k);
Ok(scored)
}
async fn delete_document(&self, doc_id: &str) -> Result<()> {
let mut store = self.nodes.write().map_err(|_| CerebroError::StorageError("Lock poisoned".into()))?;
store.retain(|_, node| node.chunk.document_id != doc_id);
Ok(())
}
async fn get_all_nodes(&self) -> Result<Vec<Node>> {
let store = self.nodes.read().map_err(|_| CerebroError::StorageError("Lock poisoned".into()))?;
Ok(store.values().cloned().collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::models::Chunk;
fn make_node(id: &str, doc_id: &str, text: &str, embedding: Vec<f32>) -> Node {
Node { id: id.into(), chunk: Chunk { document_id: doc_id.into(), index: 0, text: text.into() }, embedding, edges: vec![] }
}
#[tokio::test]
async fn test_upsert_and_get() {
let store = MemoryVectorStore::new();
store.upsert(vec![make_node("n1", "d1", "hello", vec![1.0, 0.0, 0.0])]).await.unwrap();
let results = store.get(&["n1"]).await.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].chunk.text, "hello");
}
#[tokio::test]
async fn test_get_nonexistent_returns_empty() {
let store = MemoryVectorStore::new();
assert!(store.get(&["nonexistent"]).await.unwrap().is_empty());
}
#[tokio::test]
async fn test_upsert_overwrites_existing() {
let store = MemoryVectorStore::new();
store.upsert(vec![make_node("n1", "d1", "original", vec![1.0, 0.0])]).await.unwrap();
store.upsert(vec![make_node("n1", "d1", "updated", vec![0.0, 1.0])]).await.unwrap();
let results = store.get(&["n1"]).await.unwrap();
assert_eq!(results[0].chunk.text, "updated");
}
#[tokio::test]
async fn test_search_returns_ranked_results() {
let store = MemoryVectorStore::new();
store.upsert(vec![
make_node("n1", "d1", "rust", vec![1.0, 0.0, 0.0]),
make_node("n2", "d1", "python", vec![0.0, 1.0, 0.0]),
make_node("n3", "d1", "rust mem", vec![0.9, 0.1, 0.0]),
]).await.unwrap();
let results = store.search("", &[1.0, 0.0, 0.0], 2).await.unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].0.id, "n1");
assert!((results[0].1 - 1.0).abs() < 0.001);
}
#[tokio::test]
async fn test_search_top_k_limits() {
let store = MemoryVectorStore::new();
for i in 0..10 {
store.upsert(vec![make_node(&format!("n{}", i), "d1", "t", vec![i as f32, 0.0])]).await.unwrap();
}
assert_eq!(store.search("", &[1.0, 0.0], 3).await.unwrap().len(), 3);
}
#[tokio::test]
async fn test_search_empty_store() {
assert!(MemoryVectorStore::new().search("", &[1.0], 5).await.unwrap().is_empty());
}
#[tokio::test]
async fn test_delete_document() {
let store = MemoryVectorStore::new();
store.upsert(vec![
make_node("n1", "doc-A", "c1", vec![1.0]),
make_node("n2", "doc-A", "c2", vec![0.5]),
make_node("n3", "doc-B", "c3", vec![0.3]),
]).await.unwrap();
store.delete_document("doc-A").await.unwrap();
assert!(store.get(&["n1", "n2"]).await.unwrap().is_empty());
assert_eq!(store.get(&["n3"]).await.unwrap().len(), 1);
}
#[tokio::test]
async fn test_cosine_identical() { assert!((MemoryVectorStore::cosine_similarity(&[1.0, 2.0], &[1.0, 2.0]) - 1.0).abs() < 0.001); }
#[tokio::test]
async fn test_cosine_orthogonal() { assert!(MemoryVectorStore::cosine_similarity(&[1.0, 0.0], &[0.0, 1.0]).abs() < 0.001); }
#[tokio::test]
async fn test_cosine_empty() { assert_eq!(MemoryVectorStore::cosine_similarity(&[], &[]), 0.0); }
#[tokio::test]
async fn test_cosine_mismatched() { assert_eq!(MemoryVectorStore::cosine_similarity(&[1.0, 2.0], &[1.0]), 0.0); }
}