Skip to main content

mem_vec/
memory_vec.rs

1//! In-memory vector store (brute-force KNN).
2
3use mem_types::{VecSearchHit, VecStore, VecStoreError, VecStoreItem};
4use std::collections::HashMap;
5use std::sync::Arc;
6use tokio::sync::RwLock;
7
8fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
9    if a.len() != b.len() || a.is_empty() {
10        return 0.0;
11    }
12    let dot: f64 = a
13        .iter()
14        .zip(b.iter())
15        .map(|(x, y)| (*x as f64) * (*y as f64))
16        .sum();
17    let na: f64 = a.iter().map(|x| (*x as f64).powi(2)).sum::<f64>().sqrt();
18    let nb: f64 = b.iter().map(|x| (*x as f64).powi(2)).sum::<f64>().sqrt();
19    if na == 0.0 || nb == 0.0 {
20        return 0.0;
21    }
22    dot / (na * nb)
23}
24
25/// In-memory VecStore: stores items in a map, search by brute-force cosine similarity.
26pub struct InMemoryVecStore {
27    /// collection name -> id -> item
28    store: Arc<RwLock<HashMap<String, HashMap<String, VecStoreItem>>>>,
29    default_collection: String,
30}
31
32impl InMemoryVecStore {
33    pub fn new(default_collection: Option<&str>) -> Self {
34        Self {
35            store: Arc::new(RwLock::new(HashMap::new())),
36            default_collection: default_collection.unwrap_or("memos_memories").to_string(),
37        }
38    }
39
40    fn coll(&self, collection: Option<&str>) -> String {
41        collection.unwrap_or(&self.default_collection).to_string()
42    }
43}
44
45#[async_trait::async_trait]
46impl VecStore for InMemoryVecStore {
47    async fn add(
48        &self,
49        items: &[VecStoreItem],
50        collection: Option<&str>,
51    ) -> Result<(), VecStoreError> {
52        let coll = self.coll(collection);
53        let mut guard = self.store.write().await;
54        let map = guard.entry(coll).or_default();
55        for item in items {
56            map.insert(item.id.clone(), item.clone());
57        }
58        Ok(())
59    }
60
61    async fn search(
62        &self,
63        query_vector: &[f32],
64        top_k: usize,
65        filter: Option<&HashMap<String, serde_json::Value>>,
66        collection: Option<&str>,
67    ) -> Result<Vec<VecSearchHit>, VecStoreError> {
68        let coll = self.coll(collection);
69        let guard = self.store.read().await;
70        let map = guard
71            .get(&coll)
72            .map(|m| m.values().cloned().collect::<Vec<_>>());
73        let items = map.unwrap_or_default();
74        let mut candidates: Vec<(VecStoreItem, f64)> = items
75            .into_iter()
76            .filter(|i| {
77                if let Some(f) = filter {
78                    for (k, v) in f.iter() {
79                        if let Some(pv) = i.payload.get(k) {
80                            if pv != v {
81                                return false;
82                            }
83                        } else {
84                            return false;
85                        }
86                    }
87                }
88                true
89            })
90            .map(|i| {
91                let score = cosine_similarity(query_vector, &i.vector);
92                (i, score)
93            })
94            .collect();
95        candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
96        let hits = candidates
97            .into_iter()
98            .take(top_k)
99            .map(|(i, score)| VecSearchHit { id: i.id, score })
100            .collect();
101        Ok(hits)
102    }
103
104    async fn get_by_ids(
105        &self,
106        ids: &[String],
107        collection: Option<&str>,
108    ) -> Result<Vec<VecStoreItem>, VecStoreError> {
109        let coll = self.coll(collection);
110        let guard = self.store.read().await;
111        let map = guard.get(&coll);
112        let mut out = Vec::new();
113        if let Some(m) = map {
114            for id in ids {
115                if let Some(item) = m.get(id) {
116                    out.push(item.clone());
117                }
118            }
119        }
120        Ok(out)
121    }
122
123    async fn delete(&self, ids: &[String], collection: Option<&str>) -> Result<(), VecStoreError> {
124        let coll = self.coll(collection);
125        let mut guard = self.store.write().await;
126        if let Some(m) = guard.get_mut(&coll) {
127            for id in ids {
128                m.remove(id);
129            }
130        }
131        Ok(())
132    }
133
134    async fn upsert(
135        &self,
136        items: &[VecStoreItem],
137        collection: Option<&str>,
138    ) -> Result<(), VecStoreError> {
139        let coll = self.coll(collection);
140        let mut guard = self.store.write().await;
141        let map = guard.entry(coll).or_default();
142        for item in items {
143            map.insert(item.id.clone(), item.clone());
144        }
145        Ok(())
146    }
147}