Skip to main content

mem7_vector/
index.rs

1use std::collections::HashMap;
2use std::sync::RwLock;
3
4use async_trait::async_trait;
5use mem7_core::MemoryFilter;
6use mem7_error::{Mem7Error, Result};
7use uuid::Uuid;
8
9use crate::distance::DistanceMetric;
10use crate::filter::matches_filter;
11use crate::{VectorIndex, VectorSearchResult};
12
13struct VectorEntry {
14    vector: Vec<f32>,
15    payload: serde_json::Value,
16}
17
18/// A brute-force flat vector index. Suitable for small-to-medium datasets.
19/// Can be replaced with HNSW later without changing the public API.
20pub struct FlatIndex {
21    entries: RwLock<HashMap<Uuid, VectorEntry>>,
22    metric: DistanceMetric,
23}
24
25impl FlatIndex {
26    pub fn new(metric: DistanceMetric) -> Self {
27        Self {
28            entries: RwLock::new(HashMap::new()),
29            metric,
30        }
31    }
32}
33
34#[async_trait]
35impl VectorIndex for FlatIndex {
36    async fn insert(&self, id: Uuid, vector: &[f32], payload: serde_json::Value) -> Result<()> {
37        let mut entries = self
38            .entries
39            .write()
40            .map_err(|e| Mem7Error::VectorStore(e.to_string()))?;
41        entries.insert(
42            id,
43            VectorEntry {
44                vector: vector.to_vec(),
45                payload,
46            },
47        );
48        Ok(())
49    }
50
51    async fn search(
52        &self,
53        query: &[f32],
54        limit: usize,
55        filters: Option<&MemoryFilter>,
56    ) -> Result<Vec<VectorSearchResult>> {
57        let entries = self
58            .entries
59            .read()
60            .map_err(|e| Mem7Error::VectorStore(e.to_string()))?;
61
62        let mut scored: Vec<VectorSearchResult> = entries
63            .iter()
64            .filter(|(_, entry)| {
65                filters
66                    .map(|f| matches_filter(&entry.payload, f))
67                    .unwrap_or(true)
68            })
69            .map(|(id, entry)| VectorSearchResult {
70                id: *id,
71                score: self.metric.similarity(query, &entry.vector),
72                payload: entry.payload.clone(),
73            })
74            .collect();
75
76        scored.sort_by(|a, b| {
77            b.score
78                .partial_cmp(&a.score)
79                .unwrap_or(std::cmp::Ordering::Equal)
80        });
81        scored.truncate(limit);
82        Ok(scored)
83    }
84
85    async fn delete(&self, id: &Uuid) -> Result<()> {
86        let mut entries = self
87            .entries
88            .write()
89            .map_err(|e| Mem7Error::VectorStore(e.to_string()))?;
90        entries.remove(id);
91        Ok(())
92    }
93
94    async fn update(
95        &self,
96        id: &Uuid,
97        vector: Option<&[f32]>,
98        payload: Option<serde_json::Value>,
99    ) -> Result<()> {
100        let mut entries = self
101            .entries
102            .write()
103            .map_err(|e| Mem7Error::VectorStore(e.to_string()))?;
104
105        if let Some(entry) = entries.get_mut(id) {
106            if let Some(v) = vector {
107                entry.vector = v.to_vec();
108            }
109            if let Some(p) = payload {
110                entry.payload = p;
111            }
112            Ok(())
113        } else {
114            Err(Mem7Error::NotFound(format!("vector entry {id}")))
115        }
116    }
117
118    async fn get(&self, id: &Uuid) -> Result<Option<(Vec<f32>, serde_json::Value)>> {
119        let entries = self
120            .entries
121            .read()
122            .map_err(|e| Mem7Error::VectorStore(e.to_string()))?;
123        Ok(entries
124            .get(id)
125            .map(|e| (e.vector.clone(), e.payload.clone())))
126    }
127
128    async fn list(
129        &self,
130        filters: Option<&MemoryFilter>,
131        limit: Option<usize>,
132    ) -> Result<Vec<(Uuid, serde_json::Value)>> {
133        let entries = self
134            .entries
135            .read()
136            .map_err(|e| Mem7Error::VectorStore(e.to_string()))?;
137
138        let mut results: Vec<(Uuid, serde_json::Value)> = entries
139            .iter()
140            .filter(|(_, entry)| {
141                filters
142                    .map(|f| matches_filter(&entry.payload, f))
143                    .unwrap_or(true)
144            })
145            .map(|(id, entry)| (*id, entry.payload.clone()))
146            .collect();
147
148        results.sort_by(|a, b| a.0.cmp(&b.0));
149
150        if let Some(limit) = limit {
151            results.truncate(limit);
152        }
153
154        Ok(results)
155    }
156
157    async fn reset(&self) -> Result<()> {
158        let mut entries = self
159            .entries
160            .write()
161            .map_err(|e| Mem7Error::VectorStore(e.to_string()))?;
162        entries.clear();
163        Ok(())
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170
171    #[tokio::test]
172    async fn insert_and_search() {
173        let index = FlatIndex::new(DistanceMetric::Cosine);
174        let id1 = Uuid::now_v7();
175        let id2 = Uuid::now_v7();
176
177        index
178            .insert(
179                id1,
180                &[1.0, 0.0, 0.0],
181                serde_json::json!({"user_id": "alice"}),
182            )
183            .await
184            .unwrap();
185        index
186            .insert(id2, &[0.0, 1.0, 0.0], serde_json::json!({"user_id": "bob"}))
187            .await
188            .unwrap();
189
190        let results = index.search(&[1.0, 0.0, 0.0], 1, None).await.unwrap();
191        assert_eq!(results.len(), 1);
192        assert_eq!(results[0].id, id1);
193
194        let filter = MemoryFilter {
195            user_id: Some("bob".into()),
196            ..Default::default()
197        };
198        let results = index
199            .search(&[1.0, 0.0, 0.0], 10, Some(&filter))
200            .await
201            .unwrap();
202        assert_eq!(results.len(), 1);
203        assert_eq!(results[0].id, id2);
204    }
205}