1use crate::error::{Result, StoreError};
2use crate::traits::{
3 cosine_similarity, CollectionInfo, SearchResult, VectorFilter, VectorRecord, VectorStore,
4};
5use async_trait::async_trait;
6use parking_lot::RwLock;
7use serde_json::Value;
8use std::collections::HashMap;
9
10struct Collection {
11 dimension: usize,
12 records: HashMap<String, VectorRecord>,
13}
14
15pub struct MemoryVectorStore {
16 collections: RwLock<HashMap<String, Collection>>,
17}
18
19impl MemoryVectorStore {
20 pub fn new() -> Self {
21 Self {
22 collections: RwLock::new(HashMap::new()),
23 }
24 }
25}
26
27impl Default for MemoryVectorStore {
28 fn default() -> Self {
29 Self::new()
30 }
31}
32
33fn matches_filter(record: &VectorRecord, filter: &VectorFilter) -> bool {
34 use crate::traits::vector::FilterCondition;
35
36 for condition in &filter.conditions {
37 let matched = match condition {
38 FilterCondition::Eq(field, value) => {
39 record.metadata.get(field).map(|v| v == value).unwrap_or(false)
40 }
41 FilterCondition::Ne(field, value) => {
42 record.metadata.get(field).map(|v| v != value).unwrap_or(true)
43 }
44 FilterCondition::Gt(field, value) => match (record.metadata.get(field), value) {
45 (Some(Value::Number(a)), Value::Number(b)) => {
46 a.as_f64().unwrap_or(0.0) > b.as_f64().unwrap_or(0.0)
47 }
48 _ => false,
49 },
50 FilterCondition::Gte(field, value) => match (record.metadata.get(field), value) {
51 (Some(Value::Number(a)), Value::Number(b)) => {
52 a.as_f64().unwrap_or(0.0) >= b.as_f64().unwrap_or(0.0)
53 }
54 _ => false,
55 },
56 FilterCondition::Lt(field, value) => match (record.metadata.get(field), value) {
57 (Some(Value::Number(a)), Value::Number(b)) => {
58 a.as_f64().unwrap_or(0.0) < b.as_f64().unwrap_or(0.0)
59 }
60 _ => false,
61 },
62 FilterCondition::Lte(field, value) => match (record.metadata.get(field), value) {
63 (Some(Value::Number(a)), Value::Number(b)) => {
64 a.as_f64().unwrap_or(0.0) <= b.as_f64().unwrap_or(0.0)
65 }
66 _ => false,
67 },
68 FilterCondition::In(field, values) => record
69 .metadata
70 .get(field)
71 .map(|v| values.contains(v))
72 .unwrap_or(false),
73 FilterCondition::Contains(field, substr) => record
74 .metadata
75 .get(field)
76 .and_then(|v| v.as_str())
77 .map(|s| s.contains(substr))
78 .unwrap_or(false),
79 };
80
81 if !matched {
82 return false;
83 }
84 }
85
86 true
87}
88
89#[async_trait]
90impl VectorStore for MemoryVectorStore {
91 async fn create_collection(&self, name: &str, dimension: usize) -> Result<()> {
92 let mut collections = self.collections.write();
93 if collections.contains_key(name) {
94 return Err(StoreError::AlreadyExists(format!("Collection '{}'", name)));
95 }
96 collections.insert(
97 name.to_string(),
98 Collection {
99 dimension,
100 records: HashMap::new(),
101 },
102 );
103 Ok(())
104 }
105
106 async fn delete_collection(&self, name: &str) -> Result<()> {
107 let mut collections = self.collections.write();
108 if collections.remove(name).is_none() {
109 return Err(StoreError::not_found(format!("Collection '{}'", name)));
110 }
111 Ok(())
112 }
113
114 async fn list_collections(&self) -> Result<Vec<CollectionInfo>> {
115 let collections = self.collections.read();
116 Ok(collections
117 .iter()
118 .map(|(name, col)| CollectionInfo {
119 name: name.clone(),
120 dimension: col.dimension,
121 count: col.records.len(),
122 })
123 .collect())
124 }
125
126 async fn collection_exists(&self, name: &str) -> Result<bool> {
127 Ok(self.collections.read().contains_key(name))
128 }
129
130 async fn upsert(&self, collection: &str, records: &[VectorRecord]) -> Result<usize> {
131 let mut collections = self.collections.write();
132 let col = collections
133 .get_mut(collection)
134 .ok_or_else(|| StoreError::not_found(format!("Collection '{}'", collection)))?;
135
136 let mut count = 0;
137 for record in records {
138 if record.vector.len() != col.dimension {
139 return Err(StoreError::invalid_input(format!(
140 "Vector dimension mismatch: expected {}, got {}",
141 col.dimension,
142 record.vector.len()
143 )));
144 }
145 col.records.insert(record.id.clone(), record.clone());
146 count += 1;
147 }
148
149 Ok(count)
150 }
151
152 async fn search(
153 &self,
154 collection: &str,
155 query: &[f32],
156 limit: usize,
157 ) -> Result<Vec<SearchResult>> {
158 self.search_with_filter(collection, query, &VectorFilter::default(), limit)
159 .await
160 }
161
162 async fn search_with_filter(
163 &self,
164 collection: &str,
165 query: &[f32],
166 filter: &VectorFilter,
167 limit: usize,
168 ) -> Result<Vec<SearchResult>> {
169 let collections = self.collections.read();
170 let col = collections
171 .get(collection)
172 .ok_or_else(|| StoreError::not_found(format!("Collection '{}'", collection)))?;
173
174 if query.len() != col.dimension {
175 return Err(StoreError::invalid_input(format!(
176 "Query dimension mismatch: expected {}, got {}",
177 col.dimension,
178 query.len()
179 )));
180 }
181
182 let mut scored: Vec<(String, f32, HashMap<String, Value>, Option<String>)> = col
183 .records
184 .values()
185 .filter(|r| filter.is_empty() || matches_filter(r, filter))
186 .map(|r| {
187 let score = cosine_similarity(query, &r.vector);
188 (r.id.clone(), score, r.metadata.clone(), r.content.clone())
189 })
190 .collect();
191
192 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
193
194 Ok(scored
195 .into_iter()
196 .take(limit)
197 .map(|(id, score, metadata, content)| SearchResult {
198 id,
199 score,
200 metadata,
201 content,
202 })
203 .collect())
204 }
205
206 async fn get(&self, collection: &str, id: &str) -> Result<Option<VectorRecord>> {
207 let collections = self.collections.read();
208 let col = collections
209 .get(collection)
210 .ok_or_else(|| StoreError::not_found(format!("Collection '{}'", collection)))?;
211
212 Ok(col.records.get(id).cloned())
213 }
214
215 async fn delete(&self, collection: &str, ids: &[String]) -> Result<usize> {
216 let mut collections = self.collections.write();
217 let col = collections
218 .get_mut(collection)
219 .ok_or_else(|| StoreError::not_found(format!("Collection '{}'", collection)))?;
220
221 let mut count = 0;
222 for id in ids {
223 if col.records.remove(id).is_some() {
224 count += 1;
225 }
226 }
227
228 Ok(count)
229 }
230
231 async fn count(&self, collection: &str) -> Result<usize> {
232 let collections = self.collections.read();
233 let col = collections
234 .get(collection)
235 .ok_or_else(|| StoreError::not_found(format!("Collection '{}'", collection)))?;
236
237 Ok(col.records.len())
238 }
239
240 fn backend_name(&self) -> &'static str {
241 "memory"
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248
249 #[tokio::test]
250 async fn test_collection_lifecycle() {
251 let store = MemoryVectorStore::new();
252
253 assert!(!store.collection_exists("test").await.unwrap());
254
255 store.create_collection("test", 3).await.unwrap();
256 assert!(store.collection_exists("test").await.unwrap());
257
258 let err = store.create_collection("test", 3).await;
259 assert!(err.is_err());
260
261 store.delete_collection("test").await.unwrap();
262 assert!(!store.collection_exists("test").await.unwrap());
263 }
264
265 #[tokio::test]
266 async fn test_upsert_and_search() {
267 let store = MemoryVectorStore::new();
268 store.create_collection("docs", 3).await.unwrap();
269
270 let records = vec![
271 VectorRecord::new("doc1", vec![1.0, 0.0, 0.0])
272 .with_metadata("category", "a")
273 .with_content("Document one"),
274 VectorRecord::new("doc2", vec![0.0, 1.0, 0.0])
275 .with_metadata("category", "b")
276 .with_content("Document two"),
277 VectorRecord::new("doc3", vec![0.707, 0.707, 0.0])
278 .with_metadata("category", "a")
279 .with_content("Document three"),
280 ];
281
282 let count = store.upsert("docs", &records).await.unwrap();
283 assert_eq!(count, 3);
284 assert_eq!(store.count("docs").await.unwrap(), 3);
285
286 let results = store.search("docs", &[1.0, 0.0, 0.0], 2).await.unwrap();
287 assert_eq!(results.len(), 2);
288 assert_eq!(results[0].id, "doc1");
289 assert!((results[0].score - 1.0).abs() < 0.001);
290 }
291
292 #[tokio::test]
293 async fn test_search_with_filter() {
294 let store = MemoryVectorStore::new();
295 store.create_collection("items", 2).await.unwrap();
296
297 let records = vec![
298 VectorRecord::new("item1", vec![1.0, 0.0]).with_metadata("type", "book"),
299 VectorRecord::new("item2", vec![0.9, 0.1]).with_metadata("type", "book"),
300 VectorRecord::new("item3", vec![0.8, 0.2]).with_metadata("type", "video"),
301 ];
302 store.upsert("items", &records).await.unwrap();
303
304 let filter = VectorFilter::new().eq("type", "book");
305 let results = store
306 .search_with_filter("items", &[1.0, 0.0], &filter, 10)
307 .await
308 .unwrap();
309
310 assert_eq!(results.len(), 2);
311 for r in &results {
312 assert_eq!(r.metadata.get("type"), Some(&Value::String("book".into())));
313 }
314 }
315
316 #[tokio::test]
317 async fn test_get_and_delete() {
318 let store = MemoryVectorStore::new();
319 store.create_collection("test", 2).await.unwrap();
320
321 store
322 .upsert("test", &[VectorRecord::new("id1", vec![1.0, 0.0])])
323 .await
324 .unwrap();
325
326 let record = store.get("test", "id1").await.unwrap();
327 assert!(record.is_some());
328 assert_eq!(record.unwrap().id, "id1");
329
330 let deleted = store
331 .delete("test", &["id1".to_string()])
332 .await
333 .unwrap();
334 assert_eq!(deleted, 1);
335
336 let record = store.get("test", "id1").await.unwrap();
337 assert!(record.is_none());
338 }
339
340 #[tokio::test]
341 async fn test_dimension_validation() {
342 let store = MemoryVectorStore::new();
343 store.create_collection("test", 3).await.unwrap();
344
345 let result = store
346 .upsert("test", &[VectorRecord::new("id1", vec![1.0, 0.0])])
347 .await;
348 assert!(result.is_err());
349
350 let result = store.search("test", &[1.0, 0.0], 10).await;
351 assert!(result.is_err());
352 }
353
354 #[tokio::test]
355 async fn test_list_collections() {
356 let store = MemoryVectorStore::new();
357
358 store.create_collection("col1", 10).await.unwrap();
359 store.create_collection("col2", 20).await.unwrap();
360
361 let list = store.list_collections().await.unwrap();
362 assert_eq!(list.len(), 2);
363
364 let names: Vec<_> = list.iter().map(|c| c.name.as_str()).collect();
365 assert!(names.contains(&"col1"));
366 assert!(names.contains(&"col2"));
367 }
368
369 #[tokio::test]
370 async fn test_upsert_updates_existing() {
371 let store = MemoryVectorStore::new();
372 store.create_collection("test", 2).await.unwrap();
373
374 store
375 .upsert(
376 "test",
377 &[VectorRecord::new("id1", vec![1.0, 0.0]).with_metadata("version", 1)],
378 )
379 .await
380 .unwrap();
381
382 store
383 .upsert(
384 "test",
385 &[VectorRecord::new("id1", vec![0.0, 1.0]).with_metadata("version", 2)],
386 )
387 .await
388 .unwrap();
389
390 assert_eq!(store.count("test").await.unwrap(), 1);
391
392 let record = store.get("test", "id1").await.unwrap().unwrap();
393 assert_eq!(record.vector, vec![0.0, 1.0]);
394 assert_eq!(
395 record.metadata.get("version"),
396 Some(&Value::Number(2.into()))
397 );
398 }
399}