hehe_store/local/
memory_vector.rs

1use crate::error::{Result, StoreError};
2use crate::traits::{
3    cosine_similarity, CollectionInfo, SearchResult, VectorFilter, VectorRecord, VectorStore,
4};
5use async_trait::async_trait;
6use parking_lot::RwLock;
7use serde_json::Value;
8use std::collections::HashMap;
9
10struct Collection {
11    dimension: usize,
12    records: HashMap<String, VectorRecord>,
13}
14
15pub struct MemoryVectorStore {
16    collections: RwLock<HashMap<String, Collection>>,
17}
18
19impl MemoryVectorStore {
20    pub fn new() -> Self {
21        Self {
22            collections: RwLock::new(HashMap::new()),
23        }
24    }
25}
26
27impl Default for MemoryVectorStore {
28    fn default() -> Self {
29        Self::new()
30    }
31}
32
33fn matches_filter(record: &VectorRecord, filter: &VectorFilter) -> bool {
34    use crate::traits::vector::FilterCondition;
35
36    for condition in &filter.conditions {
37        let matched = match condition {
38            FilterCondition::Eq(field, value) => {
39                record.metadata.get(field).map(|v| v == value).unwrap_or(false)
40            }
41            FilterCondition::Ne(field, value) => {
42                record.metadata.get(field).map(|v| v != value).unwrap_or(true)
43            }
44            FilterCondition::Gt(field, value) => match (record.metadata.get(field), value) {
45                (Some(Value::Number(a)), Value::Number(b)) => {
46                    a.as_f64().unwrap_or(0.0) > b.as_f64().unwrap_or(0.0)
47                }
48                _ => false,
49            },
50            FilterCondition::Gte(field, value) => match (record.metadata.get(field), value) {
51                (Some(Value::Number(a)), Value::Number(b)) => {
52                    a.as_f64().unwrap_or(0.0) >= b.as_f64().unwrap_or(0.0)
53                }
54                _ => false,
55            },
56            FilterCondition::Lt(field, value) => match (record.metadata.get(field), value) {
57                (Some(Value::Number(a)), Value::Number(b)) => {
58                    a.as_f64().unwrap_or(0.0) < b.as_f64().unwrap_or(0.0)
59                }
60                _ => false,
61            },
62            FilterCondition::Lte(field, value) => match (record.metadata.get(field), value) {
63                (Some(Value::Number(a)), Value::Number(b)) => {
64                    a.as_f64().unwrap_or(0.0) <= b.as_f64().unwrap_or(0.0)
65                }
66                _ => false,
67            },
68            FilterCondition::In(field, values) => record
69                .metadata
70                .get(field)
71                .map(|v| values.contains(v))
72                .unwrap_or(false),
73            FilterCondition::Contains(field, substr) => record
74                .metadata
75                .get(field)
76                .and_then(|v| v.as_str())
77                .map(|s| s.contains(substr))
78                .unwrap_or(false),
79        };
80
81        if !matched {
82            return false;
83        }
84    }
85
86    true
87}
88
89#[async_trait]
90impl VectorStore for MemoryVectorStore {
91    async fn create_collection(&self, name: &str, dimension: usize) -> Result<()> {
92        let mut collections = self.collections.write();
93        if collections.contains_key(name) {
94            return Err(StoreError::AlreadyExists(format!("Collection '{}'", name)));
95        }
96        collections.insert(
97            name.to_string(),
98            Collection {
99                dimension,
100                records: HashMap::new(),
101            },
102        );
103        Ok(())
104    }
105
106    async fn delete_collection(&self, name: &str) -> Result<()> {
107        let mut collections = self.collections.write();
108        if collections.remove(name).is_none() {
109            return Err(StoreError::not_found(format!("Collection '{}'", name)));
110        }
111        Ok(())
112    }
113
114    async fn list_collections(&self) -> Result<Vec<CollectionInfo>> {
115        let collections = self.collections.read();
116        Ok(collections
117            .iter()
118            .map(|(name, col)| CollectionInfo {
119                name: name.clone(),
120                dimension: col.dimension,
121                count: col.records.len(),
122            })
123            .collect())
124    }
125
126    async fn collection_exists(&self, name: &str) -> Result<bool> {
127        Ok(self.collections.read().contains_key(name))
128    }
129
130    async fn upsert(&self, collection: &str, records: &[VectorRecord]) -> Result<usize> {
131        let mut collections = self.collections.write();
132        let col = collections
133            .get_mut(collection)
134            .ok_or_else(|| StoreError::not_found(format!("Collection '{}'", collection)))?;
135
136        let mut count = 0;
137        for record in records {
138            if record.vector.len() != col.dimension {
139                return Err(StoreError::invalid_input(format!(
140                    "Vector dimension mismatch: expected {}, got {}",
141                    col.dimension,
142                    record.vector.len()
143                )));
144            }
145            col.records.insert(record.id.clone(), record.clone());
146            count += 1;
147        }
148
149        Ok(count)
150    }
151
152    async fn search(
153        &self,
154        collection: &str,
155        query: &[f32],
156        limit: usize,
157    ) -> Result<Vec<SearchResult>> {
158        self.search_with_filter(collection, query, &VectorFilter::default(), limit)
159            .await
160    }
161
162    async fn search_with_filter(
163        &self,
164        collection: &str,
165        query: &[f32],
166        filter: &VectorFilter,
167        limit: usize,
168    ) -> Result<Vec<SearchResult>> {
169        let collections = self.collections.read();
170        let col = collections
171            .get(collection)
172            .ok_or_else(|| StoreError::not_found(format!("Collection '{}'", collection)))?;
173
174        if query.len() != col.dimension {
175            return Err(StoreError::invalid_input(format!(
176                "Query dimension mismatch: expected {}, got {}",
177                col.dimension,
178                query.len()
179            )));
180        }
181
182        let mut scored: Vec<(String, f32, HashMap<String, Value>, Option<String>)> = col
183            .records
184            .values()
185            .filter(|r| filter.is_empty() || matches_filter(r, filter))
186            .map(|r| {
187                let score = cosine_similarity(query, &r.vector);
188                (r.id.clone(), score, r.metadata.clone(), r.content.clone())
189            })
190            .collect();
191
192        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
193
194        Ok(scored
195            .into_iter()
196            .take(limit)
197            .map(|(id, score, metadata, content)| SearchResult {
198                id,
199                score,
200                metadata,
201                content,
202            })
203            .collect())
204    }
205
206    async fn get(&self, collection: &str, id: &str) -> Result<Option<VectorRecord>> {
207        let collections = self.collections.read();
208        let col = collections
209            .get(collection)
210            .ok_or_else(|| StoreError::not_found(format!("Collection '{}'", collection)))?;
211
212        Ok(col.records.get(id).cloned())
213    }
214
215    async fn delete(&self, collection: &str, ids: &[String]) -> Result<usize> {
216        let mut collections = self.collections.write();
217        let col = collections
218            .get_mut(collection)
219            .ok_or_else(|| StoreError::not_found(format!("Collection '{}'", collection)))?;
220
221        let mut count = 0;
222        for id in ids {
223            if col.records.remove(id).is_some() {
224                count += 1;
225            }
226        }
227
228        Ok(count)
229    }
230
231    async fn count(&self, collection: &str) -> Result<usize> {
232        let collections = self.collections.read();
233        let col = collections
234            .get(collection)
235            .ok_or_else(|| StoreError::not_found(format!("Collection '{}'", collection)))?;
236
237        Ok(col.records.len())
238    }
239
240    fn backend_name(&self) -> &'static str {
241        "memory"
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248
249    #[tokio::test]
250    async fn test_collection_lifecycle() {
251        let store = MemoryVectorStore::new();
252
253        assert!(!store.collection_exists("test").await.unwrap());
254
255        store.create_collection("test", 3).await.unwrap();
256        assert!(store.collection_exists("test").await.unwrap());
257
258        let err = store.create_collection("test", 3).await;
259        assert!(err.is_err());
260
261        store.delete_collection("test").await.unwrap();
262        assert!(!store.collection_exists("test").await.unwrap());
263    }
264
265    #[tokio::test]
266    async fn test_upsert_and_search() {
267        let store = MemoryVectorStore::new();
268        store.create_collection("docs", 3).await.unwrap();
269
270        let records = vec![
271            VectorRecord::new("doc1", vec![1.0, 0.0, 0.0])
272                .with_metadata("category", "a")
273                .with_content("Document one"),
274            VectorRecord::new("doc2", vec![0.0, 1.0, 0.0])
275                .with_metadata("category", "b")
276                .with_content("Document two"),
277            VectorRecord::new("doc3", vec![0.707, 0.707, 0.0])
278                .with_metadata("category", "a")
279                .with_content("Document three"),
280        ];
281
282        let count = store.upsert("docs", &records).await.unwrap();
283        assert_eq!(count, 3);
284        assert_eq!(store.count("docs").await.unwrap(), 3);
285
286        let results = store.search("docs", &[1.0, 0.0, 0.0], 2).await.unwrap();
287        assert_eq!(results.len(), 2);
288        assert_eq!(results[0].id, "doc1");
289        assert!((results[0].score - 1.0).abs() < 0.001);
290    }
291
292    #[tokio::test]
293    async fn test_search_with_filter() {
294        let store = MemoryVectorStore::new();
295        store.create_collection("items", 2).await.unwrap();
296
297        let records = vec![
298            VectorRecord::new("item1", vec![1.0, 0.0]).with_metadata("type", "book"),
299            VectorRecord::new("item2", vec![0.9, 0.1]).with_metadata("type", "book"),
300            VectorRecord::new("item3", vec![0.8, 0.2]).with_metadata("type", "video"),
301        ];
302        store.upsert("items", &records).await.unwrap();
303
304        let filter = VectorFilter::new().eq("type", "book");
305        let results = store
306            .search_with_filter("items", &[1.0, 0.0], &filter, 10)
307            .await
308            .unwrap();
309
310        assert_eq!(results.len(), 2);
311        for r in &results {
312            assert_eq!(r.metadata.get("type"), Some(&Value::String("book".into())));
313        }
314    }
315
316    #[tokio::test]
317    async fn test_get_and_delete() {
318        let store = MemoryVectorStore::new();
319        store.create_collection("test", 2).await.unwrap();
320
321        store
322            .upsert("test", &[VectorRecord::new("id1", vec![1.0, 0.0])])
323            .await
324            .unwrap();
325
326        let record = store.get("test", "id1").await.unwrap();
327        assert!(record.is_some());
328        assert_eq!(record.unwrap().id, "id1");
329
330        let deleted = store
331            .delete("test", &["id1".to_string()])
332            .await
333            .unwrap();
334        assert_eq!(deleted, 1);
335
336        let record = store.get("test", "id1").await.unwrap();
337        assert!(record.is_none());
338    }
339
340    #[tokio::test]
341    async fn test_dimension_validation() {
342        let store = MemoryVectorStore::new();
343        store.create_collection("test", 3).await.unwrap();
344
345        let result = store
346            .upsert("test", &[VectorRecord::new("id1", vec![1.0, 0.0])])
347            .await;
348        assert!(result.is_err());
349
350        let result = store.search("test", &[1.0, 0.0], 10).await;
351        assert!(result.is_err());
352    }
353
354    #[tokio::test]
355    async fn test_list_collections() {
356        let store = MemoryVectorStore::new();
357
358        store.create_collection("col1", 10).await.unwrap();
359        store.create_collection("col2", 20).await.unwrap();
360
361        let list = store.list_collections().await.unwrap();
362        assert_eq!(list.len(), 2);
363
364        let names: Vec<_> = list.iter().map(|c| c.name.as_str()).collect();
365        assert!(names.contains(&"col1"));
366        assert!(names.contains(&"col2"));
367    }
368
369    #[tokio::test]
370    async fn test_upsert_updates_existing() {
371        let store = MemoryVectorStore::new();
372        store.create_collection("test", 2).await.unwrap();
373
374        store
375            .upsert(
376                "test",
377                &[VectorRecord::new("id1", vec![1.0, 0.0]).with_metadata("version", 1)],
378            )
379            .await
380            .unwrap();
381
382        store
383            .upsert(
384                "test",
385                &[VectorRecord::new("id1", vec![0.0, 1.0]).with_metadata("version", 2)],
386            )
387            .await
388            .unwrap();
389
390        assert_eq!(store.count("test").await.unwrap(), 1);
391
392        let record = store.get("test", "id1").await.unwrap().unwrap();
393        assert_eq!(record.vector, vec![0.0, 1.0]);
394        assert_eq!(
395            record.metadata.get("version"),
396            Some(&Value::Number(2.into()))
397        );
398    }
399}