Skip to main content

a3s_code_core/
memory.rs

1//! Memory and learning system for the agent
2//!
3//! This module provides memory storage, recall, and learning capabilities
4//! to enable the agent to learn from past experiences and improve over time.
5
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, VecDeque};
9use std::sync::Arc;
10use tokio::sync::RwLock;
11
12// ============================================================================
13// Configuration
14// ============================================================================
15
16/// Configuration for relevance scoring
17#[derive(Debug, Clone, Serialize, Deserialize)]
18#[serde(rename_all = "camelCase")]
19pub struct RelevanceConfig {
20    /// Exponential decay half-life in days (default: 30.0)
21    #[serde(default = "RelevanceConfig::default_decay_days")]
22    pub decay_days: f32,
23    /// Weight for importance factor (default: 0.7)
24    #[serde(default = "RelevanceConfig::default_importance_weight")]
25    pub importance_weight: f32,
26    /// Weight for recency factor (default: 0.3)
27    #[serde(default = "RelevanceConfig::default_recency_weight")]
28    pub recency_weight: f32,
29}
30
31impl RelevanceConfig {
32    fn default_decay_days() -> f32 {
33        30.0
34    }
35    fn default_importance_weight() -> f32 {
36        0.7
37    }
38    fn default_recency_weight() -> f32 {
39        0.3
40    }
41}
42
43impl Default for RelevanceConfig {
44    fn default() -> Self {
45        Self {
46            decay_days: 30.0,
47            importance_weight: 0.7,
48            recency_weight: 0.3,
49        }
50    }
51}
52
53/// Configuration for the agent memory system
54#[derive(Debug, Clone, Serialize, Deserialize)]
55#[serde(rename_all = "camelCase")]
56pub struct MemoryConfig {
57    /// Relevance scoring parameters
58    #[serde(default)]
59    pub relevance: RelevanceConfig,
60    /// Maximum short-term memory items (default: 100)
61    #[serde(default = "MemoryConfig::default_max_short_term")]
62    pub max_short_term: usize,
63    /// Maximum working memory items (default: 10)
64    #[serde(default = "MemoryConfig::default_max_working")]
65    pub max_working: usize,
66}
67
68impl MemoryConfig {
69    fn default_max_short_term() -> usize {
70        100
71    }
72    fn default_max_working() -> usize {
73        10
74    }
75}
76
77impl Default for MemoryConfig {
78    fn default() -> Self {
79        Self {
80            relevance: RelevanceConfig::default(),
81            max_short_term: 100,
82            max_working: 10,
83        }
84    }
85}
86
87// ============================================================================
88// Memory Item
89// ============================================================================
90
91/// A single memory item
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct MemoryItem {
94    /// Unique identifier
95    pub id: String,
96    /// Memory content
97    pub content: String,
98    /// When this memory was created
99    pub timestamp: DateTime<Utc>,
100    /// Importance score (0.0 - 1.0)
101    pub importance: f32,
102    /// Tags for categorization
103    pub tags: Vec<String>,
104    /// Memory type
105    pub memory_type: MemoryType,
106    /// Associated metadata
107    pub metadata: HashMap<String, String>,
108    /// Number of times this memory was accessed
109    pub access_count: u32,
110    /// Last access time
111    pub last_accessed: Option<DateTime<Utc>>,
112    /// Cached lowercase content for fast substring search
113    #[serde(skip)]
114    pub content_lower: String,
115}
116
117impl MemoryItem {
118    /// Create a new memory item
119    pub fn new(content: impl Into<String>) -> Self {
120        let content = content.into();
121        let content_lower = content.to_lowercase();
122        Self {
123            id: uuid::Uuid::new_v4().to_string(),
124            content,
125            timestamp: Utc::now(),
126            importance: 0.5,
127            tags: Vec::new(),
128            memory_type: MemoryType::Episodic,
129            metadata: HashMap::new(),
130            access_count: 0,
131            last_accessed: None,
132            content_lower,
133        }
134    }
135
136    /// Set importance
137    pub fn with_importance(mut self, importance: f32) -> Self {
138        self.importance = importance.clamp(0.0, 1.0);
139        self
140    }
141
142    /// Add tags
143    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
144        self.tags = tags;
145        self
146    }
147
148    /// Add a single tag
149    pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
150        self.tags.push(tag.into());
151        self
152    }
153
154    /// Set memory type
155    pub fn with_type(mut self, memory_type: MemoryType) -> Self {
156        self.memory_type = memory_type;
157        self
158    }
159
160    /// Add metadata
161    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
162        self.metadata.insert(key.into(), value.into());
163        self
164    }
165
166    /// Record access
167    pub fn record_access(&mut self) {
168        self.access_count += 1;
169        self.last_accessed = Some(Utc::now());
170    }
171
172    /// Calculate relevance score at a given timestamp
173    ///
174    /// Use this variant in sort comparators to avoid repeated `Utc::now()` syscalls.
175    pub fn relevance_score_at(&self, now: DateTime<Utc>) -> f32 {
176        let age_seconds = (now - self.timestamp).num_seconds() as f32;
177        let age_days = age_seconds / 86400.0;
178
179        // Decay factor: memories lose relevance over time
180        let decay = (-age_days / 30.0).exp(); // 30-day half-life
181
182        // Combine importance and recency
183        self.importance * 0.7 + decay * 0.3
184    }
185
186    /// Calculate relevance score based on recency and importance
187    pub fn relevance_score(&self) -> f32 {
188        self.relevance_score_at(Utc::now())
189    }
190}
191
192/// Type of memory
193#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
194#[serde(rename_all = "snake_case")]
195pub enum MemoryType {
196    /// Episodic memory (specific events)
197    Episodic,
198    /// Semantic memory (facts and knowledge)
199    Semantic,
200    /// Procedural memory (how to do things)
201    Procedural,
202    /// Working memory (temporary, active)
203    Working,
204}
205
206// ============================================================================
207// Memory Store Trait
208// ============================================================================
209
210/// Trait for memory storage backends
211#[async_trait::async_trait]
212pub trait MemoryStore: Send + Sync {
213    /// Store a memory item
214    async fn store(&self, item: MemoryItem) -> anyhow::Result<()>;
215
216    /// Retrieve a memory by ID
217    async fn retrieve(&self, id: &str) -> anyhow::Result<Option<MemoryItem>>;
218
219    /// Search memories by query
220    async fn search(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryItem>>;
221
222    /// Search memories by tags
223    async fn search_by_tags(
224        &self,
225        tags: &[String],
226        limit: usize,
227    ) -> anyhow::Result<Vec<MemoryItem>>;
228
229    /// Get recent memories
230    async fn get_recent(&self, limit: usize) -> anyhow::Result<Vec<MemoryItem>>;
231
232    /// Get important memories
233    async fn get_important(&self, threshold: f32, limit: usize) -> anyhow::Result<Vec<MemoryItem>>;
234
235    /// Delete a memory
236    async fn delete(&self, id: &str) -> anyhow::Result<()>;
237
238    /// Clear all memories
239    async fn clear(&self) -> anyhow::Result<()>;
240
241    /// Get total memory count
242    async fn count(&self) -> anyhow::Result<usize>;
243}
244
245// ============================================================================
246// Shared Search/Sort Helpers (DRY)
247// ============================================================================
248
249/// Sort memory items by relevance score (highest first)
250fn sort_by_relevance(items: &mut [MemoryItem]) {
251    let now = Utc::now();
252    items.sort_by(|a, b| {
253        b.relevance_score_at(now)
254            .partial_cmp(&a.relevance_score_at(now))
255            .unwrap_or(std::cmp::Ordering::Equal)
256    });
257}
258
259// ============================================================================
260// In-Memory Store
261// ============================================================================
262
263/// Agent memory system
264#[derive(Clone)]
265pub struct AgentMemory {
266    /// Long-term memory store
267    store: Arc<dyn MemoryStore>,
268    /// Short-term memory (current session)
269    short_term: Arc<RwLock<VecDeque<MemoryItem>>>,
270    /// Working memory (active context)
271    working: Arc<RwLock<Vec<MemoryItem>>>,
272    /// Maximum short-term memory size
273    max_short_term: usize,
274    /// Maximum working memory size
275    max_working: usize,
276    /// Relevance scoring configuration
277    relevance_config: RelevanceConfig,
278}
279
280impl std::fmt::Debug for AgentMemory {
281    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
282        f.debug_struct("AgentMemory")
283            .field("max_short_term", &self.max_short_term)
284            .field("max_working", &self.max_working)
285            .finish()
286    }
287}
288
289impl AgentMemory {
290    /// Create a new agent memory system with default configuration
291    pub fn new(store: Arc<dyn MemoryStore>) -> Self {
292        Self::with_config(store, MemoryConfig::default())
293    }
294
295    /// Create a new agent memory system with custom configuration
296    pub fn with_config(store: Arc<dyn MemoryStore>, config: MemoryConfig) -> Self {
297        Self {
298            store,
299            short_term: Arc::new(RwLock::new(VecDeque::new())),
300            working: Arc::new(RwLock::new(Vec::new())),
301            max_short_term: config.max_short_term,
302            max_working: config.max_working,
303            relevance_config: config.relevance,
304        }
305    }
306
307    /// Calculate relevance score using this memory system's configuration
308    fn score(&self, item: &MemoryItem, now: DateTime<Utc>) -> f32 {
309        let age_seconds = (now - item.timestamp).num_seconds() as f32;
310        let age_days = age_seconds / 86400.0;
311        let decay = (-age_days / self.relevance_config.decay_days).exp();
312        item.importance * self.relevance_config.importance_weight
313            + decay * self.relevance_config.recency_weight
314    }
315
316    /// Store a memory in long-term storage
317    pub async fn remember(&self, item: MemoryItem) -> anyhow::Result<()> {
318        // Store in long-term
319        self.store.store(item.clone()).await?;
320
321        // Add to short-term
322        let mut short_term = self.short_term.write().await;
323        short_term.push_back(item);
324
325        // Trim if needed
326        if short_term.len() > self.max_short_term {
327            short_term.pop_front();
328        }
329
330        Ok(())
331    }
332
333    /// Remember a successful pattern
334    pub async fn remember_success(
335        &self,
336        prompt: &str,
337        tools_used: &[String],
338        result: &str,
339    ) -> anyhow::Result<()> {
340        let content = format!(
341            "Success: {}\nTools: {}\nResult: {}",
342            prompt,
343            tools_used.join(", "),
344            result
345        );
346
347        let item = MemoryItem::new(content)
348            .with_importance(0.8)
349            .with_tag("success")
350            .with_tag("pattern")
351            .with_type(MemoryType::Procedural)
352            .with_metadata("prompt", prompt)
353            .with_metadata("tools", tools_used.join(","));
354
355        self.remember(item).await
356    }
357
358    /// Remember a failure to avoid repeating
359    pub async fn remember_failure(
360        &self,
361        prompt: &str,
362        error: &str,
363        attempted_tools: &[String],
364    ) -> anyhow::Result<()> {
365        let content = format!(
366            "Failure: {}\nError: {}\nAttempted tools: {}",
367            prompt,
368            error,
369            attempted_tools.join(", ")
370        );
371
372        let item = MemoryItem::new(content)
373            .with_importance(0.9) // Failures are important to remember
374            .with_tag("failure")
375            .with_tag("avoid")
376            .with_type(MemoryType::Episodic)
377            .with_metadata("prompt", prompt)
378            .with_metadata("error", error);
379
380        self.remember(item).await
381    }
382
383    /// Recall similar past experiences
384    pub async fn recall_similar(
385        &self,
386        prompt: &str,
387        limit: usize,
388    ) -> anyhow::Result<Vec<MemoryItem>> {
389        self.store.search(prompt, limit).await
390    }
391
392    /// Recall by tags
393    pub async fn recall_by_tags(
394        &self,
395        tags: &[String],
396        limit: usize,
397    ) -> anyhow::Result<Vec<MemoryItem>> {
398        self.store.search_by_tags(tags, limit).await
399    }
400
401    /// Get recent memories
402    pub async fn get_recent(&self, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
403        self.store.get_recent(limit).await
404    }
405
406    /// Add to working memory
407    pub async fn add_to_working(&self, item: MemoryItem) -> anyhow::Result<()> {
408        let mut working = self.working.write().await;
409        working.push(item);
410
411        // Trim if needed (keep most relevant)
412        if working.len() > self.max_working {
413            let now = Utc::now();
414            working.sort_by(|a, b| {
415                self.score(b, now)
416                    .partial_cmp(&self.score(a, now))
417                    .unwrap_or(std::cmp::Ordering::Equal)
418            });
419            working.truncate(self.max_working);
420        }
421
422        Ok(())
423    }
424
425    /// Get working memory
426    pub async fn get_working(&self) -> Vec<MemoryItem> {
427        self.working.read().await.clone()
428    }
429
430    /// Clear working memory
431    pub async fn clear_working(&self) {
432        self.working.write().await.clear();
433    }
434
435    /// Get short-term memory
436    pub async fn get_short_term(&self) -> Vec<MemoryItem> {
437        self.short_term.read().await.iter().cloned().collect()
438    }
439
440    /// Clear short-term memory
441    pub async fn clear_short_term(&self) {
442        self.short_term.write().await.clear();
443    }
444
445    /// Get memory statistics
446    pub async fn stats(&self) -> anyhow::Result<MemoryStats> {
447        let long_term_count = self.store.count().await?;
448        let short_term_count = self.short_term.read().await.len();
449        let working_count = self.working.read().await.len();
450
451        Ok(MemoryStats {
452            long_term_count,
453            short_term_count,
454            working_count,
455        })
456    }
457
458    /// Get access to the underlying store
459    pub fn store(&self) -> &Arc<dyn MemoryStore> {
460        &self.store
461    }
462
463    /// Get working memory count
464    pub async fn working_count(&self) -> usize {
465        self.working.read().await.len()
466    }
467
468    /// Get short-term memory count
469    pub async fn short_term_count(&self) -> usize {
470        self.short_term.read().await.len()
471    }
472}
473
474/// Memory statistics
475#[derive(Debug, Clone, Serialize, Deserialize)]
476pub struct MemoryStats {
477    /// Number of long-term memories
478    pub long_term_count: usize,
479    /// Number of short-term memories
480    pub short_term_count: usize,
481    /// Number of working memories
482    pub working_count: usize,
483}
484
485// ============================================================================
486// Memory Context Provider
487// ============================================================================
488
489/// Context provider that surfaces past memories (successes/failures) as context.
490///
491/// Wraps `AgentMemory` and implements the `ContextProvider` trait so that
492/// session memory is automatically injected into the agent's system prompt.
493pub struct MemoryContextProvider {
494    memory: AgentMemory,
495}
496
497impl MemoryContextProvider {
498    /// Create a new memory context provider
499    pub fn new(memory: AgentMemory) -> Self {
500        Self { memory }
501    }
502}
503
504#[async_trait::async_trait]
505impl crate::context::ContextProvider for MemoryContextProvider {
506    fn name(&self) -> &str {
507        "memory"
508    }
509
510    async fn query(
511        &self,
512        query: &crate::context::ContextQuery,
513    ) -> anyhow::Result<crate::context::ContextResult> {
514        let limit = query.max_results.min(5);
515        let items = self.memory.recall_similar(&query.query, limit).await?;
516
517        let mut result = crate::context::ContextResult::new("memory");
518        for item in items {
519            let relevance = item.relevance_score();
520            let token_count = item.content.len() / 4; // rough estimate
521            let context_item = crate::context::ContextItem::new(
522                &item.id,
523                crate::context::ContextType::Memory,
524                &item.content,
525            )
526            .with_relevance(relevance)
527            .with_token_count(token_count)
528            .with_source("memory");
529            result.add_item(context_item);
530        }
531
532        Ok(result)
533    }
534
535    async fn on_turn_complete(
536        &self,
537        _session_id: &str,
538        prompt: &str,
539        response: &str,
540    ) -> anyhow::Result<()> {
541        // Store the successful interaction as a memory
542        self.memory.remember_success(prompt, &[], response).await
543    }
544}
545
546// ============================================================================
547// Tests
548// ============================================================================
549
550#[cfg(test)]
551mod tests {
552    use super::*;
553
554    /// Simple in-memory store for testing
555    struct TestMemoryStore {
556        items: std::sync::Mutex<Vec<MemoryItem>>,
557    }
558
559    impl TestMemoryStore {
560        fn new() -> Self {
561            Self { items: std::sync::Mutex::new(Vec::new()) }
562        }
563    }
564
565    #[async_trait::async_trait]
566    impl MemoryStore for TestMemoryStore {
567        async fn store(&self, item: MemoryItem) -> anyhow::Result<()> {
568            self.items.lock().unwrap().push(item);
569            Ok(())
570        }
571        async fn retrieve(&self, id: &str) -> anyhow::Result<Option<MemoryItem>> {
572            Ok(self.items.lock().unwrap().iter().find(|i| i.id == id).cloned())
573        }
574        async fn search(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
575            let items = self.items.lock().unwrap();
576            let query_lower = query.to_lowercase();
577            Ok(items.iter().filter(|i| i.content.to_lowercase().contains(&query_lower)).take(limit).cloned().collect())
578        }
579        async fn search_by_tags(&self, tags: &[String], limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
580            let items = self.items.lock().unwrap();
581            Ok(items.iter().filter(|i| tags.iter().any(|t| i.tags.contains(t))).take(limit).cloned().collect())
582        }
583        async fn get_recent(&self, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
584            let items = self.items.lock().unwrap();
585            let mut sorted: Vec<_> = items.iter().cloned().collect();
586            sorted.sort_by(|a, b| b.timestamp.cmp(&a.timestamp));
587            sorted.truncate(limit);
588            Ok(sorted)
589        }
590        async fn get_important(&self, threshold: f32, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
591            let items = self.items.lock().unwrap();
592            Ok(items.iter().filter(|i| i.importance >= threshold).take(limit).cloned().collect())
593        }
594        async fn delete(&self, id: &str) -> anyhow::Result<()> {
595            self.items.lock().unwrap().retain(|i| i.id != id);
596            Ok(())
597        }
598        async fn clear(&self) -> anyhow::Result<()> {
599            self.items.lock().unwrap().clear();
600            Ok(())
601        }
602        async fn count(&self) -> anyhow::Result<usize> {
603            Ok(self.items.lock().unwrap().len())
604        }
605    }
606
607
608    #[test]
609    fn test_memory_item_creation() {
610        let item = MemoryItem::new("Test memory")
611            .with_importance(0.8)
612            .with_tag("test")
613            .with_type(MemoryType::Semantic);
614
615        assert_eq!(item.content, "Test memory");
616        assert_eq!(item.importance, 0.8);
617        assert_eq!(item.tags, vec!["test"]);
618        assert_eq!(item.memory_type, MemoryType::Semantic);
619    }
620
621    #[test]
622    fn test_memory_item_relevance() {
623        let item = MemoryItem::new("Test").with_importance(0.9);
624        let score = item.relevance_score();
625
626        // Should be high for recent, important memory
627        assert!(score > 0.6);
628    }
629
630    #[test]
631    fn test_relevance_config_defaults() {
632        let config = RelevanceConfig::default();
633        assert_eq!(config.decay_days, 30.0);
634        assert_eq!(config.importance_weight, 0.7);
635        assert_eq!(config.recency_weight, 0.3);
636    }
637
638    #[test]
639    fn test_memory_config_defaults() {
640        let config = MemoryConfig::default();
641        assert_eq!(config.max_short_term, 100);
642        assert_eq!(config.max_working, 10);
643        assert_eq!(config.relevance.decay_days, 30.0);
644    }
645
646    #[test]
647    fn test_memory_config_serde_roundtrip() {
648        let config = MemoryConfig::default();
649        let json = serde_json::to_string(&config).unwrap();
650        let parsed: MemoryConfig = serde_json::from_str(&json).unwrap();
651        assert_eq!(parsed.max_short_term, config.max_short_term);
652        assert_eq!(parsed.max_working, config.max_working);
653        assert_eq!(parsed.relevance.decay_days, config.relevance.decay_days);
654    }
655
656    #[test]
657    fn test_agent_memory_with_config() {
658        let config = MemoryConfig {
659            relevance: RelevanceConfig {
660                decay_days: 7.0,
661                importance_weight: 0.5,
662                recency_weight: 0.5,
663            },
664            max_short_term: 50,
665            max_working: 5,
666        };
667        let memory = AgentMemory::with_config(Arc::new(TestMemoryStore::new()), config);
668        assert_eq!(memory.max_short_term, 50);
669        assert_eq!(memory.max_working, 5);
670        assert_eq!(memory.relevance_config.decay_days, 7.0);
671    }
672
673    #[test]
674    fn test_agent_memory_score_uses_config() {
675        let config = MemoryConfig {
676            relevance: RelevanceConfig {
677                decay_days: 7.0,
678                importance_weight: 0.9,
679                recency_weight: 0.1,
680            },
681            ..Default::default()
682        };
683        let memory = AgentMemory::with_config(Arc::new(TestMemoryStore::new()), config);
684
685        let item = MemoryItem::new("Test").with_importance(1.0);
686        let now = Utc::now();
687        let score = memory.score(&item, now);
688
689        // With importance_weight=0.9, a brand new item with importance=1.0
690        // should score close to 0.9 + 0.1 = 1.0 (decay ~1.0 for recent items)
691        assert!(score > 0.95, "Score was {}", score);
692    }
693
694    #[tokio::test]
695    async fn test_in_memory_store() {
696        let store = TestMemoryStore::new();
697
698        let item = MemoryItem::new("Test memory").with_tag("test");
699        store.store(item.clone()).await.unwrap();
700
701        let retrieved = store.retrieve(&item.id).await.unwrap();
702        assert!(retrieved.is_some());
703        assert_eq!(retrieved.unwrap().content, "Test memory");
704    }
705
706    #[tokio::test]
707    async fn test_memory_search() {
708        let store = TestMemoryStore::new();
709
710        store
711            .store(MemoryItem::new("How to create a file").with_tag("file"))
712            .await
713            .unwrap();
714        store
715            .store(MemoryItem::new("How to delete a file").with_tag("file"))
716            .await
717            .unwrap();
718        store
719            .store(MemoryItem::new("How to create a directory").with_tag("dir"))
720            .await
721            .unwrap();
722
723        let results = store.search("create", 10).await.unwrap();
724        assert_eq!(results.len(), 2);
725
726        let results = store
727            .search_by_tags(&["file".to_string()], 10)
728            .await
729            .unwrap();
730        assert_eq!(results.len(), 2);
731    }
732
733    #[tokio::test]
734    async fn test_agent_memory() {
735        let memory = AgentMemory::new(Arc::new(TestMemoryStore::new()));
736
737        // Remember success
738        memory
739            .remember_success("Create a file", &["write".to_string()], "File created")
740            .await
741            .unwrap();
742
743        // Remember failure
744        memory
745            .remember_failure("Delete file", "Permission denied", &["bash".to_string()])
746            .await
747            .unwrap();
748
749        // Recall
750        let results = memory.recall_similar("create", 10).await.unwrap();
751        assert!(!results.is_empty());
752
753        let stats = memory.stats().await.unwrap();
754        assert_eq!(stats.long_term_count, 2);
755    }
756
757    #[tokio::test]
758    async fn test_working_memory() {
759        let memory = AgentMemory::new(Arc::new(TestMemoryStore::new()));
760
761        let item = MemoryItem::new("Active task").with_type(MemoryType::Working);
762        memory.add_to_working(item).await.unwrap();
763
764        let working = memory.get_working().await;
765        assert_eq!(working.len(), 1);
766
767        memory.clear_working().await;
768        let working = memory.get_working().await;
769        assert_eq!(working.len(), 0);
770    }
771}
772
773#[cfg(test)]
774mod extra_memory_tests {
775    use super::*;
776
777    /// Simple in-memory store for testing
778    struct TestMemoryStore {
779        items: std::sync::Mutex<Vec<MemoryItem>>,
780    }
781
782    impl TestMemoryStore {
783        fn new() -> Self {
784            Self { items: std::sync::Mutex::new(Vec::new()) }
785        }
786    }
787
788    #[async_trait::async_trait]
789    impl MemoryStore for TestMemoryStore {
790        async fn store(&self, item: MemoryItem) -> anyhow::Result<()> {
791            self.items.lock().unwrap().push(item);
792            Ok(())
793        }
794        async fn retrieve(&self, id: &str) -> anyhow::Result<Option<MemoryItem>> {
795            Ok(self.items.lock().unwrap().iter().find(|i| i.id == id).cloned())
796        }
797        async fn search(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
798            let items = self.items.lock().unwrap();
799            let query_lower = query.to_lowercase();
800            Ok(items.iter().filter(|i| i.content.to_lowercase().contains(&query_lower)).take(limit).cloned().collect())
801        }
802        async fn search_by_tags(&self, tags: &[String], limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
803            let items = self.items.lock().unwrap();
804            Ok(items.iter().filter(|i| tags.iter().any(|t| i.tags.contains(t))).take(limit).cloned().collect())
805        }
806        async fn get_recent(&self, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
807            let items = self.items.lock().unwrap();
808            let mut sorted: Vec<_> = items.iter().cloned().collect();
809            sorted.sort_by(|a, b| b.timestamp.cmp(&a.timestamp));
810            sorted.truncate(limit);
811            Ok(sorted)
812        }
813        async fn get_important(&self, threshold: f32, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
814            let items = self.items.lock().unwrap();
815            Ok(items.iter().filter(|i| i.importance >= threshold).take(limit).cloned().collect())
816        }
817        async fn delete(&self, id: &str) -> anyhow::Result<()> {
818            self.items.lock().unwrap().retain(|i| i.id != id);
819            Ok(())
820        }
821        async fn clear(&self) -> anyhow::Result<()> {
822            self.items.lock().unwrap().clear();
823            Ok(())
824        }
825        async fn count(&self) -> anyhow::Result<usize> {
826            Ok(self.items.lock().unwrap().len())
827        }
828    }
829
830
831    // ========================================================================
832    // MemoryItem builder methods
833    // ========================================================================
834
835    #[test]
836    fn test_memory_item_with_metadata() {
837        let item = MemoryItem::new("test")
838            .with_metadata("key1", "value1")
839            .with_metadata("key2", "value2");
840        assert_eq!(item.metadata.get("key1").unwrap(), "value1");
841        assert_eq!(item.metadata.get("key2").unwrap(), "value2");
842    }
843
844    #[test]
845    fn test_memory_item_with_tags_vec() {
846        let item = MemoryItem::new("test").with_tags(vec![
847            "a".to_string(),
848            "b".to_string(),
849            "c".to_string(),
850        ]);
851        assert_eq!(item.tags.len(), 3);
852    }
853
854    #[test]
855    fn test_memory_item_importance_clamped() {
856        let item_high = MemoryItem::new("test").with_importance(1.5);
857        assert_eq!(item_high.importance, 1.0);
858
859        let item_low = MemoryItem::new("test").with_importance(-0.5);
860        assert_eq!(item_low.importance, 0.0);
861    }
862
863    #[test]
864    fn test_memory_item_record_access() {
865        let mut item = MemoryItem::new("test");
866        assert_eq!(item.access_count, 0);
867        assert!(item.last_accessed.is_none());
868
869        item.record_access();
870        assert_eq!(item.access_count, 1);
871        assert!(item.last_accessed.is_some());
872
873        item.record_access();
874        assert_eq!(item.access_count, 2);
875    }
876
877    #[test]
878    fn test_memory_item_all_types() {
879        let episodic = MemoryItem::new("e").with_type(MemoryType::Episodic);
880        assert_eq!(episodic.memory_type, MemoryType::Episodic);
881
882        let semantic = MemoryItem::new("s").with_type(MemoryType::Semantic);
883        assert_eq!(semantic.memory_type, MemoryType::Semantic);
884
885        let procedural = MemoryItem::new("p").with_type(MemoryType::Procedural);
886        assert_eq!(procedural.memory_type, MemoryType::Procedural);
887
888        let working = MemoryItem::new("w").with_type(MemoryType::Working);
889        assert_eq!(working.memory_type, MemoryType::Working);
890    }
891
892    #[test]
893    fn test_memory_item_default_type_is_episodic() {
894        let item = MemoryItem::new("test");
895        assert_eq!(item.memory_type, MemoryType::Episodic);
896    }
897
898    // ========================================================================
899    // TestMemoryStore
900    // ========================================================================
901
902    #[tokio::test]
903    async fn test_in_memory_store_retrieve_nonexistent() {
904        let store = TestMemoryStore::new();
905        let result = store.retrieve("nonexistent").await.unwrap();
906        assert!(result.is_none());
907    }
908
909    #[tokio::test]
910    async fn test_in_memory_store_delete() {
911        let store = TestMemoryStore::new();
912        let item = MemoryItem::new("to delete");
913        let id = item.id.clone();
914        store.store(item).await.unwrap();
915        assert_eq!(store.count().await.unwrap(), 1);
916
917        store.delete(&id).await.unwrap();
918        assert_eq!(store.count().await.unwrap(), 0);
919    }
920
921    #[tokio::test]
922    async fn test_in_memory_store_clear() {
923        let store = TestMemoryStore::new();
924        for i in 0..5 {
925            store
926                .store(MemoryItem::new(format!("item {}", i)))
927                .await
928                .unwrap();
929        }
930        assert_eq!(store.count().await.unwrap(), 5);
931
932        store.clear().await.unwrap();
933        assert_eq!(store.count().await.unwrap(), 0);
934    }
935
936    #[tokio::test]
937    async fn test_in_memory_store_get_recent() {
938        let store = TestMemoryStore::new();
939        for i in 0..5 {
940            store
941                .store(MemoryItem::new(format!("item {}", i)))
942                .await
943                .unwrap();
944        }
945        let recent = store.get_recent(3).await.unwrap();
946        assert_eq!(recent.len(), 3);
947    }
948
949    #[tokio::test]
950    async fn test_in_memory_store_get_important() {
951        let store = TestMemoryStore::new();
952        store
953            .store(MemoryItem::new("low").with_importance(0.2))
954            .await
955            .unwrap();
956        store
957            .store(MemoryItem::new("medium").with_importance(0.5))
958            .await
959            .unwrap();
960        store
961            .store(MemoryItem::new("high").with_importance(0.9))
962            .await
963            .unwrap();
964
965        let important = store.get_important(0.7, 10).await.unwrap();
966        assert_eq!(important.len(), 1);
967        assert_eq!(important[0].content, "high");
968    }
969
970    #[tokio::test]
971    async fn test_in_memory_store_search_case_insensitive() {
972        let store = TestMemoryStore::new();
973        store
974            .store(MemoryItem::new("How to CREATE a file"))
975            .await
976            .unwrap();
977        let results = store.search("create", 10).await.unwrap();
978        assert_eq!(results.len(), 1);
979    }
980
981    // ========================================================================
982    // AgentMemory
983    // ========================================================================
984
985    #[tokio::test]
986    async fn test_agent_memory_short_term() {
987        let memory = AgentMemory::new(Arc::new(TestMemoryStore::new()));
988        memory.remember(MemoryItem::new("item 1")).await.unwrap();
989        memory.remember(MemoryItem::new("item 2")).await.unwrap();
990
991        let short_term = memory.get_short_term().await;
992        assert_eq!(short_term.len(), 2);
993
994        memory.clear_short_term().await;
995        let short_term = memory.get_short_term().await;
996        assert_eq!(short_term.len(), 0);
997    }
998
999    #[tokio::test]
1000    async fn test_agent_memory_short_term_count() {
1001        let memory = AgentMemory::new(Arc::new(TestMemoryStore::new()));
1002        assert_eq!(memory.short_term_count().await, 0);
1003        memory.remember(MemoryItem::new("item")).await.unwrap();
1004        assert_eq!(memory.short_term_count().await, 1);
1005    }
1006
1007    #[tokio::test]
1008    async fn test_agent_memory_working_count() {
1009        let memory = AgentMemory::new(Arc::new(TestMemoryStore::new()));
1010        assert_eq!(memory.working_count().await, 0);
1011        memory
1012            .add_to_working(MemoryItem::new("task"))
1013            .await
1014            .unwrap();
1015        assert_eq!(memory.working_count().await, 1);
1016    }
1017
1018    #[tokio::test]
1019    async fn test_agent_memory_recall_by_tags() {
1020        let memory = AgentMemory::new(Arc::new(TestMemoryStore::new()));
1021        memory
1022            .remember_success("create file", &["write".to_string()], "ok")
1023            .await
1024            .unwrap();
1025        memory
1026            .remember_failure("delete file", "denied", &["bash".to_string()])
1027            .await
1028            .unwrap();
1029
1030        let successes = memory
1031            .recall_by_tags(&["success".to_string()], 10)
1032            .await
1033            .unwrap();
1034        assert_eq!(successes.len(), 1);
1035
1036        let failures = memory
1037            .recall_by_tags(&["failure".to_string()], 10)
1038            .await
1039            .unwrap();
1040        assert_eq!(failures.len(), 1);
1041    }
1042
1043    #[tokio::test]
1044    async fn test_agent_memory_get_recent() {
1045        let memory = AgentMemory::new(Arc::new(TestMemoryStore::new()));
1046        for i in 0..5 {
1047            memory
1048                .remember(MemoryItem::new(format!("item {}", i)))
1049                .await
1050                .unwrap();
1051        }
1052        let recent = memory.get_recent(3).await.unwrap();
1053        assert_eq!(recent.len(), 3);
1054    }
1055
1056    #[tokio::test]
1057    async fn test_agent_memory_store_accessor() {
1058        let memory = AgentMemory::new(Arc::new(TestMemoryStore::new()));
1059        memory.remember(MemoryItem::new("test")).await.unwrap();
1060        let count = memory.store().count().await.unwrap();
1061        assert_eq!(count, 1);
1062    }
1063
1064    #[tokio::test]
1065    async fn test_agent_memory_stats_all_fields() {
1066        let memory = AgentMemory::new(Arc::new(TestMemoryStore::new()));
1067        memory.remember(MemoryItem::new("long term")).await.unwrap();
1068        memory
1069            .add_to_working(MemoryItem::new("working"))
1070            .await
1071            .unwrap();
1072
1073        let stats = memory.stats().await.unwrap();
1074        assert_eq!(stats.long_term_count, 1);
1075        assert_eq!(stats.short_term_count, 1); // remember also adds to short_term
1076        assert_eq!(stats.working_count, 1);
1077    }
1078
1079    #[tokio::test]
1080    async fn test_agent_memory_working_overflow_trims() {
1081        let store = Arc::new(TestMemoryStore::new());
1082        let memory = AgentMemory {
1083            store,
1084            short_term: Arc::new(RwLock::new(VecDeque::new())),
1085            working: Arc::new(RwLock::new(Vec::new())),
1086            max_short_term: 100,
1087            max_working: 3, // Small limit
1088            relevance_config: RelevanceConfig::default(),
1089        };
1090
1091        for i in 0..5 {
1092            memory
1093                .add_to_working(
1094                    MemoryItem::new(format!("task {}", i)).with_importance(i as f32 * 0.2),
1095                )
1096                .await
1097                .unwrap();
1098        }
1099
1100        let working = memory.get_working().await;
1101        assert_eq!(working.len(), 3); // Trimmed to max_working
1102    }
1103}