Skip to main content

a3s_code_core/
memory.rs

1//! Memory and learning system for the agent
2//!
3//! This module provides memory storage, recall, and learning capabilities
4//! to enable the agent to learn from past experiences and improve over time.
5
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, VecDeque};
9use std::sync::Arc;
10use tokio::sync::RwLock;
11
12// ============================================================================
13// Configuration
14// ============================================================================
15
16/// Configuration for relevance scoring
17#[derive(Debug, Clone, Serialize, Deserialize)]
18#[serde(rename_all = "camelCase")]
19pub struct RelevanceConfig {
20    /// Exponential decay half-life in days (default: 30.0)
21    #[serde(default = "RelevanceConfig::default_decay_days")]
22    pub decay_days: f32,
23    /// Weight for importance factor (default: 0.7)
24    #[serde(default = "RelevanceConfig::default_importance_weight")]
25    pub importance_weight: f32,
26    /// Weight for recency factor (default: 0.3)
27    #[serde(default = "RelevanceConfig::default_recency_weight")]
28    pub recency_weight: f32,
29}
30
31impl RelevanceConfig {
32    fn default_decay_days() -> f32 {
33        30.0
34    }
35    fn default_importance_weight() -> f32 {
36        0.7
37    }
38    fn default_recency_weight() -> f32 {
39        0.3
40    }
41}
42
43impl Default for RelevanceConfig {
44    fn default() -> Self {
45        Self {
46            decay_days: 30.0,
47            importance_weight: 0.7,
48            recency_weight: 0.3,
49        }
50    }
51}
52
53/// Configuration for the agent memory system
54#[derive(Debug, Clone, Serialize, Deserialize)]
55#[serde(rename_all = "camelCase")]
56pub struct MemoryConfig {
57    /// Relevance scoring parameters
58    #[serde(default)]
59    pub relevance: RelevanceConfig,
60    /// Maximum short-term memory items (default: 100)
61    #[serde(default = "MemoryConfig::default_max_short_term")]
62    pub max_short_term: usize,
63    /// Maximum working memory items (default: 10)
64    #[serde(default = "MemoryConfig::default_max_working")]
65    pub max_working: usize,
66}
67
68impl MemoryConfig {
69    fn default_max_short_term() -> usize {
70        100
71    }
72    fn default_max_working() -> usize {
73        10
74    }
75}
76
77impl Default for MemoryConfig {
78    fn default() -> Self {
79        Self {
80            relevance: RelevanceConfig::default(),
81            max_short_term: 100,
82            max_working: 10,
83        }
84    }
85}
86
87// ============================================================================
88// Memory Item
89// ============================================================================
90
91/// A single memory item
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct MemoryItem {
94    /// Unique identifier
95    pub id: String,
96    /// Memory content
97    pub content: String,
98    /// When this memory was created
99    pub timestamp: DateTime<Utc>,
100    /// Importance score (0.0 - 1.0)
101    pub importance: f32,
102    /// Tags for categorization
103    pub tags: Vec<String>,
104    /// Memory type
105    pub memory_type: MemoryType,
106    /// Associated metadata
107    pub metadata: HashMap<String, String>,
108    /// Number of times this memory was accessed
109    pub access_count: u32,
110    /// Last access time
111    pub last_accessed: Option<DateTime<Utc>>,
112    /// Cached lowercase content for fast substring search
113    #[serde(skip)]
114    pub content_lower: String,
115}
116
117impl MemoryItem {
118    /// Create a new memory item
119    pub fn new(content: impl Into<String>) -> Self {
120        let content = content.into();
121        let content_lower = content.to_lowercase();
122        Self {
123            id: uuid::Uuid::new_v4().to_string(),
124            content,
125            timestamp: Utc::now(),
126            importance: 0.5,
127            tags: Vec::new(),
128            memory_type: MemoryType::Episodic,
129            metadata: HashMap::new(),
130            access_count: 0,
131            last_accessed: None,
132            content_lower,
133        }
134    }
135
136    /// Set importance
137    pub fn with_importance(mut self, importance: f32) -> Self {
138        self.importance = importance.clamp(0.0, 1.0);
139        self
140    }
141
142    /// Add tags
143    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
144        self.tags = tags;
145        self
146    }
147
148    /// Add a single tag
149    pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
150        self.tags.push(tag.into());
151        self
152    }
153
154    /// Set memory type
155    pub fn with_type(mut self, memory_type: MemoryType) -> Self {
156        self.memory_type = memory_type;
157        self
158    }
159
160    /// Add metadata
161    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
162        self.metadata.insert(key.into(), value.into());
163        self
164    }
165
166    /// Record access
167    pub fn record_access(&mut self) {
168        self.access_count += 1;
169        self.last_accessed = Some(Utc::now());
170    }
171
172    /// Calculate relevance score at a given timestamp
173    ///
174    /// Use this variant in sort comparators to avoid repeated `Utc::now()` syscalls.
175    pub fn relevance_score_at(&self, now: DateTime<Utc>) -> f32 {
176        let age_seconds = (now - self.timestamp).num_seconds() as f32;
177        let age_days = age_seconds / 86400.0;
178
179        // Decay factor: memories lose relevance over time
180        let decay = (-age_days / 30.0).exp(); // 30-day half-life
181
182        // Combine importance and recency
183        self.importance * 0.7 + decay * 0.3
184    }
185
186    /// Calculate relevance score based on recency and importance
187    pub fn relevance_score(&self) -> f32 {
188        self.relevance_score_at(Utc::now())
189    }
190}
191
192/// Type of memory
193#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
194#[serde(rename_all = "snake_case")]
195pub enum MemoryType {
196    /// Episodic memory (specific events)
197    Episodic,
198    /// Semantic memory (facts and knowledge)
199    Semantic,
200    /// Procedural memory (how to do things)
201    Procedural,
202    /// Working memory (temporary, active)
203    Working,
204}
205
206// ============================================================================
207// Memory Store Trait
208// ============================================================================
209
210/// Trait for memory storage backends
211#[async_trait::async_trait]
212pub trait MemoryStore: Send + Sync {
213    /// Store a memory item
214    async fn store(&self, item: MemoryItem) -> anyhow::Result<()>;
215
216    /// Retrieve a memory by ID
217    async fn retrieve(&self, id: &str) -> anyhow::Result<Option<MemoryItem>>;
218
219    /// Search memories by query
220    async fn search(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryItem>>;
221
222    /// Search memories by tags
223    async fn search_by_tags(
224        &self,
225        tags: &[String],
226        limit: usize,
227    ) -> anyhow::Result<Vec<MemoryItem>>;
228
229    /// Get recent memories
230    async fn get_recent(&self, limit: usize) -> anyhow::Result<Vec<MemoryItem>>;
231
232    /// Get important memories
233    async fn get_important(&self, threshold: f32, limit: usize) -> anyhow::Result<Vec<MemoryItem>>;
234
235    /// Delete a memory
236    async fn delete(&self, id: &str) -> anyhow::Result<()>;
237
238    /// Clear all memories
239    async fn clear(&self) -> anyhow::Result<()>;
240
241    /// Get total memory count
242    async fn count(&self) -> anyhow::Result<usize>;
243}
244
245// ============================================================================
246// Shared Search/Sort Helpers (DRY)
247// ============================================================================
248
249/// Search memories by content substring, sorted by relevance
250fn search_memories(memories: &[MemoryItem], query: &str, limit: usize) -> Vec<MemoryItem> {
251    let query_lower = query.to_lowercase();
252    let mut results: Vec<_> = memories
253        .iter()
254        .filter(|m| m.content_lower.contains(&query_lower))
255        .cloned()
256        .collect();
257    sort_by_relevance(&mut results);
258    results.truncate(limit);
259    results
260}
261
262/// Search memories by tags, sorted by relevance
263fn search_memories_by_tags(
264    memories: &[MemoryItem],
265    tags: &[String],
266    limit: usize,
267) -> Vec<MemoryItem> {
268    let mut results: Vec<_> = memories
269        .iter()
270        .filter(|m| tags.iter().any(|tag| m.tags.contains(tag)))
271        .cloned()
272        .collect();
273    sort_by_relevance(&mut results);
274    results.truncate(limit);
275    results
276}
277
278/// Get recent memories sorted by timestamp (newest first)
279fn recent_memories(memories: &[MemoryItem], limit: usize) -> Vec<MemoryItem> {
280    let mut results: Vec<_> = memories.to_vec();
281    results.sort_by(|a, b| b.timestamp.cmp(&a.timestamp));
282    results.truncate(limit);
283    results
284}
285
286/// Get important memories above threshold, sorted by importance
287fn important_memories(memories: &[MemoryItem], threshold: f32, limit: usize) -> Vec<MemoryItem> {
288    let mut results: Vec<_> = memories
289        .iter()
290        .filter(|m| m.importance >= threshold)
291        .cloned()
292        .collect();
293    results.sort_by(|a, b| {
294        b.importance
295            .partial_cmp(&a.importance)
296            .unwrap_or(std::cmp::Ordering::Equal)
297    });
298    results.truncate(limit);
299    results
300}
301
302/// Sort memory items by relevance score (highest first)
303fn sort_by_relevance(items: &mut [MemoryItem]) {
304    let now = Utc::now();
305    items.sort_by(|a, b| {
306        b.relevance_score_at(now)
307            .partial_cmp(&a.relevance_score_at(now))
308            .unwrap_or(std::cmp::Ordering::Equal)
309    });
310}
311
312// ============================================================================
313// In-Memory Store
314// ============================================================================
315
316/// Simple in-memory storage (for testing and development)
317#[derive(Debug, Clone)]
318pub struct InMemoryStore {
319    memories: Arc<RwLock<Vec<MemoryItem>>>,
320}
321
322impl InMemoryStore {
323    /// Create a new in-memory store
324    pub fn new() -> Self {
325        Self {
326            memories: Arc::new(RwLock::new(Vec::new())),
327        }
328    }
329}
330
331impl Default for InMemoryStore {
332    fn default() -> Self {
333        Self::new()
334    }
335}
336
337#[async_trait::async_trait]
338impl MemoryStore for InMemoryStore {
339    async fn store(&self, item: MemoryItem) -> anyhow::Result<()> {
340        let mut memories = self.memories.write().await;
341        memories.push(item);
342        Ok(())
343    }
344
345    async fn retrieve(&self, id: &str) -> anyhow::Result<Option<MemoryItem>> {
346        let memories = self.memories.read().await;
347        Ok(memories.iter().find(|m| m.id == id).cloned())
348    }
349
350    async fn search(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
351        let memories = self.memories.read().await;
352        Ok(search_memories(&memories, query, limit))
353    }
354
355    async fn search_by_tags(
356        &self,
357        tags: &[String],
358        limit: usize,
359    ) -> anyhow::Result<Vec<MemoryItem>> {
360        let memories = self.memories.read().await;
361        Ok(search_memories_by_tags(&memories, tags, limit))
362    }
363
364    async fn get_recent(&self, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
365        let memories = self.memories.read().await;
366        Ok(recent_memories(&memories, limit))
367    }
368
369    async fn get_important(&self, threshold: f32, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
370        let memories = self.memories.read().await;
371        Ok(important_memories(&memories, threshold, limit))
372    }
373
374    async fn delete(&self, id: &str) -> anyhow::Result<()> {
375        let mut memories = self.memories.write().await;
376        memories.retain(|m| m.id != id);
377        Ok(())
378    }
379
380    async fn clear(&self) -> anyhow::Result<()> {
381        let mut memories = self.memories.write().await;
382        memories.clear();
383        Ok(())
384    }
385
386    async fn count(&self) -> anyhow::Result<usize> {
387        let memories = self.memories.read().await;
388        Ok(memories.len())
389    }
390}
391
392// ============================================================================
393// File-Based Store
394// ============================================================================
395
396/// File-based persistent storage using JSONL format
397#[derive(Debug, Clone)]
398pub struct FileStore {
399    file_path: std::path::PathBuf,
400    memories: Arc<RwLock<Vec<MemoryItem>>>,
401}
402
403impl FileStore {
404    /// Create a new file-based store
405    ///
406    /// Note: This constructor performs blocking I/O to load existing memories.
407    /// For async contexts, consider using `FileStore::open()` instead.
408    pub fn new(file_path: impl Into<std::path::PathBuf>) -> anyhow::Result<Self> {
409        let file_path = file_path.into();
410
411        // Create parent directory if it doesn't exist
412        if let Some(parent) = file_path.parent() {
413            std::fs::create_dir_all(parent)?;
414        }
415
416        // Load existing memories from file
417        let memories = if file_path.exists() {
418            Self::load_from_file(&file_path)?
419        } else {
420            Vec::new()
421        };
422
423        Ok(Self {
424            file_path,
425            memories: Arc::new(RwLock::new(memories)),
426        })
427    }
428
429    /// Create a new file-based store asynchronously
430    pub async fn open(file_path: impl Into<std::path::PathBuf>) -> anyhow::Result<Self> {
431        let file_path = file_path.into();
432
433        // Create parent directory if it doesn't exist
434        if let Some(parent) = file_path.parent() {
435            tokio::fs::create_dir_all(parent).await?;
436        }
437
438        // Load existing memories from file
439        let memories = if file_path.exists() {
440            let content = tokio::fs::read_to_string(&file_path).await?;
441            Self::parse_jsonl(&content)?
442        } else {
443            Vec::new()
444        };
445
446        Ok(Self {
447            file_path,
448            memories: Arc::new(RwLock::new(memories)),
449        })
450    }
451
452    /// Load memories from JSONL file (blocking)
453    fn load_from_file(path: &std::path::Path) -> anyhow::Result<Vec<MemoryItem>> {
454        let content = std::fs::read_to_string(path)?;
455        Self::parse_jsonl(&content)
456    }
457
458    /// Parse JSONL content into memory items
459    fn parse_jsonl(content: &str) -> anyhow::Result<Vec<MemoryItem>> {
460        let mut memories = Vec::new();
461
462        for line in content.lines() {
463            if line.trim().is_empty() {
464                continue;
465            }
466            let mut item: MemoryItem = serde_json::from_str(line)?;
467            item.content_lower = item.content.to_lowercase();
468            memories.push(item);
469        }
470
471        Ok(memories)
472    }
473
474    /// Save all memories to JSONL file
475    async fn save_to_file(&self) -> anyhow::Result<()> {
476        let memories = self.memories.read().await;
477        let mut content = String::new();
478
479        for memory in memories.iter() {
480            let json = serde_json::to_string(memory)?;
481            content.push_str(&json);
482            content.push('\n');
483        }
484
485        // Write atomically using a temporary file
486        let temp_path = self.file_path.with_extension("tmp");
487        tokio::fs::write(&temp_path, content).await?;
488        tokio::fs::rename(&temp_path, &self.file_path).await?;
489
490        Ok(())
491    }
492}
493
494#[async_trait::async_trait]
495impl MemoryStore for FileStore {
496    async fn store(&self, item: MemoryItem) -> anyhow::Result<()> {
497        {
498            let mut memories = self.memories.write().await;
499            memories.push(item);
500        }
501        self.save_to_file().await
502    }
503
504    async fn retrieve(&self, id: &str) -> anyhow::Result<Option<MemoryItem>> {
505        let memories = self.memories.read().await;
506        Ok(memories.iter().find(|m| m.id == id).cloned())
507    }
508
509    async fn search(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
510        let memories = self.memories.read().await;
511        Ok(search_memories(&memories, query, limit))
512    }
513
514    async fn search_by_tags(
515        &self,
516        tags: &[String],
517        limit: usize,
518    ) -> anyhow::Result<Vec<MemoryItem>> {
519        let memories = self.memories.read().await;
520        Ok(search_memories_by_tags(&memories, tags, limit))
521    }
522
523    async fn get_recent(&self, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
524        let memories = self.memories.read().await;
525        Ok(recent_memories(&memories, limit))
526    }
527
528    async fn get_important(&self, threshold: f32, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
529        let memories = self.memories.read().await;
530        Ok(important_memories(&memories, threshold, limit))
531    }
532
533    async fn delete(&self, id: &str) -> anyhow::Result<()> {
534        {
535            let mut memories = self.memories.write().await;
536            memories.retain(|m| m.id != id);
537        }
538        self.save_to_file().await
539    }
540
541    async fn clear(&self) -> anyhow::Result<()> {
542        {
543            let mut memories = self.memories.write().await;
544            memories.clear();
545        }
546        self.save_to_file().await
547    }
548
549    async fn count(&self) -> anyhow::Result<usize> {
550        let memories = self.memories.read().await;
551        Ok(memories.len())
552    }
553}
554
555// ============================================================================
556// Agent Memory
557// ============================================================================
558
559/// Agent memory system
560#[derive(Clone)]
561pub struct AgentMemory {
562    /// Long-term memory store
563    store: Arc<dyn MemoryStore>,
564    /// Short-term memory (current session)
565    short_term: Arc<RwLock<VecDeque<MemoryItem>>>,
566    /// Working memory (active context)
567    working: Arc<RwLock<Vec<MemoryItem>>>,
568    /// Maximum short-term memory size
569    max_short_term: usize,
570    /// Maximum working memory size
571    max_working: usize,
572    /// Relevance scoring configuration
573    relevance_config: RelevanceConfig,
574}
575
576impl AgentMemory {
577    /// Create a new agent memory system with default configuration
578    pub fn new(store: Arc<dyn MemoryStore>) -> Self {
579        Self::with_config(store, MemoryConfig::default())
580    }
581
582    /// Create a new agent memory system with custom configuration
583    pub fn with_config(store: Arc<dyn MemoryStore>, config: MemoryConfig) -> Self {
584        Self {
585            store,
586            short_term: Arc::new(RwLock::new(VecDeque::new())),
587            working: Arc::new(RwLock::new(Vec::new())),
588            max_short_term: config.max_short_term,
589            max_working: config.max_working,
590            relevance_config: config.relevance,
591        }
592    }
593
594    /// Create with in-memory store (for testing)
595    pub fn in_memory() -> Self {
596        Self::new(Arc::new(InMemoryStore::new()))
597    }
598
599    /// Calculate relevance score using this memory system's configuration
600    fn score(&self, item: &MemoryItem, now: DateTime<Utc>) -> f32 {
601        let age_seconds = (now - item.timestamp).num_seconds() as f32;
602        let age_days = age_seconds / 86400.0;
603        let decay = (-age_days / self.relevance_config.decay_days).exp();
604        item.importance * self.relevance_config.importance_weight
605            + decay * self.relevance_config.recency_weight
606    }
607
608    /// Store a memory in long-term storage
609    pub async fn remember(&self, item: MemoryItem) -> anyhow::Result<()> {
610        // Store in long-term
611        self.store.store(item.clone()).await?;
612
613        // Add to short-term
614        let mut short_term = self.short_term.write().await;
615        short_term.push_back(item);
616
617        // Trim if needed
618        if short_term.len() > self.max_short_term {
619            short_term.pop_front();
620        }
621
622        Ok(())
623    }
624
625    /// Remember a successful pattern
626    pub async fn remember_success(
627        &self,
628        prompt: &str,
629        tools_used: &[String],
630        result: &str,
631    ) -> anyhow::Result<()> {
632        let content = format!(
633            "Success: {}\nTools: {}\nResult: {}",
634            prompt,
635            tools_used.join(", "),
636            result
637        );
638
639        let item = MemoryItem::new(content)
640            .with_importance(0.8)
641            .with_tag("success")
642            .with_tag("pattern")
643            .with_type(MemoryType::Procedural)
644            .with_metadata("prompt", prompt)
645            .with_metadata("tools", tools_used.join(","));
646
647        self.remember(item).await
648    }
649
650    /// Remember a failure to avoid repeating
651    pub async fn remember_failure(
652        &self,
653        prompt: &str,
654        error: &str,
655        attempted_tools: &[String],
656    ) -> anyhow::Result<()> {
657        let content = format!(
658            "Failure: {}\nError: {}\nAttempted tools: {}",
659            prompt,
660            error,
661            attempted_tools.join(", ")
662        );
663
664        let item = MemoryItem::new(content)
665            .with_importance(0.9) // Failures are important to remember
666            .with_tag("failure")
667            .with_tag("avoid")
668            .with_type(MemoryType::Episodic)
669            .with_metadata("prompt", prompt)
670            .with_metadata("error", error);
671
672        self.remember(item).await
673    }
674
675    /// Recall similar past experiences
676    pub async fn recall_similar(
677        &self,
678        prompt: &str,
679        limit: usize,
680    ) -> anyhow::Result<Vec<MemoryItem>> {
681        self.store.search(prompt, limit).await
682    }
683
684    /// Recall by tags
685    pub async fn recall_by_tags(
686        &self,
687        tags: &[String],
688        limit: usize,
689    ) -> anyhow::Result<Vec<MemoryItem>> {
690        self.store.search_by_tags(tags, limit).await
691    }
692
693    /// Get recent memories
694    pub async fn get_recent(&self, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
695        self.store.get_recent(limit).await
696    }
697
698    /// Add to working memory
699    pub async fn add_to_working(&self, item: MemoryItem) -> anyhow::Result<()> {
700        let mut working = self.working.write().await;
701        working.push(item);
702
703        // Trim if needed (keep most relevant)
704        if working.len() > self.max_working {
705            let now = Utc::now();
706            working.sort_by(|a, b| {
707                self.score(b, now)
708                    .partial_cmp(&self.score(a, now))
709                    .unwrap_or(std::cmp::Ordering::Equal)
710            });
711            working.truncate(self.max_working);
712        }
713
714        Ok(())
715    }
716
717    /// Get working memory
718    pub async fn get_working(&self) -> Vec<MemoryItem> {
719        self.working.read().await.clone()
720    }
721
722    /// Clear working memory
723    pub async fn clear_working(&self) {
724        self.working.write().await.clear();
725    }
726
727    /// Get short-term memory
728    pub async fn get_short_term(&self) -> Vec<MemoryItem> {
729        self.short_term.read().await.iter().cloned().collect()
730    }
731
732    /// Clear short-term memory
733    pub async fn clear_short_term(&self) {
734        self.short_term.write().await.clear();
735    }
736
737    /// Get memory statistics
738    pub async fn stats(&self) -> anyhow::Result<MemoryStats> {
739        let long_term_count = self.store.count().await?;
740        let short_term_count = self.short_term.read().await.len();
741        let working_count = self.working.read().await.len();
742
743        Ok(MemoryStats {
744            long_term_count,
745            short_term_count,
746            working_count,
747        })
748    }
749
750    /// Get access to the underlying store
751    pub fn store(&self) -> &Arc<dyn MemoryStore> {
752        &self.store
753    }
754
755    /// Get working memory count
756    pub async fn working_count(&self) -> usize {
757        self.working.read().await.len()
758    }
759
760    /// Get short-term memory count
761    pub async fn short_term_count(&self) -> usize {
762        self.short_term.read().await.len()
763    }
764}
765
766/// Memory statistics
767#[derive(Debug, Clone, Serialize, Deserialize)]
768pub struct MemoryStats {
769    /// Number of long-term memories
770    pub long_term_count: usize,
771    /// Number of short-term memories
772    pub short_term_count: usize,
773    /// Number of working memories
774    pub working_count: usize,
775}
776
777// ============================================================================
778// Memory Context Provider
779// ============================================================================
780
781/// Context provider that surfaces past memories (successes/failures) as context.
782///
783/// Wraps `AgentMemory` and implements the `ContextProvider` trait so that
784/// session memory is automatically injected into the agent's system prompt.
785pub struct MemoryContextProvider {
786    memory: AgentMemory,
787}
788
789impl MemoryContextProvider {
790    /// Create a new memory context provider
791    pub fn new(memory: AgentMemory) -> Self {
792        Self { memory }
793    }
794}
795
796#[async_trait::async_trait]
797impl crate::context::ContextProvider for MemoryContextProvider {
798    fn name(&self) -> &str {
799        "memory"
800    }
801
802    async fn query(
803        &self,
804        query: &crate::context::ContextQuery,
805    ) -> anyhow::Result<crate::context::ContextResult> {
806        let limit = query.max_results.min(5);
807        let items = self.memory.recall_similar(&query.query, limit).await?;
808
809        let mut result = crate::context::ContextResult::new("memory");
810        for item in items {
811            let relevance = item.relevance_score();
812            let token_count = item.content.len() / 4; // rough estimate
813            let context_item = crate::context::ContextItem::new(
814                &item.id,
815                crate::context::ContextType::Memory,
816                &item.content,
817            )
818            .with_relevance(relevance)
819            .with_token_count(token_count)
820            .with_source("memory");
821            result.add_item(context_item);
822        }
823
824        Ok(result)
825    }
826
827    async fn on_turn_complete(
828        &self,
829        _session_id: &str,
830        prompt: &str,
831        response: &str,
832    ) -> anyhow::Result<()> {
833        // Store the successful interaction as a memory
834        self.memory.remember_success(prompt, &[], response).await
835    }
836}
837
838// ============================================================================
839// Tests
840// ============================================================================
841
842#[cfg(test)]
843mod tests {
844    use super::*;
845
846    #[test]
847    fn test_memory_item_creation() {
848        let item = MemoryItem::new("Test memory")
849            .with_importance(0.8)
850            .with_tag("test")
851            .with_type(MemoryType::Semantic);
852
853        assert_eq!(item.content, "Test memory");
854        assert_eq!(item.importance, 0.8);
855        assert_eq!(item.tags, vec!["test"]);
856        assert_eq!(item.memory_type, MemoryType::Semantic);
857    }
858
859    #[test]
860    fn test_memory_item_relevance() {
861        let item = MemoryItem::new("Test").with_importance(0.9);
862        let score = item.relevance_score();
863
864        // Should be high for recent, important memory
865        assert!(score > 0.6);
866    }
867
868    #[test]
869    fn test_relevance_config_defaults() {
870        let config = RelevanceConfig::default();
871        assert_eq!(config.decay_days, 30.0);
872        assert_eq!(config.importance_weight, 0.7);
873        assert_eq!(config.recency_weight, 0.3);
874    }
875
876    #[test]
877    fn test_memory_config_defaults() {
878        let config = MemoryConfig::default();
879        assert_eq!(config.max_short_term, 100);
880        assert_eq!(config.max_working, 10);
881        assert_eq!(config.relevance.decay_days, 30.0);
882    }
883
884    #[test]
885    fn test_memory_config_serde_roundtrip() {
886        let config = MemoryConfig::default();
887        let json = serde_json::to_string(&config).unwrap();
888        let parsed: MemoryConfig = serde_json::from_str(&json).unwrap();
889        assert_eq!(parsed.max_short_term, config.max_short_term);
890        assert_eq!(parsed.max_working, config.max_working);
891        assert_eq!(parsed.relevance.decay_days, config.relevance.decay_days);
892    }
893
894    #[test]
895    fn test_agent_memory_with_config() {
896        let config = MemoryConfig {
897            relevance: RelevanceConfig {
898                decay_days: 7.0,
899                importance_weight: 0.5,
900                recency_weight: 0.5,
901            },
902            max_short_term: 50,
903            max_working: 5,
904        };
905        let memory = AgentMemory::with_config(
906            Arc::new(InMemoryStore::new()),
907            config,
908        );
909        assert_eq!(memory.max_short_term, 50);
910        assert_eq!(memory.max_working, 5);
911        assert_eq!(memory.relevance_config.decay_days, 7.0);
912    }
913
914    #[test]
915    fn test_agent_memory_score_uses_config() {
916        let config = MemoryConfig {
917            relevance: RelevanceConfig {
918                decay_days: 7.0,
919                importance_weight: 0.9,
920                recency_weight: 0.1,
921            },
922            ..Default::default()
923        };
924        let memory = AgentMemory::with_config(
925            Arc::new(InMemoryStore::new()),
926            config,
927        );
928
929        let item = MemoryItem::new("Test").with_importance(1.0);
930        let now = Utc::now();
931        let score = memory.score(&item, now);
932
933        // With importance_weight=0.9, a brand new item with importance=1.0
934        // should score close to 0.9 + 0.1 = 1.0 (decay ~1.0 for recent items)
935        assert!(score > 0.95, "Score was {}", score);
936    }
937
938    #[tokio::test]
939    async fn test_in_memory_store() {
940        let store = InMemoryStore::new();
941
942        let item = MemoryItem::new("Test memory").with_tag("test");
943        store.store(item.clone()).await.unwrap();
944
945        let retrieved = store.retrieve(&item.id).await.unwrap();
946        assert!(retrieved.is_some());
947        assert_eq!(retrieved.unwrap().content, "Test memory");
948    }
949
950    #[tokio::test]
951    async fn test_memory_search() {
952        let store = InMemoryStore::new();
953
954        store
955            .store(MemoryItem::new("How to create a file").with_tag("file"))
956            .await
957            .unwrap();
958        store
959            .store(MemoryItem::new("How to delete a file").with_tag("file"))
960            .await
961            .unwrap();
962        store
963            .store(MemoryItem::new("How to create a directory").with_tag("dir"))
964            .await
965            .unwrap();
966
967        let results = store.search("create", 10).await.unwrap();
968        assert_eq!(results.len(), 2);
969
970        let results = store
971            .search_by_tags(&["file".to_string()], 10)
972            .await
973            .unwrap();
974        assert_eq!(results.len(), 2);
975    }
976
977    #[tokio::test]
978    async fn test_agent_memory() {
979        let memory = AgentMemory::in_memory();
980
981        // Remember success
982        memory
983            .remember_success("Create a file", &["write".to_string()], "File created")
984            .await
985            .unwrap();
986
987        // Remember failure
988        memory
989            .remember_failure("Delete file", "Permission denied", &["bash".to_string()])
990            .await
991            .unwrap();
992
993        // Recall
994        let results = memory.recall_similar("create", 10).await.unwrap();
995        assert!(!results.is_empty());
996
997        let stats = memory.stats().await.unwrap();
998        assert_eq!(stats.long_term_count, 2);
999    }
1000
1001    #[tokio::test]
1002    async fn test_working_memory() {
1003        let memory = AgentMemory::in_memory();
1004
1005        let item = MemoryItem::new("Active task").with_type(MemoryType::Working);
1006        memory.add_to_working(item).await.unwrap();
1007
1008        let working = memory.get_working().await;
1009        assert_eq!(working.len(), 1);
1010
1011        memory.clear_working().await;
1012        let working = memory.get_working().await;
1013        assert_eq!(working.len(), 0);
1014    }
1015
1016    #[tokio::test]
1017    async fn test_file_store_basic() {
1018        let temp_dir = std::env::temp_dir();
1019        let test_file = temp_dir.join(format!("test_memory_{}.jsonl", uuid::Uuid::new_v4()));
1020
1021        // Create store
1022        let store = FileStore::new(&test_file).unwrap();
1023
1024        // Store items
1025        let item1 = MemoryItem::new("Test memory 1").with_tag("test");
1026        let item2 = MemoryItem::new("Test memory 2").with_tag("test");
1027
1028        store.store(item1.clone()).await.unwrap();
1029        store.store(item2.clone()).await.unwrap();
1030
1031        // Verify count
1032        assert_eq!(store.count().await.unwrap(), 2);
1033
1034        // Retrieve
1035        let retrieved = store.retrieve(&item1.id).await.unwrap();
1036        assert!(retrieved.is_some());
1037        assert_eq!(retrieved.unwrap().content, "Test memory 1");
1038
1039        // Clean up
1040        let _ = std::fs::remove_file(&test_file);
1041    }
1042
1043    #[tokio::test]
1044    async fn test_file_store_persistence() {
1045        let temp_dir = std::env::temp_dir();
1046        let test_file = temp_dir.join(format!(
1047            "test_memory_persist_{}.jsonl",
1048            uuid::Uuid::new_v4()
1049        ));
1050
1051        let item_id = {
1052            // Create store and add item
1053            let store = FileStore::new(&test_file).unwrap();
1054            let item = MemoryItem::new("Persistent memory").with_importance(0.9);
1055            let id = item.id.clone();
1056            store.store(item).await.unwrap();
1057            id
1058        };
1059
1060        // Create new store instance (simulating restart)
1061        let store2 = FileStore::new(&test_file).unwrap();
1062
1063        // Verify data persisted
1064        assert_eq!(store2.count().await.unwrap(), 1);
1065        let retrieved = store2.retrieve(&item_id).await.unwrap();
1066        assert!(retrieved.is_some());
1067        assert_eq!(retrieved.unwrap().content, "Persistent memory");
1068
1069        // Clean up
1070        let _ = std::fs::remove_file(&test_file);
1071    }
1072
1073    #[tokio::test]
1074    async fn test_file_store_search() {
1075        let temp_dir = std::env::temp_dir();
1076        let test_file = temp_dir.join(format!("test_memory_search_{}.jsonl", uuid::Uuid::new_v4()));
1077
1078        let store = FileStore::new(&test_file).unwrap();
1079
1080        // Store multiple items
1081        store
1082            .store(MemoryItem::new("How to create a file").with_tag("file"))
1083            .await
1084            .unwrap();
1085        store
1086            .store(MemoryItem::new("How to delete a file").with_tag("file"))
1087            .await
1088            .unwrap();
1089        store
1090            .store(MemoryItem::new("How to create a directory").with_tag("dir"))
1091            .await
1092            .unwrap();
1093
1094        // Search by content
1095        let results = store.search("create", 10).await.unwrap();
1096        assert_eq!(results.len(), 2);
1097
1098        // Search by tags
1099        let results = store
1100            .search_by_tags(&["file".to_string()], 10)
1101            .await
1102            .unwrap();
1103        assert_eq!(results.len(), 2);
1104
1105        // Clean up
1106        let _ = std::fs::remove_file(&test_file);
1107    }
1108
1109    #[tokio::test]
1110    async fn test_file_store_delete() {
1111        let temp_dir = std::env::temp_dir();
1112        let test_file = temp_dir.join(format!("test_memory_delete_{}.jsonl", uuid::Uuid::new_v4()));
1113
1114        let store = FileStore::new(&test_file).unwrap();
1115
1116        let item = MemoryItem::new("To be deleted");
1117        let item_id = item.id.clone();
1118        store.store(item).await.unwrap();
1119
1120        assert_eq!(store.count().await.unwrap(), 1);
1121
1122        // Delete
1123        store.delete(&item_id).await.unwrap();
1124        assert_eq!(store.count().await.unwrap(), 0);
1125
1126        // Verify persistence
1127        let store2 = FileStore::new(&test_file).unwrap();
1128        assert_eq!(store2.count().await.unwrap(), 0);
1129
1130        // Clean up
1131        let _ = std::fs::remove_file(&test_file);
1132    }
1133
1134    #[tokio::test]
1135    async fn test_file_store_clear() {
1136        let temp_dir = std::env::temp_dir();
1137        let test_file = temp_dir.join(format!("test_memory_clear_{}.jsonl", uuid::Uuid::new_v4()));
1138
1139        let store = FileStore::new(&test_file).unwrap();
1140
1141        // Store multiple items
1142        for i in 0..5 {
1143            store
1144                .store(MemoryItem::new(format!("Memory {}", i)))
1145                .await
1146                .unwrap();
1147        }
1148
1149        assert_eq!(store.count().await.unwrap(), 5);
1150
1151        // Clear
1152        store.clear().await.unwrap();
1153        assert_eq!(store.count().await.unwrap(), 0);
1154
1155        // Verify persistence
1156        let store2 = FileStore::new(&test_file).unwrap();
1157        assert_eq!(store2.count().await.unwrap(), 0);
1158
1159        // Clean up
1160        let _ = std::fs::remove_file(&test_file);
1161    }
1162}
1163
1164#[cfg(test)]
1165mod extra_memory_tests {
1166    use super::*;
1167
1168    // ========================================================================
1169    // MemoryItem builder methods
1170    // ========================================================================
1171
1172    #[test]
1173    fn test_memory_item_with_metadata() {
1174        let item = MemoryItem::new("test")
1175            .with_metadata("key1", "value1")
1176            .with_metadata("key2", "value2");
1177        assert_eq!(item.metadata.get("key1").unwrap(), "value1");
1178        assert_eq!(item.metadata.get("key2").unwrap(), "value2");
1179    }
1180
1181    #[test]
1182    fn test_memory_item_with_tags_vec() {
1183        let item = MemoryItem::new("test").with_tags(vec![
1184            "a".to_string(),
1185            "b".to_string(),
1186            "c".to_string(),
1187        ]);
1188        assert_eq!(item.tags.len(), 3);
1189    }
1190
1191    #[test]
1192    fn test_memory_item_importance_clamped() {
1193        let item_high = MemoryItem::new("test").with_importance(1.5);
1194        assert_eq!(item_high.importance, 1.0);
1195
1196        let item_low = MemoryItem::new("test").with_importance(-0.5);
1197        assert_eq!(item_low.importance, 0.0);
1198    }
1199
1200    #[test]
1201    fn test_memory_item_record_access() {
1202        let mut item = MemoryItem::new("test");
1203        assert_eq!(item.access_count, 0);
1204        assert!(item.last_accessed.is_none());
1205
1206        item.record_access();
1207        assert_eq!(item.access_count, 1);
1208        assert!(item.last_accessed.is_some());
1209
1210        item.record_access();
1211        assert_eq!(item.access_count, 2);
1212    }
1213
1214    #[test]
1215    fn test_memory_item_all_types() {
1216        let episodic = MemoryItem::new("e").with_type(MemoryType::Episodic);
1217        assert_eq!(episodic.memory_type, MemoryType::Episodic);
1218
1219        let semantic = MemoryItem::new("s").with_type(MemoryType::Semantic);
1220        assert_eq!(semantic.memory_type, MemoryType::Semantic);
1221
1222        let procedural = MemoryItem::new("p").with_type(MemoryType::Procedural);
1223        assert_eq!(procedural.memory_type, MemoryType::Procedural);
1224
1225        let working = MemoryItem::new("w").with_type(MemoryType::Working);
1226        assert_eq!(working.memory_type, MemoryType::Working);
1227    }
1228
1229    #[test]
1230    fn test_memory_item_default_type_is_episodic() {
1231        let item = MemoryItem::new("test");
1232        assert_eq!(item.memory_type, MemoryType::Episodic);
1233    }
1234
1235    // ========================================================================
1236    // InMemoryStore
1237    // ========================================================================
1238
1239    #[tokio::test]
1240    async fn test_in_memory_store_retrieve_nonexistent() {
1241        let store = InMemoryStore::new();
1242        let result = store.retrieve("nonexistent").await.unwrap();
1243        assert!(result.is_none());
1244    }
1245
1246    #[tokio::test]
1247    async fn test_in_memory_store_delete() {
1248        let store = InMemoryStore::new();
1249        let item = MemoryItem::new("to delete");
1250        let id = item.id.clone();
1251        store.store(item).await.unwrap();
1252        assert_eq!(store.count().await.unwrap(), 1);
1253
1254        store.delete(&id).await.unwrap();
1255        assert_eq!(store.count().await.unwrap(), 0);
1256    }
1257
1258    #[tokio::test]
1259    async fn test_in_memory_store_clear() {
1260        let store = InMemoryStore::new();
1261        for i in 0..5 {
1262            store
1263                .store(MemoryItem::new(format!("item {}", i)))
1264                .await
1265                .unwrap();
1266        }
1267        assert_eq!(store.count().await.unwrap(), 5);
1268
1269        store.clear().await.unwrap();
1270        assert_eq!(store.count().await.unwrap(), 0);
1271    }
1272
1273    #[tokio::test]
1274    async fn test_in_memory_store_get_recent() {
1275        let store = InMemoryStore::new();
1276        for i in 0..5 {
1277            store
1278                .store(MemoryItem::new(format!("item {}", i)))
1279                .await
1280                .unwrap();
1281        }
1282        let recent = store.get_recent(3).await.unwrap();
1283        assert_eq!(recent.len(), 3);
1284    }
1285
1286    #[tokio::test]
1287    async fn test_in_memory_store_get_important() {
1288        let store = InMemoryStore::new();
1289        store
1290            .store(MemoryItem::new("low").with_importance(0.2))
1291            .await
1292            .unwrap();
1293        store
1294            .store(MemoryItem::new("medium").with_importance(0.5))
1295            .await
1296            .unwrap();
1297        store
1298            .store(MemoryItem::new("high").with_importance(0.9))
1299            .await
1300            .unwrap();
1301
1302        let important = store.get_important(0.7, 10).await.unwrap();
1303        assert_eq!(important.len(), 1);
1304        assert_eq!(important[0].content, "high");
1305    }
1306
1307    #[tokio::test]
1308    async fn test_in_memory_store_search_case_insensitive() {
1309        let store = InMemoryStore::new();
1310        store
1311            .store(MemoryItem::new("How to CREATE a file"))
1312            .await
1313            .unwrap();
1314        let results = store.search("create", 10).await.unwrap();
1315        assert_eq!(results.len(), 1);
1316    }
1317
1318    // ========================================================================
1319    // AgentMemory
1320    // ========================================================================
1321
1322    #[tokio::test]
1323    async fn test_agent_memory_short_term() {
1324        let memory = AgentMemory::in_memory();
1325        memory.remember(MemoryItem::new("item 1")).await.unwrap();
1326        memory.remember(MemoryItem::new("item 2")).await.unwrap();
1327
1328        let short_term = memory.get_short_term().await;
1329        assert_eq!(short_term.len(), 2);
1330
1331        memory.clear_short_term().await;
1332        let short_term = memory.get_short_term().await;
1333        assert_eq!(short_term.len(), 0);
1334    }
1335
1336    #[tokio::test]
1337    async fn test_agent_memory_short_term_count() {
1338        let memory = AgentMemory::in_memory();
1339        assert_eq!(memory.short_term_count().await, 0);
1340        memory.remember(MemoryItem::new("item")).await.unwrap();
1341        assert_eq!(memory.short_term_count().await, 1);
1342    }
1343
1344    #[tokio::test]
1345    async fn test_agent_memory_working_count() {
1346        let memory = AgentMemory::in_memory();
1347        assert_eq!(memory.working_count().await, 0);
1348        memory
1349            .add_to_working(MemoryItem::new("task"))
1350            .await
1351            .unwrap();
1352        assert_eq!(memory.working_count().await, 1);
1353    }
1354
1355    #[tokio::test]
1356    async fn test_agent_memory_recall_by_tags() {
1357        let memory = AgentMemory::in_memory();
1358        memory
1359            .remember_success("create file", &["write".to_string()], "ok")
1360            .await
1361            .unwrap();
1362        memory
1363            .remember_failure("delete file", "denied", &["bash".to_string()])
1364            .await
1365            .unwrap();
1366
1367        let successes = memory
1368            .recall_by_tags(&["success".to_string()], 10)
1369            .await
1370            .unwrap();
1371        assert_eq!(successes.len(), 1);
1372
1373        let failures = memory
1374            .recall_by_tags(&["failure".to_string()], 10)
1375            .await
1376            .unwrap();
1377        assert_eq!(failures.len(), 1);
1378    }
1379
1380    #[tokio::test]
1381    async fn test_agent_memory_get_recent() {
1382        let memory = AgentMemory::in_memory();
1383        for i in 0..5 {
1384            memory
1385                .remember(MemoryItem::new(format!("item {}", i)))
1386                .await
1387                .unwrap();
1388        }
1389        let recent = memory.get_recent(3).await.unwrap();
1390        assert_eq!(recent.len(), 3);
1391    }
1392
1393    #[tokio::test]
1394    async fn test_agent_memory_store_accessor() {
1395        let memory = AgentMemory::in_memory();
1396        memory.remember(MemoryItem::new("test")).await.unwrap();
1397        let count = memory.store().count().await.unwrap();
1398        assert_eq!(count, 1);
1399    }
1400
1401    #[tokio::test]
1402    async fn test_agent_memory_stats_all_fields() {
1403        let memory = AgentMemory::in_memory();
1404        memory.remember(MemoryItem::new("long term")).await.unwrap();
1405        memory
1406            .add_to_working(MemoryItem::new("working"))
1407            .await
1408            .unwrap();
1409
1410        let stats = memory.stats().await.unwrap();
1411        assert_eq!(stats.long_term_count, 1);
1412        assert_eq!(stats.short_term_count, 1); // remember also adds to short_term
1413        assert_eq!(stats.working_count, 1);
1414    }
1415
1416    #[tokio::test]
1417    async fn test_agent_memory_working_overflow_trims() {
1418        let store = Arc::new(InMemoryStore::new());
1419        let memory = AgentMemory {
1420            store,
1421            short_term: Arc::new(RwLock::new(VecDeque::new())),
1422            working: Arc::new(RwLock::new(Vec::new())),
1423            max_short_term: 100,
1424            max_working: 3, // Small limit
1425            relevance_config: RelevanceConfig::default(),
1426        };
1427
1428        for i in 0..5 {
1429            memory
1430                .add_to_working(
1431                    MemoryItem::new(format!("task {}", i)).with_importance(i as f32 * 0.2),
1432                )
1433                .await
1434                .unwrap();
1435        }
1436
1437        let working = memory.get_working().await;
1438        assert_eq!(working.len(), 3); // Trimmed to max_working
1439    }
1440}
1441
1442#[cfg(test)]
1443mod extra_memory_tests2 {
1444    use super::*;
1445
1446    #[tokio::test]
1447    async fn test_file_store_open_creates_parent_dirs() {
1448        // Use a nested path that doesn't exist yet
1449        let dir = tempfile::tempdir().unwrap();
1450        let path = dir
1451            .path()
1452            .join("nested")
1453            .join("deep")
1454            .join("memories.jsonl");
1455        let store = FileStore::open(&path).await.unwrap();
1456        // Should create the parent dirs and start with empty memories
1457        let all = store.search("", 100).await.unwrap();
1458        assert!(all.is_empty());
1459    }
1460
1461    #[tokio::test]
1462    async fn test_file_store_open_loads_existing() {
1463        let dir = tempfile::tempdir().unwrap();
1464        let path = dir.path().join("memories.jsonl");
1465        // Create a store, add a memory, which saves to file
1466        {
1467            let store = FileStore::open(&path).await.unwrap();
1468            let item = MemoryItem::new("test memory".to_string());
1469            store.store(item).await.unwrap();
1470        }
1471        // Re-open and verify the memory persists
1472        let store = FileStore::open(&path).await.unwrap();
1473        let results = store.search("test", 10).await.unwrap();
1474        assert_eq!(results.len(), 1);
1475        assert!(results[0].content.contains("test memory"));
1476    }
1477
1478    #[tokio::test]
1479    async fn test_file_store_open_nonexistent_file() {
1480        let dir = tempfile::tempdir().unwrap();
1481        let path = dir.path().join("nonexistent.jsonl");
1482        let store = FileStore::open(&path).await.unwrap();
1483        let all = store.search("", 100).await.unwrap();
1484        assert!(all.is_empty());
1485    }
1486
1487    #[test]
1488    fn test_parse_jsonl_empty_string() {
1489        let result = FileStore::parse_jsonl("").unwrap();
1490        assert!(result.is_empty());
1491    }
1492
1493    #[test]
1494    fn test_parse_jsonl_empty_lines_skipped() {
1495        // Create valid JSONL with empty lines interspersed
1496        let item = MemoryItem::new("hello".to_string());
1497        let json = serde_json::to_string(&item).unwrap();
1498        let content = format!("\n{}\n\n{}\n\n", json, json);
1499        let result = FileStore::parse_jsonl(&content).unwrap();
1500        assert_eq!(result.len(), 2);
1501    }
1502
1503    #[test]
1504    fn test_parse_jsonl_invalid_json_returns_error() {
1505        let result = FileStore::parse_jsonl("not valid json");
1506        assert!(result.is_err());
1507    }
1508
1509    #[test]
1510    fn test_parse_jsonl_valid_single_line() {
1511        let item = MemoryItem::new("single".to_string());
1512        let json = serde_json::to_string(&item).unwrap();
1513        let result = FileStore::parse_jsonl(&json).unwrap();
1514        assert_eq!(result.len(), 1);
1515        assert_eq!(result[0].content, "single");
1516    }
1517}