use crate::models::Node;
use crate::traits::{CerebroError, Result, VectorStore};
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
#[derive(Clone, Default)]
pub struct MemoryVectorStore {
nodes: Arc<RwLock<HashMap<String, Node>>>,
}
impl MemoryVectorStore {
pub fn new() -> Self {
Self {
nodes: Arc::new(RwLock::new(HashMap::new())),
}
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|y| y * y).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
0.0
} else {
dot / (norm_a * norm_b)
}
}
}
#[async_trait]
impl VectorStore for MemoryVectorStore {
async fn upsert(&self, nodes: Vec<Node>) -> Result<()> {
let mut store = self
.nodes
.write()
.map_err(|_| CerebroError::StorageError("Lock poisoned".into()))?;
for n in nodes {
store.insert(n.id.clone(), n);
}
Ok(())
}
async fn get(&self, node_ids: &[&str]) -> Result<Vec<Node>> {
let store = self
.nodes
.read()
.map_err(|_| CerebroError::StorageError("Lock poisoned".into()))?;
Ok(node_ids
.iter()
.filter_map(|id| store.get(*id).cloned())
.collect())
}
async fn search(
&self,
_text_query: &str,
embedding: &[f32],
top_k: usize,
) -> Result<Vec<(Node, f32)>> {
let store = self
.nodes
.read()
.map_err(|_| CerebroError::StorageError("Lock poisoned".into()))?;
let mut scored: Vec<(Node, f32)> = store
.values()
.map(|node| {
let score = Self::cosine_similarity(&node.embedding, embedding);
(node.clone(), score)
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(top_k);
Ok(scored)
}
async fn delete_document(&self, doc_id: &str) -> Result<()> {
let mut store = self
.nodes
.write()
.map_err(|_| CerebroError::StorageError("Lock poisoned".into()))?;
store.retain(|_, node| node.chunk.document_id != doc_id);
Ok(())
}
async fn get_all_nodes(&self) -> Result<Vec<Node>> {
let store = self
.nodes
.read()
.map_err(|_| CerebroError::StorageError("Lock poisoned".into()))?;
Ok(store.values().cloned().collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::models::Chunk;
fn make_node(id: &str, doc_id: &str, text: &str, embedding: Vec<f32>) -> Node {
Node {
id: id.into(),
chunk: Chunk {
document_id: doc_id.into(),
index: 0,
text: text.into(),
},
embedding,
edges: vec![],
}
}
#[tokio::test]
async fn test_upsert_and_get() {
let store = MemoryVectorStore::new();
store
.upsert(vec![make_node("n1", "d1", "hello", vec![1.0, 0.0, 0.0])])
.await
.unwrap();
let results = store.get(&["n1"]).await.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].chunk.text, "hello");
}
#[tokio::test]
async fn test_get_nonexistent_returns_empty() {
let store = MemoryVectorStore::new();
assert!(store.get(&["nonexistent"]).await.unwrap().is_empty());
}
#[tokio::test]
async fn test_upsert_overwrites_existing() {
let store = MemoryVectorStore::new();
store
.upsert(vec![make_node("n1", "d1", "original", vec![1.0, 0.0])])
.await
.unwrap();
store
.upsert(vec![make_node("n1", "d1", "updated", vec![0.0, 1.0])])
.await
.unwrap();
let results = store.get(&["n1"]).await.unwrap();
assert_eq!(results[0].chunk.text, "updated");
}
#[tokio::test]
async fn test_search_returns_ranked_results() {
let store = MemoryVectorStore::new();
store
.upsert(vec![
make_node("n1", "d1", "rust", vec![1.0, 0.0, 0.0]),
make_node("n2", "d1", "python", vec![0.0, 1.0, 0.0]),
make_node("n3", "d1", "rust mem", vec![0.9, 0.1, 0.0]),
])
.await
.unwrap();
let results = store.search("", &[1.0, 0.0, 0.0], 2).await.unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].0.id, "n1");
assert!((results[0].1 - 1.0).abs() < 0.001);
}
#[tokio::test]
async fn test_search_top_k_limits() {
let store = MemoryVectorStore::new();
for i in 0..10 {
store
.upsert(vec![make_node(
&format!("n{}", i),
"d1",
"t",
vec![i as f32, 0.0],
)])
.await
.unwrap();
}
assert_eq!(store.search("", &[1.0, 0.0], 3).await.unwrap().len(), 3);
}
#[tokio::test]
async fn test_search_empty_store() {
assert!(MemoryVectorStore::new()
.search("", &[1.0], 5)
.await
.unwrap()
.is_empty());
}
#[tokio::test]
async fn test_delete_document() {
let store = MemoryVectorStore::new();
store
.upsert(vec![
make_node("n1", "doc-A", "c1", vec![1.0]),
make_node("n2", "doc-A", "c2", vec![0.5]),
make_node("n3", "doc-B", "c3", vec![0.3]),
])
.await
.unwrap();
store.delete_document("doc-A").await.unwrap();
assert!(store.get(&["n1", "n2"]).await.unwrap().is_empty());
assert_eq!(store.get(&["n3"]).await.unwrap().len(), 1);
}
#[tokio::test]
async fn test_cosine_identical() {
assert!(
(MemoryVectorStore::cosine_similarity(&[1.0, 2.0], &[1.0, 2.0]) - 1.0).abs() < 0.001
);
}
#[tokio::test]
async fn test_cosine_orthogonal() {
assert!(MemoryVectorStore::cosine_similarity(&[1.0, 0.0], &[0.0, 1.0]).abs() < 0.001);
}
#[tokio::test]
async fn test_cosine_empty() {
assert_eq!(MemoryVectorStore::cosine_similarity(&[], &[]), 0.0);
}
#[tokio::test]
async fn test_cosine_mismatched() {
assert_eq!(
MemoryVectorStore::cosine_similarity(&[1.0, 2.0], &[1.0]),
0.0
);
}
}