Skip to main content

agent_io/memory/backends/
in_memory.rs

1//! In-memory memory store implementation
2
3use async_trait::async_trait;
4use std::collections::HashMap;
5use tokio::sync::RwLock;
6
7use crate::Result;
8use crate::memory::entry::MemoryEntry;
9use crate::memory::store::MemoryStore;
10
11/// In-memory vector store for development and testing
12pub struct InMemoryStore {
13    memories: RwLock<HashMap<String, MemoryEntry>>,
14}
15
16impl InMemoryStore {
17    /// Create a new in-memory store
18    pub fn new() -> Self {
19        Self {
20            memories: RwLock::new(HashMap::new()),
21        }
22    }
23
24    /// Create a new in-memory store with pre-seeded memories
25    pub fn with_memories(memories: Vec<MemoryEntry>) -> Self {
26        let map: HashMap<String, MemoryEntry> =
27            memories.into_iter().map(|m| (m.id.clone(), m)).collect();
28        Self {
29            memories: RwLock::new(map),
30        }
31    }
32
33    /// Calculate cosine similarity between two vectors
34    fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
35        if a.len() != b.len() || a.is_empty() {
36            return 0.0;
37        }
38
39        let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
40        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
41        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
42
43        if norm_a == 0.0 || norm_b == 0.0 {
44            0.0
45        } else {
46            dot / (norm_a * norm_b)
47        }
48    }
49}
50
51impl Default for InMemoryStore {
52    fn default() -> Self {
53        Self::new()
54    }
55}
56
57#[async_trait]
58impl MemoryStore for InMemoryStore {
59    async fn add(&self, entry: MemoryEntry) -> Result<String> {
60        let id = entry.id.clone();
61        let mut memories = self.memories.write().await;
62        memories.insert(id.clone(), entry);
63        Ok(id)
64    }
65
66    async fn search(&self, query: &str, limit: usize) -> Result<Vec<MemoryEntry>> {
67        let memories = self.memories.read().await;
68        let query_lower = query.to_lowercase();
69
70        let mut results: Vec<MemoryEntry> = memories
71            .values()
72            .filter(|m| m.content.to_lowercase().contains(&query_lower))
73            .cloned()
74            .collect();
75
76        // Sort by relevance score
77        results.sort_by(|a, b| {
78            b.relevance_score()
79                .partial_cmp(&a.relevance_score())
80                .unwrap_or(std::cmp::Ordering::Equal)
81        });
82
83        results.truncate(limit);
84        Ok(results)
85    }
86
87    async fn search_by_embedding(
88        &self,
89        embedding: &[f32],
90        limit: usize,
91        threshold: f32,
92    ) -> Result<Vec<MemoryEntry>> {
93        let memories = self.memories.read().await;
94
95        let mut scored: Vec<(f32, MemoryEntry)> = memories
96            .values()
97            .filter_map(|m| {
98                let emb = m.embedding.as_ref()?;
99                let score = Self::cosine_similarity(embedding, emb);
100                if score >= threshold {
101                    Some((score, m.clone()))
102                } else {
103                    None
104                }
105            })
106            .collect();
107
108        // Sort by similarity score (descending)
109        scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
110
111        Ok(scored.into_iter().take(limit).map(|(_, m)| m).collect())
112    }
113
114    async fn get(&self, id: &str) -> Result<Option<MemoryEntry>> {
115        let memories = self.memories.read().await;
116        Ok(memories.get(id).cloned())
117    }
118
119    async fn update(&self, entry: MemoryEntry) -> Result<()> {
120        let mut memories = self.memories.write().await;
121        memories.insert(entry.id.clone(), entry);
122        Ok(())
123    }
124
125    async fn delete(&self, id: &str) -> Result<()> {
126        let mut memories = self.memories.write().await;
127        memories.remove(id);
128        Ok(())
129    }
130
131    async fn clear(&self) -> Result<()> {
132        let mut memories = self.memories.write().await;
133        memories.clear();
134        Ok(())
135    }
136
137    async fn count(&self) -> Result<usize> {
138        let memories = self.memories.read().await;
139        Ok(memories.len())
140    }
141
142    async fn ids(&self) -> Result<Vec<String>> {
143        let memories = self.memories.read().await;
144        Ok(memories.keys().cloned().collect())
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151
152    #[tokio::test]
153    async fn test_add_and_get() {
154        let store = InMemoryStore::new();
155        let entry = MemoryEntry::new("Test memory");
156
157        let id = store.add(entry.clone()).await.unwrap();
158        let retrieved = store.get(&id).await.unwrap();
159
160        assert!(retrieved.is_some());
161        assert_eq!(retrieved.unwrap().content, "Test memory");
162    }
163
164    #[tokio::test]
165    async fn test_search() {
166        let store = InMemoryStore::new();
167
168        store
169            .add(MemoryEntry::new("Rust is a programming language"))
170            .await
171            .unwrap();
172        store
173            .add(MemoryEntry::new("Python is also a programming language"))
174            .await
175            .unwrap();
176        store
177            .add(MemoryEntry::new("The weather is nice today"))
178            .await
179            .unwrap();
180
181        let results = store.search("programming", 10).await.unwrap();
182        assert_eq!(results.len(), 2);
183    }
184
185    #[tokio::test]
186    async fn test_search_by_embedding() {
187        let store = InMemoryStore::new();
188
189        let mut entry1 = MemoryEntry::new("Rust programming");
190        entry1.embedding = Some(vec![1.0, 0.0, 0.0]);
191
192        let mut entry2 = MemoryEntry::new("Python programming");
193        entry2.embedding = Some(vec![0.0, 1.0, 0.0]);
194
195        store.add(entry1).await.unwrap();
196        store.add(entry2).await.unwrap();
197
198        // Search with similar embedding
199        let results = store
200            .search_by_embedding(&[0.9, 0.1, 0.0], 10, 0.5)
201            .await
202            .unwrap();
203        assert!(!results.is_empty());
204        assert_eq!(results[0].content, "Rust programming");
205    }
206
207    #[tokio::test]
208    async fn test_delete() {
209        let store = InMemoryStore::new();
210        let entry = MemoryEntry::new("Test");
211
212        let id = store.add(entry).await.unwrap();
213        assert_eq!(store.count().await.unwrap(), 1);
214
215        store.delete(&id).await.unwrap();
216        assert_eq!(store.count().await.unwrap(), 0);
217    }
218}