1use crate::{Document, Vector, VectorStore};
8use async_trait::async_trait;
9use std::sync::{Arc, Mutex};
10
11#[derive(Debug, Clone, Default)]
16pub struct MemoryVectorStore {
17 documents: Arc<Mutex<Vec<Document>>>,
18}
19
20impl MemoryVectorStore {
21 pub fn new() -> Self {
23 Self::default()
24 }
25
26 pub fn len(&self) -> usize {
28 self.documents.lock().unwrap().len()
29 }
30
31 pub fn is_empty(&self) -> bool {
33 self.len() == 0
34 }
35}
36
37#[async_trait]
38impl VectorStore for MemoryVectorStore {
39 async fn add_documents(
40 &self,
41 docs: Vec<Document>,
42 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
43 self.documents.lock().unwrap().extend(docs);
44 Ok(())
45 }
46
47 async fn search(
48 &self,
49 query_vector: Vector,
50 limit: usize,
51 ) -> Result<Vec<Document>, Box<dyn std::error::Error + Send + Sync>> {
52 let documents = self.documents.lock().unwrap();
53
54 let mut scored: Vec<(f32, &Document)> = documents
55 .iter()
56 .filter_map(|doc| {
57 doc.embedding
58 .as_ref()
59 .map(|emb| (cosine_similarity(&query_vector, emb), doc))
60 })
61 .collect();
62
63 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Less));
65
66 Ok(scored
67 .into_iter()
68 .take(limit)
69 .map(|(_, doc)| doc.clone())
70 .collect())
71 }
72}
73
74fn cosine_similarity(v1: &[f32], v2: &[f32]) -> f32 {
78 let dot: f32 = v1.iter().zip(v2.iter()).map(|(a, b)| a * b).sum();
79 let n1: f32 = v1.iter().map(|a| a * a).sum::<f32>().sqrt();
80 let n2: f32 = v2.iter().map(|b| b * b).sum::<f32>().sqrt();
81 if n1 == 0.0 || n2 == 0.0 {
82 0.0
83 } else {
84 dot / (n1 * n2)
85 }
86}
87
88#[cfg(test)]
89mod tests {
90 use super::*;
91 use crate::VectorStore;
92
93 fn doc(id: &str, content: &str, embedding: Vec<f32>) -> Document {
94 Document {
95 id: id.to_string(),
96 content: content.to_string(),
97 metadata: serde_json::Value::Null,
98 embedding: Some(embedding),
99 }
100 }
101
102 #[tokio::test]
103 async fn test_search_returns_closest() {
104 let store = MemoryVectorStore::new();
105 store
106 .add_documents(vec![
107 doc("1", "close", vec![1.0, 0.0, 0.0]),
108 doc("2", "far", vec![0.0, 1.0, 0.0]),
109 doc("3", "medium", vec![0.7, 0.7, 0.0]),
110 ])
111 .await
112 .unwrap();
113
114 let results = store
115 .search(vec![1.0, 0.0, 0.0], 1)
116 .await
117 .unwrap();
118
119 assert_eq!(results.len(), 1);
120 assert_eq!(results[0].content, "close");
121 }
122
123 #[tokio::test]
124 async fn test_search_respects_limit() {
125 let store = MemoryVectorStore::new();
126 store
127 .add_documents(vec![
128 doc("a", "a", vec![1.0, 0.0]),
129 doc("b", "b", vec![0.8, 0.6]),
130 doc("c", "c", vec![0.0, 1.0]),
131 ])
132 .await
133 .unwrap();
134
135 let results = store.search(vec![1.0, 0.0], 2).await.unwrap();
136 assert_eq!(results.len(), 2);
137 }
138
139 #[test]
140 fn test_cosine_zero_vector() {
141 assert_eq!(cosine_similarity(&[0.0, 0.0], &[1.0, 0.0]), 0.0);
142 }
143}