Skip to main content

a3s_code_core/
memory.rs

1//! Memory and learning system for the agent.
2//!
3//! Core types (`MemoryStore`, `MemoryItem`, `MemoryType`, `RelevanceConfig`,
4//! `FileMemoryStore`, `InMemoryStore`) live in `a3s-memory`.
5//!
6//! This module owns `MemoryConfig`, `MemoryStats`, `AgentMemory` (three-tier
7//! session memory), and `MemoryContextProvider` (context injection bridge).
8
9use a3s_memory::{MemoryItem, MemoryStore, MemoryType, PrunePolicy, RelevanceConfig};
10use chrono::{DateTime, Utc};
11use serde::{Deserialize, Serialize};
12use std::collections::VecDeque;
13use std::sync::Arc;
14use tokio::sync::RwLock;
15
16// ============================================================================
17// Configuration
18// ============================================================================
19
20/// Configuration for the agent memory system (three-tier: working/short-term/long-term)
21#[derive(Debug, Clone, Serialize, Deserialize)]
22#[serde(rename_all = "camelCase")]
23pub struct MemoryConfig {
24    /// Relevance scoring parameters
25    #[serde(default)]
26    pub relevance: RelevanceConfig,
27    /// Maximum short-term memory items (default: 100)
28    #[serde(default = "MemoryConfig::default_max_short_term")]
29    pub max_short_term: usize,
30    /// Maximum working memory items (default: 10)
31    #[serde(default = "MemoryConfig::default_max_working")]
32    pub max_working: usize,
33    /// Automatic pruning policy for long-term storage. `None` disables background pruning.
34    #[serde(default)]
35    pub prune_policy: Option<PrunePolicy>,
36    /// How often the background pruning task runs, in seconds (default: 3600).
37    #[serde(default = "MemoryConfig::default_prune_interval_secs")]
38    pub prune_interval_secs: u64,
39}
40
41impl MemoryConfig {
42    fn default_max_short_term() -> usize {
43        100
44    }
45    fn default_max_working() -> usize {
46        10
47    }
48    fn default_prune_interval_secs() -> u64 {
49        3600
50    }
51}
52
53impl Default for MemoryConfig {
54    fn default() -> Self {
55        Self {
56            relevance: RelevanceConfig::default(),
57            max_short_term: 100,
58            max_working: 10,
59            prune_policy: None,
60            prune_interval_secs: 3600,
61        }
62    }
63}
64
65// ============================================================================
66// Memory Stats
67// ============================================================================
68
69/// Statistics for the three-tier agent memory system
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct MemoryStats {
72    pub long_term_count: usize,
73    pub short_term_count: usize,
74    pub working_count: usize,
75}
76
77// ============================================================================
78// Agent Memory (three-tier: working / short-term / long-term)
79// ============================================================================
80
81/// Three-tier agent memory: working, short-term (session), and long-term (persisted).
82#[derive(Clone)]
83pub struct AgentMemory {
84    /// Long-term memory store
85    pub(crate) store: Arc<dyn MemoryStore>,
86    /// Short-term memory (current session)
87    short_term: Arc<RwLock<VecDeque<MemoryItem>>>,
88    /// Working memory (active context)
89    working: Arc<RwLock<Vec<MemoryItem>>>,
90    pub(crate) max_short_term: usize,
91    pub(crate) max_working: usize,
92    pub(crate) relevance_config: RelevanceConfig,
93}
94
95impl std::fmt::Debug for AgentMemory {
96    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97        f.debug_struct("AgentMemory")
98            .field("max_short_term", &self.max_short_term)
99            .field("max_working", &self.max_working)
100            .finish()
101    }
102}
103
104impl AgentMemory {
105    /// Create a new agent memory system with default configuration
106    pub fn new(store: Arc<dyn MemoryStore>) -> Self {
107        Self::with_config(store, MemoryConfig::default())
108    }
109
110    /// Create a new agent memory system with custom configuration.
111    ///
112    /// If `config.prune_policy` is `Some`, a background Tokio task is spawned
113    /// that periodically calls `store.prune()` at the configured interval.
114    pub fn with_config(store: Arc<dyn MemoryStore>, config: MemoryConfig) -> Self {
115        if let Some(policy) = config.prune_policy.clone() {
116            let store_for_task = Arc::clone(&store);
117            let interval_secs = config.prune_interval_secs;
118            tokio::spawn(async move {
119                let mut ticker =
120                    tokio::time::interval(std::time::Duration::from_secs(interval_secs));
121                ticker.tick().await; // skip the immediate first tick
122                loop {
123                    ticker.tick().await;
124                    if let Err(e) = store_for_task.prune(&policy).await {
125                        tracing::warn!("memory prune failed: {e}");
126                    }
127                }
128            });
129        }
130
131        Self {
132            store,
133            short_term: Arc::new(RwLock::new(VecDeque::new())),
134            working: Arc::new(RwLock::new(Vec::new())),
135            max_short_term: config.max_short_term,
136            max_working: config.max_working,
137            relevance_config: config.relevance,
138        }
139    }
140
141    pub(crate) fn score(&self, item: &MemoryItem, now: DateTime<Utc>) -> f32 {
142        let age_days = (now - item.timestamp).num_seconds() as f32 / 86400.0;
143        let decay = (-age_days / self.relevance_config.decay_days).exp();
144        item.importance * self.relevance_config.importance_weight
145            + decay * self.relevance_config.recency_weight
146    }
147
148    /// Store a memory in long-term storage and add to short-term
149    pub async fn remember(&self, item: MemoryItem) -> anyhow::Result<()> {
150        self.store.store(item.clone()).await?;
151        let mut short_term = self.short_term.write().await;
152        short_term.push_back(item);
153        if short_term.len() > self.max_short_term {
154            short_term.pop_front();
155        }
156        Ok(())
157    }
158
159    /// Remember a successful pattern
160    pub async fn remember_success(
161        &self,
162        prompt: &str,
163        tools_used: &[String],
164        result: &str,
165    ) -> anyhow::Result<()> {
166        let content = format!(
167            "Success: {}\nTools: {}\nResult: {}",
168            prompt,
169            tools_used.join(", "),
170            result
171        );
172        let item = MemoryItem::new(content)
173            .with_importance(0.8)
174            .with_tag("success")
175            .with_tag("pattern")
176            .with_type(MemoryType::Procedural)
177            .with_metadata("prompt", prompt)
178            .with_metadata("tools", tools_used.join(","));
179        self.remember(item).await
180    }
181
182    /// Remember a failure to avoid repeating
183    pub async fn remember_failure(
184        &self,
185        prompt: &str,
186        error: &str,
187        attempted_tools: &[String],
188    ) -> anyhow::Result<()> {
189        let content = format!(
190            "Failure: {}\nError: {}\nAttempted tools: {}",
191            prompt,
192            error,
193            attempted_tools.join(", ")
194        );
195        let item = MemoryItem::new(content)
196            .with_importance(0.9)
197            .with_tag("failure")
198            .with_tag("avoid")
199            .with_type(MemoryType::Episodic)
200            .with_metadata("prompt", prompt)
201            .with_metadata("error", error);
202        self.remember(item).await
203    }
204
205    /// Recall similar past experiences
206    pub async fn recall_similar(
207        &self,
208        prompt: &str,
209        limit: usize,
210    ) -> anyhow::Result<Vec<MemoryItem>> {
211        self.store.search(prompt, limit).await
212    }
213
214    /// Recall by tags
215    pub async fn recall_by_tags(
216        &self,
217        tags: &[String],
218        limit: usize,
219    ) -> anyhow::Result<Vec<MemoryItem>> {
220        self.store.search_by_tags(tags, limit).await
221    }
222
223    /// Get recent memories
224    pub async fn get_recent(&self, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
225        self.store.get_recent(limit).await
226    }
227
228    /// Add to working memory (auto-trims by relevance if over capacity)
229    pub async fn add_to_working(&self, item: MemoryItem) -> anyhow::Result<()> {
230        let mut working = self.working.write().await;
231        working.push(item);
232        if working.len() > self.max_working {
233            let now = Utc::now();
234            working.sort_by(|a, b| {
235                self.score(b, now)
236                    .partial_cmp(&self.score(a, now))
237                    .unwrap_or(std::cmp::Ordering::Equal)
238            });
239            working.truncate(self.max_working);
240        }
241        Ok(())
242    }
243
244    /// Get working memory
245    pub async fn get_working(&self) -> Vec<MemoryItem> {
246        self.working.read().await.clone()
247    }
248
249    /// Clear working memory
250    pub async fn clear_working(&self) {
251        self.working.write().await.clear();
252    }
253
254    /// Get short-term memory
255    pub async fn get_short_term(&self) -> Vec<MemoryItem> {
256        self.short_term.read().await.iter().cloned().collect()
257    }
258
259    /// Clear short-term memory
260    pub async fn clear_short_term(&self) {
261        self.short_term.write().await.clear();
262    }
263
264    /// Get memory statistics
265    pub async fn stats(&self) -> anyhow::Result<MemoryStats> {
266        Ok(MemoryStats {
267            long_term_count: self.store.count().await?,
268            short_term_count: self.short_term.read().await.len(),
269            working_count: self.working.read().await.len(),
270        })
271    }
272
273    /// Get access to the underlying store
274    pub fn store(&self) -> &Arc<dyn MemoryStore> {
275        &self.store
276    }
277
278    /// Get working memory count
279    pub async fn working_count(&self) -> usize {
280        self.working.read().await.len()
281    }
282
283    /// Get short-term memory count
284    pub async fn short_term_count(&self) -> usize {
285        self.short_term.read().await.len()
286    }
287}
288
289// ============================================================================
290// Memory Context Provider
291// ============================================================================
292
293/// Context provider that surfaces past memories as agent context.
294pub struct MemoryContextProvider {
295    memory: AgentMemory,
296}
297
298impl MemoryContextProvider {
299    pub fn new(memory: AgentMemory) -> Self {
300        Self { memory }
301    }
302}
303
304#[async_trait::async_trait]
305impl crate::context::ContextProvider for MemoryContextProvider {
306    fn name(&self) -> &str {
307        "memory"
308    }
309
310    async fn query(
311        &self,
312        query: &crate::context::ContextQuery,
313    ) -> anyhow::Result<crate::context::ContextResult> {
314        let limit = query.max_results.min(5);
315        let items = self.memory.recall_similar(&query.query, limit).await?;
316
317        let mut result = crate::context::ContextResult::new("memory");
318        for item in items {
319            let relevance = item.relevance_score();
320            let token_count = item.content.len() / 4;
321            let context_item = crate::context::ContextItem::new(
322                &item.id,
323                crate::context::ContextType::Memory,
324                &item.content,
325            )
326            .with_relevance(relevance)
327            .with_token_count(token_count)
328            .with_source("memory");
329            result.add_item(context_item);
330        }
331        Ok(result)
332    }
333
334    async fn on_turn_complete(
335        &self,
336        _session_id: &str,
337        prompt: &str,
338        response: &str,
339    ) -> anyhow::Result<()> {
340        self.memory.remember_success(prompt, &[], response).await
341    }
342}
343
344// ============================================================================
345// Tests
346// ============================================================================
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351    use a3s_memory::InMemoryStore;
352    use std::sync::Arc;
353
354    #[tokio::test]
355    async fn test_agent_memory_remember_and_recall() {
356        let memory = AgentMemory::new(Arc::new(InMemoryStore::new()));
357        memory
358            .remember_success("create file", &["write".to_string()], "ok")
359            .await
360            .unwrap();
361        memory
362            .remember_failure("delete file", "denied", &["bash".to_string()])
363            .await
364            .unwrap();
365
366        let results = memory.recall_similar("create", 10).await.unwrap();
367        assert!(!results.is_empty());
368
369        let stats = memory.stats().await.unwrap();
370        assert_eq!(stats.long_term_count, 2);
371        assert_eq!(stats.short_term_count, 2);
372    }
373
374    #[tokio::test]
375    async fn test_agent_memory_working() {
376        let memory = AgentMemory::new(Arc::new(InMemoryStore::new()));
377        memory
378            .add_to_working(MemoryItem::new("task").with_type(MemoryType::Working))
379            .await
380            .unwrap();
381        assert_eq!(memory.working_count().await, 1);
382        memory.clear_working().await;
383        assert_eq!(memory.working_count().await, 0);
384    }
385
386    #[tokio::test]
387    async fn test_agent_memory_working_overflow_trims() {
388        let memory = AgentMemory {
389            store: Arc::new(InMemoryStore::new()),
390            short_term: Arc::new(RwLock::new(VecDeque::new())),
391            working: Arc::new(RwLock::new(Vec::new())),
392            max_short_term: 100,
393            max_working: 3,
394            relevance_config: RelevanceConfig::default(),
395        };
396        for i in 0..5 {
397            memory
398                .add_to_working(
399                    MemoryItem::new(format!("task {i}")).with_importance(i as f32 * 0.2),
400                )
401                .await
402                .unwrap();
403        }
404        assert_eq!(memory.get_working().await.len(), 3);
405    }
406
407    #[tokio::test]
408    async fn test_agent_memory_recall_by_tags() {
409        let memory = AgentMemory::new(Arc::new(InMemoryStore::new()));
410        memory
411            .remember_success("create file", &["write".to_string()], "ok")
412            .await
413            .unwrap();
414        memory
415            .remember_failure("delete file", "denied", &["bash".to_string()])
416            .await
417            .unwrap();
418
419        let successes = memory
420            .recall_by_tags(&["success".to_string()], 10)
421            .await
422            .unwrap();
423        assert_eq!(successes.len(), 1);
424        let failures = memory
425            .recall_by_tags(&["failure".to_string()], 10)
426            .await
427            .unwrap();
428        assert_eq!(failures.len(), 1);
429    }
430
431    #[tokio::test]
432    async fn test_agent_memory_short_term_trim() {
433        let store = Arc::new(InMemoryStore::new());
434        let memory = AgentMemory {
435            store,
436            short_term: Arc::new(RwLock::new(VecDeque::new())),
437            working: Arc::new(RwLock::new(Vec::new())),
438            max_short_term: 3,
439            max_working: 10,
440            relevance_config: RelevanceConfig::default(),
441        };
442        for i in 0..5 {
443            memory
444                .remember(MemoryItem::new(format!("item {i}")))
445                .await
446                .unwrap();
447        }
448        assert_eq!(memory.short_term_count().await, 3);
449    }
450
451    #[tokio::test]
452    async fn test_agent_memory_prune_delegates() {
453        use a3s_memory::PrunePolicy;
454
455        let store = Arc::new(InMemoryStore::new());
456        let memory = AgentMemory::new(store.clone());
457
458        // Insert one old low-importance item directly into the store.
459        let mut old_item = a3s_memory::MemoryItem::new("stale").with_importance(0.2);
460        old_item.timestamp = chrono::Utc::now() - chrono::Duration::days(100);
461        store.store(old_item).await.unwrap();
462
463        assert_eq!(store.count().await.unwrap(), 1);
464
465        // Calling prune on the underlying store via the public accessor works.
466        let policy = PrunePolicy {
467            max_age_days: 90,
468            min_importance_to_keep: 0.5,
469            max_items: 0,
470        };
471        let deleted = memory.store().prune(&policy).await.unwrap();
472        assert_eq!(deleted, 1);
473        assert_eq!(store.count().await.unwrap(), 0);
474    }
475
476    #[test]
477    fn test_agent_memory_score_uses_config() {
478        let config = MemoryConfig {
479            relevance: RelevanceConfig {
480                decay_days: 7.0,
481                importance_weight: 0.9,
482                recency_weight: 0.1,
483            },
484            ..Default::default()
485        };
486        let memory = AgentMemory::with_config(Arc::new(InMemoryStore::new()), config);
487        let item = MemoryItem::new("Test").with_importance(1.0);
488        let score = memory.score(&item, Utc::now());
489        assert!(score > 0.95, "Score was {score}");
490    }
491}