Skip to main content

a3s_code_core/
memory.rs

1//! Memory and learning system for the agent.
2//!
3//! Core types (`MemoryStore`, `MemoryItem`, `MemoryType`, `RelevanceConfig`,
4//! `FileMemoryStore`, `InMemoryStore`) live in `a3s-memory`.
5//!
6//! This module owns `MemoryConfig`, `MemoryStats`, `AgentMemory` (three-tier
7//! session memory), and `MemoryContextProvider` (context injection bridge).
8
9use a3s_memory::{MemoryItem, MemoryStore, MemoryType, PrunePolicy, RelevanceConfig};
10use chrono::{DateTime, Utc};
11use serde::{Deserialize, Serialize};
12use std::collections::VecDeque;
13use std::sync::Arc;
14use tokio::sync::RwLock;
15
16// ============================================================================
17// Configuration
18// ============================================================================
19
20/// Configuration for the agent memory system (three-tier: working/short-term/long-term)
21#[derive(Debug, Clone, Serialize, Deserialize)]
22#[serde(rename_all = "camelCase")]
23pub struct MemoryConfig {
24    /// Relevance scoring parameters
25    #[serde(default)]
26    pub relevance: RelevanceConfig,
27    /// Maximum short-term memory items (default: 100)
28    #[serde(default = "MemoryConfig::default_max_short_term")]
29    pub max_short_term: usize,
30    /// Maximum working memory items (default: 10)
31    #[serde(default = "MemoryConfig::default_max_working")]
32    pub max_working: usize,
33    /// Automatic pruning policy for long-term storage. `None` disables background pruning.
34    #[serde(default)]
35    pub prune_policy: Option<PrunePolicy>,
36    /// How often the background pruning task runs, in seconds (default: 3600).
37    #[serde(default = "MemoryConfig::default_prune_interval_secs")]
38    pub prune_interval_secs: u64,
39}
40
41impl MemoryConfig {
42    fn default_max_short_term() -> usize {
43        100
44    }
45    fn default_max_working() -> usize {
46        10
47    }
48    fn default_prune_interval_secs() -> u64 {
49        3600
50    }
51}
52
53impl Default for MemoryConfig {
54    fn default() -> Self {
55        Self {
56            relevance: RelevanceConfig::default(),
57            max_short_term: 100,
58            max_working: 10,
59            prune_policy: None,
60            prune_interval_secs: 3600,
61        }
62    }
63}
64
65// ============================================================================
66// Memory Stats
67// ============================================================================
68
69/// Statistics for the three-tier agent memory system
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct MemoryStats {
72    pub long_term_count: usize,
73    pub short_term_count: usize,
74    pub working_count: usize,
75}
76
77// ============================================================================
78// Agent Memory (three-tier: working / short-term / long-term)
79// ============================================================================
80
81/// Three-tier agent memory: working, short-term (session), and long-term (persisted).
82#[derive(Clone)]
83pub struct AgentMemory {
84    /// Long-term memory store
85    pub(crate) store: Arc<dyn MemoryStore>,
86    /// Short-term memory (current session)
87    short_term: Arc<RwLock<VecDeque<MemoryItem>>>,
88    /// Working memory (active context)
89    working: Arc<RwLock<Vec<MemoryItem>>>,
90    pub(crate) max_short_term: usize,
91    pub(crate) max_working: usize,
92    pub(crate) relevance_config: RelevanceConfig,
93}
94
95impl std::fmt::Debug for AgentMemory {
96    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97        f.debug_struct("AgentMemory")
98            .field("max_short_term", &self.max_short_term)
99            .field("max_working", &self.max_working)
100            .finish()
101    }
102}
103
104impl AgentMemory {
105    /// Create a new agent memory system with default configuration
106    pub fn new(store: Arc<dyn MemoryStore>) -> Self {
107        Self::with_config(store, MemoryConfig::default())
108    }
109
110    /// Create a new agent memory system with custom configuration.
111    ///
112    /// If `config.prune_policy` is `Some`, a background Tokio task is spawned
113    /// that periodically calls `store.prune()` at the configured interval.
114    pub fn with_config(store: Arc<dyn MemoryStore>, config: MemoryConfig) -> Self {
115        if let Some(policy) = config.prune_policy.clone() {
116            let store_for_task = Arc::clone(&store);
117            let interval_secs = config.prune_interval_secs;
118            tokio::spawn(async move {
119                let mut ticker =
120                    tokio::time::interval(std::time::Duration::from_secs(interval_secs));
121                ticker.tick().await; // skip the immediate first tick
122                loop {
123                    ticker.tick().await;
124                    if let Err(e) = store_for_task.prune(&policy).await {
125                        tracing::warn!("memory prune failed: {e}");
126                    }
127                }
128            });
129        }
130
131        Self {
132            store,
133            short_term: Arc::new(RwLock::new(VecDeque::new())),
134            working: Arc::new(RwLock::new(Vec::new())),
135            max_short_term: config.max_short_term,
136            max_working: config.max_working,
137            relevance_config: config.relevance,
138        }
139    }
140
141    pub(crate) fn score(&self, item: &MemoryItem, now: DateTime<Utc>) -> f32 {
142        let age_days = (now - item.timestamp).num_seconds() as f32 / 86400.0;
143        let decay = (-age_days / self.relevance_config.decay_days).exp();
144        item.importance * self.relevance_config.importance_weight
145            + decay * self.relevance_config.recency_weight
146    }
147
148    /// Store a memory in long-term storage and add to short-term
149    pub async fn remember(&self, item: MemoryItem) -> anyhow::Result<()> {
150        self.store.store(item.clone()).await?;
151        let mut short_term = self.short_term.write().await;
152        short_term.push_back(item);
153        if short_term.len() > self.max_short_term {
154            short_term.pop_front();
155        }
156        Ok(())
157    }
158
159    /// Remember a successful pattern
160    pub async fn remember_success(
161        &self,
162        prompt: &str,
163        tools_used: &[String],
164        result: &str,
165    ) -> anyhow::Result<()> {
166        let content = format!(
167            "Success: {}\nTools: {}\nResult: {}",
168            prompt,
169            tools_used.join(", "),
170            result
171        );
172        let item = MemoryItem::new(content)
173            .with_importance(0.8)
174            .with_tag("success")
175            .with_tag("pattern")
176            .with_type(MemoryType::Procedural)
177            .with_metadata("prompt", prompt)
178            .with_metadata("tools", tools_used.join(","));
179        self.remember(item).await
180    }
181
182    /// Remember a failure to avoid repeating
183    pub async fn remember_failure(
184        &self,
185        prompt: &str,
186        error: &str,
187        attempted_tools: &[String],
188    ) -> anyhow::Result<()> {
189        let content = format!(
190            "Failure: {}\nError: {}\nAttempted tools: {}",
191            prompt,
192            error,
193            attempted_tools.join(", ")
194        );
195        let item = MemoryItem::new(content)
196            .with_importance(0.9)
197            .with_tag("failure")
198            .with_tag("avoid")
199            .with_type(MemoryType::Episodic)
200            .with_metadata("prompt", prompt)
201            .with_metadata("error", error);
202        self.remember(item).await
203    }
204
205    /// Recall similar past experiences
206    pub async fn recall_similar(
207        &self,
208        prompt: &str,
209        limit: usize,
210    ) -> anyhow::Result<Vec<MemoryItem>> {
211        self.store.search(prompt, limit).await
212    }
213
214    /// Recall by tags
215    pub async fn recall_by_tags(
216        &self,
217        tags: &[String],
218        limit: usize,
219    ) -> anyhow::Result<Vec<MemoryItem>> {
220        self.store.search_by_tags(tags, limit).await
221    }
222
223    /// Get recent memories
224    pub async fn get_recent(&self, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
225        self.store.get_recent(limit).await
226    }
227
228    /// Add to working memory (auto-trims by relevance if over capacity)
229    pub async fn add_to_working(&self, item: MemoryItem) -> anyhow::Result<()> {
230        let mut working = self.working.write().await;
231        working.push(item);
232        if working.len() > self.max_working {
233            let now = Utc::now();
234            working.sort_by(|a, b| {
235                self.score(b, now)
236                    .partial_cmp(&self.score(a, now))
237                    .unwrap_or(std::cmp::Ordering::Equal)
238            });
239            working.truncate(self.max_working);
240        }
241        Ok(())
242    }
243
244    /// Get working memory
245    pub async fn get_working(&self) -> Vec<MemoryItem> {
246        self.working.read().await.clone()
247    }
248
249    /// Clear working memory
250    pub async fn clear_working(&self) {
251        self.working.write().await.clear();
252    }
253
254    /// Get short-term memory
255    pub async fn get_short_term(&self) -> Vec<MemoryItem> {
256        self.short_term.read().await.iter().cloned().collect()
257    }
258
259    /// Clear short-term memory
260    pub async fn clear_short_term(&self) {
261        self.short_term.write().await.clear();
262    }
263
264    /// Get memory statistics
265    pub async fn stats(&self) -> anyhow::Result<MemoryStats> {
266        Ok(MemoryStats {
267            long_term_count: self.store.count().await?,
268            short_term_count: self.short_term.read().await.len(),
269            working_count: self.working.read().await.len(),
270        })
271    }
272
273    /// Get access to the underlying store
274    pub fn store(&self) -> &Arc<dyn MemoryStore> {
275        &self.store
276    }
277
278    /// Get working memory count
279    pub async fn working_count(&self) -> usize {
280        self.working.read().await.len()
281    }
282
283    /// Get short-term memory count
284    pub async fn short_term_count(&self) -> usize {
285        self.short_term.read().await.len()
286    }
287}
288
289// ============================================================================
290// Memory Context Provider
291// ============================================================================
292
293/// Context provider that surfaces past memories as agent context.
294pub struct MemoryContextProvider {
295    memory: AgentMemory,
296}
297
298impl MemoryContextProvider {
299    pub fn new(memory: AgentMemory) -> Self {
300        Self { memory }
301    }
302}
303
304pub(crate) fn memory_items_to_context_result(
305    provider: impl Into<String>,
306    items: Vec<MemoryItem>,
307) -> crate::context::ContextResult {
308    let mut result = crate::context::ContextResult::new(provider);
309    for item in items {
310        let token_count = (item.content.len() / 4).max(1);
311        let context_item = crate::context::ContextItem::new(
312            &item.id,
313            crate::context::ContextType::Memory,
314            &item.content,
315        )
316        .with_relevance(item.relevance_score())
317        .with_token_count(token_count)
318        .with_source(format!("memory://{}", item.id))
319        .with_provenance("long_term_memory")
320        .with_priority(0.35)
321        .with_trust(0.7)
322        .with_freshness(0.5);
323        result.add_item(context_item);
324    }
325    result
326}
327
328#[async_trait::async_trait]
329impl crate::context::ContextProvider for MemoryContextProvider {
330    fn name(&self) -> &str {
331        "memory"
332    }
333
334    async fn query(
335        &self,
336        query: &crate::context::ContextQuery,
337    ) -> anyhow::Result<crate::context::ContextResult> {
338        let limit = query.max_results.min(5);
339        let items = self.memory.recall_similar(&query.query, limit).await?;
340
341        Ok(memory_items_to_context_result("memory", items))
342    }
343
344    async fn on_turn_complete(
345        &self,
346        _session_id: &str,
347        prompt: &str,
348        response: &str,
349    ) -> anyhow::Result<()> {
350        self.memory.remember_success(prompt, &[], response).await
351    }
352}
353
354// ============================================================================
355// Tests
356// ============================================================================
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361    use a3s_memory::InMemoryStore;
362    use std::sync::Arc;
363
364    #[tokio::test]
365    async fn test_agent_memory_remember_and_recall() {
366        let memory = AgentMemory::new(Arc::new(InMemoryStore::new()));
367        memory
368            .remember_success("create file", &["write".to_string()], "ok")
369            .await
370            .unwrap();
371        memory
372            .remember_failure("delete file", "denied", &["bash".to_string()])
373            .await
374            .unwrap();
375
376        let results = memory.recall_similar("create", 10).await.unwrap();
377        assert!(!results.is_empty());
378
379        let stats = memory.stats().await.unwrap();
380        assert_eq!(stats.long_term_count, 2);
381        assert_eq!(stats.short_term_count, 2);
382    }
383
384    #[tokio::test]
385    async fn test_agent_memory_working() {
386        let memory = AgentMemory::new(Arc::new(InMemoryStore::new()));
387        memory
388            .add_to_working(MemoryItem::new("task").with_type(MemoryType::Working))
389            .await
390            .unwrap();
391        assert_eq!(memory.working_count().await, 1);
392        memory.clear_working().await;
393        assert_eq!(memory.working_count().await, 0);
394    }
395
396    #[tokio::test]
397    async fn test_agent_memory_working_overflow_trims() {
398        let memory = AgentMemory {
399            store: Arc::new(InMemoryStore::new()),
400            short_term: Arc::new(RwLock::new(VecDeque::new())),
401            working: Arc::new(RwLock::new(Vec::new())),
402            max_short_term: 100,
403            max_working: 3,
404            relevance_config: RelevanceConfig::default(),
405        };
406        for i in 0..5 {
407            memory
408                .add_to_working(
409                    MemoryItem::new(format!("task {i}")).with_importance(i as f32 * 0.2),
410                )
411                .await
412                .unwrap();
413        }
414        assert_eq!(memory.get_working().await.len(), 3);
415    }
416
417    #[tokio::test]
418    async fn test_agent_memory_recall_by_tags() {
419        let memory = AgentMemory::new(Arc::new(InMemoryStore::new()));
420        memory
421            .remember_success("create file", &["write".to_string()], "ok")
422            .await
423            .unwrap();
424        memory
425            .remember_failure("delete file", "denied", &["bash".to_string()])
426            .await
427            .unwrap();
428
429        let successes = memory
430            .recall_by_tags(&["success".to_string()], 10)
431            .await
432            .unwrap();
433        assert_eq!(successes.len(), 1);
434        let failures = memory
435            .recall_by_tags(&["failure".to_string()], 10)
436            .await
437            .unwrap();
438        assert_eq!(failures.len(), 1);
439    }
440
441    #[tokio::test]
442    async fn test_agent_memory_short_term_trim() {
443        let store = Arc::new(InMemoryStore::new());
444        let memory = AgentMemory {
445            store,
446            short_term: Arc::new(RwLock::new(VecDeque::new())),
447            working: Arc::new(RwLock::new(Vec::new())),
448            max_short_term: 3,
449            max_working: 10,
450            relevance_config: RelevanceConfig::default(),
451        };
452        for i in 0..5 {
453            memory
454                .remember(MemoryItem::new(format!("item {i}")))
455                .await
456                .unwrap();
457        }
458        assert_eq!(memory.short_term_count().await, 3);
459    }
460
461    #[tokio::test]
462    async fn test_agent_memory_prune_delegates() {
463        use a3s_memory::PrunePolicy;
464
465        let store = Arc::new(InMemoryStore::new());
466        let memory = AgentMemory::new(store.clone());
467
468        // Insert one old low-importance item directly into the store.
469        let mut old_item = a3s_memory::MemoryItem::new("stale").with_importance(0.2);
470        old_item.timestamp = chrono::Utc::now() - chrono::Duration::days(100);
471        store.store(old_item).await.unwrap();
472
473        assert_eq!(store.count().await.unwrap(), 1);
474
475        // Calling prune on the underlying store via the public accessor works.
476        let policy = PrunePolicy {
477            max_age_days: 90,
478            min_importance_to_keep: 0.5,
479            max_items: 0,
480        };
481        let deleted = memory.store().prune(&policy).await.unwrap();
482        assert_eq!(deleted, 1);
483        assert_eq!(store.count().await.unwrap(), 0);
484    }
485
486    #[test]
487    fn test_agent_memory_score_uses_config() {
488        let config = MemoryConfig {
489            relevance: RelevanceConfig {
490                decay_days: 7.0,
491                importance_weight: 0.9,
492                recency_weight: 0.1,
493            },
494            ..Default::default()
495        };
496        let memory = AgentMemory::with_config(Arc::new(InMemoryStore::new()), config);
497        let item = MemoryItem::new("Test").with_importance(1.0);
498        let score = memory.score(&item, Utc::now());
499        assert!(score > 0.95, "Score was {score}");
500    }
501}