Skip to main content

aagt_core/agent/
memory.rs

1//! Memory system for agents
2//!
3//! Provides short-term (conversation) and long-term (persistent) memory.
4
5use std::collections::VecDeque;
6use std::sync::Arc;
7
8use dashmap::DashMap;
9
10use crate::agent::message::Message;
11use std::collections::HashMap;
12use std::path::PathBuf;
13use std::sync::Weak;
14use async_trait::async_trait;
15
16use crate::agent::scheduler::Scheduler;
17
18/// Trait for memory implementations
19#[async_trait]
20pub trait Memory: Send + Sync {
21    /// Store a message
22    async fn store(&self, user_id: &str, agent_id: Option<&str>, message: Message) -> crate::error::Result<()>;
23
24    /// Store multiple messages efficiently
25    async fn store_batch(&self, user_id: &str, agent_id: Option<&str>, messages: Vec<Message>) -> crate::error::Result<()> {
26        for msg in messages {
27            self.store(user_id, agent_id, msg).await?;
28        }
29        Ok(())
30    }
31
32    /// Retrieve recent messages
33    async fn retrieve(&self, user_id: &str, agent_id: Option<&str>, limit: usize) -> Vec<Message>;
34
35    /// Search the memory for relevant content
36    async fn search(&self, user_id: &str, agent_id: Option<&str>, query: &str, limit: usize) -> crate::error::Result<Vec<crate::knowledge::rag::Document>> {
37        let _ = (user_id, agent_id, query, limit);
38        Ok(Vec::new())
39    }
40
41    /// Store a specific piece of knowledge (not just a message)
42    async fn store_knowledge(&self, user_id: &str, agent_id: Option<&str>, title: &str, content: &str, collection: &str) -> crate::error::Result<()> {
43        let _ = (user_id, agent_id, title, content, collection);
44        Ok(())
45    }
46
47    /// Clear memory for a user
48    async fn clear(&self, user_id: &str, agent_id: Option<&str>) -> crate::error::Result<()>;
49
50    /// Undo last message
51    async fn undo(&self, user_id: &str, agent_id: Option<&str>) -> crate::error::Result<Option<Message>>;
52
53    /// Update summary for a piece of knowledge
54    async fn update_summary(&self, collection: &str, path: &str, summary: &str) -> crate::error::Result<()> {
55        let _ = (collection, path, summary);
56        Ok(())
57    }
58
59    /// Link a scheduler for background tasks
60    fn link_scheduler(&self, _scheduler: Weak<Scheduler>) {}
61
62    /// Fetch a full document by path
63    async fn fetch_document(&self, collection: &str, path: &str) -> crate::error::Result<Option<crate::knowledge::rag::Document>> {
64        let _ = (collection, path);
65        Ok(None)
66    }
67
68    /// Store an agent session state
69    async fn store_session(&self, _session: crate::agent::session::AgentSession) -> crate::error::Result<()> {
70        Ok(())
71    }
72
73    /// Retrieve an agent session state
74    async fn retrieve_session(&self, _session_id: &str) -> crate::error::Result<Option<crate::agent::session::AgentSession>> {
75        Ok(None)
76    }
77}
78
79/// Short-term memory - stores recent conversation history
80/// Uses a fixed-size ring buffer per user for memory efficiency
81/// Persists to disk (JSON) to allow restarts without losing context.
82pub struct ShortTermMemory {
83    /// Max messages to keep per user
84    max_messages: usize,
85    /// Max active users/contexts to keep in memory (DoS protection)
86    max_users: usize,
87    /// Storage: composite_key -> message ring buffer
88    store: DashMap<String, VecDeque<Message>>,
89    /// Track last access time for cleanup
90    last_access: DashMap<String, std::time::Instant>,
91    /// Persistence path
92    path: PathBuf,
93}
94
95impl ShortTermMemory {
96    /// Create with custom capacity and persistence path
97    pub async fn new(max_messages: usize, max_users: usize, path: impl Into<PathBuf>) -> Self {
98        let path = path.into();
99        let store = DashMap::new();
100        let last_access = DashMap::new();
101        
102        let mem = Self {
103            max_messages,
104            max_users,
105            store,
106            last_access,
107            path,
108        };
109        
110        // Try to load existing state
111        if let Err(e) = mem.load().await {
112            tracing::warn!("Failed to load short-term memory from {:?}: {}", mem.path, e);
113        }
114        
115        mem
116    }
117
118    /// Create with default capacity (100 messages per user, 1000 active users)
119    pub async fn default_capacity() -> Self {
120        Self::new(100, 1000, "data/short_term_memory.json").await
121    }
122
123    /// Load state from disk
124    async fn load(&self) -> crate::error::Result<()> {
125        if !self.path.exists() {
126            return Ok(());
127        }
128        
129        let content = tokio::fs::read_to_string(&self.path).await
130            .map_err(|e| crate::error::Error::Internal(format!("Failed to read memory file: {}", e)))?;
131            
132        if content.trim().is_empty() {
133            return Ok(());
134        }
135
136        let data: HashMap<String, VecDeque<Message>> = serde_json::from_str(&content)
137            .map_err(|e| crate::error::Error::Internal(format!("Failed to parse memory file: {}", e)))?;
138            
139        self.store.clear();
140        for (k, v) in data {
141            self.store.insert(k.clone(), v);
142            self.last_access.insert(k, std::time::Instant::now());
143        }
144        
145        tracing::info!("Loaded short-term memory for {} users", self.store.len());
146        Ok(())
147    }
148
149    /// Save state to disk
150    async fn save(&self) -> crate::error::Result<()> {
151        if let Some(parent) = self.path.parent() {
152            tokio::fs::create_dir_all(parent).await.ok();
153        }
154        
155        // Convert DashMap to HashMap for serialization
156        let data: HashMap<_, _> = self.store.iter().map(|r| (r.key().clone(), r.value().clone())).collect();
157        
158        let json = serde_json::to_string_pretty(&data)
159             .map_err(|e| crate::error::Error::Internal(format!("Failed to serialize memory: {}", e)))?;
160             
161        // Atomic save: write to tmp then rename
162        let tmp_path = self.path.with_extension("tmp");
163        tokio::fs::write(&tmp_path, json).await
164             .map_err(|e| crate::error::Error::Internal(format!("Failed to write temporary memory file: {}", e)))?;
165             
166        tokio::fs::rename(tmp_path, &self.path).await
167             .map_err(|e| crate::error::Error::Internal(format!("Failed to rename memory file: {}", e)))?;
168             
169        Ok(())
170    }
171
172    /// Get current message count for a user/agent pair
173    pub fn message_count(&self, user_id: &str, agent_id: Option<&str>) -> usize {
174        let key = self.key(user_id, agent_id);
175        self.store.get(&key).map(|v| v.len()).unwrap_or(0)
176    }
177    
178    /// Generate composite key
179    fn key(&self, user_id: &str, agent_id: Option<&str>) -> String {
180        if let Some(agent) = agent_id {
181            format!("{}:{}", user_id, agent)
182        } else {
183            user_id.to_string()
184        }
185    }
186    
187    /// Prune inactive users (older than duration) - Useful for manual cleanup
188    pub fn prune_inactive(&self, duration: std::time::Duration) {
189        let now = std::time::Instant::now();
190        // DashMap retain is efficient
191        self.last_access.retain(|key, last_time| {
192            let keep = now.duration_since(*last_time) < duration;
193            if !keep {
194                self.store.remove(key);
195            }
196            keep
197        });
198    }
199
200    /// Check and enforce total user capacity (LRU eviction)
201    fn enforce_user_capacity(&self) {
202        if self.store.len() < self.max_users {
203            return;
204        }
205
206        let mut oldest_key = None;
207        let mut oldest_time = std::time::Instant::now();
208
209        for r in self.last_access.iter() {
210            if *r.value() < oldest_time {
211                oldest_time = *r.value();
212                oldest_key = Some(r.key().clone());
213            }
214        }
215
216        if let Some(key) = oldest_key {
217            self.store.remove(&key);
218            self.last_access.remove(&key);
219        }
220    }
221
222    /// Pop the oldest N messages for a user
223    pub async fn pop_oldest(&self, user_id: &str, agent_id: Option<&str>, count: usize) -> Vec<Message> {
224        let key = self.key(user_id, agent_id);
225        let mut popped = Vec::new();
226        
227        if let Some(mut entry) = self.store.get_mut(&key) {
228             for _ in 0..count {
229                 if let Some(msg) = entry.pop_front() {
230                     popped.push(msg);
231                 } else {
232                     break;
233                 }
234             }
235        }
236        
237        if !popped.is_empty() {
238             // Save change immediately
239             let _ = self.save().await;
240        }
241        
242        popped
243    }
244}
245
246#[async_trait]
247impl Memory for ShortTermMemory {
248    async fn store(&self, user_id: &str, agent_id: Option<&str>, message: Message) -> crate::error::Result<()> {
249        let key = self.key(user_id, agent_id);
250        
251        // Enforce capacity before inserting new user
252        if !self.store.contains_key(&key) {
253             self.enforce_user_capacity();
254        }
255
256        {
257            let mut entry = self.store.entry(key.clone()).or_default();
258
259            // Ring buffer behavior: remove oldest if at capacity
260            // NOTE: With Tiered Storage, MemoryManager should handle archiving BEFORE this limit is hit commonly.
261            // But as a safety net, we still keep the hard limit.
262            if entry.len() >= self.max_messages {
263                entry.pop_front();
264            }
265            entry.push_back(message);
266        } // Lock on DashMap bucket dropped here
267        
268        // Update access time
269        self.last_access.insert(key, std::time::Instant::now());
270        
271        // Save immediately for safety (Async I/O)
272        // With Tiered storage, this file stays small (KB), so atomic write is fast enough.
273        if let Err(e) = self.save().await {
274            tracing::error!("Failed to persist short-term memory: {}", e);
275        }
276
277        Ok(())
278    }
279
280
281
282    async fn retrieve(&self, user_id: &str, agent_id: Option<&str>, limit: usize) -> Vec<Message> {
283        let key = self.key(user_id, agent_id);
284        self.store
285            .get(&key)
286            .map(|v| {
287                // Update access time on retrieval too
288                self.last_access.insert(key, std::time::Instant::now());
289                
290                let skip = v.len().saturating_sub(limit);
291                v.iter().skip(skip).cloned().collect()
292            })
293            .unwrap_or_default()
294    }
295
296    async fn store_knowledge(&self, user_id: &str, agent_id: Option<&str>, title: &str, content: &str, collection: &str) -> crate::error::Result<()> {
297        let text = format!("[{}] {}: {}", collection, title, content);
298        self.store(user_id, agent_id, Message::assistant(text)).await
299    }
300
301    async fn clear(&self, user_id: &str, agent_id: Option<&str>) -> crate::error::Result<()> {
302        let key = self.key(user_id, agent_id);
303        self.store.remove(&key);
304        self.last_access.remove(&key);
305        
306        self.save().await
307    }
308
309    async fn undo(&self, user_id: &str, agent_id: Option<&str>) -> crate::error::Result<Option<Message>> {
310        let key = self.key(user_id, agent_id);
311        let msg = {
312            let mut entry = self.store.entry(key.clone()).or_default();
313            entry.pop_back()
314        };
315        
316        if msg.is_some() {
317            self.save().await?;
318        }
319        
320        Ok(msg)
321    }
322
323    async fn search(&self, user_id: &str, agent_id: Option<&str>, query: &str, limit: usize) -> crate::error::Result<Vec<crate::knowledge::rag::Document>> {
324        let query_lower = query.to_lowercase();
325        let messages = self.retrieve(user_id, agent_id, 1000).await; // Search through all STM for this user
326        
327        let mut results = Vec::new();
328        for (i, msg) in messages.iter().enumerate() {
329            let content = msg.text();
330            if content.to_lowercase().contains(&query_lower) {
331                results.push(crate::knowledge::rag::Document {
332                    id: format!("stm_{}_{}", self.key(user_id, agent_id), i),
333                    title: format!("Recent conversation ({})", msg.role.as_str()),
334                    content: content.to_string(),
335                    summary: None,
336                    collection: None,
337                    path: None,
338                    metadata: HashMap::new(),
339                    score: 0.9, // STM matches are highly relevant but given a fixed sub-1.0 score to prioritize exact LTM matches if needed
340                });
341            }
342            if results.len() >= limit {
343                break;
344            }
345        }
346        
347        Ok(results)
348    }
349}
350
351/// Simple in-memory storage for testing or fast ephemeral context
352pub struct InMemoryMemory {
353    store: DashMap<String, VecDeque<Message>>,
354}
355
356impl InMemoryMemory {
357    /// Create a new in-memory storage
358    pub fn new() -> Self {
359        Self { store: DashMap::new() }
360    }
361    
362    fn key(&self, user_id: &str, agent_id: Option<&str>) -> String {
363        if let Some(agent) = agent_id {
364            format!("{}:{}", user_id, agent)
365        } else {
366            user_id.to_string()
367        }
368    }
369}
370
371#[async_trait]
372impl Memory for InMemoryMemory {
373    async fn store(&self, user_id: &str, agent_id: Option<&str>, message: Message) -> crate::error::Result<()> {
374        let key = self.key(user_id, agent_id);
375        self.store.entry(key).or_default().push_back(message);
376        Ok(())
377    }
378
379    async fn retrieve(&self, user_id: &str, agent_id: Option<&str>, limit: usize) -> Vec<Message> {
380        let key = self.key(user_id, agent_id);
381        self.store.get(&key).map(|v| {
382            let skip = v.len().saturating_sub(limit);
383            v.iter().skip(skip).cloned().collect()
384        }).unwrap_or_default()
385    }
386
387    async fn clear(&self, user_id: &str, agent_id: Option<&str>) -> crate::error::Result<()> {
388        let key = self.key(user_id, agent_id);
389        self.store.remove(&key);
390        Ok(())
391    }
392
393    async fn undo(&self, user_id: &str, agent_id: Option<&str>) -> crate::error::Result<Option<Message>> {
394        let key = self.key(user_id, agent_id);
395        Ok(self.store.get_mut(&key).and_then(|mut v| v.pop_back()))
396    }
397}
398
399/// Combined memory manager for tiered storage
400pub struct MemoryManager {
401    /// Hot Storage Layer (e.g. In-memory or fast local cache)
402    pub hot_tier: Arc<dyn Memory>,
403    /// Cold Storage Layer (e.g. SQLite, Vector DB)
404    pub cold_tier: Arc<dyn Memory>,
405}
406
407impl MemoryManager {
408    /// Create a new MemoryManager with specific backends
409    pub fn new(hot_tier: Arc<dyn Memory>, cold_tier: Arc<dyn Memory>) -> Self {
410        Self { hot_tier, cold_tier }
411    }
412
413    /// Tiered Storage Store
414    /// Stores in Hot Tier, then auto-archives to Cold Tier if capacity exceeded
415    pub async fn store(&self, user_id: &str, agent_id: Option<&str>, message: Message) -> crate::error::Result<()> {
416        // 1. Write to Hot Storage - Fast
417        self.hot_tier.store(user_id, agent_id, message).await?;
418        
419        // 2. Archive older messages if needed
420        // Note: The specific logic for "when to archive" could be moved to a TieringPolicy
421        // For now, we use a simple heuristic if the Hot Tier supports counting.
422        // Since we are now using dyn Memory, we might need to add a 'count' method to the trait 
423        // if we want generic tiering logic here, or let the Hot Tier handle its own overflow.
424        
425        Ok(())
426    }
427    
428    /// Unified Retrieve
429    /// Fetches from Hot + Cold seamlessly
430    pub async fn retrieve_unified(&self, user_id: &str, agent_id: Option<&str>, limit: usize) -> Vec<Message> {
431        let mut messages = self.hot_tier.retrieve(user_id, agent_id, limit).await;
432        
433        if messages.len() < limit {
434             let needed = limit - messages.len();
435             let cold_messages = self.cold_tier.retrieve(user_id, agent_id, needed).await;
436             
437             let mut combined = cold_messages;
438             combined.extend(messages);
439             messages = combined;
440        }
441        
442        messages
443    }
444
445    /// Union Search - searches both Hot and Cold tiers
446    pub async fn search_unified(&self, user_id: &str, agent_id: Option<&str>, query: &str, limit: usize) -> crate::error::Result<Vec<crate::knowledge::rag::Document>> {
447        let hot_results = self.hot_tier.search(user_id, agent_id, query, limit).await?;
448        let cold_results = self.cold_tier.search(user_id, agent_id, query, limit).await?;
449        
450        let mut combined = hot_results;
451        for cold_res in cold_results {
452            if !combined.iter().any(|r| r.content == cold_res.content) {
453                combined.push(cold_res);
454            }
455        }
456        
457        combined.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
458        combined.truncate(limit);
459        
460        Ok(combined)
461    }
462
463    /// Undo last message
464    pub async fn undo(&self, user_id: &str, agent_id: Option<&str>) -> crate::error::Result<Option<Message>> {
465        let hot_msg = self.hot_tier.undo(user_id, agent_id).await?;
466        let _ = self.cold_tier.undo(user_id, agent_id).await?;
467        Ok(hot_msg)
468    }
469}
470
471#[async_trait]
472impl Memory for MemoryManager {
473    async fn store(&self, user_id: &str, agent_id: Option<&str>, message: Message) -> crate::error::Result<()> {
474        self.store(user_id, agent_id, message).await
475    }
476
477    async fn retrieve(&self, user_id: &str, agent_id: Option<&str>, limit: usize) -> Vec<Message> {
478        self.retrieve_unified(user_id, agent_id, limit).await
479    }
480
481    async fn search(&self, user_id: &str, agent_id: Option<&str>, query: &str, limit: usize) -> crate::error::Result<Vec<crate::knowledge::rag::Document>> {
482        self.search_unified(user_id, agent_id, query, limit).await
483    }
484
485    async fn store_knowledge(&self, user_id: &str, agent_id: Option<&str>, title: &str, content: &str, collection: &str) -> crate::error::Result<()> {
486        // Knowledge usually goes directly to Cold tier for permanence
487        self.cold_tier.store_knowledge(user_id, agent_id, title, content, collection).await
488    }
489
490    async fn clear(&self, user_id: &str, agent_id: Option<&str>) -> crate::error::Result<()> {
491        self.hot_tier.clear(user_id, agent_id).await?;
492        self.cold_tier.clear(user_id, agent_id).await?;
493        Ok(())
494    }
495
496    async fn undo(&self, user_id: &str, agent_id: Option<&str>) -> crate::error::Result<Option<Message>> {
497        self.undo(user_id, agent_id).await
498    }
499
500    async fn store_session(&self, session: crate::agent::session::AgentSession) -> crate::error::Result<()> {
501        self.cold_tier.store_session(session).await
502    }
503
504    async fn retrieve_session(&self, session_id: &str) -> crate::error::Result<Option<crate::agent::session::AgentSession>> {
505        self.cold_tier.retrieve_session(session_id).await
506    }
507
508    async fn fetch_document(&self, collection: &str, path: &str) -> crate::error::Result<Option<crate::knowledge::rag::Document>> {
509        // Cold tier usually holds the full documents
510        self.cold_tier.fetch_document(collection, path).await
511    }
512}
513
514#[cfg(test)]
515mod tests {
516    use super::*;
517
518    #[tokio::test]
519    async fn test_short_term_memory() {
520        let memory = ShortTermMemory::new(3, 10, "test_stm.json").await;
521
522        memory.store("user1", None, Message::user("Hello")).await.unwrap();
523        memory.store("user1", None, Message::assistant("Hi there")).await.unwrap();
524        memory.store("user1", None, Message::user("How are you?")).await.unwrap();
525        // This should evict "Hello"
526        memory.store("user1", None, Message::assistant("I'm good!")).await.unwrap(); 
527
528        let messages = memory.retrieve("user1", None, 10).await;
529        assert_eq!(messages.len(), 3);
530        assert_eq!(messages[0].text(), "Hi there");
531        
532        let _ = std::fs::remove_file("test_stm.json");
533    }
534}