Skip to main content

batuta/agent/memory/
in_memory.rs

1//! In-memory substrate — ephemeral, substring-matching memory.
2//!
3//! Phase 1 implementation. Uses `HashMap` for key-value and `Vec` for
4//! fragment storage. Recall uses case-insensitive substring matching
5//! (NOT semantic similarity — that requires `TruenoMemory` in Phase 2).
6
7use async_trait::async_trait;
8use std::collections::HashMap;
9use std::sync::Mutex;
10
11use super::{MemoryFilter, MemoryFragment, MemoryId, MemorySource, MemorySubstrate};
12use crate::agent::result::AgentError;
13
14/// In-memory substrate (ephemeral, no persistence).
15pub struct InMemorySubstrate {
16    /// Fragment storage.
17    fragments: Mutex<Vec<StoredFragment>>,
18    /// Key-value storage.
19    kv: Mutex<HashMap<String, serde_json::Value>>,
20    /// Counter for generating unique IDs.
21    next_id: Mutex<u64>,
22}
23
24struct StoredFragment {
25    id: MemoryId,
26    agent_id: String,
27    content: String,
28    source: MemorySource,
29    created_at: chrono::DateTime<chrono::Utc>,
30}
31
32impl InMemorySubstrate {
33    /// Create an empty in-memory substrate.
34    pub fn new() -> Self {
35        Self {
36            fragments: Mutex::new(Vec::new()),
37            kv: Mutex::new(HashMap::new()),
38            next_id: Mutex::new(1),
39        }
40    }
41
42    fn gen_id(&self) -> String {
43        let mut id = self.next_id.lock().unwrap_or_else(std::sync::PoisonError::into_inner);
44        let current = *id;
45        *id += 1;
46        format!("mem-{current}")
47    }
48
49    fn kv_key(agent_id: &str, key: &str) -> String {
50        format!("{agent_id}:{key}")
51    }
52}
53
54impl Default for InMemorySubstrate {
55    fn default() -> Self {
56        Self::new()
57    }
58}
59
60/// Acquire a mutex lock, mapping poison errors to `AgentError`.
61fn lock<T>(mutex: &Mutex<T>) -> Result<std::sync::MutexGuard<'_, T>, AgentError> {
62    mutex.lock().map_err(|e| AgentError::Memory(format!("lock: {e}")))
63}
64
65/// Check if a stored fragment passes the optional filter.
66fn matches_filter(f: &StoredFragment, filter: Option<&MemoryFilter>) -> bool {
67    let Some(filter) = filter else { return true };
68    if let Some(ref aid) = filter.agent_id {
69        if f.agent_id != *aid {
70            return false;
71        }
72    }
73    if let Some(ref src) = filter.source {
74        if f.source != *src {
75            return false;
76        }
77    }
78    if let Some(since) = filter.since {
79        if f.created_at < since {
80            return false;
81        }
82    }
83    true
84}
85
86/// Score a stored fragment for relevance based on query length ratio.
87fn score_fragment(f: &StoredFragment, query: &str) -> MemoryFragment {
88    let score = if f.content.is_empty() {
89        0.0
90    } else {
91        // Precision loss acceptable: string lengths fit easily in f32
92        #[allow(clippy::cast_precision_loss)]
93        let s = (query.len() as f32 / f.content.len() as f32).min(1.0);
94        s
95    };
96    MemoryFragment {
97        id: f.id.clone(),
98        content: f.content.clone(),
99        source: f.source.clone(),
100        relevance_score: score,
101        created_at: f.created_at,
102    }
103}
104
105#[async_trait]
106impl MemorySubstrate for InMemorySubstrate {
107    async fn remember(
108        &self,
109        agent_id: &str,
110        content: &str,
111        source: MemorySource,
112        _embedding: Option<&[f32]>,
113    ) -> Result<MemoryId, AgentError> {
114        let id = self.gen_id();
115        let fragment = StoredFragment {
116            id: id.clone(),
117            agent_id: agent_id.to_string(),
118            content: content.to_string(),
119            source,
120            created_at: chrono::Utc::now(),
121        };
122        lock(&self.fragments)?.push(fragment);
123        Ok(id)
124    }
125
126    async fn recall(
127        &self,
128        query: &str,
129        limit: usize,
130        filter: Option<MemoryFilter>,
131        _query_embedding: Option<&[f32]>,
132    ) -> Result<Vec<MemoryFragment>, AgentError> {
133        let fragments = lock(&self.fragments)?;
134
135        let query_lower = query.to_lowercase();
136
137        let mut results: Vec<MemoryFragment> = fragments
138            .iter()
139            .filter(|f| {
140                matches_filter(f, filter.as_ref())
141                    && f.content.to_lowercase().contains(&query_lower)
142            })
143            .map(|f| score_fragment(f, query))
144            .collect();
145
146        results.sort_by(|a, b| {
147            b.relevance_score.partial_cmp(&a.relevance_score).unwrap_or(std::cmp::Ordering::Equal)
148        });
149        results.truncate(limit);
150
151        Ok(results)
152    }
153
154    async fn set(
155        &self,
156        agent_id: &str,
157        key: &str,
158        value: serde_json::Value,
159    ) -> Result<(), AgentError> {
160        lock(&self.kv)?.insert(Self::kv_key(agent_id, key), value);
161        Ok(())
162    }
163
164    async fn get(
165        &self,
166        agent_id: &str,
167        key: &str,
168    ) -> Result<Option<serde_json::Value>, AgentError> {
169        let kv = lock(&self.kv)?;
170        Ok(kv.get(&Self::kv_key(agent_id, key)).cloned())
171    }
172
173    async fn forget(&self, id: MemoryId) -> Result<(), AgentError> {
174        lock(&self.fragments)?.retain(|f| f.id != id);
175        Ok(())
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182
183    #[tokio::test]
184    async fn test_remember_and_recall() {
185        let substrate = InMemorySubstrate::new();
186        substrate
187            .remember("agent1", "Rust is fast", MemorySource::User, None)
188            .await
189            .expect("remember failed");
190
191        let results = substrate.recall("Rust", 10, None, None).await.expect("recall failed");
192        assert_eq!(results.len(), 1);
193        assert!(results[0].content.contains("Rust is fast"));
194    }
195
196    #[tokio::test]
197    async fn test_recall_case_insensitive() {
198        let substrate = InMemorySubstrate::new();
199        substrate
200            .remember("a", "HELLO WORLD", MemorySource::System, None)
201            .await
202            .expect("remember failed");
203
204        let results = substrate.recall("hello", 10, None, None).await.expect("recall failed");
205        assert_eq!(results.len(), 1);
206    }
207
208    #[tokio::test]
209    async fn test_recall_no_match() {
210        let substrate = InMemorySubstrate::new();
211        substrate.remember("a", "apples", MemorySource::User, None).await.expect("remember failed");
212
213        let results = substrate.recall("oranges", 10, None, None).await.expect("recall failed");
214        assert!(results.is_empty());
215    }
216
217    #[tokio::test]
218    async fn test_recall_limit() {
219        let substrate = InMemorySubstrate::new();
220        for i in 0..10 {
221            substrate
222                .remember("a", &format!("item {i} with keyword"), MemorySource::Conversation, None)
223                .await
224                .expect("remember failed");
225        }
226
227        let results = substrate.recall("keyword", 3, None, None).await.expect("recall failed");
228        assert_eq!(results.len(), 3);
229    }
230
231    #[tokio::test]
232    async fn test_filter_by_agent_id() {
233        let substrate = InMemorySubstrate::new();
234        substrate
235            .remember("agent1", "secret data", MemorySource::User, None)
236            .await
237            .expect("remember failed");
238        substrate
239            .remember("agent2", "other data", MemorySource::User, None)
240            .await
241            .expect("remember failed");
242
243        let filter = MemoryFilter { agent_id: Some("agent1".into()), ..Default::default() };
244        let results =
245            substrate.recall("data", 10, Some(filter), None).await.expect("recall failed");
246        assert_eq!(results.len(), 1);
247        assert!(results[0].content.contains("secret"));
248    }
249
250    #[tokio::test]
251    async fn test_kv_set_get() {
252        let substrate = InMemorySubstrate::new();
253        substrate.set("a", "key1", serde_json::json!(42)).await.expect("set failed");
254
255        let val = substrate.get("a", "key1").await.expect("get failed");
256        assert_eq!(val, Some(serde_json::json!(42)));
257
258        let missing = substrate.get("a", "nonexistent").await.expect("get failed");
259        assert!(missing.is_none());
260    }
261
262    #[tokio::test]
263    async fn test_kv_isolation() {
264        let substrate = InMemorySubstrate::new();
265        substrate.set("agent1", "key", serde_json::json!("one")).await.expect("set failed");
266        substrate.set("agent2", "key", serde_json::json!("two")).await.expect("set failed");
267
268        let v1 = substrate.get("agent1", "key").await.expect("get failed");
269        let v2 = substrate.get("agent2", "key").await.expect("get failed");
270        assert_eq!(v1, Some(serde_json::json!("one")));
271        assert_eq!(v2, Some(serde_json::json!("two")));
272    }
273
274    #[tokio::test]
275    async fn test_forget() {
276        let substrate = InMemorySubstrate::new();
277        let id = substrate
278            .remember("a", "forget me", MemorySource::User, None)
279            .await
280            .expect("remember failed");
281
282        substrate.forget(id).await.expect("forget failed");
283
284        let results = substrate.recall("forget", 10, None, None).await.expect("recall failed");
285        assert!(results.is_empty());
286    }
287
288    #[tokio::test]
289    async fn test_unique_ids() {
290        let substrate = InMemorySubstrate::new();
291        let id1 = substrate
292            .remember("a", "one", MemorySource::User, None)
293            .await
294            .expect("remember failed");
295        let id2 = substrate
296            .remember("a", "two", MemorySource::User, None)
297            .await
298            .expect("remember failed");
299        assert_ne!(id1, id2);
300    }
301
302    #[test]
303    fn test_default() {
304        let substrate = InMemorySubstrate::default();
305        assert_eq!(substrate.gen_id(), "mem-1");
306    }
307
308    #[tokio::test]
309    async fn test_filter_by_source() {
310        let substrate = InMemorySubstrate::new();
311        substrate
312            .remember("a", "user msg", MemorySource::User, None)
313            .await
314            .expect("remember failed");
315        substrate
316            .remember("a", "system msg", MemorySource::System, None)
317            .await
318            .expect("remember failed");
319
320        let filter = MemoryFilter { source: Some(MemorySource::System), ..Default::default() };
321        let results = substrate.recall("msg", 10, Some(filter), None).await.expect("recall failed");
322        assert_eq!(results.len(), 1);
323        assert!(results[0].content.contains("system"));
324    }
325
326    #[tokio::test]
327    async fn test_filter_by_since() {
328        let substrate = InMemorySubstrate::new();
329        substrate
330            .remember("a", "old memory", MemorySource::User, None)
331            .await
332            .expect("remember failed");
333
334        let after_first = chrono::Utc::now();
335
336        substrate
337            .remember("a", "new memory", MemorySource::User, None)
338            .await
339            .expect("remember failed");
340
341        let filter = MemoryFilter { since: Some(after_first), ..Default::default() };
342        let results =
343            substrate.recall("memory", 10, Some(filter), None).await.expect("recall failed");
344        assert_eq!(results.len(), 1);
345        assert!(results[0].content.contains("new"));
346    }
347
348    #[test]
349    fn test_score_empty_content() {
350        let f = StoredFragment {
351            id: "mem-1".into(),
352            agent_id: "a".into(),
353            content: String::new(),
354            source: MemorySource::User,
355            created_at: chrono::Utc::now(),
356        };
357        let scored = score_fragment(&f, "query");
358        assert_eq!(scored.relevance_score, 0.0);
359    }
360
361    #[test]
362    fn test_score_long_content() {
363        let f = StoredFragment {
364            id: "mem-1".into(),
365            agent_id: "a".into(),
366            content: "a very long content string that is much longer than the query".into(),
367            source: MemorySource::User,
368            created_at: chrono::Utc::now(),
369        };
370        let scored = score_fragment(&f, "short");
371        assert!(scored.relevance_score > 0.0);
372        assert!(scored.relevance_score < 1.0);
373    }
374}