1use async_trait::async_trait;
8use std::collections::HashMap;
9use std::sync::Mutex;
10
11use super::{MemoryFilter, MemoryFragment, MemoryId, MemorySource, MemorySubstrate};
12use crate::agent::result::AgentError;
13
14pub struct InMemorySubstrate {
16 fragments: Mutex<Vec<StoredFragment>>,
18 kv: Mutex<HashMap<String, serde_json::Value>>,
20 next_id: Mutex<u64>,
22}
23
24struct StoredFragment {
25 id: MemoryId,
26 agent_id: String,
27 content: String,
28 source: MemorySource,
29 created_at: chrono::DateTime<chrono::Utc>,
30}
31
32impl InMemorySubstrate {
33 pub fn new() -> Self {
35 Self {
36 fragments: Mutex::new(Vec::new()),
37 kv: Mutex::new(HashMap::new()),
38 next_id: Mutex::new(1),
39 }
40 }
41
42 fn gen_id(&self) -> String {
43 let mut id = self.next_id.lock().unwrap_or_else(std::sync::PoisonError::into_inner);
44 let current = *id;
45 *id += 1;
46 format!("mem-{current}")
47 }
48
49 fn kv_key(agent_id: &str, key: &str) -> String {
50 format!("{agent_id}:{key}")
51 }
52}
53
54impl Default for InMemorySubstrate {
55 fn default() -> Self {
56 Self::new()
57 }
58}
59
60fn lock<T>(mutex: &Mutex<T>) -> Result<std::sync::MutexGuard<'_, T>, AgentError> {
62 mutex.lock().map_err(|e| AgentError::Memory(format!("lock: {e}")))
63}
64
65fn matches_filter(f: &StoredFragment, filter: Option<&MemoryFilter>) -> bool {
67 let Some(filter) = filter else { return true };
68 if let Some(ref aid) = filter.agent_id {
69 if f.agent_id != *aid {
70 return false;
71 }
72 }
73 if let Some(ref src) = filter.source {
74 if f.source != *src {
75 return false;
76 }
77 }
78 if let Some(since) = filter.since {
79 if f.created_at < since {
80 return false;
81 }
82 }
83 true
84}
85
86fn score_fragment(f: &StoredFragment, query: &str) -> MemoryFragment {
88 let score = if f.content.is_empty() {
89 0.0
90 } else {
91 #[allow(clippy::cast_precision_loss)]
93 let s = (query.len() as f32 / f.content.len() as f32).min(1.0);
94 s
95 };
96 MemoryFragment {
97 id: f.id.clone(),
98 content: f.content.clone(),
99 source: f.source.clone(),
100 relevance_score: score,
101 created_at: f.created_at,
102 }
103}
104
105#[async_trait]
106impl MemorySubstrate for InMemorySubstrate {
107 async fn remember(
108 &self,
109 agent_id: &str,
110 content: &str,
111 source: MemorySource,
112 _embedding: Option<&[f32]>,
113 ) -> Result<MemoryId, AgentError> {
114 let id = self.gen_id();
115 let fragment = StoredFragment {
116 id: id.clone(),
117 agent_id: agent_id.to_string(),
118 content: content.to_string(),
119 source,
120 created_at: chrono::Utc::now(),
121 };
122 lock(&self.fragments)?.push(fragment);
123 Ok(id)
124 }
125
126 async fn recall(
127 &self,
128 query: &str,
129 limit: usize,
130 filter: Option<MemoryFilter>,
131 _query_embedding: Option<&[f32]>,
132 ) -> Result<Vec<MemoryFragment>, AgentError> {
133 let fragments = lock(&self.fragments)?;
134
135 let query_lower = query.to_lowercase();
136
137 let mut results: Vec<MemoryFragment> = fragments
138 .iter()
139 .filter(|f| {
140 matches_filter(f, filter.as_ref())
141 && f.content.to_lowercase().contains(&query_lower)
142 })
143 .map(|f| score_fragment(f, query))
144 .collect();
145
146 results.sort_by(|a, b| {
147 b.relevance_score.partial_cmp(&a.relevance_score).unwrap_or(std::cmp::Ordering::Equal)
148 });
149 results.truncate(limit);
150
151 Ok(results)
152 }
153
154 async fn set(
155 &self,
156 agent_id: &str,
157 key: &str,
158 value: serde_json::Value,
159 ) -> Result<(), AgentError> {
160 lock(&self.kv)?.insert(Self::kv_key(agent_id, key), value);
161 Ok(())
162 }
163
164 async fn get(
165 &self,
166 agent_id: &str,
167 key: &str,
168 ) -> Result<Option<serde_json::Value>, AgentError> {
169 let kv = lock(&self.kv)?;
170 Ok(kv.get(&Self::kv_key(agent_id, key)).cloned())
171 }
172
173 async fn forget(&self, id: MemoryId) -> Result<(), AgentError> {
174 lock(&self.fragments)?.retain(|f| f.id != id);
175 Ok(())
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182
183 #[tokio::test]
184 async fn test_remember_and_recall() {
185 let substrate = InMemorySubstrate::new();
186 substrate
187 .remember("agent1", "Rust is fast", MemorySource::User, None)
188 .await
189 .expect("remember failed");
190
191 let results = substrate.recall("Rust", 10, None, None).await.expect("recall failed");
192 assert_eq!(results.len(), 1);
193 assert!(results[0].content.contains("Rust is fast"));
194 }
195
196 #[tokio::test]
197 async fn test_recall_case_insensitive() {
198 let substrate = InMemorySubstrate::new();
199 substrate
200 .remember("a", "HELLO WORLD", MemorySource::System, None)
201 .await
202 .expect("remember failed");
203
204 let results = substrate.recall("hello", 10, None, None).await.expect("recall failed");
205 assert_eq!(results.len(), 1);
206 }
207
208 #[tokio::test]
209 async fn test_recall_no_match() {
210 let substrate = InMemorySubstrate::new();
211 substrate.remember("a", "apples", MemorySource::User, None).await.expect("remember failed");
212
213 let results = substrate.recall("oranges", 10, None, None).await.expect("recall failed");
214 assert!(results.is_empty());
215 }
216
217 #[tokio::test]
218 async fn test_recall_limit() {
219 let substrate = InMemorySubstrate::new();
220 for i in 0..10 {
221 substrate
222 .remember("a", &format!("item {i} with keyword"), MemorySource::Conversation, None)
223 .await
224 .expect("remember failed");
225 }
226
227 let results = substrate.recall("keyword", 3, None, None).await.expect("recall failed");
228 assert_eq!(results.len(), 3);
229 }
230
231 #[tokio::test]
232 async fn test_filter_by_agent_id() {
233 let substrate = InMemorySubstrate::new();
234 substrate
235 .remember("agent1", "secret data", MemorySource::User, None)
236 .await
237 .expect("remember failed");
238 substrate
239 .remember("agent2", "other data", MemorySource::User, None)
240 .await
241 .expect("remember failed");
242
243 let filter = MemoryFilter { agent_id: Some("agent1".into()), ..Default::default() };
244 let results =
245 substrate.recall("data", 10, Some(filter), None).await.expect("recall failed");
246 assert_eq!(results.len(), 1);
247 assert!(results[0].content.contains("secret"));
248 }
249
250 #[tokio::test]
251 async fn test_kv_set_get() {
252 let substrate = InMemorySubstrate::new();
253 substrate.set("a", "key1", serde_json::json!(42)).await.expect("set failed");
254
255 let val = substrate.get("a", "key1").await.expect("get failed");
256 assert_eq!(val, Some(serde_json::json!(42)));
257
258 let missing = substrate.get("a", "nonexistent").await.expect("get failed");
259 assert!(missing.is_none());
260 }
261
262 #[tokio::test]
263 async fn test_kv_isolation() {
264 let substrate = InMemorySubstrate::new();
265 substrate.set("agent1", "key", serde_json::json!("one")).await.expect("set failed");
266 substrate.set("agent2", "key", serde_json::json!("two")).await.expect("set failed");
267
268 let v1 = substrate.get("agent1", "key").await.expect("get failed");
269 let v2 = substrate.get("agent2", "key").await.expect("get failed");
270 assert_eq!(v1, Some(serde_json::json!("one")));
271 assert_eq!(v2, Some(serde_json::json!("two")));
272 }
273
274 #[tokio::test]
275 async fn test_forget() {
276 let substrate = InMemorySubstrate::new();
277 let id = substrate
278 .remember("a", "forget me", MemorySource::User, None)
279 .await
280 .expect("remember failed");
281
282 substrate.forget(id).await.expect("forget failed");
283
284 let results = substrate.recall("forget", 10, None, None).await.expect("recall failed");
285 assert!(results.is_empty());
286 }
287
288 #[tokio::test]
289 async fn test_unique_ids() {
290 let substrate = InMemorySubstrate::new();
291 let id1 = substrate
292 .remember("a", "one", MemorySource::User, None)
293 .await
294 .expect("remember failed");
295 let id2 = substrate
296 .remember("a", "two", MemorySource::User, None)
297 .await
298 .expect("remember failed");
299 assert_ne!(id1, id2);
300 }
301
302 #[test]
303 fn test_default() {
304 let substrate = InMemorySubstrate::default();
305 assert_eq!(substrate.gen_id(), "mem-1");
306 }
307
308 #[tokio::test]
309 async fn test_filter_by_source() {
310 let substrate = InMemorySubstrate::new();
311 substrate
312 .remember("a", "user msg", MemorySource::User, None)
313 .await
314 .expect("remember failed");
315 substrate
316 .remember("a", "system msg", MemorySource::System, None)
317 .await
318 .expect("remember failed");
319
320 let filter = MemoryFilter { source: Some(MemorySource::System), ..Default::default() };
321 let results = substrate.recall("msg", 10, Some(filter), None).await.expect("recall failed");
322 assert_eq!(results.len(), 1);
323 assert!(results[0].content.contains("system"));
324 }
325
326 #[tokio::test]
327 async fn test_filter_by_since() {
328 let substrate = InMemorySubstrate::new();
329 substrate
330 .remember("a", "old memory", MemorySource::User, None)
331 .await
332 .expect("remember failed");
333
334 let after_first = chrono::Utc::now();
335
336 substrate
337 .remember("a", "new memory", MemorySource::User, None)
338 .await
339 .expect("remember failed");
340
341 let filter = MemoryFilter { since: Some(after_first), ..Default::default() };
342 let results =
343 substrate.recall("memory", 10, Some(filter), None).await.expect("recall failed");
344 assert_eq!(results.len(), 1);
345 assert!(results[0].content.contains("new"));
346 }
347
348 #[test]
349 fn test_score_empty_content() {
350 let f = StoredFragment {
351 id: "mem-1".into(),
352 agent_id: "a".into(),
353 content: String::new(),
354 source: MemorySource::User,
355 created_at: chrono::Utc::now(),
356 };
357 let scored = score_fragment(&f, "query");
358 assert_eq!(scored.relevance_score, 0.0);
359 }
360
361 #[test]
362 fn test_score_long_content() {
363 let f = StoredFragment {
364 id: "mem-1".into(),
365 agent_id: "a".into(),
366 content: "a very long content string that is much longer than the query".into(),
367 source: MemorySource::User,
368 created_at: chrono::Utc::now(),
369 };
370 let scored = score_fragment(&f, "short");
371 assert!(scored.relevance_score > 0.0);
372 assert!(scored.relevance_score < 1.0);
373 }
374}