Skip to main content

argentor_memory/
store.rs

1use argentor_core::{ArgentorError, ArgentorResult};
2use async_trait::async_trait;
3use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use tokio::sync::RwLock;
7use uuid::Uuid;
8
9/// A single entry stored in vector memory.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct MemoryEntry {
12    /// Unique identifier for this memory entry.
13    pub id: Uuid,
14    /// The text content stored in memory.
15    pub content: String,
16    /// The embedding vector representation of the content.
17    pub embedding: Vec<f32>,
18    /// Arbitrary key-value metadata associated with this entry.
19    pub metadata: HashMap<String, serde_json::Value>,
20    /// Optional session ID this entry belongs to.
21    pub session_id: Option<Uuid>,
22    /// Timestamp when this entry was created.
23    pub created_at: DateTime<Utc>,
24}
25
26/// Result of a semantic search query.
27#[derive(Debug, Clone)]
28pub struct SearchResult {
29    /// The matching memory entry.
30    pub entry: MemoryEntry,
31    /// Cosine similarity score (0.0 -- 1.0).
32    pub score: f32,
33}
34
35/// Trait for vector storage backends.
36#[async_trait]
37pub trait VectorStore: Send + Sync {
38    /// Insert a memory entry.
39    async fn insert(&self, entry: MemoryEntry) -> ArgentorResult<()>;
40
41    /// Search for the top-k most similar entries to a query embedding.
42    async fn search(
43        &self,
44        query_embedding: &[f32],
45        top_k: usize,
46        session_filter: Option<Uuid>,
47    ) -> ArgentorResult<Vec<SearchResult>>;
48
49    /// Delete a memory entry by ID.
50    async fn delete(&self, id: Uuid) -> ArgentorResult<bool>;
51
52    /// List all entries (optionally filtered by session).
53    async fn list(&self, session_filter: Option<Uuid>) -> ArgentorResult<Vec<MemoryEntry>>;
54
55    /// Count entries.
56    async fn count(&self) -> ArgentorResult<usize>;
57}
58
59/// In-memory vector store using brute-force cosine similarity.
60/// Suitable for MVP and small datasets (<100k entries).
61pub struct InMemoryVectorStore {
62    entries: RwLock<Vec<MemoryEntry>>,
63}
64
65impl InMemoryVectorStore {
66    /// Create a new empty in-memory vector store.
67    pub fn new() -> Self {
68        Self {
69            entries: RwLock::new(Vec::new()),
70        }
71    }
72}
73
74impl Default for InMemoryVectorStore {
75    fn default() -> Self {
76        Self::new()
77    }
78}
79
80#[async_trait]
81impl VectorStore for InMemoryVectorStore {
82    async fn insert(&self, entry: MemoryEntry) -> ArgentorResult<()> {
83        let mut entries = self.entries.write().await;
84        entries.push(entry);
85        Ok(())
86    }
87
88    async fn search(
89        &self,
90        query_embedding: &[f32],
91        top_k: usize,
92        session_filter: Option<Uuid>,
93    ) -> ArgentorResult<Vec<SearchResult>> {
94        if query_embedding.is_empty() {
95            return Err(ArgentorError::Agent("Empty query embedding".to_string()));
96        }
97
98        let entries = self.entries.read().await;
99
100        let mut scored: Vec<SearchResult> = entries
101            .iter()
102            .filter(|e| {
103                if let Some(sid) = session_filter {
104                    e.session_id == Some(sid)
105                } else {
106                    true
107                }
108            })
109            .map(|e| {
110                let score = cosine_similarity(query_embedding, &e.embedding);
111                SearchResult {
112                    entry: e.clone(),
113                    score,
114                }
115            })
116            .collect();
117
118        // Sort by score descending
119        scored.sort_by(|a, b| {
120            b.score
121                .partial_cmp(&a.score)
122                .unwrap_or(std::cmp::Ordering::Equal)
123        });
124        scored.truncate(top_k);
125
126        Ok(scored)
127    }
128
129    async fn delete(&self, id: Uuid) -> ArgentorResult<bool> {
130        let mut entries = self.entries.write().await;
131        let before = entries.len();
132        entries.retain(|e| e.id != id);
133        Ok(entries.len() < before)
134    }
135
136    async fn list(&self, session_filter: Option<Uuid>) -> ArgentorResult<Vec<MemoryEntry>> {
137        let entries = self.entries.read().await;
138        let filtered: Vec<MemoryEntry> = entries
139            .iter()
140            .filter(|e| {
141                if let Some(sid) = session_filter {
142                    e.session_id == Some(sid)
143                } else {
144                    true
145                }
146            })
147            .cloned()
148            .collect();
149        Ok(filtered)
150    }
151
152    async fn count(&self) -> ArgentorResult<usize> {
153        let entries = self.entries.read().await;
154        Ok(entries.len())
155    }
156}
157
158/// File-backed vector store that persists entries as JSONL on disk.
159/// Loads all entries into memory on creation; appends on insert; rewrites on delete.
160pub struct FileVectorStore {
161    path: std::path::PathBuf,
162    inner: InMemoryVectorStore,
163}
164
165impl FileVectorStore {
166    /// Create a new FileVectorStore at the given path.
167    /// If the file exists, loads all entries from it.
168    pub async fn new(path: std::path::PathBuf) -> ArgentorResult<Self> {
169        let inner = InMemoryVectorStore::new();
170
171        if path.exists() {
172            let data = tokio::fs::read_to_string(&path)
173                .await
174                .map_err(|e| ArgentorError::Session(format!("Failed to read vector store: {e}")))?;
175            for line in data.lines() {
176                if line.trim().is_empty() {
177                    continue;
178                }
179                let entry: MemoryEntry = serde_json::from_str(line)
180                    .map_err(|e| ArgentorError::Session(format!("Invalid JSONL entry: {e}")))?;
181                inner.insert(entry).await?;
182            }
183        } else if let Some(parent) = path.parent() {
184            tokio::fs::create_dir_all(parent)
185                .await
186                .map_err(|e| ArgentorError::Session(format!("Failed to create dir: {e}")))?;
187        }
188
189        Ok(Self { path, inner })
190    }
191
192    /// Append a single entry to the JSONL file.
193    async fn append_to_file(&self, entry: &MemoryEntry) -> ArgentorResult<()> {
194        use tokio::io::AsyncWriteExt;
195        let mut file = tokio::fs::OpenOptions::new()
196            .create(true)
197            .append(true)
198            .open(&self.path)
199            .await
200            .map_err(|e| ArgentorError::Session(format!("Failed to open vector store: {e}")))?;
201        let mut line = serde_json::to_string(entry)
202            .map_err(|e| ArgentorError::Session(format!("Failed to serialize entry: {e}")))?;
203        line.push('\n');
204        file.write_all(line.as_bytes())
205            .await
206            .map_err(|e| ArgentorError::Session(format!("Failed to write entry: {e}")))?;
207        Ok(())
208    }
209
210    /// Rewrite the entire file from in-memory entries.
211    async fn rewrite_file(&self) -> ArgentorResult<()> {
212        let entries = self.inner.list(None).await?;
213        let mut data = String::new();
214        for entry in &entries {
215            let line = serde_json::to_string(entry)
216                .map_err(|e| ArgentorError::Session(format!("Failed to serialize entry: {e}")))?;
217            data.push_str(&line);
218            data.push('\n');
219        }
220        tokio::fs::write(&self.path, data.as_bytes())
221            .await
222            .map_err(|e| ArgentorError::Session(format!("Failed to write vector store: {e}")))?;
223        Ok(())
224    }
225}
226
227#[async_trait]
228impl VectorStore for FileVectorStore {
229    async fn insert(&self, entry: MemoryEntry) -> ArgentorResult<()> {
230        self.append_to_file(&entry).await?;
231        self.inner.insert(entry).await
232    }
233
234    async fn search(
235        &self,
236        query_embedding: &[f32],
237        top_k: usize,
238        session_filter: Option<Uuid>,
239    ) -> ArgentorResult<Vec<SearchResult>> {
240        self.inner
241            .search(query_embedding, top_k, session_filter)
242            .await
243    }
244
245    async fn delete(&self, id: Uuid) -> ArgentorResult<bool> {
246        let deleted = self.inner.delete(id).await?;
247        if deleted {
248            self.rewrite_file().await?;
249        }
250        Ok(deleted)
251    }
252
253    async fn list(&self, session_filter: Option<Uuid>) -> ArgentorResult<Vec<MemoryEntry>> {
254        self.inner.list(session_filter).await
255    }
256
257    async fn count(&self) -> ArgentorResult<usize> {
258        self.inner.count().await
259    }
260}
261
262/// Cosine similarity between two vectors.
263fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
264    if a.len() != b.len() {
265        return 0.0;
266    }
267    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
268    let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
269    let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
270    if na == 0.0 || nb == 0.0 {
271        0.0
272    } else {
273        dot / (na * nb)
274    }
275}
276
277#[cfg(test)]
278#[allow(clippy::unwrap_used, clippy::expect_used)]
279mod tests {
280    use super::*;
281
282    fn make_entry(content: &str, embedding: Vec<f32>, session: Option<Uuid>) -> MemoryEntry {
283        MemoryEntry {
284            id: Uuid::new_v4(),
285            content: content.to_string(),
286            embedding,
287            metadata: HashMap::new(),
288            session_id: session,
289            created_at: Utc::now(),
290        }
291    }
292
293    #[tokio::test]
294    async fn test_insert_and_count() {
295        let store = InMemoryVectorStore::new();
296        assert_eq!(store.count().await.unwrap(), 0);
297
298        store
299            .insert(make_entry("hello", vec![1.0, 0.0, 0.0], None))
300            .await
301            .unwrap();
302        assert_eq!(store.count().await.unwrap(), 1);
303    }
304
305    #[tokio::test]
306    async fn test_search_returns_similar() {
307        let store = InMemoryVectorStore::new();
308
309        // Entry close to query
310        store
311            .insert(make_entry("rust lang", vec![0.9, 0.1, 0.0], None))
312            .await
313            .unwrap();
314        // Entry far from query
315        store
316            .insert(make_entry("cooking", vec![0.0, 0.0, 1.0], None))
317            .await
318            .unwrap();
319
320        let results = store.search(&[1.0, 0.0, 0.0], 2, None).await.unwrap();
321        assert_eq!(results.len(), 2);
322        assert_eq!(results[0].entry.content, "rust lang");
323        assert!(results[0].score > results[1].score);
324    }
325
326    #[tokio::test]
327    async fn test_search_top_k() {
328        let store = InMemoryVectorStore::new();
329        for i in 0..10 {
330            let mut emb = vec![0.0f32; 3];
331            emb[i % 3] = 1.0;
332            store
333                .insert(make_entry(&format!("entry_{i}"), emb, None))
334                .await
335                .unwrap();
336        }
337
338        let results = store.search(&[1.0, 0.0, 0.0], 3, None).await.unwrap();
339        assert_eq!(results.len(), 3);
340    }
341
342    #[tokio::test]
343    async fn test_search_session_filter() {
344        let store = InMemoryVectorStore::new();
345        let sid1 = Uuid::new_v4();
346        let sid2 = Uuid::new_v4();
347
348        store
349            .insert(make_entry("a", vec![1.0, 0.0], Some(sid1)))
350            .await
351            .unwrap();
352        store
353            .insert(make_entry("b", vec![0.9, 0.1], Some(sid2)))
354            .await
355            .unwrap();
356
357        let results = store.search(&[1.0, 0.0], 10, Some(sid1)).await.unwrap();
358        assert_eq!(results.len(), 1);
359        assert_eq!(results[0].entry.content, "a");
360    }
361
362    #[tokio::test]
363    async fn test_delete() {
364        let store = InMemoryVectorStore::new();
365        let entry = make_entry("to_delete", vec![1.0], None);
366        let id = entry.id;
367
368        store.insert(entry).await.unwrap();
369        assert_eq!(store.count().await.unwrap(), 1);
370
371        assert!(store.delete(id).await.unwrap());
372        assert_eq!(store.count().await.unwrap(), 0);
373
374        // Delete non-existent
375        assert!(!store.delete(Uuid::new_v4()).await.unwrap());
376    }
377
378    #[tokio::test]
379    async fn test_list_all() {
380        let store = InMemoryVectorStore::new();
381        store
382            .insert(make_entry("a", vec![1.0], None))
383            .await
384            .unwrap();
385        store
386            .insert(make_entry("b", vec![0.5], None))
387            .await
388            .unwrap();
389
390        let all = store.list(None).await.unwrap();
391        assert_eq!(all.len(), 2);
392    }
393
394    #[tokio::test]
395    async fn test_list_filtered() {
396        let store = InMemoryVectorStore::new();
397        let sid = Uuid::new_v4();
398
399        store
400            .insert(make_entry("a", vec![1.0], Some(sid)))
401            .await
402            .unwrap();
403        store
404            .insert(make_entry("b", vec![0.5], None))
405            .await
406            .unwrap();
407
408        let filtered = store.list(Some(sid)).await.unwrap();
409        assert_eq!(filtered.len(), 1);
410        assert_eq!(filtered[0].content, "a");
411    }
412
413    #[tokio::test]
414    async fn test_search_empty_query() {
415        let store = InMemoryVectorStore::new();
416        assert!(store.search(&[], 5, None).await.is_err());
417    }
418
419    #[test]
420    fn test_cosine_similarity_identical() {
421        let v = vec![1.0, 0.0, 0.0];
422        assert!((cosine_similarity(&v, &v) - 1.0).abs() < 0.001);
423    }
424
425    #[test]
426    fn test_cosine_similarity_orthogonal() {
427        let a = vec![1.0, 0.0];
428        let b = vec![0.0, 1.0];
429        assert!(cosine_similarity(&a, &b).abs() < 0.001);
430    }
431
432    #[test]
433    fn test_cosine_similarity_opposite() {
434        let a = vec![1.0, 0.0];
435        let b = vec![-1.0, 0.0];
436        assert!((cosine_similarity(&a, &b) + 1.0).abs() < 0.001);
437    }
438
439    // --- FileVectorStore tests ---
440
441    #[tokio::test]
442    async fn test_file_store_insert_and_persist() {
443        let tmp = tempfile::tempdir().unwrap();
444        let path = tmp.path().join("vectors.jsonl");
445
446        {
447            let store = FileVectorStore::new(path.clone()).await.unwrap();
448            store
449                .insert(make_entry("hello", vec![1.0, 0.0], None))
450                .await
451                .unwrap();
452            store
453                .insert(make_entry("world", vec![0.0, 1.0], None))
454                .await
455                .unwrap();
456            assert_eq!(store.count().await.unwrap(), 2);
457        }
458
459        // Reload from disk
460        let store2 = FileVectorStore::new(path).await.unwrap();
461        assert_eq!(store2.count().await.unwrap(), 2);
462        let all = store2.list(None).await.unwrap();
463        let contents: Vec<&str> = all.iter().map(|e| e.content.as_str()).collect();
464        assert!(contents.contains(&"hello"));
465        assert!(contents.contains(&"world"));
466    }
467
468    #[tokio::test]
469    async fn test_file_store_delete_rewrites() {
470        let tmp = tempfile::tempdir().unwrap();
471        let path = tmp.path().join("vectors.jsonl");
472
473        let store = FileVectorStore::new(path.clone()).await.unwrap();
474        let entry = make_entry("to_delete", vec![1.0], None);
475        let id = entry.id;
476        store.insert(entry).await.unwrap();
477        store
478            .insert(make_entry("keep", vec![0.5], None))
479            .await
480            .unwrap();
481
482        assert!(store.delete(id).await.unwrap());
483        assert_eq!(store.count().await.unwrap(), 1);
484
485        // Reload and verify
486        let store2 = FileVectorStore::new(path).await.unwrap();
487        assert_eq!(store2.count().await.unwrap(), 1);
488        let all = store2.list(None).await.unwrap();
489        assert_eq!(all[0].content, "keep");
490    }
491
492    #[tokio::test]
493    async fn test_file_store_search() {
494        let tmp = tempfile::tempdir().unwrap();
495        let path = tmp.path().join("vectors.jsonl");
496
497        let store = FileVectorStore::new(path).await.unwrap();
498        store
499            .insert(make_entry("close", vec![0.9, 0.1, 0.0], None))
500            .await
501            .unwrap();
502        store
503            .insert(make_entry("far", vec![0.0, 0.0, 1.0], None))
504            .await
505            .unwrap();
506
507        let results = store.search(&[1.0, 0.0, 0.0], 2, None).await.unwrap();
508        assert_eq!(results[0].entry.content, "close");
509    }
510
511    #[tokio::test]
512    async fn test_file_store_empty_file() {
513        let tmp = tempfile::tempdir().unwrap();
514        let path = tmp.path().join("vectors.jsonl");
515
516        let store = FileVectorStore::new(path).await.unwrap();
517        assert_eq!(store.count().await.unwrap(), 0);
518    }
519}