Skip to main content

a3s_code_core/
memory.rs

1//! Memory and learning system for the agent
2//!
3//! This module provides memory storage, recall, and learning capabilities
4//! to enable the agent to learn from past experiences and improve over time.
5
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, VecDeque};
9use std::sync::Arc;
10use tokio::sync::RwLock;
11
12// ============================================================================
13// Configuration
14// ============================================================================
15
16/// Configuration for relevance scoring
17#[derive(Debug, Clone, Serialize, Deserialize)]
18#[serde(rename_all = "camelCase")]
19pub struct RelevanceConfig {
20    /// Exponential decay half-life in days (default: 30.0)
21    #[serde(default = "RelevanceConfig::default_decay_days")]
22    pub decay_days: f32,
23    /// Weight for importance factor (default: 0.7)
24    #[serde(default = "RelevanceConfig::default_importance_weight")]
25    pub importance_weight: f32,
26    /// Weight for recency factor (default: 0.3)
27    #[serde(default = "RelevanceConfig::default_recency_weight")]
28    pub recency_weight: f32,
29}
30
31impl RelevanceConfig {
32    fn default_decay_days() -> f32 {
33        30.0
34    }
35    fn default_importance_weight() -> f32 {
36        0.7
37    }
38    fn default_recency_weight() -> f32 {
39        0.3
40    }
41}
42
43impl Default for RelevanceConfig {
44    fn default() -> Self {
45        Self {
46            decay_days: 30.0,
47            importance_weight: 0.7,
48            recency_weight: 0.3,
49        }
50    }
51}
52
53/// Configuration for the agent memory system
54#[derive(Debug, Clone, Serialize, Deserialize)]
55#[serde(rename_all = "camelCase")]
56pub struct MemoryConfig {
57    /// Relevance scoring parameters
58    #[serde(default)]
59    pub relevance: RelevanceConfig,
60    /// Maximum short-term memory items (default: 100)
61    #[serde(default = "MemoryConfig::default_max_short_term")]
62    pub max_short_term: usize,
63    /// Maximum working memory items (default: 10)
64    #[serde(default = "MemoryConfig::default_max_working")]
65    pub max_working: usize,
66}
67
68impl MemoryConfig {
69    fn default_max_short_term() -> usize {
70        100
71    }
72    fn default_max_working() -> usize {
73        10
74    }
75}
76
77impl Default for MemoryConfig {
78    fn default() -> Self {
79        Self {
80            relevance: RelevanceConfig::default(),
81            max_short_term: 100,
82            max_working: 10,
83        }
84    }
85}
86
87// ============================================================================
88// Memory Item
89// ============================================================================
90
91/// A single memory item
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct MemoryItem {
94    /// Unique identifier
95    pub id: String,
96    /// Memory content
97    pub content: String,
98    /// When this memory was created
99    pub timestamp: DateTime<Utc>,
100    /// Importance score (0.0 - 1.0)
101    pub importance: f32,
102    /// Tags for categorization
103    pub tags: Vec<String>,
104    /// Memory type
105    pub memory_type: MemoryType,
106    /// Associated metadata
107    pub metadata: HashMap<String, String>,
108    /// Number of times this memory was accessed
109    pub access_count: u32,
110    /// Last access time
111    pub last_accessed: Option<DateTime<Utc>>,
112    /// Cached lowercase content for fast substring search
113    #[serde(skip)]
114    pub content_lower: String,
115}
116
117impl MemoryItem {
118    /// Create a new memory item
119    pub fn new(content: impl Into<String>) -> Self {
120        let content = content.into();
121        let content_lower = content.to_lowercase();
122        Self {
123            id: uuid::Uuid::new_v4().to_string(),
124            content,
125            timestamp: Utc::now(),
126            importance: 0.5,
127            tags: Vec::new(),
128            memory_type: MemoryType::Episodic,
129            metadata: HashMap::new(),
130            access_count: 0,
131            last_accessed: None,
132            content_lower,
133        }
134    }
135
136    /// Set importance
137    pub fn with_importance(mut self, importance: f32) -> Self {
138        self.importance = importance.clamp(0.0, 1.0);
139        self
140    }
141
142    /// Add tags
143    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
144        self.tags = tags;
145        self
146    }
147
148    /// Add a single tag
149    pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
150        self.tags.push(tag.into());
151        self
152    }
153
154    /// Set memory type
155    pub fn with_type(mut self, memory_type: MemoryType) -> Self {
156        self.memory_type = memory_type;
157        self
158    }
159
160    /// Add metadata
161    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
162        self.metadata.insert(key.into(), value.into());
163        self
164    }
165
166    /// Record access
167    pub fn record_access(&mut self) {
168        self.access_count += 1;
169        self.last_accessed = Some(Utc::now());
170    }
171
172    /// Calculate relevance score at a given timestamp
173    ///
174    /// Use this variant in sort comparators to avoid repeated `Utc::now()` syscalls.
175    pub fn relevance_score_at(&self, now: DateTime<Utc>) -> f32 {
176        let age_seconds = (now - self.timestamp).num_seconds() as f32;
177        let age_days = age_seconds / 86400.0;
178
179        // Decay factor: memories lose relevance over time
180        let decay = (-age_days / 30.0).exp(); // 30-day half-life
181
182        // Combine importance and recency
183        self.importance * 0.7 + decay * 0.3
184    }
185
186    /// Calculate relevance score based on recency and importance
187    pub fn relevance_score(&self) -> f32 {
188        self.relevance_score_at(Utc::now())
189    }
190}
191
192/// Type of memory
193#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
194#[serde(rename_all = "snake_case")]
195pub enum MemoryType {
196    /// Episodic memory (specific events)
197    Episodic,
198    /// Semantic memory (facts and knowledge)
199    Semantic,
200    /// Procedural memory (how to do things)
201    Procedural,
202    /// Working memory (temporary, active)
203    Working,
204}
205
206// ============================================================================
207// Memory Store Trait
208// ============================================================================
209
210/// Trait for memory storage backends
211#[async_trait::async_trait]
212pub trait MemoryStore: Send + Sync {
213    /// Store a memory item
214    async fn store(&self, item: MemoryItem) -> anyhow::Result<()>;
215
216    /// Retrieve a memory by ID
217    async fn retrieve(&self, id: &str) -> anyhow::Result<Option<MemoryItem>>;
218
219    /// Search memories by query
220    async fn search(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryItem>>;
221
222    /// Search memories by tags
223    async fn search_by_tags(
224        &self,
225        tags: &[String],
226        limit: usize,
227    ) -> anyhow::Result<Vec<MemoryItem>>;
228
229    /// Get recent memories
230    async fn get_recent(&self, limit: usize) -> anyhow::Result<Vec<MemoryItem>>;
231
232    /// Get important memories
233    async fn get_important(&self, threshold: f32, limit: usize) -> anyhow::Result<Vec<MemoryItem>>;
234
235    /// Delete a memory
236    async fn delete(&self, id: &str) -> anyhow::Result<()>;
237
238    /// Clear all memories
239    async fn clear(&self) -> anyhow::Result<()>;
240
241    /// Get total memory count
242    async fn count(&self) -> anyhow::Result<usize>;
243}
244
245// ============================================================================
246// Shared Search/Sort Helpers (DRY)
247// ============================================================================
248
249/// Search memories by content substring, sorted by relevance
250fn search_memories(memories: &[MemoryItem], query: &str, limit: usize) -> Vec<MemoryItem> {
251    let query_lower = query.to_lowercase();
252    let mut results: Vec<_> = memories
253        .iter()
254        .filter(|m| m.content_lower.contains(&query_lower))
255        .cloned()
256        .collect();
257    sort_by_relevance(&mut results);
258    results.truncate(limit);
259    results
260}
261
262/// Search memories by tags, sorted by relevance
263fn search_memories_by_tags(
264    memories: &[MemoryItem],
265    tags: &[String],
266    limit: usize,
267) -> Vec<MemoryItem> {
268    let mut results: Vec<_> = memories
269        .iter()
270        .filter(|m| tags.iter().any(|tag| m.tags.contains(tag)))
271        .cloned()
272        .collect();
273    sort_by_relevance(&mut results);
274    results.truncate(limit);
275    results
276}
277
278/// Get recent memories sorted by timestamp (newest first)
279fn recent_memories(memories: &[MemoryItem], limit: usize) -> Vec<MemoryItem> {
280    let mut results: Vec<_> = memories.to_vec();
281    results.sort_by(|a, b| b.timestamp.cmp(&a.timestamp));
282    results.truncate(limit);
283    results
284}
285
286/// Get important memories above threshold, sorted by importance
287fn important_memories(memories: &[MemoryItem], threshold: f32, limit: usize) -> Vec<MemoryItem> {
288    let mut results: Vec<_> = memories
289        .iter()
290        .filter(|m| m.importance >= threshold)
291        .cloned()
292        .collect();
293    results.sort_by(|a, b| {
294        b.importance
295            .partial_cmp(&a.importance)
296            .unwrap_or(std::cmp::Ordering::Equal)
297    });
298    results.truncate(limit);
299    results
300}
301
302/// Sort memory items by relevance score (highest first)
303fn sort_by_relevance(items: &mut [MemoryItem]) {
304    let now = Utc::now();
305    items.sort_by(|a, b| {
306        b.relevance_score_at(now)
307            .partial_cmp(&a.relevance_score_at(now))
308            .unwrap_or(std::cmp::Ordering::Equal)
309    });
310}
311
312// ============================================================================
313// In-Memory Store
314// ============================================================================
315
316/// Simple in-memory storage (for testing and development)
317#[derive(Debug, Clone)]
318pub struct InMemoryStore {
319    memories: Arc<RwLock<Vec<MemoryItem>>>,
320}
321
322impl InMemoryStore {
323    /// Create a new in-memory store
324    pub fn new() -> Self {
325        Self {
326            memories: Arc::new(RwLock::new(Vec::new())),
327        }
328    }
329}
330
331impl Default for InMemoryStore {
332    fn default() -> Self {
333        Self::new()
334    }
335}
336
337#[async_trait::async_trait]
338impl MemoryStore for InMemoryStore {
339    async fn store(&self, item: MemoryItem) -> anyhow::Result<()> {
340        let mut memories = self.memories.write().await;
341        memories.push(item);
342        Ok(())
343    }
344
345    async fn retrieve(&self, id: &str) -> anyhow::Result<Option<MemoryItem>> {
346        let memories = self.memories.read().await;
347        Ok(memories.iter().find(|m| m.id == id).cloned())
348    }
349
350    async fn search(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
351        let memories = self.memories.read().await;
352        Ok(search_memories(&memories, query, limit))
353    }
354
355    async fn search_by_tags(
356        &self,
357        tags: &[String],
358        limit: usize,
359    ) -> anyhow::Result<Vec<MemoryItem>> {
360        let memories = self.memories.read().await;
361        Ok(search_memories_by_tags(&memories, tags, limit))
362    }
363
364    async fn get_recent(&self, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
365        let memories = self.memories.read().await;
366        Ok(recent_memories(&memories, limit))
367    }
368
369    async fn get_important(&self, threshold: f32, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
370        let memories = self.memories.read().await;
371        Ok(important_memories(&memories, threshold, limit))
372    }
373
374    async fn delete(&self, id: &str) -> anyhow::Result<()> {
375        let mut memories = self.memories.write().await;
376        memories.retain(|m| m.id != id);
377        Ok(())
378    }
379
380    async fn clear(&self) -> anyhow::Result<()> {
381        let mut memories = self.memories.write().await;
382        memories.clear();
383        Ok(())
384    }
385
386    async fn count(&self) -> anyhow::Result<usize> {
387        let memories = self.memories.read().await;
388        Ok(memories.len())
389    }
390}
391
392// ============================================================================
393// File-Based Store
394// ============================================================================
395
396/// File-based persistent storage using JSONL format
397#[derive(Debug, Clone)]
398pub struct FileStore {
399    file_path: std::path::PathBuf,
400    memories: Arc<RwLock<Vec<MemoryItem>>>,
401}
402
403impl FileStore {
404    /// Create a new file-based store
405    ///
406    /// Note: This constructor performs blocking I/O to load existing memories.
407    /// For async contexts, consider using `FileStore::open()` instead.
408    pub fn new(file_path: impl Into<std::path::PathBuf>) -> anyhow::Result<Self> {
409        let file_path = file_path.into();
410
411        // Create parent directory if it doesn't exist
412        if let Some(parent) = file_path.parent() {
413            std::fs::create_dir_all(parent)?;
414        }
415
416        // Load existing memories from file
417        let memories = if file_path.exists() {
418            Self::load_from_file(&file_path)?
419        } else {
420            Vec::new()
421        };
422
423        Ok(Self {
424            file_path,
425            memories: Arc::new(RwLock::new(memories)),
426        })
427    }
428
429    /// Create a new file-based store asynchronously
430    pub async fn open(file_path: impl Into<std::path::PathBuf>) -> anyhow::Result<Self> {
431        let file_path = file_path.into();
432
433        // Create parent directory if it doesn't exist
434        if let Some(parent) = file_path.parent() {
435            tokio::fs::create_dir_all(parent).await?;
436        }
437
438        // Load existing memories from file
439        let memories = if file_path.exists() {
440            let content = tokio::fs::read_to_string(&file_path).await?;
441            Self::parse_jsonl(&content)?
442        } else {
443            Vec::new()
444        };
445
446        Ok(Self {
447            file_path,
448            memories: Arc::new(RwLock::new(memories)),
449        })
450    }
451
452    /// Load memories from JSONL file (blocking)
453    fn load_from_file(path: &std::path::Path) -> anyhow::Result<Vec<MemoryItem>> {
454        let content = std::fs::read_to_string(path)?;
455        Self::parse_jsonl(&content)
456    }
457
458    /// Parse JSONL content into memory items
459    fn parse_jsonl(content: &str) -> anyhow::Result<Vec<MemoryItem>> {
460        let mut memories = Vec::new();
461
462        for line in content.lines() {
463            if line.trim().is_empty() {
464                continue;
465            }
466            let mut item: MemoryItem = serde_json::from_str(line)?;
467            item.content_lower = item.content.to_lowercase();
468            memories.push(item);
469        }
470
471        Ok(memories)
472    }
473
474    /// Save all memories to JSONL file
475    async fn save_to_file(&self) -> anyhow::Result<()> {
476        let memories = self.memories.read().await;
477        let mut content = String::new();
478
479        for memory in memories.iter() {
480            let json = serde_json::to_string(memory)?;
481            content.push_str(&json);
482            content.push('\n');
483        }
484
485        // Write atomically using a temporary file
486        let temp_path = self.file_path.with_extension("tmp");
487        tokio::fs::write(&temp_path, content).await?;
488        tokio::fs::rename(&temp_path, &self.file_path).await?;
489
490        Ok(())
491    }
492}
493
494#[async_trait::async_trait]
495impl MemoryStore for FileStore {
496    async fn store(&self, item: MemoryItem) -> anyhow::Result<()> {
497        {
498            let mut memories = self.memories.write().await;
499            memories.push(item);
500        }
501        self.save_to_file().await
502    }
503
504    async fn retrieve(&self, id: &str) -> anyhow::Result<Option<MemoryItem>> {
505        let memories = self.memories.read().await;
506        Ok(memories.iter().find(|m| m.id == id).cloned())
507    }
508
509    async fn search(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
510        let memories = self.memories.read().await;
511        Ok(search_memories(&memories, query, limit))
512    }
513
514    async fn search_by_tags(
515        &self,
516        tags: &[String],
517        limit: usize,
518    ) -> anyhow::Result<Vec<MemoryItem>> {
519        let memories = self.memories.read().await;
520        Ok(search_memories_by_tags(&memories, tags, limit))
521    }
522
523    async fn get_recent(&self, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
524        let memories = self.memories.read().await;
525        Ok(recent_memories(&memories, limit))
526    }
527
528    async fn get_important(&self, threshold: f32, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
529        let memories = self.memories.read().await;
530        Ok(important_memories(&memories, threshold, limit))
531    }
532
533    async fn delete(&self, id: &str) -> anyhow::Result<()> {
534        {
535            let mut memories = self.memories.write().await;
536            memories.retain(|m| m.id != id);
537        }
538        self.save_to_file().await
539    }
540
541    async fn clear(&self) -> anyhow::Result<()> {
542        {
543            let mut memories = self.memories.write().await;
544            memories.clear();
545        }
546        self.save_to_file().await
547    }
548
549    async fn count(&self) -> anyhow::Result<usize> {
550        let memories = self.memories.read().await;
551        Ok(memories.len())
552    }
553}
554
555// ============================================================================
556// Agent Memory
557// ============================================================================
558
559/// Agent memory system
560#[derive(Clone)]
561pub struct AgentMemory {
562    /// Long-term memory store
563    store: Arc<dyn MemoryStore>,
564    /// Short-term memory (current session)
565    short_term: Arc<RwLock<VecDeque<MemoryItem>>>,
566    /// Working memory (active context)
567    working: Arc<RwLock<Vec<MemoryItem>>>,
568    /// Maximum short-term memory size
569    max_short_term: usize,
570    /// Maximum working memory size
571    max_working: usize,
572    /// Relevance scoring configuration
573    relevance_config: RelevanceConfig,
574}
575
576impl AgentMemory {
577    /// Create a new agent memory system with default configuration
578    pub fn new(store: Arc<dyn MemoryStore>) -> Self {
579        Self::with_config(store, MemoryConfig::default())
580    }
581
582    /// Create a new agent memory system with custom configuration
583    pub fn with_config(store: Arc<dyn MemoryStore>, config: MemoryConfig) -> Self {
584        Self {
585            store,
586            short_term: Arc::new(RwLock::new(VecDeque::new())),
587            working: Arc::new(RwLock::new(Vec::new())),
588            max_short_term: config.max_short_term,
589            max_working: config.max_working,
590            relevance_config: config.relevance,
591        }
592    }
593
594    /// Create with in-memory store (for testing)
595    pub fn in_memory() -> Self {
596        Self::new(Arc::new(InMemoryStore::new()))
597    }
598
599    /// Calculate relevance score using this memory system's configuration
600    fn score(&self, item: &MemoryItem, now: DateTime<Utc>) -> f32 {
601        let age_seconds = (now - item.timestamp).num_seconds() as f32;
602        let age_days = age_seconds / 86400.0;
603        let decay = (-age_days / self.relevance_config.decay_days).exp();
604        item.importance * self.relevance_config.importance_weight
605            + decay * self.relevance_config.recency_weight
606    }
607
608    /// Store a memory in long-term storage
609    pub async fn remember(&self, item: MemoryItem) -> anyhow::Result<()> {
610        // Store in long-term
611        self.store.store(item.clone()).await?;
612
613        // Add to short-term
614        let mut short_term = self.short_term.write().await;
615        short_term.push_back(item);
616
617        // Trim if needed
618        if short_term.len() > self.max_short_term {
619            short_term.pop_front();
620        }
621
622        Ok(())
623    }
624
625    /// Remember a successful pattern
626    pub async fn remember_success(
627        &self,
628        prompt: &str,
629        tools_used: &[String],
630        result: &str,
631    ) -> anyhow::Result<()> {
632        let content = format!(
633            "Success: {}\nTools: {}\nResult: {}",
634            prompt,
635            tools_used.join(", "),
636            result
637        );
638
639        let item = MemoryItem::new(content)
640            .with_importance(0.8)
641            .with_tag("success")
642            .with_tag("pattern")
643            .with_type(MemoryType::Procedural)
644            .with_metadata("prompt", prompt)
645            .with_metadata("tools", tools_used.join(","));
646
647        self.remember(item).await
648    }
649
650    /// Remember a failure to avoid repeating
651    pub async fn remember_failure(
652        &self,
653        prompt: &str,
654        error: &str,
655        attempted_tools: &[String],
656    ) -> anyhow::Result<()> {
657        let content = format!(
658            "Failure: {}\nError: {}\nAttempted tools: {}",
659            prompt,
660            error,
661            attempted_tools.join(", ")
662        );
663
664        let item = MemoryItem::new(content)
665            .with_importance(0.9) // Failures are important to remember
666            .with_tag("failure")
667            .with_tag("avoid")
668            .with_type(MemoryType::Episodic)
669            .with_metadata("prompt", prompt)
670            .with_metadata("error", error);
671
672        self.remember(item).await
673    }
674
675    /// Recall similar past experiences
676    pub async fn recall_similar(
677        &self,
678        prompt: &str,
679        limit: usize,
680    ) -> anyhow::Result<Vec<MemoryItem>> {
681        self.store.search(prompt, limit).await
682    }
683
684    /// Recall by tags
685    pub async fn recall_by_tags(
686        &self,
687        tags: &[String],
688        limit: usize,
689    ) -> anyhow::Result<Vec<MemoryItem>> {
690        self.store.search_by_tags(tags, limit).await
691    }
692
693    /// Get recent memories
694    pub async fn get_recent(&self, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
695        self.store.get_recent(limit).await
696    }
697
698    /// Add to working memory
699    pub async fn add_to_working(&self, item: MemoryItem) -> anyhow::Result<()> {
700        let mut working = self.working.write().await;
701        working.push(item);
702
703        // Trim if needed (keep most relevant)
704        if working.len() > self.max_working {
705            let now = Utc::now();
706            working.sort_by(|a, b| {
707                self.score(b, now)
708                    .partial_cmp(&self.score(a, now))
709                    .unwrap_or(std::cmp::Ordering::Equal)
710            });
711            working.truncate(self.max_working);
712        }
713
714        Ok(())
715    }
716
717    /// Get working memory
718    pub async fn get_working(&self) -> Vec<MemoryItem> {
719        self.working.read().await.clone()
720    }
721
722    /// Clear working memory
723    pub async fn clear_working(&self) {
724        self.working.write().await.clear();
725    }
726
727    /// Get short-term memory
728    pub async fn get_short_term(&self) -> Vec<MemoryItem> {
729        self.short_term.read().await.iter().cloned().collect()
730    }
731
732    /// Clear short-term memory
733    pub async fn clear_short_term(&self) {
734        self.short_term.write().await.clear();
735    }
736
737    /// Get memory statistics
738    pub async fn stats(&self) -> anyhow::Result<MemoryStats> {
739        let long_term_count = self.store.count().await?;
740        let short_term_count = self.short_term.read().await.len();
741        let working_count = self.working.read().await.len();
742
743        Ok(MemoryStats {
744            long_term_count,
745            short_term_count,
746            working_count,
747        })
748    }
749
750    /// Get access to the underlying store
751    pub fn store(&self) -> &Arc<dyn MemoryStore> {
752        &self.store
753    }
754
755    /// Get working memory count
756    pub async fn working_count(&self) -> usize {
757        self.working.read().await.len()
758    }
759
760    /// Get short-term memory count
761    pub async fn short_term_count(&self) -> usize {
762        self.short_term.read().await.len()
763    }
764}
765
766/// Memory statistics
767#[derive(Debug, Clone, Serialize, Deserialize)]
768pub struct MemoryStats {
769    /// Number of long-term memories
770    pub long_term_count: usize,
771    /// Number of short-term memories
772    pub short_term_count: usize,
773    /// Number of working memories
774    pub working_count: usize,
775}
776
777// ============================================================================
778// Memory Context Provider
779// ============================================================================
780
781/// Context provider that surfaces past memories (successes/failures) as context.
782///
783/// Wraps `AgentMemory` and implements the `ContextProvider` trait so that
784/// session memory is automatically injected into the agent's system prompt.
785pub struct MemoryContextProvider {
786    memory: AgentMemory,
787}
788
789impl MemoryContextProvider {
790    /// Create a new memory context provider
791    pub fn new(memory: AgentMemory) -> Self {
792        Self { memory }
793    }
794}
795
796#[async_trait::async_trait]
797impl crate::context::ContextProvider for MemoryContextProvider {
798    fn name(&self) -> &str {
799        "memory"
800    }
801
802    async fn query(
803        &self,
804        query: &crate::context::ContextQuery,
805    ) -> anyhow::Result<crate::context::ContextResult> {
806        let limit = query.max_results.min(5);
807        let items = self.memory.recall_similar(&query.query, limit).await?;
808
809        let mut result = crate::context::ContextResult::new("memory");
810        for item in items {
811            let relevance = item.relevance_score();
812            let token_count = item.content.len() / 4; // rough estimate
813            let context_item = crate::context::ContextItem::new(
814                &item.id,
815                crate::context::ContextType::Memory,
816                &item.content,
817            )
818            .with_relevance(relevance)
819            .with_token_count(token_count)
820            .with_source("memory");
821            result.add_item(context_item);
822        }
823
824        Ok(result)
825    }
826
827    async fn on_turn_complete(
828        &self,
829        _session_id: &str,
830        prompt: &str,
831        response: &str,
832    ) -> anyhow::Result<()> {
833        // Store the successful interaction as a memory
834        self.memory.remember_success(prompt, &[], response).await
835    }
836}
837
838// ============================================================================
839// Tests
840// ============================================================================
841
842#[cfg(test)]
843mod tests {
844    use super::*;
845
846    #[test]
847    fn test_memory_item_creation() {
848        let item = MemoryItem::new("Test memory")
849            .with_importance(0.8)
850            .with_tag("test")
851            .with_type(MemoryType::Semantic);
852
853        assert_eq!(item.content, "Test memory");
854        assert_eq!(item.importance, 0.8);
855        assert_eq!(item.tags, vec!["test"]);
856        assert_eq!(item.memory_type, MemoryType::Semantic);
857    }
858
859    #[test]
860    fn test_memory_item_relevance() {
861        let item = MemoryItem::new("Test").with_importance(0.9);
862        let score = item.relevance_score();
863
864        // Should be high for recent, important memory
865        assert!(score > 0.6);
866    }
867
868    #[test]
869    fn test_relevance_config_defaults() {
870        let config = RelevanceConfig::default();
871        assert_eq!(config.decay_days, 30.0);
872        assert_eq!(config.importance_weight, 0.7);
873        assert_eq!(config.recency_weight, 0.3);
874    }
875
876    #[test]
877    fn test_memory_config_defaults() {
878        let config = MemoryConfig::default();
879        assert_eq!(config.max_short_term, 100);
880        assert_eq!(config.max_working, 10);
881        assert_eq!(config.relevance.decay_days, 30.0);
882    }
883
884    #[test]
885    fn test_memory_config_serde_roundtrip() {
886        let config = MemoryConfig::default();
887        let json = serde_json::to_string(&config).unwrap();
888        let parsed: MemoryConfig = serde_json::from_str(&json).unwrap();
889        assert_eq!(parsed.max_short_term, config.max_short_term);
890        assert_eq!(parsed.max_working, config.max_working);
891        assert_eq!(parsed.relevance.decay_days, config.relevance.decay_days);
892    }
893
894    #[test]
895    fn test_agent_memory_with_config() {
896        let config = MemoryConfig {
897            relevance: RelevanceConfig {
898                decay_days: 7.0,
899                importance_weight: 0.5,
900                recency_weight: 0.5,
901            },
902            max_short_term: 50,
903            max_working: 5,
904        };
905        let memory = AgentMemory::with_config(Arc::new(InMemoryStore::new()), config);
906        assert_eq!(memory.max_short_term, 50);
907        assert_eq!(memory.max_working, 5);
908        assert_eq!(memory.relevance_config.decay_days, 7.0);
909    }
910
911    #[test]
912    fn test_agent_memory_score_uses_config() {
913        let config = MemoryConfig {
914            relevance: RelevanceConfig {
915                decay_days: 7.0,
916                importance_weight: 0.9,
917                recency_weight: 0.1,
918            },
919            ..Default::default()
920        };
921        let memory = AgentMemory::with_config(Arc::new(InMemoryStore::new()), config);
922
923        let item = MemoryItem::new("Test").with_importance(1.0);
924        let now = Utc::now();
925        let score = memory.score(&item, now);
926
927        // With importance_weight=0.9, a brand new item with importance=1.0
928        // should score close to 0.9 + 0.1 = 1.0 (decay ~1.0 for recent items)
929        assert!(score > 0.95, "Score was {}", score);
930    }
931
932    #[tokio::test]
933    async fn test_in_memory_store() {
934        let store = InMemoryStore::new();
935
936        let item = MemoryItem::new("Test memory").with_tag("test");
937        store.store(item.clone()).await.unwrap();
938
939        let retrieved = store.retrieve(&item.id).await.unwrap();
940        assert!(retrieved.is_some());
941        assert_eq!(retrieved.unwrap().content, "Test memory");
942    }
943
944    #[tokio::test]
945    async fn test_memory_search() {
946        let store = InMemoryStore::new();
947
948        store
949            .store(MemoryItem::new("How to create a file").with_tag("file"))
950            .await
951            .unwrap();
952        store
953            .store(MemoryItem::new("How to delete a file").with_tag("file"))
954            .await
955            .unwrap();
956        store
957            .store(MemoryItem::new("How to create a directory").with_tag("dir"))
958            .await
959            .unwrap();
960
961        let results = store.search("create", 10).await.unwrap();
962        assert_eq!(results.len(), 2);
963
964        let results = store
965            .search_by_tags(&["file".to_string()], 10)
966            .await
967            .unwrap();
968        assert_eq!(results.len(), 2);
969    }
970
971    #[tokio::test]
972    async fn test_agent_memory() {
973        let memory = AgentMemory::in_memory();
974
975        // Remember success
976        memory
977            .remember_success("Create a file", &["write".to_string()], "File created")
978            .await
979            .unwrap();
980
981        // Remember failure
982        memory
983            .remember_failure("Delete file", "Permission denied", &["bash".to_string()])
984            .await
985            .unwrap();
986
987        // Recall
988        let results = memory.recall_similar("create", 10).await.unwrap();
989        assert!(!results.is_empty());
990
991        let stats = memory.stats().await.unwrap();
992        assert_eq!(stats.long_term_count, 2);
993    }
994
995    #[tokio::test]
996    async fn test_working_memory() {
997        let memory = AgentMemory::in_memory();
998
999        let item = MemoryItem::new("Active task").with_type(MemoryType::Working);
1000        memory.add_to_working(item).await.unwrap();
1001
1002        let working = memory.get_working().await;
1003        assert_eq!(working.len(), 1);
1004
1005        memory.clear_working().await;
1006        let working = memory.get_working().await;
1007        assert_eq!(working.len(), 0);
1008    }
1009
1010    #[tokio::test]
1011    async fn test_file_store_basic() {
1012        let temp_dir = std::env::temp_dir();
1013        let test_file = temp_dir.join(format!("test_memory_{}.jsonl", uuid::Uuid::new_v4()));
1014
1015        // Create store
1016        let store = FileStore::new(&test_file).unwrap();
1017
1018        // Store items
1019        let item1 = MemoryItem::new("Test memory 1").with_tag("test");
1020        let item2 = MemoryItem::new("Test memory 2").with_tag("test");
1021
1022        store.store(item1.clone()).await.unwrap();
1023        store.store(item2.clone()).await.unwrap();
1024
1025        // Verify count
1026        assert_eq!(store.count().await.unwrap(), 2);
1027
1028        // Retrieve
1029        let retrieved = store.retrieve(&item1.id).await.unwrap();
1030        assert!(retrieved.is_some());
1031        assert_eq!(retrieved.unwrap().content, "Test memory 1");
1032
1033        // Clean up
1034        let _ = std::fs::remove_file(&test_file);
1035    }
1036
1037    #[tokio::test]
1038    async fn test_file_store_persistence() {
1039        let temp_dir = std::env::temp_dir();
1040        let test_file = temp_dir.join(format!(
1041            "test_memory_persist_{}.jsonl",
1042            uuid::Uuid::new_v4()
1043        ));
1044
1045        let item_id = {
1046            // Create store and add item
1047            let store = FileStore::new(&test_file).unwrap();
1048            let item = MemoryItem::new("Persistent memory").with_importance(0.9);
1049            let id = item.id.clone();
1050            store.store(item).await.unwrap();
1051            id
1052        };
1053
1054        // Create new store instance (simulating restart)
1055        let store2 = FileStore::new(&test_file).unwrap();
1056
1057        // Verify data persisted
1058        assert_eq!(store2.count().await.unwrap(), 1);
1059        let retrieved = store2.retrieve(&item_id).await.unwrap();
1060        assert!(retrieved.is_some());
1061        assert_eq!(retrieved.unwrap().content, "Persistent memory");
1062
1063        // Clean up
1064        let _ = std::fs::remove_file(&test_file);
1065    }
1066
1067    #[tokio::test]
1068    async fn test_file_store_search() {
1069        let temp_dir = std::env::temp_dir();
1070        let test_file = temp_dir.join(format!("test_memory_search_{}.jsonl", uuid::Uuid::new_v4()));
1071
1072        let store = FileStore::new(&test_file).unwrap();
1073
1074        // Store multiple items
1075        store
1076            .store(MemoryItem::new("How to create a file").with_tag("file"))
1077            .await
1078            .unwrap();
1079        store
1080            .store(MemoryItem::new("How to delete a file").with_tag("file"))
1081            .await
1082            .unwrap();
1083        store
1084            .store(MemoryItem::new("How to create a directory").with_tag("dir"))
1085            .await
1086            .unwrap();
1087
1088        // Search by content
1089        let results = store.search("create", 10).await.unwrap();
1090        assert_eq!(results.len(), 2);
1091
1092        // Search by tags
1093        let results = store
1094            .search_by_tags(&["file".to_string()], 10)
1095            .await
1096            .unwrap();
1097        assert_eq!(results.len(), 2);
1098
1099        // Clean up
1100        let _ = std::fs::remove_file(&test_file);
1101    }
1102
1103    #[tokio::test]
1104    async fn test_file_store_delete() {
1105        let temp_dir = std::env::temp_dir();
1106        let test_file = temp_dir.join(format!("test_memory_delete_{}.jsonl", uuid::Uuid::new_v4()));
1107
1108        let store = FileStore::new(&test_file).unwrap();
1109
1110        let item = MemoryItem::new("To be deleted");
1111        let item_id = item.id.clone();
1112        store.store(item).await.unwrap();
1113
1114        assert_eq!(store.count().await.unwrap(), 1);
1115
1116        // Delete
1117        store.delete(&item_id).await.unwrap();
1118        assert_eq!(store.count().await.unwrap(), 0);
1119
1120        // Verify persistence
1121        let store2 = FileStore::new(&test_file).unwrap();
1122        assert_eq!(store2.count().await.unwrap(), 0);
1123
1124        // Clean up
1125        let _ = std::fs::remove_file(&test_file);
1126    }
1127
1128    #[tokio::test]
1129    async fn test_file_store_clear() {
1130        let temp_dir = std::env::temp_dir();
1131        let test_file = temp_dir.join(format!("test_memory_clear_{}.jsonl", uuid::Uuid::new_v4()));
1132
1133        let store = FileStore::new(&test_file).unwrap();
1134
1135        // Store multiple items
1136        for i in 0..5 {
1137            store
1138                .store(MemoryItem::new(format!("Memory {}", i)))
1139                .await
1140                .unwrap();
1141        }
1142
1143        assert_eq!(store.count().await.unwrap(), 5);
1144
1145        // Clear
1146        store.clear().await.unwrap();
1147        assert_eq!(store.count().await.unwrap(), 0);
1148
1149        // Verify persistence
1150        let store2 = FileStore::new(&test_file).unwrap();
1151        assert_eq!(store2.count().await.unwrap(), 0);
1152
1153        // Clean up
1154        let _ = std::fs::remove_file(&test_file);
1155    }
1156}
1157
1158#[cfg(test)]
1159mod extra_memory_tests {
1160    use super::*;
1161
1162    // ========================================================================
1163    // MemoryItem builder methods
1164    // ========================================================================
1165
1166    #[test]
1167    fn test_memory_item_with_metadata() {
1168        let item = MemoryItem::new("test")
1169            .with_metadata("key1", "value1")
1170            .with_metadata("key2", "value2");
1171        assert_eq!(item.metadata.get("key1").unwrap(), "value1");
1172        assert_eq!(item.metadata.get("key2").unwrap(), "value2");
1173    }
1174
1175    #[test]
1176    fn test_memory_item_with_tags_vec() {
1177        let item = MemoryItem::new("test").with_tags(vec![
1178            "a".to_string(),
1179            "b".to_string(),
1180            "c".to_string(),
1181        ]);
1182        assert_eq!(item.tags.len(), 3);
1183    }
1184
1185    #[test]
1186    fn test_memory_item_importance_clamped() {
1187        let item_high = MemoryItem::new("test").with_importance(1.5);
1188        assert_eq!(item_high.importance, 1.0);
1189
1190        let item_low = MemoryItem::new("test").with_importance(-0.5);
1191        assert_eq!(item_low.importance, 0.0);
1192    }
1193
1194    #[test]
1195    fn test_memory_item_record_access() {
1196        let mut item = MemoryItem::new("test");
1197        assert_eq!(item.access_count, 0);
1198        assert!(item.last_accessed.is_none());
1199
1200        item.record_access();
1201        assert_eq!(item.access_count, 1);
1202        assert!(item.last_accessed.is_some());
1203
1204        item.record_access();
1205        assert_eq!(item.access_count, 2);
1206    }
1207
1208    #[test]
1209    fn test_memory_item_all_types() {
1210        let episodic = MemoryItem::new("e").with_type(MemoryType::Episodic);
1211        assert_eq!(episodic.memory_type, MemoryType::Episodic);
1212
1213        let semantic = MemoryItem::new("s").with_type(MemoryType::Semantic);
1214        assert_eq!(semantic.memory_type, MemoryType::Semantic);
1215
1216        let procedural = MemoryItem::new("p").with_type(MemoryType::Procedural);
1217        assert_eq!(procedural.memory_type, MemoryType::Procedural);
1218
1219        let working = MemoryItem::new("w").with_type(MemoryType::Working);
1220        assert_eq!(working.memory_type, MemoryType::Working);
1221    }
1222
1223    #[test]
1224    fn test_memory_item_default_type_is_episodic() {
1225        let item = MemoryItem::new("test");
1226        assert_eq!(item.memory_type, MemoryType::Episodic);
1227    }
1228
1229    // ========================================================================
1230    // InMemoryStore
1231    // ========================================================================
1232
1233    #[tokio::test]
1234    async fn test_in_memory_store_retrieve_nonexistent() {
1235        let store = InMemoryStore::new();
1236        let result = store.retrieve("nonexistent").await.unwrap();
1237        assert!(result.is_none());
1238    }
1239
1240    #[tokio::test]
1241    async fn test_in_memory_store_delete() {
1242        let store = InMemoryStore::new();
1243        let item = MemoryItem::new("to delete");
1244        let id = item.id.clone();
1245        store.store(item).await.unwrap();
1246        assert_eq!(store.count().await.unwrap(), 1);
1247
1248        store.delete(&id).await.unwrap();
1249        assert_eq!(store.count().await.unwrap(), 0);
1250    }
1251
1252    #[tokio::test]
1253    async fn test_in_memory_store_clear() {
1254        let store = InMemoryStore::new();
1255        for i in 0..5 {
1256            store
1257                .store(MemoryItem::new(format!("item {}", i)))
1258                .await
1259                .unwrap();
1260        }
1261        assert_eq!(store.count().await.unwrap(), 5);
1262
1263        store.clear().await.unwrap();
1264        assert_eq!(store.count().await.unwrap(), 0);
1265    }
1266
1267    #[tokio::test]
1268    async fn test_in_memory_store_get_recent() {
1269        let store = InMemoryStore::new();
1270        for i in 0..5 {
1271            store
1272                .store(MemoryItem::new(format!("item {}", i)))
1273                .await
1274                .unwrap();
1275        }
1276        let recent = store.get_recent(3).await.unwrap();
1277        assert_eq!(recent.len(), 3);
1278    }
1279
1280    #[tokio::test]
1281    async fn test_in_memory_store_get_important() {
1282        let store = InMemoryStore::new();
1283        store
1284            .store(MemoryItem::new("low").with_importance(0.2))
1285            .await
1286            .unwrap();
1287        store
1288            .store(MemoryItem::new("medium").with_importance(0.5))
1289            .await
1290            .unwrap();
1291        store
1292            .store(MemoryItem::new("high").with_importance(0.9))
1293            .await
1294            .unwrap();
1295
1296        let important = store.get_important(0.7, 10).await.unwrap();
1297        assert_eq!(important.len(), 1);
1298        assert_eq!(important[0].content, "high");
1299    }
1300
1301    #[tokio::test]
1302    async fn test_in_memory_store_search_case_insensitive() {
1303        let store = InMemoryStore::new();
1304        store
1305            .store(MemoryItem::new("How to CREATE a file"))
1306            .await
1307            .unwrap();
1308        let results = store.search("create", 10).await.unwrap();
1309        assert_eq!(results.len(), 1);
1310    }
1311
1312    // ========================================================================
1313    // AgentMemory
1314    // ========================================================================
1315
1316    #[tokio::test]
1317    async fn test_agent_memory_short_term() {
1318        let memory = AgentMemory::in_memory();
1319        memory.remember(MemoryItem::new("item 1")).await.unwrap();
1320        memory.remember(MemoryItem::new("item 2")).await.unwrap();
1321
1322        let short_term = memory.get_short_term().await;
1323        assert_eq!(short_term.len(), 2);
1324
1325        memory.clear_short_term().await;
1326        let short_term = memory.get_short_term().await;
1327        assert_eq!(short_term.len(), 0);
1328    }
1329
1330    #[tokio::test]
1331    async fn test_agent_memory_short_term_count() {
1332        let memory = AgentMemory::in_memory();
1333        assert_eq!(memory.short_term_count().await, 0);
1334        memory.remember(MemoryItem::new("item")).await.unwrap();
1335        assert_eq!(memory.short_term_count().await, 1);
1336    }
1337
1338    #[tokio::test]
1339    async fn test_agent_memory_working_count() {
1340        let memory = AgentMemory::in_memory();
1341        assert_eq!(memory.working_count().await, 0);
1342        memory
1343            .add_to_working(MemoryItem::new("task"))
1344            .await
1345            .unwrap();
1346        assert_eq!(memory.working_count().await, 1);
1347    }
1348
1349    #[tokio::test]
1350    async fn test_agent_memory_recall_by_tags() {
1351        let memory = AgentMemory::in_memory();
1352        memory
1353            .remember_success("create file", &["write".to_string()], "ok")
1354            .await
1355            .unwrap();
1356        memory
1357            .remember_failure("delete file", "denied", &["bash".to_string()])
1358            .await
1359            .unwrap();
1360
1361        let successes = memory
1362            .recall_by_tags(&["success".to_string()], 10)
1363            .await
1364            .unwrap();
1365        assert_eq!(successes.len(), 1);
1366
1367        let failures = memory
1368            .recall_by_tags(&["failure".to_string()], 10)
1369            .await
1370            .unwrap();
1371        assert_eq!(failures.len(), 1);
1372    }
1373
1374    #[tokio::test]
1375    async fn test_agent_memory_get_recent() {
1376        let memory = AgentMemory::in_memory();
1377        for i in 0..5 {
1378            memory
1379                .remember(MemoryItem::new(format!("item {}", i)))
1380                .await
1381                .unwrap();
1382        }
1383        let recent = memory.get_recent(3).await.unwrap();
1384        assert_eq!(recent.len(), 3);
1385    }
1386
1387    #[tokio::test]
1388    async fn test_agent_memory_store_accessor() {
1389        let memory = AgentMemory::in_memory();
1390        memory.remember(MemoryItem::new("test")).await.unwrap();
1391        let count = memory.store().count().await.unwrap();
1392        assert_eq!(count, 1);
1393    }
1394
1395    #[tokio::test]
1396    async fn test_agent_memory_stats_all_fields() {
1397        let memory = AgentMemory::in_memory();
1398        memory.remember(MemoryItem::new("long term")).await.unwrap();
1399        memory
1400            .add_to_working(MemoryItem::new("working"))
1401            .await
1402            .unwrap();
1403
1404        let stats = memory.stats().await.unwrap();
1405        assert_eq!(stats.long_term_count, 1);
1406        assert_eq!(stats.short_term_count, 1); // remember also adds to short_term
1407        assert_eq!(stats.working_count, 1);
1408    }
1409
1410    #[tokio::test]
1411    async fn test_agent_memory_working_overflow_trims() {
1412        let store = Arc::new(InMemoryStore::new());
1413        let memory = AgentMemory {
1414            store,
1415            short_term: Arc::new(RwLock::new(VecDeque::new())),
1416            working: Arc::new(RwLock::new(Vec::new())),
1417            max_short_term: 100,
1418            max_working: 3, // Small limit
1419            relevance_config: RelevanceConfig::default(),
1420        };
1421
1422        for i in 0..5 {
1423            memory
1424                .add_to_working(
1425                    MemoryItem::new(format!("task {}", i)).with_importance(i as f32 * 0.2),
1426                )
1427                .await
1428                .unwrap();
1429        }
1430
1431        let working = memory.get_working().await;
1432        assert_eq!(working.len(), 3); // Trimmed to max_working
1433    }
1434}
1435
1436#[cfg(test)]
1437mod extra_memory_tests2 {
1438    use super::*;
1439
1440    #[tokio::test]
1441    async fn test_file_store_open_creates_parent_dirs() {
1442        // Use a nested path that doesn't exist yet
1443        let dir = tempfile::tempdir().unwrap();
1444        let path = dir
1445            .path()
1446            .join("nested")
1447            .join("deep")
1448            .join("memories.jsonl");
1449        let store = FileStore::open(&path).await.unwrap();
1450        // Should create the parent dirs and start with empty memories
1451        let all = store.search("", 100).await.unwrap();
1452        assert!(all.is_empty());
1453    }
1454
1455    #[tokio::test]
1456    async fn test_file_store_open_loads_existing() {
1457        let dir = tempfile::tempdir().unwrap();
1458        let path = dir.path().join("memories.jsonl");
1459        // Create a store, add a memory, which saves to file
1460        {
1461            let store = FileStore::open(&path).await.unwrap();
1462            let item = MemoryItem::new("test memory".to_string());
1463            store.store(item).await.unwrap();
1464        }
1465        // Re-open and verify the memory persists
1466        let store = FileStore::open(&path).await.unwrap();
1467        let results = store.search("test", 10).await.unwrap();
1468        assert_eq!(results.len(), 1);
1469        assert!(results[0].content.contains("test memory"));
1470    }
1471
1472    #[tokio::test]
1473    async fn test_file_store_open_nonexistent_file() {
1474        let dir = tempfile::tempdir().unwrap();
1475        let path = dir.path().join("nonexistent.jsonl");
1476        let store = FileStore::open(&path).await.unwrap();
1477        let all = store.search("", 100).await.unwrap();
1478        assert!(all.is_empty());
1479    }
1480
1481    #[test]
1482    fn test_parse_jsonl_empty_string() {
1483        let result = FileStore::parse_jsonl("").unwrap();
1484        assert!(result.is_empty());
1485    }
1486
1487    #[test]
1488    fn test_parse_jsonl_empty_lines_skipped() {
1489        // Create valid JSONL with empty lines interspersed
1490        let item = MemoryItem::new("hello".to_string());
1491        let json = serde_json::to_string(&item).unwrap();
1492        let content = format!("\n{}\n\n{}\n\n", json, json);
1493        let result = FileStore::parse_jsonl(&content).unwrap();
1494        assert_eq!(result.len(), 2);
1495    }
1496
1497    #[test]
1498    fn test_parse_jsonl_invalid_json_returns_error() {
1499        let result = FileStore::parse_jsonl("not valid json");
1500        assert!(result.is_err());
1501    }
1502
1503    #[test]
1504    fn test_parse_jsonl_valid_single_line() {
1505        let item = MemoryItem::new("single".to_string());
1506        let json = serde_json::to_string(&item).unwrap();
1507        let result = FileStore::parse_jsonl(&json).unwrap();
1508        assert_eq!(result.len(), 1);
1509        assert_eq!(result[0].content, "single");
1510    }
1511}