1use argentor_core::{ArgentorError, ArgentorResult};
2use async_trait::async_trait;
3use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use tokio::sync::RwLock;
7use uuid::Uuid;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct MemoryEntry {
12 pub id: Uuid,
14 pub content: String,
16 pub embedding: Vec<f32>,
18 pub metadata: HashMap<String, serde_json::Value>,
20 pub session_id: Option<Uuid>,
22 pub created_at: DateTime<Utc>,
24}
25
26#[derive(Debug, Clone)]
28pub struct SearchResult {
29 pub entry: MemoryEntry,
31 pub score: f32,
33}
34
35#[async_trait]
37pub trait VectorStore: Send + Sync {
38 async fn insert(&self, entry: MemoryEntry) -> ArgentorResult<()>;
40
41 async fn search(
43 &self,
44 query_embedding: &[f32],
45 top_k: usize,
46 session_filter: Option<Uuid>,
47 ) -> ArgentorResult<Vec<SearchResult>>;
48
49 async fn delete(&self, id: Uuid) -> ArgentorResult<bool>;
51
52 async fn list(&self, session_filter: Option<Uuid>) -> ArgentorResult<Vec<MemoryEntry>>;
54
55 async fn count(&self) -> ArgentorResult<usize>;
57}
58
59pub struct InMemoryVectorStore {
62 entries: RwLock<Vec<MemoryEntry>>,
63}
64
65impl InMemoryVectorStore {
66 pub fn new() -> Self {
68 Self {
69 entries: RwLock::new(Vec::new()),
70 }
71 }
72}
73
74impl Default for InMemoryVectorStore {
75 fn default() -> Self {
76 Self::new()
77 }
78}
79
80#[async_trait]
81impl VectorStore for InMemoryVectorStore {
82 async fn insert(&self, entry: MemoryEntry) -> ArgentorResult<()> {
83 let mut entries = self.entries.write().await;
84 entries.push(entry);
85 Ok(())
86 }
87
88 async fn search(
89 &self,
90 query_embedding: &[f32],
91 top_k: usize,
92 session_filter: Option<Uuid>,
93 ) -> ArgentorResult<Vec<SearchResult>> {
94 if query_embedding.is_empty() {
95 return Err(ArgentorError::Agent("Empty query embedding".to_string()));
96 }
97
98 let entries = self.entries.read().await;
99
100 let mut scored: Vec<SearchResult> = entries
101 .iter()
102 .filter(|e| {
103 if let Some(sid) = session_filter {
104 e.session_id == Some(sid)
105 } else {
106 true
107 }
108 })
109 .map(|e| {
110 let score = cosine_similarity(query_embedding, &e.embedding);
111 SearchResult {
112 entry: e.clone(),
113 score,
114 }
115 })
116 .collect();
117
118 scored.sort_by(|a, b| {
120 b.score
121 .partial_cmp(&a.score)
122 .unwrap_or(std::cmp::Ordering::Equal)
123 });
124 scored.truncate(top_k);
125
126 Ok(scored)
127 }
128
129 async fn delete(&self, id: Uuid) -> ArgentorResult<bool> {
130 let mut entries = self.entries.write().await;
131 let before = entries.len();
132 entries.retain(|e| e.id != id);
133 Ok(entries.len() < before)
134 }
135
136 async fn list(&self, session_filter: Option<Uuid>) -> ArgentorResult<Vec<MemoryEntry>> {
137 let entries = self.entries.read().await;
138 let filtered: Vec<MemoryEntry> = entries
139 .iter()
140 .filter(|e| {
141 if let Some(sid) = session_filter {
142 e.session_id == Some(sid)
143 } else {
144 true
145 }
146 })
147 .cloned()
148 .collect();
149 Ok(filtered)
150 }
151
152 async fn count(&self) -> ArgentorResult<usize> {
153 let entries = self.entries.read().await;
154 Ok(entries.len())
155 }
156}
157
158pub struct FileVectorStore {
161 path: std::path::PathBuf,
162 inner: InMemoryVectorStore,
163}
164
165impl FileVectorStore {
166 pub async fn new(path: std::path::PathBuf) -> ArgentorResult<Self> {
169 let inner = InMemoryVectorStore::new();
170
171 if path.exists() {
172 let data = tokio::fs::read_to_string(&path)
173 .await
174 .map_err(|e| ArgentorError::Session(format!("Failed to read vector store: {e}")))?;
175 for line in data.lines() {
176 if line.trim().is_empty() {
177 continue;
178 }
179 let entry: MemoryEntry = serde_json::from_str(line)
180 .map_err(|e| ArgentorError::Session(format!("Invalid JSONL entry: {e}")))?;
181 inner.insert(entry).await?;
182 }
183 } else if let Some(parent) = path.parent() {
184 tokio::fs::create_dir_all(parent)
185 .await
186 .map_err(|e| ArgentorError::Session(format!("Failed to create dir: {e}")))?;
187 }
188
189 Ok(Self { path, inner })
190 }
191
192 async fn append_to_file(&self, entry: &MemoryEntry) -> ArgentorResult<()> {
194 use tokio::io::AsyncWriteExt;
195 let mut file = tokio::fs::OpenOptions::new()
196 .create(true)
197 .append(true)
198 .open(&self.path)
199 .await
200 .map_err(|e| ArgentorError::Session(format!("Failed to open vector store: {e}")))?;
201 let mut line = serde_json::to_string(entry)
202 .map_err(|e| ArgentorError::Session(format!("Failed to serialize entry: {e}")))?;
203 line.push('\n');
204 file.write_all(line.as_bytes())
205 .await
206 .map_err(|e| ArgentorError::Session(format!("Failed to write entry: {e}")))?;
207 Ok(())
208 }
209
210 async fn rewrite_file(&self) -> ArgentorResult<()> {
212 let entries = self.inner.list(None).await?;
213 let mut data = String::new();
214 for entry in &entries {
215 let line = serde_json::to_string(entry)
216 .map_err(|e| ArgentorError::Session(format!("Failed to serialize entry: {e}")))?;
217 data.push_str(&line);
218 data.push('\n');
219 }
220 tokio::fs::write(&self.path, data.as_bytes())
221 .await
222 .map_err(|e| ArgentorError::Session(format!("Failed to write vector store: {e}")))?;
223 Ok(())
224 }
225}
226
227#[async_trait]
228impl VectorStore for FileVectorStore {
229 async fn insert(&self, entry: MemoryEntry) -> ArgentorResult<()> {
230 self.append_to_file(&entry).await?;
231 self.inner.insert(entry).await
232 }
233
234 async fn search(
235 &self,
236 query_embedding: &[f32],
237 top_k: usize,
238 session_filter: Option<Uuid>,
239 ) -> ArgentorResult<Vec<SearchResult>> {
240 self.inner
241 .search(query_embedding, top_k, session_filter)
242 .await
243 }
244
245 async fn delete(&self, id: Uuid) -> ArgentorResult<bool> {
246 let deleted = self.inner.delete(id).await?;
247 if deleted {
248 self.rewrite_file().await?;
249 }
250 Ok(deleted)
251 }
252
253 async fn list(&self, session_filter: Option<Uuid>) -> ArgentorResult<Vec<MemoryEntry>> {
254 self.inner.list(session_filter).await
255 }
256
257 async fn count(&self) -> ArgentorResult<usize> {
258 self.inner.count().await
259 }
260}
261
262fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
264 if a.len() != b.len() {
265 return 0.0;
266 }
267 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
268 let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
269 let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
270 if na == 0.0 || nb == 0.0 {
271 0.0
272 } else {
273 dot / (na * nb)
274 }
275}
276
277#[cfg(test)]
278#[allow(clippy::unwrap_used, clippy::expect_used)]
279mod tests {
280 use super::*;
281
282 fn make_entry(content: &str, embedding: Vec<f32>, session: Option<Uuid>) -> MemoryEntry {
283 MemoryEntry {
284 id: Uuid::new_v4(),
285 content: content.to_string(),
286 embedding,
287 metadata: HashMap::new(),
288 session_id: session,
289 created_at: Utc::now(),
290 }
291 }
292
293 #[tokio::test]
294 async fn test_insert_and_count() {
295 let store = InMemoryVectorStore::new();
296 assert_eq!(store.count().await.unwrap(), 0);
297
298 store
299 .insert(make_entry("hello", vec![1.0, 0.0, 0.0], None))
300 .await
301 .unwrap();
302 assert_eq!(store.count().await.unwrap(), 1);
303 }
304
305 #[tokio::test]
306 async fn test_search_returns_similar() {
307 let store = InMemoryVectorStore::new();
308
309 store
311 .insert(make_entry("rust lang", vec![0.9, 0.1, 0.0], None))
312 .await
313 .unwrap();
314 store
316 .insert(make_entry("cooking", vec![0.0, 0.0, 1.0], None))
317 .await
318 .unwrap();
319
320 let results = store.search(&[1.0, 0.0, 0.0], 2, None).await.unwrap();
321 assert_eq!(results.len(), 2);
322 assert_eq!(results[0].entry.content, "rust lang");
323 assert!(results[0].score > results[1].score);
324 }
325
326 #[tokio::test]
327 async fn test_search_top_k() {
328 let store = InMemoryVectorStore::new();
329 for i in 0..10 {
330 let mut emb = vec![0.0f32; 3];
331 emb[i % 3] = 1.0;
332 store
333 .insert(make_entry(&format!("entry_{i}"), emb, None))
334 .await
335 .unwrap();
336 }
337
338 let results = store.search(&[1.0, 0.0, 0.0], 3, None).await.unwrap();
339 assert_eq!(results.len(), 3);
340 }
341
342 #[tokio::test]
343 async fn test_search_session_filter() {
344 let store = InMemoryVectorStore::new();
345 let sid1 = Uuid::new_v4();
346 let sid2 = Uuid::new_v4();
347
348 store
349 .insert(make_entry("a", vec![1.0, 0.0], Some(sid1)))
350 .await
351 .unwrap();
352 store
353 .insert(make_entry("b", vec![0.9, 0.1], Some(sid2)))
354 .await
355 .unwrap();
356
357 let results = store.search(&[1.0, 0.0], 10, Some(sid1)).await.unwrap();
358 assert_eq!(results.len(), 1);
359 assert_eq!(results[0].entry.content, "a");
360 }
361
362 #[tokio::test]
363 async fn test_delete() {
364 let store = InMemoryVectorStore::new();
365 let entry = make_entry("to_delete", vec![1.0], None);
366 let id = entry.id;
367
368 store.insert(entry).await.unwrap();
369 assert_eq!(store.count().await.unwrap(), 1);
370
371 assert!(store.delete(id).await.unwrap());
372 assert_eq!(store.count().await.unwrap(), 0);
373
374 assert!(!store.delete(Uuid::new_v4()).await.unwrap());
376 }
377
378 #[tokio::test]
379 async fn test_list_all() {
380 let store = InMemoryVectorStore::new();
381 store
382 .insert(make_entry("a", vec![1.0], None))
383 .await
384 .unwrap();
385 store
386 .insert(make_entry("b", vec![0.5], None))
387 .await
388 .unwrap();
389
390 let all = store.list(None).await.unwrap();
391 assert_eq!(all.len(), 2);
392 }
393
394 #[tokio::test]
395 async fn test_list_filtered() {
396 let store = InMemoryVectorStore::new();
397 let sid = Uuid::new_v4();
398
399 store
400 .insert(make_entry("a", vec![1.0], Some(sid)))
401 .await
402 .unwrap();
403 store
404 .insert(make_entry("b", vec![0.5], None))
405 .await
406 .unwrap();
407
408 let filtered = store.list(Some(sid)).await.unwrap();
409 assert_eq!(filtered.len(), 1);
410 assert_eq!(filtered[0].content, "a");
411 }
412
413 #[tokio::test]
414 async fn test_search_empty_query() {
415 let store = InMemoryVectorStore::new();
416 assert!(store.search(&[], 5, None).await.is_err());
417 }
418
419 #[test]
420 fn test_cosine_similarity_identical() {
421 let v = vec![1.0, 0.0, 0.0];
422 assert!((cosine_similarity(&v, &v) - 1.0).abs() < 0.001);
423 }
424
425 #[test]
426 fn test_cosine_similarity_orthogonal() {
427 let a = vec![1.0, 0.0];
428 let b = vec![0.0, 1.0];
429 assert!(cosine_similarity(&a, &b).abs() < 0.001);
430 }
431
432 #[test]
433 fn test_cosine_similarity_opposite() {
434 let a = vec![1.0, 0.0];
435 let b = vec![-1.0, 0.0];
436 assert!((cosine_similarity(&a, &b) + 1.0).abs() < 0.001);
437 }
438
439 #[tokio::test]
442 async fn test_file_store_insert_and_persist() {
443 let tmp = tempfile::tempdir().unwrap();
444 let path = tmp.path().join("vectors.jsonl");
445
446 {
447 let store = FileVectorStore::new(path.clone()).await.unwrap();
448 store
449 .insert(make_entry("hello", vec![1.0, 0.0], None))
450 .await
451 .unwrap();
452 store
453 .insert(make_entry("world", vec![0.0, 1.0], None))
454 .await
455 .unwrap();
456 assert_eq!(store.count().await.unwrap(), 2);
457 }
458
459 let store2 = FileVectorStore::new(path).await.unwrap();
461 assert_eq!(store2.count().await.unwrap(), 2);
462 let all = store2.list(None).await.unwrap();
463 let contents: Vec<&str> = all.iter().map(|e| e.content.as_str()).collect();
464 assert!(contents.contains(&"hello"));
465 assert!(contents.contains(&"world"));
466 }
467
468 #[tokio::test]
469 async fn test_file_store_delete_rewrites() {
470 let tmp = tempfile::tempdir().unwrap();
471 let path = tmp.path().join("vectors.jsonl");
472
473 let store = FileVectorStore::new(path.clone()).await.unwrap();
474 let entry = make_entry("to_delete", vec![1.0], None);
475 let id = entry.id;
476 store.insert(entry).await.unwrap();
477 store
478 .insert(make_entry("keep", vec![0.5], None))
479 .await
480 .unwrap();
481
482 assert!(store.delete(id).await.unwrap());
483 assert_eq!(store.count().await.unwrap(), 1);
484
485 let store2 = FileVectorStore::new(path).await.unwrap();
487 assert_eq!(store2.count().await.unwrap(), 1);
488 let all = store2.list(None).await.unwrap();
489 assert_eq!(all[0].content, "keep");
490 }
491
492 #[tokio::test]
493 async fn test_file_store_search() {
494 let tmp = tempfile::tempdir().unwrap();
495 let path = tmp.path().join("vectors.jsonl");
496
497 let store = FileVectorStore::new(path).await.unwrap();
498 store
499 .insert(make_entry("close", vec![0.9, 0.1, 0.0], None))
500 .await
501 .unwrap();
502 store
503 .insert(make_entry("far", vec![0.0, 0.0, 1.0], None))
504 .await
505 .unwrap();
506
507 let results = store.search(&[1.0, 0.0, 0.0], 2, None).await.unwrap();
508 assert_eq!(results[0].entry.content, "close");
509 }
510
511 #[tokio::test]
512 async fn test_file_store_empty_file() {
513 let tmp = tempfile::tempdir().unwrap();
514 let path = tmp.path().join("vectors.jsonl");
515
516 let store = FileVectorStore::new(path).await.unwrap();
517 assert_eq!(store.count().await.unwrap(), 0);
518 }
519}