mem0_rust/vector_stores/
memory.rs

1//! In-memory vector store for testing and development.
2
3use async_trait::async_trait;
4use std::collections::HashMap;
5use std::sync::RwLock;
6
7use super::traits::{VectorSearchResult, VectorStore};
8use crate::errors::VectorStoreError;
9use crate::models::{FilterLogic, FilterOperator, Filters, Payload};
10
11/// In-memory vector store entry
12struct Entry {
13    embedding: Vec<f32>,
14    payload: Payload,
15}
16
17/// In-memory vector store
18pub struct InMemoryStore {
19    entries: RwLock<HashMap<String, Entry>>,
20}
21
22impl InMemoryStore {
23    /// Create a new in-memory store
24    pub fn new() -> Self {
25        Self {
26            entries: RwLock::new(HashMap::new()),
27        }
28    }
29
30    /// Compute cosine similarity between two vectors
31    fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
32        if a.len() != b.len() || a.is_empty() {
33            return 0.0;
34        }
35
36        let mut dot = 0.0f32;
37        let mut norm_a = 0.0f32;
38        let mut norm_b = 0.0f32;
39
40        for (va, vb) in a.iter().zip(b.iter()) {
41            dot += va * vb;
42            norm_a += va * va;
43            norm_b += vb * vb;
44        }
45
46        if norm_a == 0.0 || norm_b == 0.0 {
47            return 0.0;
48        }
49
50        dot / (norm_a.sqrt() * norm_b.sqrt())
51    }
52
53    /// Check if a payload matches the given filters
54    fn matches_filters(payload: &Payload, filters: Option<&Filters>) -> bool {
55        let Some(filters) = filters else {
56            return true;
57        };
58
59        if filters.conditions.is_empty() {
60            return true;
61        }
62
63        let results: Vec<bool> = filters
64            .conditions
65            .iter()
66            .map(|cond| {
67                let value = payload.metadata.get(&cond.field);
68                Self::evaluate_condition(value, &cond.operator, &cond.value)
69            })
70            .collect();
71
72        match filters.logic {
73            FilterLogic::And => results.iter().all(|&r| r),
74            FilterLogic::Or => results.iter().any(|&r| r),
75        }
76    }
77
78    /// Evaluate a single filter condition
79    fn evaluate_condition(
80        field_value: Option<&serde_json::Value>,
81        operator: &FilterOperator,
82        filter_value: &serde_json::Value,
83    ) -> bool {
84        match operator {
85            FilterOperator::Eq => field_value == Some(filter_value),
86            FilterOperator::Ne => field_value != Some(filter_value),
87            FilterOperator::Gt => Self::compare_values(field_value, filter_value, |a, b| a > b),
88            FilterOperator::Gte => Self::compare_values(field_value, filter_value, |a, b| a >= b),
89            FilterOperator::Lt => Self::compare_values(field_value, filter_value, |a, b| a < b),
90            FilterOperator::Lte => Self::compare_values(field_value, filter_value, |a, b| a <= b),
91            FilterOperator::In => {
92                if let Some(arr) = filter_value.as_array() {
93                    field_value.map(|v| arr.contains(v)).unwrap_or(false)
94                } else {
95                    false
96                }
97            }
98            FilterOperator::Nin => {
99                if let Some(arr) = filter_value.as_array() {
100                    field_value.map(|v| !arr.contains(v)).unwrap_or(true)
101                } else {
102                    true
103                }
104            }
105            FilterOperator::Contains => {
106                if let (Some(field_str), Some(filter_str)) =
107                    (field_value.and_then(|v| v.as_str()), filter_value.as_str())
108                {
109                    field_str.contains(filter_str)
110                } else {
111                    false
112                }
113            }
114            FilterOperator::IContains => {
115                if let (Some(field_str), Some(filter_str)) =
116                    (field_value.and_then(|v| v.as_str()), filter_value.as_str())
117                {
118                    field_str.to_lowercase().contains(&filter_str.to_lowercase())
119                } else {
120                    false
121                }
122            }
123        }
124    }
125
126    /// Compare numeric values
127    fn compare_values<F>(
128        field_value: Option<&serde_json::Value>,
129        filter_value: &serde_json::Value,
130        cmp: F,
131    ) -> bool
132    where
133        F: Fn(f64, f64) -> bool,
134    {
135        let field_num = field_value.and_then(|v| v.as_f64());
136        let filter_num = filter_value.as_f64();
137
138        match (field_num, filter_num) {
139            (Some(a), Some(b)) => cmp(a, b),
140            _ => false,
141        }
142    }
143
144
145}
146
147impl Default for InMemoryStore {
148    fn default() -> Self {
149        Self::new()
150    }
151}
152
153#[async_trait]
154impl VectorStore for InMemoryStore {
155    async fn insert(
156        &self,
157        id: &str,
158        embedding: Vec<f32>,
159        payload: Payload,
160    ) -> Result<(), VectorStoreError> {
161        let mut entries = self
162            .entries
163            .write()
164            .map_err(|e| VectorStoreError::Insert(e.to_string()))?;
165
166        entries.insert(id.to_string(), Entry { embedding, payload });
167        Ok(())
168    }
169
170    async fn search(
171        &self,
172        embedding: &[f32],
173        limit: usize,
174        filters: Option<&Filters>,
175    ) -> Result<Vec<VectorSearchResult>, VectorStoreError> {
176        let entries = self
177            .entries
178            .read()
179            .map_err(|e| VectorStoreError::Search(e.to_string()))?;
180
181        let mut results: Vec<VectorSearchResult> = entries
182            .iter()
183            .filter(|(_, entry)| Self::matches_filters(&entry.payload, filters))
184            .map(|(id, entry)| VectorSearchResult {
185                id: id.clone(),
186                score: Self::cosine_similarity(embedding, &entry.embedding),
187                payload: entry.payload.clone(),
188            })
189            .collect();
190
191        // Sort by score descending
192        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
193        results.truncate(limit);
194
195        Ok(results)
196    }
197
198    async fn get(&self, id: &str) -> Result<Option<VectorSearchResult>, VectorStoreError> {
199        let entries = self
200            .entries
201            .read()
202            .map_err(|e| VectorStoreError::Search(e.to_string()))?;
203
204        Ok(entries.get(id).map(|entry| VectorSearchResult {
205            id: id.to_string(),
206            score: 1.0,
207            payload: entry.payload.clone(),
208        }))
209    }
210
211    async fn delete(&self, id: &str) -> Result<(), VectorStoreError> {
212        let mut entries = self
213            .entries
214            .write()
215            .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
216
217        entries
218            .remove(id)
219            .ok_or_else(|| VectorStoreError::NotFound(id.to_string()))?;
220
221        Ok(())
222    }
223
224    async fn update(
225        &self,
226        id: &str,
227        embedding: Option<Vec<f32>>,
228        payload: Payload,
229    ) -> Result<(), VectorStoreError> {
230        let mut entries = self
231            .entries
232            .write()
233            .map_err(|e| VectorStoreError::Update(e.to_string()))?;
234
235        let entry = entries
236            .get_mut(id)
237            .ok_or_else(|| VectorStoreError::NotFound(id.to_string()))?;
238
239        if let Some(emb) = embedding {
240            entry.embedding = emb;
241        }
242        entry.payload = payload;
243
244        Ok(())
245    }
246
247    async fn list(
248        &self,
249        filters: Option<&Filters>,
250        limit: usize,
251    ) -> Result<Vec<VectorSearchResult>, VectorStoreError> {
252        let entries = self
253            .entries
254            .read()
255            .map_err(|e| VectorStoreError::Search(e.to_string()))?;
256
257        let mut results: Vec<VectorSearchResult> = entries
258            .iter()
259            .filter(|(_, entry)| Self::matches_filters(&entry.payload, filters))
260            .map(|(id, entry)| VectorSearchResult {
261                id: id.clone(),
262                score: 1.0,
263                payload: entry.payload.clone(),
264            })
265            .collect();
266
267        results.truncate(limit);
268        Ok(results)
269    }
270
271    async fn delete_all(&self, filters: Option<&Filters>) -> Result<usize, VectorStoreError> {
272        let mut entries = self
273            .entries
274            .write()
275            .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
276
277        let to_delete: Vec<String> = entries
278            .iter()
279            .filter(|(_, entry)| Self::matches_filters(&entry.payload, filters))
280            .map(|(id, _)| id.clone())
281            .collect();
282
283        let count = to_delete.len();
284        for id in to_delete {
285            entries.remove(&id);
286        }
287
288        Ok(count)
289    }
290
291    async fn collection_exists(&self) -> Result<bool, VectorStoreError> {
292        Ok(true) // In-memory store always "exists"
293    }
294
295    async fn create_collection(&self) -> Result<(), VectorStoreError> {
296        Ok(()) // No-op for in-memory store
297    }
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303    use chrono::Utc;
304    use std::collections::HashMap;
305
306    fn create_test_payload(data: &str) -> Payload {
307        Payload {
308            data: data.to_string(),
309            hash: "test_hash".to_string(),
310            created_at: Utc::now(),
311            user_id: None,
312            agent_id: None,
313            run_id: None,
314            metadata: HashMap::new(),
315        }
316    }
317
318    #[tokio::test]
319    async fn test_insert_and_get() {
320        let store = InMemoryStore::new();
321        let payload = create_test_payload("test content");
322        let embedding = vec![0.1, 0.2, 0.3];
323
324        store.insert("test-id", embedding, payload).await.unwrap();
325
326        let result = store.get("test-id").await.unwrap();
327        assert!(result.is_some());
328        assert_eq!(result.unwrap().payload.data, "test content");
329    }
330
331    #[tokio::test]
332    async fn test_search() {
333        let store = InMemoryStore::new();
334
335        store
336            .insert("id1", vec![1.0, 0.0, 0.0], create_test_payload("doc1"))
337            .await
338            .unwrap();
339        store
340            .insert("id2", vec![0.0, 1.0, 0.0], create_test_payload("doc2"))
341            .await
342            .unwrap();
343
344        let results = store.search(&[1.0, 0.0, 0.0], 10, None).await.unwrap();
345        assert_eq!(results.len(), 2);
346        assert_eq!(results[0].id, "id1"); // Most similar
347    }
348
349    #[tokio::test]
350    async fn test_delete() {
351        let store = InMemoryStore::new();
352        store
353            .insert("id1", vec![1.0], create_test_payload("doc1"))
354            .await
355            .unwrap();
356
357        store.delete("id1").await.unwrap();
358        assert!(store.get("id1").await.unwrap().is_none());
359    }
360}