1use std::collections::HashMap;
2use std::sync::RwLock;
3
4use async_trait::async_trait;
5use mem7_core::MemoryFilter;
6use mem7_error::{Mem7Error, Result};
7use uuid::Uuid;
8
9use crate::distance::DistanceMetric;
10use crate::filter::matches_filter;
11use crate::{VectorIndex, VectorSearchResult};
12
13struct VectorEntry {
14 vector: Vec<f32>,
15 payload: serde_json::Value,
16}
17
18pub struct FlatIndex {
21 entries: RwLock<HashMap<Uuid, VectorEntry>>,
22 metric: DistanceMetric,
23}
24
25impl FlatIndex {
26 pub fn new(metric: DistanceMetric) -> Self {
27 Self {
28 entries: RwLock::new(HashMap::new()),
29 metric,
30 }
31 }
32}
33
34#[async_trait]
35impl VectorIndex for FlatIndex {
36 async fn insert(&self, id: Uuid, vector: &[f32], payload: serde_json::Value) -> Result<()> {
37 let mut entries = self
38 .entries
39 .write()
40 .map_err(|e| Mem7Error::VectorStore(e.to_string()))?;
41 entries.insert(
42 id,
43 VectorEntry {
44 vector: vector.to_vec(),
45 payload,
46 },
47 );
48 Ok(())
49 }
50
51 async fn search(
52 &self,
53 query: &[f32],
54 limit: usize,
55 filters: Option<&MemoryFilter>,
56 ) -> Result<Vec<VectorSearchResult>> {
57 let entries = self
58 .entries
59 .read()
60 .map_err(|e| Mem7Error::VectorStore(e.to_string()))?;
61
62 let mut scored: Vec<VectorSearchResult> = entries
63 .iter()
64 .filter(|(_, entry)| {
65 filters
66 .map(|f| matches_filter(&entry.payload, f))
67 .unwrap_or(true)
68 })
69 .map(|(id, entry)| VectorSearchResult {
70 id: *id,
71 score: self.metric.similarity(query, &entry.vector),
72 payload: entry.payload.clone(),
73 })
74 .collect();
75
76 scored.sort_by(|a, b| {
77 b.score
78 .partial_cmp(&a.score)
79 .unwrap_or(std::cmp::Ordering::Equal)
80 });
81 scored.truncate(limit);
82 Ok(scored)
83 }
84
85 async fn delete(&self, id: &Uuid) -> Result<()> {
86 let mut entries = self
87 .entries
88 .write()
89 .map_err(|e| Mem7Error::VectorStore(e.to_string()))?;
90 entries.remove(id);
91 Ok(())
92 }
93
94 async fn update(
95 &self,
96 id: &Uuid,
97 vector: Option<&[f32]>,
98 payload: Option<serde_json::Value>,
99 ) -> Result<()> {
100 let mut entries = self
101 .entries
102 .write()
103 .map_err(|e| Mem7Error::VectorStore(e.to_string()))?;
104
105 if let Some(entry) = entries.get_mut(id) {
106 if let Some(v) = vector {
107 entry.vector = v.to_vec();
108 }
109 if let Some(p) = payload {
110 entry.payload = p;
111 }
112 Ok(())
113 } else {
114 Err(Mem7Error::NotFound(format!("vector entry {id}")))
115 }
116 }
117
118 async fn get(&self, id: &Uuid) -> Result<Option<(Vec<f32>, serde_json::Value)>> {
119 let entries = self
120 .entries
121 .read()
122 .map_err(|e| Mem7Error::VectorStore(e.to_string()))?;
123 Ok(entries
124 .get(id)
125 .map(|e| (e.vector.clone(), e.payload.clone())))
126 }
127
128 async fn list(
129 &self,
130 filters: Option<&MemoryFilter>,
131 limit: Option<usize>,
132 ) -> Result<Vec<(Uuid, serde_json::Value)>> {
133 let entries = self
134 .entries
135 .read()
136 .map_err(|e| Mem7Error::VectorStore(e.to_string()))?;
137
138 let mut results: Vec<(Uuid, serde_json::Value)> = entries
139 .iter()
140 .filter(|(_, entry)| {
141 filters
142 .map(|f| matches_filter(&entry.payload, f))
143 .unwrap_or(true)
144 })
145 .map(|(id, entry)| (*id, entry.payload.clone()))
146 .collect();
147
148 results.sort_by(|a, b| a.0.cmp(&b.0));
149
150 if let Some(limit) = limit {
151 results.truncate(limit);
152 }
153
154 Ok(results)
155 }
156
157 async fn reset(&self) -> Result<()> {
158 let mut entries = self
159 .entries
160 .write()
161 .map_err(|e| Mem7Error::VectorStore(e.to_string()))?;
162 entries.clear();
163 Ok(())
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170
171 #[tokio::test]
172 async fn insert_and_search() {
173 let index = FlatIndex::new(DistanceMetric::Cosine);
174 let id1 = Uuid::now_v7();
175 let id2 = Uuid::now_v7();
176
177 index
178 .insert(
179 id1,
180 &[1.0, 0.0, 0.0],
181 serde_json::json!({"user_id": "alice"}),
182 )
183 .await
184 .unwrap();
185 index
186 .insert(id2, &[0.0, 1.0, 0.0], serde_json::json!({"user_id": "bob"}))
187 .await
188 .unwrap();
189
190 let results = index.search(&[1.0, 0.0, 0.0], 1, None).await.unwrap();
191 assert_eq!(results.len(), 1);
192 assert_eq!(results[0].id, id1);
193
194 let filter = MemoryFilter {
195 user_id: Some("bob".into()),
196 ..Default::default()
197 };
198 let results = index
199 .search(&[1.0, 0.0, 0.0], 10, Some(&filter))
200 .await
201 .unwrap();
202 assert_eq!(results.len(), 1);
203 assert_eq!(results[0].id, id2);
204 }
205}