agent_io/memory/backends/
in_memory.rs1use async_trait::async_trait;
4use std::collections::HashMap;
5use tokio::sync::RwLock;
6
7use crate::Result;
8use crate::memory::entry::MemoryEntry;
9use crate::memory::store::MemoryStore;
10
11pub struct InMemoryStore {
13 memories: RwLock<HashMap<String, MemoryEntry>>,
14}
15
16impl InMemoryStore {
17 pub fn new() -> Self {
19 Self {
20 memories: RwLock::new(HashMap::new()),
21 }
22 }
23
24 pub fn with_memories(memories: Vec<MemoryEntry>) -> Self {
26 let map: HashMap<String, MemoryEntry> =
27 memories.into_iter().map(|m| (m.id.clone(), m)).collect();
28 Self {
29 memories: RwLock::new(map),
30 }
31 }
32
33 fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
35 if a.len() != b.len() || a.is_empty() {
36 return 0.0;
37 }
38
39 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
40 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
41 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
42
43 if norm_a == 0.0 || norm_b == 0.0 {
44 0.0
45 } else {
46 dot / (norm_a * norm_b)
47 }
48 }
49}
50
51impl Default for InMemoryStore {
52 fn default() -> Self {
53 Self::new()
54 }
55}
56
57#[async_trait]
58impl MemoryStore for InMemoryStore {
59 async fn add(&self, entry: MemoryEntry) -> Result<String> {
60 let id = entry.id.clone();
61 let mut memories = self.memories.write().await;
62 memories.insert(id.clone(), entry);
63 Ok(id)
64 }
65
66 async fn search(&self, query: &str, limit: usize) -> Result<Vec<MemoryEntry>> {
67 let memories = self.memories.read().await;
68 let query_lower = query.to_lowercase();
69
70 let mut results: Vec<MemoryEntry> = memories
71 .values()
72 .filter(|m| m.content.to_lowercase().contains(&query_lower))
73 .cloned()
74 .collect();
75
76 results.sort_by(|a, b| {
78 b.relevance_score()
79 .partial_cmp(&a.relevance_score())
80 .unwrap_or(std::cmp::Ordering::Equal)
81 });
82
83 results.truncate(limit);
84 Ok(results)
85 }
86
87 async fn search_by_embedding(
88 &self,
89 embedding: &[f32],
90 limit: usize,
91 threshold: f32,
92 ) -> Result<Vec<MemoryEntry>> {
93 let memories = self.memories.read().await;
94
95 let mut scored: Vec<(f32, MemoryEntry)> = memories
96 .values()
97 .filter_map(|m| {
98 let emb = m.embedding.as_ref()?;
99 let score = Self::cosine_similarity(embedding, emb);
100 if score >= threshold {
101 Some((score, m.clone()))
102 } else {
103 None
104 }
105 })
106 .collect();
107
108 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
110
111 Ok(scored.into_iter().take(limit).map(|(_, m)| m).collect())
112 }
113
114 async fn get(&self, id: &str) -> Result<Option<MemoryEntry>> {
115 let memories = self.memories.read().await;
116 Ok(memories.get(id).cloned())
117 }
118
119 async fn update(&self, entry: MemoryEntry) -> Result<()> {
120 let mut memories = self.memories.write().await;
121 memories.insert(entry.id.clone(), entry);
122 Ok(())
123 }
124
125 async fn delete(&self, id: &str) -> Result<()> {
126 let mut memories = self.memories.write().await;
127 memories.remove(id);
128 Ok(())
129 }
130
131 async fn clear(&self) -> Result<()> {
132 let mut memories = self.memories.write().await;
133 memories.clear();
134 Ok(())
135 }
136
137 async fn count(&self) -> Result<usize> {
138 let memories = self.memories.read().await;
139 Ok(memories.len())
140 }
141
142 async fn ids(&self) -> Result<Vec<String>> {
143 let memories = self.memories.read().await;
144 Ok(memories.keys().cloned().collect())
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151
152 #[tokio::test]
153 async fn test_add_and_get() {
154 let store = InMemoryStore::new();
155 let entry = MemoryEntry::new("Test memory");
156
157 let id = store.add(entry.clone()).await.unwrap();
158 let retrieved = store.get(&id).await.unwrap();
159
160 assert!(retrieved.is_some());
161 assert_eq!(retrieved.unwrap().content, "Test memory");
162 }
163
164 #[tokio::test]
165 async fn test_search() {
166 let store = InMemoryStore::new();
167
168 store
169 .add(MemoryEntry::new("Rust is a programming language"))
170 .await
171 .unwrap();
172 store
173 .add(MemoryEntry::new("Python is also a programming language"))
174 .await
175 .unwrap();
176 store
177 .add(MemoryEntry::new("The weather is nice today"))
178 .await
179 .unwrap();
180
181 let results = store.search("programming", 10).await.unwrap();
182 assert_eq!(results.len(), 2);
183 }
184
185 #[tokio::test]
186 async fn test_search_by_embedding() {
187 let store = InMemoryStore::new();
188
189 let mut entry1 = MemoryEntry::new("Rust programming");
190 entry1.embedding = Some(vec![1.0, 0.0, 0.0]);
191
192 let mut entry2 = MemoryEntry::new("Python programming");
193 entry2.embedding = Some(vec![0.0, 1.0, 0.0]);
194
195 store.add(entry1).await.unwrap();
196 store.add(entry2).await.unwrap();
197
198 let results = store
200 .search_by_embedding(&[0.9, 0.1, 0.0], 10, 0.5)
201 .await
202 .unwrap();
203 assert!(!results.is_empty());
204 assert_eq!(results[0].content, "Rust programming");
205 }
206
207 #[tokio::test]
208 async fn test_delete() {
209 let store = InMemoryStore::new();
210 let entry = MemoryEntry::new("Test");
211
212 let id = store.add(entry).await.unwrap();
213 assert_eq!(store.count().await.unwrap(), 1);
214
215 store.delete(&id).await.unwrap();
216 assert_eq!(store.count().await.unwrap(), 0);
217 }
218}