1use async_trait::async_trait;
4use std::collections::HashMap;
5use std::sync::RwLock;
6
7use super::traits::{VectorSearchResult, VectorStore};
8use crate::errors::VectorStoreError;
9use crate::models::{FilterLogic, FilterOperator, Filters, Payload};
10
11struct Entry {
13 embedding: Vec<f32>,
14 payload: Payload,
15}
16
17pub struct InMemoryStore {
19 entries: RwLock<HashMap<String, Entry>>,
20}
21
22impl InMemoryStore {
23 pub fn new() -> Self {
25 Self {
26 entries: RwLock::new(HashMap::new()),
27 }
28 }
29
30 fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
32 if a.len() != b.len() || a.is_empty() {
33 return 0.0;
34 }
35
36 let mut dot = 0.0f32;
37 let mut norm_a = 0.0f32;
38 let mut norm_b = 0.0f32;
39
40 for (va, vb) in a.iter().zip(b.iter()) {
41 dot += va * vb;
42 norm_a += va * va;
43 norm_b += vb * vb;
44 }
45
46 if norm_a == 0.0 || norm_b == 0.0 {
47 return 0.0;
48 }
49
50 dot / (norm_a.sqrt() * norm_b.sqrt())
51 }
52
53 fn matches_filters(payload: &Payload, filters: Option<&Filters>) -> bool {
55 let Some(filters) = filters else {
56 return true;
57 };
58
59 if filters.conditions.is_empty() {
60 return true;
61 }
62
63 let results: Vec<bool> = filters
64 .conditions
65 .iter()
66 .map(|cond| {
67 let value = payload.metadata.get(&cond.field);
68 Self::evaluate_condition(value, &cond.operator, &cond.value)
69 })
70 .collect();
71
72 match filters.logic {
73 FilterLogic::And => results.iter().all(|&r| r),
74 FilterLogic::Or => results.iter().any(|&r| r),
75 }
76 }
77
78 fn evaluate_condition(
80 field_value: Option<&serde_json::Value>,
81 operator: &FilterOperator,
82 filter_value: &serde_json::Value,
83 ) -> bool {
84 match operator {
85 FilterOperator::Eq => field_value == Some(filter_value),
86 FilterOperator::Ne => field_value != Some(filter_value),
87 FilterOperator::Gt => Self::compare_values(field_value, filter_value, |a, b| a > b),
88 FilterOperator::Gte => Self::compare_values(field_value, filter_value, |a, b| a >= b),
89 FilterOperator::Lt => Self::compare_values(field_value, filter_value, |a, b| a < b),
90 FilterOperator::Lte => Self::compare_values(field_value, filter_value, |a, b| a <= b),
91 FilterOperator::In => {
92 if let Some(arr) = filter_value.as_array() {
93 field_value.map(|v| arr.contains(v)).unwrap_or(false)
94 } else {
95 false
96 }
97 }
98 FilterOperator::Nin => {
99 if let Some(arr) = filter_value.as_array() {
100 field_value.map(|v| !arr.contains(v)).unwrap_or(true)
101 } else {
102 true
103 }
104 }
105 FilterOperator::Contains => {
106 if let (Some(field_str), Some(filter_str)) =
107 (field_value.and_then(|v| v.as_str()), filter_value.as_str())
108 {
109 field_str.contains(filter_str)
110 } else {
111 false
112 }
113 }
114 FilterOperator::IContains => {
115 if let (Some(field_str), Some(filter_str)) =
116 (field_value.and_then(|v| v.as_str()), filter_value.as_str())
117 {
118 field_str.to_lowercase().contains(&filter_str.to_lowercase())
119 } else {
120 false
121 }
122 }
123 }
124 }
125
126 fn compare_values<F>(
128 field_value: Option<&serde_json::Value>,
129 filter_value: &serde_json::Value,
130 cmp: F,
131 ) -> bool
132 where
133 F: Fn(f64, f64) -> bool,
134 {
135 let field_num = field_value.and_then(|v| v.as_f64());
136 let filter_num = filter_value.as_f64();
137
138 match (field_num, filter_num) {
139 (Some(a), Some(b)) => cmp(a, b),
140 _ => false,
141 }
142 }
143
144
145}
146
147impl Default for InMemoryStore {
148 fn default() -> Self {
149 Self::new()
150 }
151}
152
153#[async_trait]
154impl VectorStore for InMemoryStore {
155 async fn insert(
156 &self,
157 id: &str,
158 embedding: Vec<f32>,
159 payload: Payload,
160 ) -> Result<(), VectorStoreError> {
161 let mut entries = self
162 .entries
163 .write()
164 .map_err(|e| VectorStoreError::Insert(e.to_string()))?;
165
166 entries.insert(id.to_string(), Entry { embedding, payload });
167 Ok(())
168 }
169
170 async fn search(
171 &self,
172 embedding: &[f32],
173 limit: usize,
174 filters: Option<&Filters>,
175 ) -> Result<Vec<VectorSearchResult>, VectorStoreError> {
176 let entries = self
177 .entries
178 .read()
179 .map_err(|e| VectorStoreError::Search(e.to_string()))?;
180
181 let mut results: Vec<VectorSearchResult> = entries
182 .iter()
183 .filter(|(_, entry)| Self::matches_filters(&entry.payload, filters))
184 .map(|(id, entry)| VectorSearchResult {
185 id: id.clone(),
186 score: Self::cosine_similarity(embedding, &entry.embedding),
187 payload: entry.payload.clone(),
188 })
189 .collect();
190
191 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
193 results.truncate(limit);
194
195 Ok(results)
196 }
197
198 async fn get(&self, id: &str) -> Result<Option<VectorSearchResult>, VectorStoreError> {
199 let entries = self
200 .entries
201 .read()
202 .map_err(|e| VectorStoreError::Search(e.to_string()))?;
203
204 Ok(entries.get(id).map(|entry| VectorSearchResult {
205 id: id.to_string(),
206 score: 1.0,
207 payload: entry.payload.clone(),
208 }))
209 }
210
211 async fn delete(&self, id: &str) -> Result<(), VectorStoreError> {
212 let mut entries = self
213 .entries
214 .write()
215 .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
216
217 entries
218 .remove(id)
219 .ok_or_else(|| VectorStoreError::NotFound(id.to_string()))?;
220
221 Ok(())
222 }
223
224 async fn update(
225 &self,
226 id: &str,
227 embedding: Option<Vec<f32>>,
228 payload: Payload,
229 ) -> Result<(), VectorStoreError> {
230 let mut entries = self
231 .entries
232 .write()
233 .map_err(|e| VectorStoreError::Update(e.to_string()))?;
234
235 let entry = entries
236 .get_mut(id)
237 .ok_or_else(|| VectorStoreError::NotFound(id.to_string()))?;
238
239 if let Some(emb) = embedding {
240 entry.embedding = emb;
241 }
242 entry.payload = payload;
243
244 Ok(())
245 }
246
247 async fn list(
248 &self,
249 filters: Option<&Filters>,
250 limit: usize,
251 ) -> Result<Vec<VectorSearchResult>, VectorStoreError> {
252 let entries = self
253 .entries
254 .read()
255 .map_err(|e| VectorStoreError::Search(e.to_string()))?;
256
257 let mut results: Vec<VectorSearchResult> = entries
258 .iter()
259 .filter(|(_, entry)| Self::matches_filters(&entry.payload, filters))
260 .map(|(id, entry)| VectorSearchResult {
261 id: id.clone(),
262 score: 1.0,
263 payload: entry.payload.clone(),
264 })
265 .collect();
266
267 results.truncate(limit);
268 Ok(results)
269 }
270
271 async fn delete_all(&self, filters: Option<&Filters>) -> Result<usize, VectorStoreError> {
272 let mut entries = self
273 .entries
274 .write()
275 .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
276
277 let to_delete: Vec<String> = entries
278 .iter()
279 .filter(|(_, entry)| Self::matches_filters(&entry.payload, filters))
280 .map(|(id, _)| id.clone())
281 .collect();
282
283 let count = to_delete.len();
284 for id in to_delete {
285 entries.remove(&id);
286 }
287
288 Ok(count)
289 }
290
291 async fn collection_exists(&self) -> Result<bool, VectorStoreError> {
292 Ok(true) }
294
295 async fn create_collection(&self) -> Result<(), VectorStoreError> {
296 Ok(()) }
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303 use chrono::Utc;
304 use std::collections::HashMap;
305
306 fn create_test_payload(data: &str) -> Payload {
307 Payload {
308 data: data.to_string(),
309 hash: "test_hash".to_string(),
310 created_at: Utc::now(),
311 user_id: None,
312 agent_id: None,
313 run_id: None,
314 metadata: HashMap::new(),
315 }
316 }
317
318 #[tokio::test]
319 async fn test_insert_and_get() {
320 let store = InMemoryStore::new();
321 let payload = create_test_payload("test content");
322 let embedding = vec![0.1, 0.2, 0.3];
323
324 store.insert("test-id", embedding, payload).await.unwrap();
325
326 let result = store.get("test-id").await.unwrap();
327 assert!(result.is_some());
328 assert_eq!(result.unwrap().payload.data, "test content");
329 }
330
331 #[tokio::test]
332 async fn test_search() {
333 let store = InMemoryStore::new();
334
335 store
336 .insert("id1", vec![1.0, 0.0, 0.0], create_test_payload("doc1"))
337 .await
338 .unwrap();
339 store
340 .insert("id2", vec![0.0, 1.0, 0.0], create_test_payload("doc2"))
341 .await
342 .unwrap();
343
344 let results = store.search(&[1.0, 0.0, 0.0], 10, None).await.unwrap();
345 assert_eq!(results.len(), 2);
346 assert_eq!(results[0].id, "id1"); }
348
349 #[tokio::test]
350 async fn test_delete() {
351 let store = InMemoryStore::new();
352 store
353 .insert("id1", vec![1.0], create_test_payload("doc1"))
354 .await
355 .unwrap();
356
357 store.delete("id1").await.unwrap();
358 assert!(store.get("id1").await.unwrap().is_none());
359 }
360}