1use std::collections::HashMap;
6use std::sync::Arc;
7
8use async_trait::async_trait;
9use uuid::Uuid;
10
11use cognis_core::{CognisError, Result};
12
13use crate::distance::Distance;
14use crate::embeddings::Embeddings;
15
16use super::{SearchResult, VectorStore};
17
18#[derive(Clone)]
19struct StoredDoc {
20 id: String,
21 text: String,
22 vector: Vec<f32>,
23 metadata: HashMap<String, serde_json::Value>,
24}
25
26pub struct InMemoryVectorStore {
28 embedder: Arc<dyn Embeddings>,
29 distance: Distance,
30 docs: Vec<StoredDoc>,
31}
32
33impl InMemoryVectorStore {
34 pub fn new(embedder: Arc<dyn Embeddings>) -> Self {
36 Self::with_distance(embedder, Distance::Cosine)
37 }
38
39 pub fn with_distance(embedder: Arc<dyn Embeddings>, distance: Distance) -> Self {
41 Self {
42 embedder,
43 distance,
44 docs: Vec::new(),
45 }
46 }
47
48 pub fn distance(&self) -> Distance {
50 self.distance
51 }
52}
53
54#[async_trait]
55impl VectorStore for InMemoryVectorStore {
56 async fn add_texts(
57 &mut self,
58 texts: Vec<String>,
59 metadata: Option<Vec<HashMap<String, serde_json::Value>>>,
60 ) -> Result<Vec<String>> {
61 if texts.is_empty() {
62 return Ok(Vec::new());
63 }
64 let vectors = self.embedder.embed_documents(texts.clone()).await?;
65 self.add_vectors(vectors, texts, metadata).await
66 }
67
68 async fn add_vectors(
69 &mut self,
70 vectors: Vec<Vec<f32>>,
71 texts: Vec<String>,
72 metadata: Option<Vec<HashMap<String, serde_json::Value>>>,
73 ) -> Result<Vec<String>> {
74 if vectors.len() != texts.len() {
75 return Err(CognisError::Configuration(format!(
76 "vectors.len() ({}) must equal texts.len() ({})",
77 vectors.len(),
78 texts.len()
79 )));
80 }
81 if let Some(m) = &metadata {
82 if m.len() != texts.len() {
83 return Err(CognisError::Configuration(format!(
84 "metadata.len() ({}) must equal texts.len() ({})",
85 m.len(),
86 texts.len()
87 )));
88 }
89 }
90
91 let mut ids = Vec::with_capacity(texts.len());
92 for (i, (text, vector)) in texts.into_iter().zip(vectors).enumerate() {
93 let id = Uuid::new_v4().to_string();
94 let md = metadata.as_ref().map(|m| m[i].clone()).unwrap_or_default();
95 ids.push(id.clone());
96 self.docs.push(StoredDoc {
97 id,
98 text,
99 vector,
100 metadata: md,
101 });
102 }
103 Ok(ids)
104 }
105
106 async fn similarity_search(&self, query: &str, k: usize) -> Result<Vec<SearchResult>> {
107 let qv = self.embedder.embed_query(query.to_string()).await?;
108 self.similarity_search_by_vector(qv, k).await
109 }
110
111 async fn similarity_search_by_vector(
112 &self,
113 query_vector: Vec<f32>,
114 k: usize,
115 ) -> Result<Vec<SearchResult>> {
116 if self.docs.is_empty() || k == 0 {
117 return Ok(Vec::new());
118 }
119
120 let mut scored: Vec<(f32, &StoredDoc)> = self
122 .docs
123 .iter()
124 .map(|d| (self.distance.similarity(&query_vector, &d.vector), d))
125 .collect();
126
127 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
129
130 Ok(scored
131 .into_iter()
132 .take(k)
133 .map(|(score, d)| SearchResult {
134 id: d.id.clone(),
135 text: d.text.clone(),
136 score,
137 metadata: d.metadata.clone(),
138 })
139 .collect())
140 }
141
142 async fn delete(&mut self, ids: Vec<String>) -> Result<()> {
143 let to_delete: std::collections::HashSet<String> = ids.into_iter().collect();
144 self.docs.retain(|d| !to_delete.contains(&d.id));
145 Ok(())
146 }
147
148 fn len(&self) -> usize {
149 self.docs.len()
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156 use crate::embeddings::FakeEmbeddings;
157
158 fn fake_embedder(dim: usize) -> Arc<dyn Embeddings> {
159 Arc::new(FakeEmbeddings::new(dim))
160 }
161
162 #[tokio::test]
163 async fn add_texts_assigns_ids() {
164 let mut store = InMemoryVectorStore::new(fake_embedder(8));
165 let ids = store
166 .add_texts(vec!["a".into(), "b".into(), "c".into()], None)
167 .await
168 .unwrap();
169 assert_eq!(ids.len(), 3);
170 assert_eq!(store.len(), 3);
171 let unique: std::collections::HashSet<_> = ids.iter().collect();
173 assert_eq!(unique.len(), 3);
174 }
175
176 #[tokio::test]
177 async fn search_returns_matches_in_order() {
178 let mut store = InMemoryVectorStore::new(fake_embedder(8));
179 store
180 .add_texts(vec!["dog".into(), "cat".into(), "fish".into()], None)
181 .await
182 .unwrap();
183
184 let results = store.similarity_search("dog", 2).await.unwrap();
185 assert_eq!(results.len(), 2);
186 assert_eq!(results[0].text, "dog");
189 }
190
191 #[tokio::test]
192 async fn search_respects_k() {
193 let mut store = InMemoryVectorStore::new(fake_embedder(8));
194 store
195 .add_texts((0..10).map(|i| format!("doc {i}")).collect(), None)
196 .await
197 .unwrap();
198 let r1 = store.similarity_search("doc 5", 1).await.unwrap();
199 let r5 = store.similarity_search("doc 5", 5).await.unwrap();
200 assert_eq!(r1.len(), 1);
201 assert_eq!(r5.len(), 5);
202 }
203
204 #[tokio::test]
205 async fn metadata_roundtrip() {
206 let mut store = InMemoryVectorStore::new(fake_embedder(8));
207 let mut md = HashMap::new();
208 md.insert("source".into(), serde_json::json!("wiki"));
209 md.insert("year".into(), serde_json::json!(2024));
210 store
211 .add_texts(vec!["hello".into()], Some(vec![md.clone()]))
212 .await
213 .unwrap();
214 let r = store.similarity_search("hello", 1).await.unwrap();
215 assert_eq!(r[0].metadata.get("source").unwrap(), "wiki");
216 assert_eq!(r[0].metadata.get("year").unwrap(), 2024);
217 }
218
219 #[tokio::test]
220 async fn add_vectors_dimension_mismatch_errors() {
221 let mut store = InMemoryVectorStore::new(fake_embedder(8));
222 let err = store
223 .add_vectors(vec![vec![0.1; 8], vec![0.2; 8]], vec!["one".into()], None)
224 .await
225 .unwrap_err();
226 assert!(format!("{err}").contains("must equal"));
227 }
228
229 #[tokio::test]
230 async fn delete_removes_docs() {
231 let mut store = InMemoryVectorStore::new(fake_embedder(8));
232 let ids = store
233 .add_texts(vec!["a".into(), "b".into(), "c".into()], None)
234 .await
235 .unwrap();
236 store.delete(vec![ids[1].clone()]).await.unwrap();
237 assert_eq!(store.len(), 2);
238 let r = store.similarity_search("b", 5).await.unwrap();
239 assert!(!r.iter().any(|s| s.text == "b"));
241 }
242
243 #[tokio::test]
244 async fn delete_unknown_ids_silent() {
245 let mut store = InMemoryVectorStore::new(fake_embedder(8));
246 store.add_texts(vec!["a".into()], None).await.unwrap();
247 store.delete(vec!["nonexistent".into()]).await.unwrap();
249 assert_eq!(store.len(), 1);
250 }
251
252 #[tokio::test]
253 async fn empty_store_search_returns_empty() {
254 let store = InMemoryVectorStore::new(fake_embedder(8));
255 let r = store.similarity_search("anything", 5).await.unwrap();
256 assert!(r.is_empty());
257 }
258
259 #[tokio::test]
260 async fn k_zero_returns_empty() {
261 let mut store = InMemoryVectorStore::new(fake_embedder(8));
262 store.add_texts(vec!["a".into()], None).await.unwrap();
263 let r = store.similarity_search("a", 0).await.unwrap();
264 assert!(r.is_empty());
265 }
266}