Skip to main content

a3s_code_core/
memory.rs

1//! Memory and learning system for the agent.
2//!
3//! Core types (`MemoryStore`, `MemoryItem`, `MemoryType`, `RelevanceConfig`,
4//! `FileMemoryStore`, `InMemoryStore`) live in `a3s-memory`.
5//!
6//! This module owns `MemoryConfig`, `MemoryStats`, `AgentMemory` (three-tier
7//! session memory), and `MemoryContextProvider` (context injection bridge).
8
9use a3s_memory::{MemoryItem, MemoryStore, MemoryType, RelevanceConfig};
10use chrono::{DateTime, Utc};
11use serde::{Deserialize, Serialize};
12use std::collections::VecDeque;
13use std::sync::Arc;
14use tokio::sync::RwLock;
15
16// ============================================================================
17// Configuration
18// ============================================================================
19
20/// Configuration for the agent memory system (three-tier: working/short-term/long-term)
21#[derive(Debug, Clone, Serialize, Deserialize)]
22#[serde(rename_all = "camelCase")]
23pub struct MemoryConfig {
24    /// Relevance scoring parameters
25    #[serde(default)]
26    pub relevance: RelevanceConfig,
27    /// Maximum short-term memory items (default: 100)
28    #[serde(default = "MemoryConfig::default_max_short_term")]
29    pub max_short_term: usize,
30    /// Maximum working memory items (default: 10)
31    #[serde(default = "MemoryConfig::default_max_working")]
32    pub max_working: usize,
33}
34
35impl MemoryConfig {
36    fn default_max_short_term() -> usize {
37        100
38    }
39    fn default_max_working() -> usize {
40        10
41    }
42}
43
44impl Default for MemoryConfig {
45    fn default() -> Self {
46        Self {
47            relevance: RelevanceConfig::default(),
48            max_short_term: 100,
49            max_working: 10,
50        }
51    }
52}
53
54// ============================================================================
55// Memory Stats
56// ============================================================================
57
58/// Statistics for the three-tier agent memory system
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct MemoryStats {
61    pub long_term_count: usize,
62    pub short_term_count: usize,
63    pub working_count: usize,
64}
65
66// ============================================================================
67// Agent Memory (three-tier: working / short-term / long-term)
68// ============================================================================
69
70/// Three-tier agent memory: working, short-term (session), and long-term (persisted).
71#[derive(Clone)]
72pub struct AgentMemory {
73    /// Long-term memory store
74    pub(crate) store: Arc<dyn MemoryStore>,
75    /// Short-term memory (current session)
76    short_term: Arc<RwLock<VecDeque<MemoryItem>>>,
77    /// Working memory (active context)
78    working: Arc<RwLock<Vec<MemoryItem>>>,
79    pub(crate) max_short_term: usize,
80    pub(crate) max_working: usize,
81    pub(crate) relevance_config: RelevanceConfig,
82}
83
84impl std::fmt::Debug for AgentMemory {
85    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86        f.debug_struct("AgentMemory")
87            .field("max_short_term", &self.max_short_term)
88            .field("max_working", &self.max_working)
89            .finish()
90    }
91}
92
93impl AgentMemory {
94    /// Create a new agent memory system with default configuration
95    pub fn new(store: Arc<dyn MemoryStore>) -> Self {
96        Self::with_config(store, MemoryConfig::default())
97    }
98
99    /// Create a new agent memory system with custom configuration
100    pub fn with_config(store: Arc<dyn MemoryStore>, config: MemoryConfig) -> Self {
101        Self {
102            store,
103            short_term: Arc::new(RwLock::new(VecDeque::new())),
104            working: Arc::new(RwLock::new(Vec::new())),
105            max_short_term: config.max_short_term,
106            max_working: config.max_working,
107            relevance_config: config.relevance,
108        }
109    }
110
111    pub(crate) fn score(&self, item: &MemoryItem, now: DateTime<Utc>) -> f32 {
112        let age_days = (now - item.timestamp).num_seconds() as f32 / 86400.0;
113        let decay = (-age_days / self.relevance_config.decay_days).exp();
114        item.importance * self.relevance_config.importance_weight
115            + decay * self.relevance_config.recency_weight
116    }
117
118    /// Store a memory in long-term storage and add to short-term
119    pub async fn remember(&self, item: MemoryItem) -> anyhow::Result<()> {
120        self.store.store(item.clone()).await?;
121        let mut short_term = self.short_term.write().await;
122        short_term.push_back(item);
123        if short_term.len() > self.max_short_term {
124            short_term.pop_front();
125        }
126        Ok(())
127    }
128
129    /// Remember a successful pattern
130    pub async fn remember_success(
131        &self,
132        prompt: &str,
133        tools_used: &[String],
134        result: &str,
135    ) -> anyhow::Result<()> {
136        let content = format!(
137            "Success: {}\nTools: {}\nResult: {}",
138            prompt,
139            tools_used.join(", "),
140            result
141        );
142        let item = MemoryItem::new(content)
143            .with_importance(0.8)
144            .with_tag("success")
145            .with_tag("pattern")
146            .with_type(MemoryType::Procedural)
147            .with_metadata("prompt", prompt)
148            .with_metadata("tools", tools_used.join(","));
149        self.remember(item).await
150    }
151
152    /// Remember a failure to avoid repeating
153    pub async fn remember_failure(
154        &self,
155        prompt: &str,
156        error: &str,
157        attempted_tools: &[String],
158    ) -> anyhow::Result<()> {
159        let content = format!(
160            "Failure: {}\nError: {}\nAttempted tools: {}",
161            prompt,
162            error,
163            attempted_tools.join(", ")
164        );
165        let item = MemoryItem::new(content)
166            .with_importance(0.9)
167            .with_tag("failure")
168            .with_tag("avoid")
169            .with_type(MemoryType::Episodic)
170            .with_metadata("prompt", prompt)
171            .with_metadata("error", error);
172        self.remember(item).await
173    }
174
175    /// Recall similar past experiences
176    pub async fn recall_similar(
177        &self,
178        prompt: &str,
179        limit: usize,
180    ) -> anyhow::Result<Vec<MemoryItem>> {
181        self.store.search(prompt, limit).await
182    }
183
184    /// Recall by tags
185    pub async fn recall_by_tags(
186        &self,
187        tags: &[String],
188        limit: usize,
189    ) -> anyhow::Result<Vec<MemoryItem>> {
190        self.store.search_by_tags(tags, limit).await
191    }
192
193    /// Get recent memories
194    pub async fn get_recent(&self, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
195        self.store.get_recent(limit).await
196    }
197
198    /// Add to working memory (auto-trims by relevance if over capacity)
199    pub async fn add_to_working(&self, item: MemoryItem) -> anyhow::Result<()> {
200        let mut working = self.working.write().await;
201        working.push(item);
202        if working.len() > self.max_working {
203            let now = Utc::now();
204            working.sort_by(|a, b| {
205                self.score(b, now)
206                    .partial_cmp(&self.score(a, now))
207                    .unwrap_or(std::cmp::Ordering::Equal)
208            });
209            working.truncate(self.max_working);
210        }
211        Ok(())
212    }
213
214    /// Get working memory
215    pub async fn get_working(&self) -> Vec<MemoryItem> {
216        self.working.read().await.clone()
217    }
218
219    /// Clear working memory
220    pub async fn clear_working(&self) {
221        self.working.write().await.clear();
222    }
223
224    /// Get short-term memory
225    pub async fn get_short_term(&self) -> Vec<MemoryItem> {
226        self.short_term.read().await.iter().cloned().collect()
227    }
228
229    /// Clear short-term memory
230    pub async fn clear_short_term(&self) {
231        self.short_term.write().await.clear();
232    }
233
234    /// Get memory statistics
235    pub async fn stats(&self) -> anyhow::Result<MemoryStats> {
236        Ok(MemoryStats {
237            long_term_count: self.store.count().await?,
238            short_term_count: self.short_term.read().await.len(),
239            working_count: self.working.read().await.len(),
240        })
241    }
242
243    /// Get access to the underlying store
244    pub fn store(&self) -> &Arc<dyn MemoryStore> {
245        &self.store
246    }
247
248    /// Get working memory count
249    pub async fn working_count(&self) -> usize {
250        self.working.read().await.len()
251    }
252
253    /// Get short-term memory count
254    pub async fn short_term_count(&self) -> usize {
255        self.short_term.read().await.len()
256    }
257}
258
259// ============================================================================
260// Memory Context Provider
261// ============================================================================
262
263/// Context provider that surfaces past memories as agent context.
264pub struct MemoryContextProvider {
265    memory: AgentMemory,
266}
267
268impl MemoryContextProvider {
269    pub fn new(memory: AgentMemory) -> Self {
270        Self { memory }
271    }
272}
273
274#[async_trait::async_trait]
275impl crate::context::ContextProvider for MemoryContextProvider {
276    fn name(&self) -> &str {
277        "memory"
278    }
279
280    async fn query(
281        &self,
282        query: &crate::context::ContextQuery,
283    ) -> anyhow::Result<crate::context::ContextResult> {
284        let limit = query.max_results.min(5);
285        let items = self.memory.recall_similar(&query.query, limit).await?;
286
287        let mut result = crate::context::ContextResult::new("memory");
288        for item in items {
289            let relevance = item.relevance_score();
290            let token_count = item.content.len() / 4;
291            let context_item = crate::context::ContextItem::new(
292                &item.id,
293                crate::context::ContextType::Memory,
294                &item.content,
295            )
296            .with_relevance(relevance)
297            .with_token_count(token_count)
298            .with_source("memory");
299            result.add_item(context_item);
300        }
301        Ok(result)
302    }
303
304    async fn on_turn_complete(
305        &self,
306        _session_id: &str,
307        prompt: &str,
308        response: &str,
309    ) -> anyhow::Result<()> {
310        self.memory.remember_success(prompt, &[], response).await
311    }
312}
313
314// ============================================================================
315// Tests
316// ============================================================================
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321    use a3s_memory::InMemoryStore;
322    use std::sync::Arc;
323
324    #[tokio::test]
325    async fn test_agent_memory_remember_and_recall() {
326        let memory = AgentMemory::new(Arc::new(InMemoryStore::new()));
327        memory
328            .remember_success("create file", &["write".to_string()], "ok")
329            .await
330            .unwrap();
331        memory
332            .remember_failure("delete file", "denied", &["bash".to_string()])
333            .await
334            .unwrap();
335
336        let results = memory.recall_similar("create", 10).await.unwrap();
337        assert!(!results.is_empty());
338
339        let stats = memory.stats().await.unwrap();
340        assert_eq!(stats.long_term_count, 2);
341        assert_eq!(stats.short_term_count, 2);
342    }
343
344    #[tokio::test]
345    async fn test_agent_memory_working() {
346        let memory = AgentMemory::new(Arc::new(InMemoryStore::new()));
347        memory
348            .add_to_working(MemoryItem::new("task").with_type(MemoryType::Working))
349            .await
350            .unwrap();
351        assert_eq!(memory.working_count().await, 1);
352        memory.clear_working().await;
353        assert_eq!(memory.working_count().await, 0);
354    }
355
356    #[tokio::test]
357    async fn test_agent_memory_working_overflow_trims() {
358        let memory = AgentMemory {
359            store: Arc::new(InMemoryStore::new()),
360            short_term: Arc::new(RwLock::new(VecDeque::new())),
361            working: Arc::new(RwLock::new(Vec::new())),
362            max_short_term: 100,
363            max_working: 3,
364            relevance_config: RelevanceConfig::default(),
365        };
366        for i in 0..5 {
367            memory
368                .add_to_working(
369                    MemoryItem::new(format!("task {i}")).with_importance(i as f32 * 0.2),
370                )
371                .await
372                .unwrap();
373        }
374        assert_eq!(memory.get_working().await.len(), 3);
375    }
376
377    #[tokio::test]
378    async fn test_agent_memory_recall_by_tags() {
379        let memory = AgentMemory::new(Arc::new(InMemoryStore::new()));
380        memory
381            .remember_success("create file", &["write".to_string()], "ok")
382            .await
383            .unwrap();
384        memory
385            .remember_failure("delete file", "denied", &["bash".to_string()])
386            .await
387            .unwrap();
388
389        let successes = memory
390            .recall_by_tags(&["success".to_string()], 10)
391            .await
392            .unwrap();
393        assert_eq!(successes.len(), 1);
394        let failures = memory
395            .recall_by_tags(&["failure".to_string()], 10)
396            .await
397            .unwrap();
398        assert_eq!(failures.len(), 1);
399    }
400
401    #[tokio::test]
402    async fn test_agent_memory_short_term_trim() {
403        let store = Arc::new(InMemoryStore::new());
404        let memory = AgentMemory {
405            store,
406            short_term: Arc::new(RwLock::new(VecDeque::new())),
407            working: Arc::new(RwLock::new(Vec::new())),
408            max_short_term: 3,
409            max_working: 10,
410            relevance_config: RelevanceConfig::default(),
411        };
412        for i in 0..5 {
413            memory
414                .remember(MemoryItem::new(format!("item {i}")))
415                .await
416                .unwrap();
417        }
418        assert_eq!(memory.short_term_count().await, 3);
419    }
420
421    #[test]
422    fn test_agent_memory_score_uses_config() {
423        let config = MemoryConfig {
424            relevance: RelevanceConfig {
425                decay_days: 7.0,
426                importance_weight: 0.9,
427                recency_weight: 0.1,
428            },
429            ..Default::default()
430        };
431        let memory = AgentMemory::with_config(Arc::new(InMemoryStore::new()), config);
432        let item = MemoryItem::new("Test").with_importance(1.0);
433        let score = memory.score(&item, Utc::now());
434        assert!(score > 0.95, "Score was {score}");
435    }
436}