Skip to main content

a3s_code_core/
memory.rs

1//! Memory and learning system for the agent
2//!
3//! This module provides memory storage, recall, and learning capabilities
4//! to enable the agent to learn from past experiences and improve over time.
5
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, VecDeque};
9use std::sync::Arc;
10use tokio::sync::RwLock;
11
12// ============================================================================
13// Configuration
14// ============================================================================
15
16/// Configuration for relevance scoring
17#[derive(Debug, Clone, Serialize, Deserialize)]
18#[serde(rename_all = "camelCase")]
19pub struct RelevanceConfig {
20    /// Exponential decay half-life in days (default: 30.0)
21    #[serde(default = "RelevanceConfig::default_decay_days")]
22    pub decay_days: f32,
23    /// Weight for importance factor (default: 0.7)
24    #[serde(default = "RelevanceConfig::default_importance_weight")]
25    pub importance_weight: f32,
26    /// Weight for recency factor (default: 0.3)
27    #[serde(default = "RelevanceConfig::default_recency_weight")]
28    pub recency_weight: f32,
29}
30
31impl RelevanceConfig {
32    fn default_decay_days() -> f32 {
33        30.0
34    }
35    fn default_importance_weight() -> f32 {
36        0.7
37    }
38    fn default_recency_weight() -> f32 {
39        0.3
40    }
41}
42
43impl Default for RelevanceConfig {
44    fn default() -> Self {
45        Self {
46            decay_days: 30.0,
47            importance_weight: 0.7,
48            recency_weight: 0.3,
49        }
50    }
51}
52
53/// Configuration for the agent memory system
54#[derive(Debug, Clone, Serialize, Deserialize)]
55#[serde(rename_all = "camelCase")]
56pub struct MemoryConfig {
57    /// Relevance scoring parameters
58    #[serde(default)]
59    pub relevance: RelevanceConfig,
60    /// Maximum short-term memory items (default: 100)
61    #[serde(default = "MemoryConfig::default_max_short_term")]
62    pub max_short_term: usize,
63    /// Maximum working memory items (default: 10)
64    #[serde(default = "MemoryConfig::default_max_working")]
65    pub max_working: usize,
66}
67
68impl MemoryConfig {
69    fn default_max_short_term() -> usize {
70        100
71    }
72    fn default_max_working() -> usize {
73        10
74    }
75}
76
77impl Default for MemoryConfig {
78    fn default() -> Self {
79        Self {
80            relevance: RelevanceConfig::default(),
81            max_short_term: 100,
82            max_working: 10,
83        }
84    }
85}
86
87// ============================================================================
88// Memory Item
89// ============================================================================
90
91/// A single memory item
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct MemoryItem {
94    /// Unique identifier
95    pub id: String,
96    /// Memory content
97    pub content: String,
98    /// When this memory was created
99    pub timestamp: DateTime<Utc>,
100    /// Importance score (0.0 - 1.0)
101    pub importance: f32,
102    /// Tags for categorization
103    pub tags: Vec<String>,
104    /// Memory type
105    pub memory_type: MemoryType,
106    /// Associated metadata
107    pub metadata: HashMap<String, String>,
108    /// Number of times this memory was accessed
109    pub access_count: u32,
110    /// Last access time
111    pub last_accessed: Option<DateTime<Utc>>,
112    /// Cached lowercase content for fast substring search
113    #[serde(skip)]
114    pub content_lower: String,
115}
116
117impl MemoryItem {
118    /// Create a new memory item
119    pub fn new(content: impl Into<String>) -> Self {
120        let content = content.into();
121        let content_lower = content.to_lowercase();
122        Self {
123            id: uuid::Uuid::new_v4().to_string(),
124            content,
125            timestamp: Utc::now(),
126            importance: 0.5,
127            tags: Vec::new(),
128            memory_type: MemoryType::Episodic,
129            metadata: HashMap::new(),
130            access_count: 0,
131            last_accessed: None,
132            content_lower,
133        }
134    }
135
136    /// Set importance
137    pub fn with_importance(mut self, importance: f32) -> Self {
138        self.importance = importance.clamp(0.0, 1.0);
139        self
140    }
141
142    /// Add tags
143    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
144        self.tags = tags;
145        self
146    }
147
148    /// Add a single tag
149    pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
150        self.tags.push(tag.into());
151        self
152    }
153
154    /// Set memory type
155    pub fn with_type(mut self, memory_type: MemoryType) -> Self {
156        self.memory_type = memory_type;
157        self
158    }
159
160    /// Add metadata
161    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
162        self.metadata.insert(key.into(), value.into());
163        self
164    }
165
166    /// Record access
167    pub fn record_access(&mut self) {
168        self.access_count += 1;
169        self.last_accessed = Some(Utc::now());
170    }
171
172    /// Calculate relevance score at a given timestamp
173    ///
174    /// Use this variant in sort comparators to avoid repeated `Utc::now()` syscalls.
175    pub fn relevance_score_at(&self, now: DateTime<Utc>) -> f32 {
176        let age_seconds = (now - self.timestamp).num_seconds() as f32;
177        let age_days = age_seconds / 86400.0;
178
179        // Decay factor: memories lose relevance over time
180        let decay = (-age_days / 30.0).exp(); // 30-day half-life
181
182        // Combine importance and recency
183        self.importance * 0.7 + decay * 0.3
184    }
185
186    /// Calculate relevance score based on recency and importance
187    pub fn relevance_score(&self) -> f32 {
188        self.relevance_score_at(Utc::now())
189    }
190}
191
192/// Type of memory
193#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
194#[serde(rename_all = "snake_case")]
195pub enum MemoryType {
196    /// Episodic memory (specific events)
197    Episodic,
198    /// Semantic memory (facts and knowledge)
199    Semantic,
200    /// Procedural memory (how to do things)
201    Procedural,
202    /// Working memory (temporary, active)
203    Working,
204}
205
206// ============================================================================
207// Memory Store Trait
208// ============================================================================
209
210/// Trait for memory storage backends
211#[async_trait::async_trait]
212pub trait MemoryStore: Send + Sync {
213    /// Store a memory item
214    async fn store(&self, item: MemoryItem) -> anyhow::Result<()>;
215
216    /// Retrieve a memory by ID
217    async fn retrieve(&self, id: &str) -> anyhow::Result<Option<MemoryItem>>;
218
219    /// Search memories by query
220    async fn search(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryItem>>;
221
222    /// Search memories by tags
223    async fn search_by_tags(
224        &self,
225        tags: &[String],
226        limit: usize,
227    ) -> anyhow::Result<Vec<MemoryItem>>;
228
229    /// Get recent memories
230    async fn get_recent(&self, limit: usize) -> anyhow::Result<Vec<MemoryItem>>;
231
232    /// Get important memories
233    async fn get_important(&self, threshold: f32, limit: usize) -> anyhow::Result<Vec<MemoryItem>>;
234
235    /// Delete a memory
236    async fn delete(&self, id: &str) -> anyhow::Result<()>;
237
238    /// Clear all memories
239    async fn clear(&self) -> anyhow::Result<()>;
240
241    /// Get total memory count
242    async fn count(&self) -> anyhow::Result<usize>;
243}
244
245// ============================================================================
246// Shared Search/Sort Helpers (DRY)
247// ============================================================================
248
249/// Sort memory items by relevance score (highest first)
250fn sort_by_relevance(items: &mut [MemoryItem]) {
251    let now = Utc::now();
252    items.sort_by(|a, b| {
253        b.relevance_score_at(now)
254            .partial_cmp(&a.relevance_score_at(now))
255            .unwrap_or(std::cmp::Ordering::Equal)
256    });
257}
258
259// ============================================================================
260// File-Based Memory Store
261// ============================================================================
262
263/// Compact index entry for fast in-memory search
264#[derive(Debug, Clone, Serialize, Deserialize)]
265struct IndexEntry {
266    id: String,
267    content_lower: String,
268    tags: Vec<String>,
269    importance: f32,
270    timestamp: DateTime<Utc>,
271    memory_type: MemoryType,
272}
273
274impl From<&MemoryItem> for IndexEntry {
275    fn from(item: &MemoryItem) -> Self {
276        Self {
277            id: item.id.clone(),
278            content_lower: item.content.to_lowercase(),
279            tags: item.tags.clone(),
280            importance: item.importance,
281            timestamp: item.timestamp,
282            memory_type: item.memory_type,
283        }
284    }
285}
286
287/// File-based memory store.
288///
289/// Stores each memory item as a JSON file with an in-memory index for fast search.
290///
291/// ```text
292/// memory_dir/
293///   index.json           # Compact index for fast search
294///   items/
295///     {id}.json          # Individual memory items
296/// ```
297///
298/// Follows the same atomic-write pattern as `FileSessionStore`:
299/// write to `.tmp`, then rename.
300pub struct FileMemoryStore {
301    items_dir: std::path::PathBuf,
302    index_path: std::path::PathBuf,
303    index: tokio::sync::RwLock<Vec<IndexEntry>>,
304}
305
306impl FileMemoryStore {
307    /// Create a new file memory store, loading the existing index if present.
308    pub async fn new(dir: impl AsRef<std::path::Path>) -> anyhow::Result<Self> {
309        let dir = dir.as_ref().to_path_buf();
310        let items_dir = dir.join("items");
311        let index_path = dir.join("index.json");
312
313        tokio::fs::create_dir_all(&items_dir)
314            .await
315            .with_context(|| {
316                format!("Failed to create memory directory: {}", items_dir.display())
317            })?;
318
319        // Load existing index or start empty
320        let index = if index_path.exists() {
321            let data = tokio::fs::read_to_string(&index_path)
322                .await
323                .with_context(|| {
324                    format!("Failed to read memory index: {}", index_path.display())
325                })?;
326            serde_json::from_str(&data).unwrap_or_default()
327        } else {
328            Vec::new()
329        };
330
331        Ok(Self {
332            items_dir,
333            index_path,
334            index: tokio::sync::RwLock::new(index),
335        })
336    }
337
338    /// Sanitize ID to prevent path traversal
339    fn safe_id(id: &str) -> String {
340        id.replace(['/', '\\'], "_").replace("..", "_")
341    }
342
343    /// Get the file path for a memory item
344    fn item_path(&self, id: &str) -> std::path::PathBuf {
345        self.items_dir.join(format!("{}.json", Self::safe_id(id)))
346    }
347
348    /// Persist the index to disk (atomic write)
349    async fn save_index(&self) -> anyhow::Result<()> {
350        let index = self.index.read().await;
351        let json = serde_json::to_string(&*index).context("Failed to serialize memory index")?;
352        drop(index);
353
354        let tmp = self.index_path.with_extension("json.tmp");
355        tokio::fs::write(&tmp, json.as_bytes())
356            .await
357            .context("Failed to write memory index temp file")?;
358        tokio::fs::rename(&tmp, &self.index_path)
359            .await
360            .context("Failed to rename memory index")?;
361        Ok(())
362    }
363
364    /// Write a single memory item to disk (atomic write)
365    async fn save_item(&self, item: &MemoryItem) -> anyhow::Result<()> {
366        let path = self.item_path(&item.id);
367        let json = serde_json::to_string_pretty(item)
368            .with_context(|| format!("Failed to serialize memory item: {}", item.id))?;
369
370        let tmp = path.with_extension("json.tmp");
371        tokio::fs::write(&tmp, json.as_bytes())
372            .await
373            .with_context(|| format!("Failed to write memory item: {}", item.id))?;
374        tokio::fs::rename(&tmp, &path)
375            .await
376            .with_context(|| format!("Failed to rename memory item: {}", item.id))?;
377        Ok(())
378    }
379
380    /// Rebuild the index from item files on disk.
381    ///
382    /// Useful for recovery if the index file is corrupted.
383    pub async fn rebuild_index(&self) -> anyhow::Result<usize> {
384        let mut entries = tokio::fs::read_dir(&self.items_dir).await?;
385        let mut new_index = Vec::new();
386
387        while let Some(entry) = entries.next_entry().await? {
388            let path = entry.path();
389            if path.extension().is_some_and(|ext| ext == "json") {
390                if let Ok(data) = tokio::fs::read_to_string(&path).await {
391                    if let Ok(item) = serde_json::from_str::<MemoryItem>(&data) {
392                        new_index.push(IndexEntry::from(&item));
393                    }
394                }
395            }
396        }
397
398        let count = new_index.len();
399        *self.index.write().await = new_index;
400        self.save_index().await?;
401        Ok(count)
402    }
403}
404
405use anyhow::Context as _;
406
407#[async_trait::async_trait]
408impl MemoryStore for FileMemoryStore {
409    async fn store(&self, item: MemoryItem) -> anyhow::Result<()> {
410        // Sanitize ID to prevent path traversal
411        let mut item = item;
412        item.id = Self::safe_id(&item.id);
413
414        // Write item file
415        self.save_item(&item).await?;
416
417        // Update index
418        let entry = IndexEntry::from(&item);
419        let mut index = self.index.write().await;
420        // Replace if exists, otherwise push
421        if let Some(pos) = index.iter().position(|e| e.id == item.id) {
422            index[pos] = entry;
423        } else {
424            index.push(entry);
425        }
426        drop(index);
427
428        self.save_index().await
429    }
430
431    async fn retrieve(&self, id: &str) -> anyhow::Result<Option<MemoryItem>> {
432        let path = self.item_path(id);
433        if !path.exists() {
434            return Ok(None);
435        }
436        let data = tokio::fs::read_to_string(&path).await?;
437        let mut item: MemoryItem = serde_json::from_str(&data)?;
438        item.content_lower = item.content.to_lowercase();
439        Ok(Some(item))
440    }
441
442    async fn search(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
443        let query_lower = query.to_lowercase();
444        let index = self.index.read().await;
445
446        // Find matching IDs from index
447        let mut matches: Vec<&IndexEntry> = index
448            .iter()
449            .filter(|e| e.content_lower.contains(&query_lower))
450            .collect();
451
452        // Sort by relevance
453        let now = Utc::now();
454        matches.sort_by(|a, b| {
455            let score_a = a.importance * 0.7
456                + (-((now - a.timestamp).num_seconds() as f32) / 2592000.0).exp() * 0.3;
457            let score_b = b.importance * 0.7
458                + (-((now - b.timestamp).num_seconds() as f32) / 2592000.0).exp() * 0.3;
459            score_b
460                .partial_cmp(&score_a)
461                .unwrap_or(std::cmp::Ordering::Equal)
462        });
463
464        let ids: Vec<String> = matches.iter().take(limit).map(|e| e.id.clone()).collect();
465        drop(index);
466
467        // Load full items from disk
468        let mut items = Vec::with_capacity(ids.len());
469        for id in ids {
470            if let Some(item) = self.retrieve(&id).await? {
471                items.push(item);
472            }
473        }
474        sort_by_relevance(&mut items);
475        Ok(items)
476    }
477
478    async fn search_by_tags(
479        &self,
480        tags: &[String],
481        limit: usize,
482    ) -> anyhow::Result<Vec<MemoryItem>> {
483        let index = self.index.read().await;
484
485        let mut matches: Vec<&IndexEntry> = index
486            .iter()
487            .filter(|e| tags.iter().any(|t| e.tags.contains(t)))
488            .collect();
489
490        let now = Utc::now();
491        matches.sort_by(|a, b| {
492            let score_a = a.importance * 0.7
493                + (-((now - a.timestamp).num_seconds() as f32) / 2592000.0).exp() * 0.3;
494            let score_b = b.importance * 0.7
495                + (-((now - b.timestamp).num_seconds() as f32) / 2592000.0).exp() * 0.3;
496            score_b
497                .partial_cmp(&score_a)
498                .unwrap_or(std::cmp::Ordering::Equal)
499        });
500
501        let ids: Vec<String> = matches.iter().take(limit).map(|e| e.id.clone()).collect();
502        drop(index);
503
504        let mut items = Vec::with_capacity(ids.len());
505        for id in ids {
506            if let Some(item) = self.retrieve(&id).await? {
507                items.push(item);
508            }
509        }
510        sort_by_relevance(&mut items);
511        Ok(items)
512    }
513
514    async fn get_recent(&self, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
515        let index = self.index.read().await;
516        let mut sorted: Vec<&IndexEntry> = index.iter().collect();
517        sorted.sort_by(|a, b| b.timestamp.cmp(&a.timestamp));
518
519        let ids: Vec<String> = sorted.iter().take(limit).map(|e| e.id.clone()).collect();
520        drop(index);
521
522        let mut items = Vec::with_capacity(ids.len());
523        for id in ids {
524            if let Some(item) = self.retrieve(&id).await? {
525                items.push(item);
526            }
527        }
528        // Preserve recency order
529        items.sort_by(|a, b| b.timestamp.cmp(&a.timestamp));
530        Ok(items)
531    }
532
533    async fn get_important(&self, threshold: f32, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
534        let index = self.index.read().await;
535        let mut matches: Vec<&IndexEntry> =
536            index.iter().filter(|e| e.importance >= threshold).collect();
537        matches.sort_by(|a, b| {
538            b.importance
539                .partial_cmp(&a.importance)
540                .unwrap_or(std::cmp::Ordering::Equal)
541        });
542
543        let ids: Vec<String> = matches.iter().take(limit).map(|e| e.id.clone()).collect();
544        drop(index);
545
546        let mut items = Vec::with_capacity(ids.len());
547        for id in ids {
548            if let Some(item) = self.retrieve(&id).await? {
549                items.push(item);
550            }
551        }
552        items.sort_by(|a, b| {
553            b.importance
554                .partial_cmp(&a.importance)
555                .unwrap_or(std::cmp::Ordering::Equal)
556        });
557        Ok(items)
558    }
559
560    async fn delete(&self, id: &str) -> anyhow::Result<()> {
561        let path = self.item_path(id);
562        if path.exists() {
563            tokio::fs::remove_file(&path).await?;
564        }
565
566        let mut index = self.index.write().await;
567        index.retain(|e| e.id != id);
568        drop(index);
569
570        self.save_index().await
571    }
572
573    async fn clear(&self) -> anyhow::Result<()> {
574        // Remove all item files
575        let mut entries = tokio::fs::read_dir(&self.items_dir).await?;
576        while let Some(entry) = entries.next_entry().await? {
577            let path = entry.path();
578            if path.extension().is_some_and(|ext| ext == "json") {
579                let _ = tokio::fs::remove_file(&path).await;
580            }
581        }
582
583        // Clear index
584        self.index.write().await.clear();
585        self.save_index().await
586    }
587
588    async fn count(&self) -> anyhow::Result<usize> {
589        Ok(self.index.read().await.len())
590    }
591}
592
593// ============================================================================
594// In-Memory Store
595// ============================================================================
596
597/// Agent memory system
598#[derive(Clone)]
599pub struct AgentMemory {
600    /// Long-term memory store
601    store: Arc<dyn MemoryStore>,
602    /// Short-term memory (current session)
603    short_term: Arc<RwLock<VecDeque<MemoryItem>>>,
604    /// Working memory (active context)
605    working: Arc<RwLock<Vec<MemoryItem>>>,
606    /// Maximum short-term memory size
607    max_short_term: usize,
608    /// Maximum working memory size
609    max_working: usize,
610    /// Relevance scoring configuration
611    relevance_config: RelevanceConfig,
612}
613
614impl std::fmt::Debug for AgentMemory {
615    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
616        f.debug_struct("AgentMemory")
617            .field("max_short_term", &self.max_short_term)
618            .field("max_working", &self.max_working)
619            .finish()
620    }
621}
622
623impl AgentMemory {
624    /// Create a new agent memory system with default configuration
625    pub fn new(store: Arc<dyn MemoryStore>) -> Self {
626        Self::with_config(store, MemoryConfig::default())
627    }
628
629    /// Create a new agent memory system with custom configuration
630    pub fn with_config(store: Arc<dyn MemoryStore>, config: MemoryConfig) -> Self {
631        Self {
632            store,
633            short_term: Arc::new(RwLock::new(VecDeque::new())),
634            working: Arc::new(RwLock::new(Vec::new())),
635            max_short_term: config.max_short_term,
636            max_working: config.max_working,
637            relevance_config: config.relevance,
638        }
639    }
640
641    /// Calculate relevance score using this memory system's configuration
642    fn score(&self, item: &MemoryItem, now: DateTime<Utc>) -> f32 {
643        let age_seconds = (now - item.timestamp).num_seconds() as f32;
644        let age_days = age_seconds / 86400.0;
645        let decay = (-age_days / self.relevance_config.decay_days).exp();
646        item.importance * self.relevance_config.importance_weight
647            + decay * self.relevance_config.recency_weight
648    }
649
650    /// Store a memory in long-term storage
651    pub async fn remember(&self, item: MemoryItem) -> anyhow::Result<()> {
652        // Store in long-term
653        self.store.store(item.clone()).await?;
654
655        // Add to short-term
656        let mut short_term = self.short_term.write().await;
657        short_term.push_back(item);
658
659        // Trim if needed
660        if short_term.len() > self.max_short_term {
661            short_term.pop_front();
662        }
663
664        Ok(())
665    }
666
667    /// Remember a successful pattern
668    pub async fn remember_success(
669        &self,
670        prompt: &str,
671        tools_used: &[String],
672        result: &str,
673    ) -> anyhow::Result<()> {
674        let content = format!(
675            "Success: {}\nTools: {}\nResult: {}",
676            prompt,
677            tools_used.join(", "),
678            result
679        );
680
681        let item = MemoryItem::new(content)
682            .with_importance(0.8)
683            .with_tag("success")
684            .with_tag("pattern")
685            .with_type(MemoryType::Procedural)
686            .with_metadata("prompt", prompt)
687            .with_metadata("tools", tools_used.join(","));
688
689        self.remember(item).await
690    }
691
692    /// Remember a failure to avoid repeating
693    pub async fn remember_failure(
694        &self,
695        prompt: &str,
696        error: &str,
697        attempted_tools: &[String],
698    ) -> anyhow::Result<()> {
699        let content = format!(
700            "Failure: {}\nError: {}\nAttempted tools: {}",
701            prompt,
702            error,
703            attempted_tools.join(", ")
704        );
705
706        let item = MemoryItem::new(content)
707            .with_importance(0.9) // Failures are important to remember
708            .with_tag("failure")
709            .with_tag("avoid")
710            .with_type(MemoryType::Episodic)
711            .with_metadata("prompt", prompt)
712            .with_metadata("error", error);
713
714        self.remember(item).await
715    }
716
717    /// Recall similar past experiences
718    pub async fn recall_similar(
719        &self,
720        prompt: &str,
721        limit: usize,
722    ) -> anyhow::Result<Vec<MemoryItem>> {
723        self.store.search(prompt, limit).await
724    }
725
726    /// Recall by tags
727    pub async fn recall_by_tags(
728        &self,
729        tags: &[String],
730        limit: usize,
731    ) -> anyhow::Result<Vec<MemoryItem>> {
732        self.store.search_by_tags(tags, limit).await
733    }
734
735    /// Get recent memories
736    pub async fn get_recent(&self, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
737        self.store.get_recent(limit).await
738    }
739
740    /// Add to working memory
741    pub async fn add_to_working(&self, item: MemoryItem) -> anyhow::Result<()> {
742        let mut working = self.working.write().await;
743        working.push(item);
744
745        // Trim if needed (keep most relevant)
746        if working.len() > self.max_working {
747            let now = Utc::now();
748            working.sort_by(|a, b| {
749                self.score(b, now)
750                    .partial_cmp(&self.score(a, now))
751                    .unwrap_or(std::cmp::Ordering::Equal)
752            });
753            working.truncate(self.max_working);
754        }
755
756        Ok(())
757    }
758
759    /// Get working memory
760    pub async fn get_working(&self) -> Vec<MemoryItem> {
761        self.working.read().await.clone()
762    }
763
764    /// Clear working memory
765    pub async fn clear_working(&self) {
766        self.working.write().await.clear();
767    }
768
769    /// Get short-term memory
770    pub async fn get_short_term(&self) -> Vec<MemoryItem> {
771        self.short_term.read().await.iter().cloned().collect()
772    }
773
774    /// Clear short-term memory
775    pub async fn clear_short_term(&self) {
776        self.short_term.write().await.clear();
777    }
778
779    /// Get memory statistics
780    pub async fn stats(&self) -> anyhow::Result<MemoryStats> {
781        let long_term_count = self.store.count().await?;
782        let short_term_count = self.short_term.read().await.len();
783        let working_count = self.working.read().await.len();
784
785        Ok(MemoryStats {
786            long_term_count,
787            short_term_count,
788            working_count,
789        })
790    }
791
792    /// Get access to the underlying store
793    pub fn store(&self) -> &Arc<dyn MemoryStore> {
794        &self.store
795    }
796
797    /// Get working memory count
798    pub async fn working_count(&self) -> usize {
799        self.working.read().await.len()
800    }
801
802    /// Get short-term memory count
803    pub async fn short_term_count(&self) -> usize {
804        self.short_term.read().await.len()
805    }
806}
807
808/// Memory statistics
809#[derive(Debug, Clone, Serialize, Deserialize)]
810pub struct MemoryStats {
811    /// Number of long-term memories
812    pub long_term_count: usize,
813    /// Number of short-term memories
814    pub short_term_count: usize,
815    /// Number of working memories
816    pub working_count: usize,
817}
818
819// ============================================================================
820// Memory Context Provider
821// ============================================================================
822
823/// Context provider that surfaces past memories (successes/failures) as context.
824///
825/// Wraps `AgentMemory` and implements the `ContextProvider` trait so that
826/// session memory is automatically injected into the agent's system prompt.
827pub struct MemoryContextProvider {
828    memory: AgentMemory,
829}
830
831impl MemoryContextProvider {
832    /// Create a new memory context provider
833    pub fn new(memory: AgentMemory) -> Self {
834        Self { memory }
835    }
836}
837
838#[async_trait::async_trait]
839impl crate::context::ContextProvider for MemoryContextProvider {
840    fn name(&self) -> &str {
841        "memory"
842    }
843
844    async fn query(
845        &self,
846        query: &crate::context::ContextQuery,
847    ) -> anyhow::Result<crate::context::ContextResult> {
848        let limit = query.max_results.min(5);
849        let items = self.memory.recall_similar(&query.query, limit).await?;
850
851        let mut result = crate::context::ContextResult::new("memory");
852        for item in items {
853            let relevance = item.relevance_score();
854            let token_count = item.content.len() / 4; // rough estimate
855            let context_item = crate::context::ContextItem::new(
856                &item.id,
857                crate::context::ContextType::Memory,
858                &item.content,
859            )
860            .with_relevance(relevance)
861            .with_token_count(token_count)
862            .with_source("memory");
863            result.add_item(context_item);
864        }
865
866        Ok(result)
867    }
868
869    async fn on_turn_complete(
870        &self,
871        _session_id: &str,
872        prompt: &str,
873        response: &str,
874    ) -> anyhow::Result<()> {
875        // Store the successful interaction as a memory
876        self.memory.remember_success(prompt, &[], response).await
877    }
878}
879
880// ============================================================================
881// Tests
882// ============================================================================
883
884#[cfg(test)]
885mod tests {
886    use super::*;
887
888    /// Simple in-memory store for testing
889    struct TestMemoryStore {
890        items: std::sync::Mutex<Vec<MemoryItem>>,
891    }
892
893    impl TestMemoryStore {
894        fn new() -> Self {
895            Self {
896                items: std::sync::Mutex::new(Vec::new()),
897            }
898        }
899    }
900
901    #[async_trait::async_trait]
902    impl MemoryStore for TestMemoryStore {
903        async fn store(&self, item: MemoryItem) -> anyhow::Result<()> {
904            self.items.lock().unwrap().push(item);
905            Ok(())
906        }
907        async fn retrieve(&self, id: &str) -> anyhow::Result<Option<MemoryItem>> {
908            Ok(self
909                .items
910                .lock()
911                .unwrap()
912                .iter()
913                .find(|i| i.id == id)
914                .cloned())
915        }
916        async fn search(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
917            let items = self.items.lock().unwrap();
918            let query_lower = query.to_lowercase();
919            Ok(items
920                .iter()
921                .filter(|i| i.content.to_lowercase().contains(&query_lower))
922                .take(limit)
923                .cloned()
924                .collect())
925        }
926        async fn search_by_tags(
927            &self,
928            tags: &[String],
929            limit: usize,
930        ) -> anyhow::Result<Vec<MemoryItem>> {
931            let items = self.items.lock().unwrap();
932            Ok(items
933                .iter()
934                .filter(|i| tags.iter().any(|t| i.tags.contains(t)))
935                .take(limit)
936                .cloned()
937                .collect())
938        }
939        async fn get_recent(&self, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
940            let items = self.items.lock().unwrap();
941            let mut sorted: Vec<_> = items.iter().cloned().collect();
942            sorted.sort_by(|a, b| b.timestamp.cmp(&a.timestamp));
943            sorted.truncate(limit);
944            Ok(sorted)
945        }
946        async fn get_important(
947            &self,
948            threshold: f32,
949            limit: usize,
950        ) -> anyhow::Result<Vec<MemoryItem>> {
951            let items = self.items.lock().unwrap();
952            Ok(items
953                .iter()
954                .filter(|i| i.importance >= threshold)
955                .take(limit)
956                .cloned()
957                .collect())
958        }
959        async fn delete(&self, id: &str) -> anyhow::Result<()> {
960            self.items.lock().unwrap().retain(|i| i.id != id);
961            Ok(())
962        }
963        async fn clear(&self) -> anyhow::Result<()> {
964            self.items.lock().unwrap().clear();
965            Ok(())
966        }
967        async fn count(&self) -> anyhow::Result<usize> {
968            Ok(self.items.lock().unwrap().len())
969        }
970    }
971
972    #[test]
973    fn test_memory_item_creation() {
974        let item = MemoryItem::new("Test memory")
975            .with_importance(0.8)
976            .with_tag("test")
977            .with_type(MemoryType::Semantic);
978
979        assert_eq!(item.content, "Test memory");
980        assert_eq!(item.importance, 0.8);
981        assert_eq!(item.tags, vec!["test"]);
982        assert_eq!(item.memory_type, MemoryType::Semantic);
983    }
984
985    #[test]
986    fn test_memory_item_relevance() {
987        let item = MemoryItem::new("Test").with_importance(0.9);
988        let score = item.relevance_score();
989
990        // Should be high for recent, important memory
991        assert!(score > 0.6);
992    }
993
994    #[test]
995    fn test_relevance_config_defaults() {
996        let config = RelevanceConfig::default();
997        assert_eq!(config.decay_days, 30.0);
998        assert_eq!(config.importance_weight, 0.7);
999        assert_eq!(config.recency_weight, 0.3);
1000    }
1001
1002    #[test]
1003    fn test_memory_config_defaults() {
1004        let config = MemoryConfig::default();
1005        assert_eq!(config.max_short_term, 100);
1006        assert_eq!(config.max_working, 10);
1007        assert_eq!(config.relevance.decay_days, 30.0);
1008    }
1009
1010    #[test]
1011    fn test_memory_config_serde_roundtrip() {
1012        let config = MemoryConfig::default();
1013        let json = serde_json::to_string(&config).unwrap();
1014        let parsed: MemoryConfig = serde_json::from_str(&json).unwrap();
1015        assert_eq!(parsed.max_short_term, config.max_short_term);
1016        assert_eq!(parsed.max_working, config.max_working);
1017        assert_eq!(parsed.relevance.decay_days, config.relevance.decay_days);
1018    }
1019
1020    #[test]
1021    fn test_agent_memory_with_config() {
1022        let config = MemoryConfig {
1023            relevance: RelevanceConfig {
1024                decay_days: 7.0,
1025                importance_weight: 0.5,
1026                recency_weight: 0.5,
1027            },
1028            max_short_term: 50,
1029            max_working: 5,
1030        };
1031        let memory = AgentMemory::with_config(Arc::new(TestMemoryStore::new()), config);
1032        assert_eq!(memory.max_short_term, 50);
1033        assert_eq!(memory.max_working, 5);
1034        assert_eq!(memory.relevance_config.decay_days, 7.0);
1035    }
1036
1037    #[test]
1038    fn test_agent_memory_score_uses_config() {
1039        let config = MemoryConfig {
1040            relevance: RelevanceConfig {
1041                decay_days: 7.0,
1042                importance_weight: 0.9,
1043                recency_weight: 0.1,
1044            },
1045            ..Default::default()
1046        };
1047        let memory = AgentMemory::with_config(Arc::new(TestMemoryStore::new()), config);
1048
1049        let item = MemoryItem::new("Test").with_importance(1.0);
1050        let now = Utc::now();
1051        let score = memory.score(&item, now);
1052
1053        // With importance_weight=0.9, a brand new item with importance=1.0
1054        // should score close to 0.9 + 0.1 = 1.0 (decay ~1.0 for recent items)
1055        assert!(score > 0.95, "Score was {}", score);
1056    }
1057
1058    #[tokio::test]
1059    async fn test_in_memory_store() {
1060        let store = TestMemoryStore::new();
1061
1062        let item = MemoryItem::new("Test memory").with_tag("test");
1063        store.store(item.clone()).await.unwrap();
1064
1065        let retrieved = store.retrieve(&item.id).await.unwrap();
1066        assert!(retrieved.is_some());
1067        assert_eq!(retrieved.unwrap().content, "Test memory");
1068    }
1069
1070    #[tokio::test]
1071    async fn test_memory_search() {
1072        let store = TestMemoryStore::new();
1073
1074        store
1075            .store(MemoryItem::new("How to create a file").with_tag("file"))
1076            .await
1077            .unwrap();
1078        store
1079            .store(MemoryItem::new("How to delete a file").with_tag("file"))
1080            .await
1081            .unwrap();
1082        store
1083            .store(MemoryItem::new("How to create a directory").with_tag("dir"))
1084            .await
1085            .unwrap();
1086
1087        let results = store.search("create", 10).await.unwrap();
1088        assert_eq!(results.len(), 2);
1089
1090        let results = store
1091            .search_by_tags(&["file".to_string()], 10)
1092            .await
1093            .unwrap();
1094        assert_eq!(results.len(), 2);
1095    }
1096
1097    #[tokio::test]
1098    async fn test_agent_memory() {
1099        let memory = AgentMemory::new(Arc::new(TestMemoryStore::new()));
1100
1101        // Remember success
1102        memory
1103            .remember_success("Create a file", &["write".to_string()], "File created")
1104            .await
1105            .unwrap();
1106
1107        // Remember failure
1108        memory
1109            .remember_failure("Delete file", "Permission denied", &["bash".to_string()])
1110            .await
1111            .unwrap();
1112
1113        // Recall
1114        let results = memory.recall_similar("create", 10).await.unwrap();
1115        assert!(!results.is_empty());
1116
1117        let stats = memory.stats().await.unwrap();
1118        assert_eq!(stats.long_term_count, 2);
1119    }
1120
1121    #[tokio::test]
1122    async fn test_working_memory() {
1123        let memory = AgentMemory::new(Arc::new(TestMemoryStore::new()));
1124
1125        let item = MemoryItem::new("Active task").with_type(MemoryType::Working);
1126        memory.add_to_working(item).await.unwrap();
1127
1128        let working = memory.get_working().await;
1129        assert_eq!(working.len(), 1);
1130
1131        memory.clear_working().await;
1132        let working = memory.get_working().await;
1133        assert_eq!(working.len(), 0);
1134    }
1135}
1136
1137#[cfg(test)]
1138mod extra_memory_tests {
1139    use super::*;
1140
1141    /// Simple in-memory store for testing
1142    struct TestMemoryStore {
1143        items: std::sync::Mutex<Vec<MemoryItem>>,
1144    }
1145
1146    impl TestMemoryStore {
1147        fn new() -> Self {
1148            Self {
1149                items: std::sync::Mutex::new(Vec::new()),
1150            }
1151        }
1152    }
1153
1154    #[async_trait::async_trait]
1155    impl MemoryStore for TestMemoryStore {
1156        async fn store(&self, item: MemoryItem) -> anyhow::Result<()> {
1157            self.items.lock().unwrap().push(item);
1158            Ok(())
1159        }
1160        async fn retrieve(&self, id: &str) -> anyhow::Result<Option<MemoryItem>> {
1161            Ok(self
1162                .items
1163                .lock()
1164                .unwrap()
1165                .iter()
1166                .find(|i| i.id == id)
1167                .cloned())
1168        }
1169        async fn search(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
1170            let items = self.items.lock().unwrap();
1171            let query_lower = query.to_lowercase();
1172            Ok(items
1173                .iter()
1174                .filter(|i| i.content.to_lowercase().contains(&query_lower))
1175                .take(limit)
1176                .cloned()
1177                .collect())
1178        }
1179        async fn search_by_tags(
1180            &self,
1181            tags: &[String],
1182            limit: usize,
1183        ) -> anyhow::Result<Vec<MemoryItem>> {
1184            let items = self.items.lock().unwrap();
1185            Ok(items
1186                .iter()
1187                .filter(|i| tags.iter().any(|t| i.tags.contains(t)))
1188                .take(limit)
1189                .cloned()
1190                .collect())
1191        }
1192        async fn get_recent(&self, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
1193            let items = self.items.lock().unwrap();
1194            let mut sorted: Vec<_> = items.iter().cloned().collect();
1195            sorted.sort_by(|a, b| b.timestamp.cmp(&a.timestamp));
1196            sorted.truncate(limit);
1197            Ok(sorted)
1198        }
1199        async fn get_important(
1200            &self,
1201            threshold: f32,
1202            limit: usize,
1203        ) -> anyhow::Result<Vec<MemoryItem>> {
1204            let items = self.items.lock().unwrap();
1205            Ok(items
1206                .iter()
1207                .filter(|i| i.importance >= threshold)
1208                .take(limit)
1209                .cloned()
1210                .collect())
1211        }
1212        async fn delete(&self, id: &str) -> anyhow::Result<()> {
1213            self.items.lock().unwrap().retain(|i| i.id != id);
1214            Ok(())
1215        }
1216        async fn clear(&self) -> anyhow::Result<()> {
1217            self.items.lock().unwrap().clear();
1218            Ok(())
1219        }
1220        async fn count(&self) -> anyhow::Result<usize> {
1221            Ok(self.items.lock().unwrap().len())
1222        }
1223    }
1224
1225    // ========================================================================
1226    // MemoryItem builder methods
1227    // ========================================================================
1228
1229    #[test]
1230    fn test_memory_item_with_metadata() {
1231        let item = MemoryItem::new("test")
1232            .with_metadata("key1", "value1")
1233            .with_metadata("key2", "value2");
1234        assert_eq!(item.metadata.get("key1").unwrap(), "value1");
1235        assert_eq!(item.metadata.get("key2").unwrap(), "value2");
1236    }
1237
1238    #[test]
1239    fn test_memory_item_with_tags_vec() {
1240        let item = MemoryItem::new("test").with_tags(vec![
1241            "a".to_string(),
1242            "b".to_string(),
1243            "c".to_string(),
1244        ]);
1245        assert_eq!(item.tags.len(), 3);
1246    }
1247
1248    #[test]
1249    fn test_memory_item_importance_clamped() {
1250        let item_high = MemoryItem::new("test").with_importance(1.5);
1251        assert_eq!(item_high.importance, 1.0);
1252
1253        let item_low = MemoryItem::new("test").with_importance(-0.5);
1254        assert_eq!(item_low.importance, 0.0);
1255    }
1256
1257    #[test]
1258    fn test_memory_item_record_access() {
1259        let mut item = MemoryItem::new("test");
1260        assert_eq!(item.access_count, 0);
1261        assert!(item.last_accessed.is_none());
1262
1263        item.record_access();
1264        assert_eq!(item.access_count, 1);
1265        assert!(item.last_accessed.is_some());
1266
1267        item.record_access();
1268        assert_eq!(item.access_count, 2);
1269    }
1270
1271    #[test]
1272    fn test_memory_item_all_types() {
1273        let episodic = MemoryItem::new("e").with_type(MemoryType::Episodic);
1274        assert_eq!(episodic.memory_type, MemoryType::Episodic);
1275
1276        let semantic = MemoryItem::new("s").with_type(MemoryType::Semantic);
1277        assert_eq!(semantic.memory_type, MemoryType::Semantic);
1278
1279        let procedural = MemoryItem::new("p").with_type(MemoryType::Procedural);
1280        assert_eq!(procedural.memory_type, MemoryType::Procedural);
1281
1282        let working = MemoryItem::new("w").with_type(MemoryType::Working);
1283        assert_eq!(working.memory_type, MemoryType::Working);
1284    }
1285
1286    #[test]
1287    fn test_memory_item_default_type_is_episodic() {
1288        let item = MemoryItem::new("test");
1289        assert_eq!(item.memory_type, MemoryType::Episodic);
1290    }
1291
1292    // ========================================================================
1293    // TestMemoryStore
1294    // ========================================================================
1295
1296    #[tokio::test]
1297    async fn test_in_memory_store_retrieve_nonexistent() {
1298        let store = TestMemoryStore::new();
1299        let result = store.retrieve("nonexistent").await.unwrap();
1300        assert!(result.is_none());
1301    }
1302
1303    #[tokio::test]
1304    async fn test_in_memory_store_delete() {
1305        let store = TestMemoryStore::new();
1306        let item = MemoryItem::new("to delete");
1307        let id = item.id.clone();
1308        store.store(item).await.unwrap();
1309        assert_eq!(store.count().await.unwrap(), 1);
1310
1311        store.delete(&id).await.unwrap();
1312        assert_eq!(store.count().await.unwrap(), 0);
1313    }
1314
1315    #[tokio::test]
1316    async fn test_in_memory_store_clear() {
1317        let store = TestMemoryStore::new();
1318        for i in 0..5 {
1319            store
1320                .store(MemoryItem::new(format!("item {}", i)))
1321                .await
1322                .unwrap();
1323        }
1324        assert_eq!(store.count().await.unwrap(), 5);
1325
1326        store.clear().await.unwrap();
1327        assert_eq!(store.count().await.unwrap(), 0);
1328    }
1329
1330    #[tokio::test]
1331    async fn test_in_memory_store_get_recent() {
1332        let store = TestMemoryStore::new();
1333        for i in 0..5 {
1334            store
1335                .store(MemoryItem::new(format!("item {}", i)))
1336                .await
1337                .unwrap();
1338        }
1339        let recent = store.get_recent(3).await.unwrap();
1340        assert_eq!(recent.len(), 3);
1341    }
1342
1343    #[tokio::test]
1344    async fn test_in_memory_store_get_important() {
1345        let store = TestMemoryStore::new();
1346        store
1347            .store(MemoryItem::new("low").with_importance(0.2))
1348            .await
1349            .unwrap();
1350        store
1351            .store(MemoryItem::new("medium").with_importance(0.5))
1352            .await
1353            .unwrap();
1354        store
1355            .store(MemoryItem::new("high").with_importance(0.9))
1356            .await
1357            .unwrap();
1358
1359        let important = store.get_important(0.7, 10).await.unwrap();
1360        assert_eq!(important.len(), 1);
1361        assert_eq!(important[0].content, "high");
1362    }
1363
1364    #[tokio::test]
1365    async fn test_in_memory_store_search_case_insensitive() {
1366        let store = TestMemoryStore::new();
1367        store
1368            .store(MemoryItem::new("How to CREATE a file"))
1369            .await
1370            .unwrap();
1371        let results = store.search("create", 10).await.unwrap();
1372        assert_eq!(results.len(), 1);
1373    }
1374
1375    // ========================================================================
1376    // AgentMemory
1377    // ========================================================================
1378
1379    #[tokio::test]
1380    async fn test_agent_memory_short_term() {
1381        let memory = AgentMemory::new(Arc::new(TestMemoryStore::new()));
1382        memory.remember(MemoryItem::new("item 1")).await.unwrap();
1383        memory.remember(MemoryItem::new("item 2")).await.unwrap();
1384
1385        let short_term = memory.get_short_term().await;
1386        assert_eq!(short_term.len(), 2);
1387
1388        memory.clear_short_term().await;
1389        let short_term = memory.get_short_term().await;
1390        assert_eq!(short_term.len(), 0);
1391    }
1392
1393    #[tokio::test]
1394    async fn test_agent_memory_short_term_count() {
1395        let memory = AgentMemory::new(Arc::new(TestMemoryStore::new()));
1396        assert_eq!(memory.short_term_count().await, 0);
1397        memory.remember(MemoryItem::new("item")).await.unwrap();
1398        assert_eq!(memory.short_term_count().await, 1);
1399    }
1400
1401    #[tokio::test]
1402    async fn test_agent_memory_working_count() {
1403        let memory = AgentMemory::new(Arc::new(TestMemoryStore::new()));
1404        assert_eq!(memory.working_count().await, 0);
1405        memory
1406            .add_to_working(MemoryItem::new("task"))
1407            .await
1408            .unwrap();
1409        assert_eq!(memory.working_count().await, 1);
1410    }
1411
1412    #[tokio::test]
1413    async fn test_agent_memory_recall_by_tags() {
1414        let memory = AgentMemory::new(Arc::new(TestMemoryStore::new()));
1415        memory
1416            .remember_success("create file", &["write".to_string()], "ok")
1417            .await
1418            .unwrap();
1419        memory
1420            .remember_failure("delete file", "denied", &["bash".to_string()])
1421            .await
1422            .unwrap();
1423
1424        let successes = memory
1425            .recall_by_tags(&["success".to_string()], 10)
1426            .await
1427            .unwrap();
1428        assert_eq!(successes.len(), 1);
1429
1430        let failures = memory
1431            .recall_by_tags(&["failure".to_string()], 10)
1432            .await
1433            .unwrap();
1434        assert_eq!(failures.len(), 1);
1435    }
1436
1437    #[tokio::test]
1438    async fn test_agent_memory_get_recent() {
1439        let memory = AgentMemory::new(Arc::new(TestMemoryStore::new()));
1440        for i in 0..5 {
1441            memory
1442                .remember(MemoryItem::new(format!("item {}", i)))
1443                .await
1444                .unwrap();
1445        }
1446        let recent = memory.get_recent(3).await.unwrap();
1447        assert_eq!(recent.len(), 3);
1448    }
1449
1450    #[tokio::test]
1451    async fn test_agent_memory_store_accessor() {
1452        let memory = AgentMemory::new(Arc::new(TestMemoryStore::new()));
1453        memory.remember(MemoryItem::new("test")).await.unwrap();
1454        let count = memory.store().count().await.unwrap();
1455        assert_eq!(count, 1);
1456    }
1457
1458    #[tokio::test]
1459    async fn test_agent_memory_stats_all_fields() {
1460        let memory = AgentMemory::new(Arc::new(TestMemoryStore::new()));
1461        memory.remember(MemoryItem::new("long term")).await.unwrap();
1462        memory
1463            .add_to_working(MemoryItem::new("working"))
1464            .await
1465            .unwrap();
1466
1467        let stats = memory.stats().await.unwrap();
1468        assert_eq!(stats.long_term_count, 1);
1469        assert_eq!(stats.short_term_count, 1); // remember also adds to short_term
1470        assert_eq!(stats.working_count, 1);
1471    }
1472
1473    #[tokio::test]
1474    async fn test_agent_memory_working_overflow_trims() {
1475        let store = Arc::new(TestMemoryStore::new());
1476        let memory = AgentMemory {
1477            store,
1478            short_term: Arc::new(RwLock::new(VecDeque::new())),
1479            working: Arc::new(RwLock::new(Vec::new())),
1480            max_short_term: 100,
1481            max_working: 3, // Small limit
1482            relevance_config: RelevanceConfig::default(),
1483        };
1484
1485        for i in 0..5 {
1486            memory
1487                .add_to_working(
1488                    MemoryItem::new(format!("task {}", i)).with_importance(i as f32 * 0.2),
1489                )
1490                .await
1491                .unwrap();
1492        }
1493
1494        let working = memory.get_working().await;
1495        assert_eq!(working.len(), 3); // Trimmed to max_working
1496    }
1497}
1498
1499#[cfg(test)]
1500mod file_memory_store_tests {
1501    use super::*;
1502    use tempfile::TempDir;
1503
1504    async fn setup() -> (TempDir, FileMemoryStore) {
1505        let dir = TempDir::new().unwrap();
1506        let store = FileMemoryStore::new(dir.path()).await.unwrap();
1507        (dir, store)
1508    }
1509
1510    fn sample_item(content: &str) -> MemoryItem {
1511        MemoryItem::new(content.to_string())
1512    }
1513
1514    #[tokio::test]
1515    async fn test_store_and_retrieve() {
1516        let (_dir, store) = setup().await;
1517        let item = sample_item("hello world");
1518        let id = item.id.clone();
1519
1520        store.store(item).await.unwrap();
1521        let retrieved = store.retrieve(&id).await.unwrap().unwrap();
1522        assert_eq!(retrieved.content, "hello world");
1523    }
1524
1525    #[tokio::test]
1526    async fn test_retrieve_nonexistent() {
1527        let (_dir, store) = setup().await;
1528        let result = store.retrieve("nonexistent").await.unwrap();
1529        assert!(result.is_none());
1530    }
1531
1532    #[tokio::test]
1533    async fn test_search_by_content() {
1534        let (_dir, store) = setup().await;
1535        store.store(sample_item("rust programming")).await.unwrap();
1536        store.store(sample_item("python scripting")).await.unwrap();
1537        store
1538            .store(sample_item("rust async patterns"))
1539            .await
1540            .unwrap();
1541
1542        let results = store.search("rust", 10).await.unwrap();
1543        assert_eq!(results.len(), 2);
1544        assert!(results.iter().all(|r| r.content.contains("rust")));
1545    }
1546
1547    #[tokio::test]
1548    async fn test_search_limit() {
1549        let (_dir, store) = setup().await;
1550        for i in 0..10 {
1551            store
1552                .store(sample_item(&format!("item {}", i)))
1553                .await
1554                .unwrap();
1555        }
1556
1557        let results = store.search("item", 3).await.unwrap();
1558        assert_eq!(results.len(), 3);
1559    }
1560
1561    #[tokio::test]
1562    async fn test_search_by_tags() {
1563        let (_dir, store) = setup().await;
1564        store
1565            .store(sample_item("tagged one").with_tags(vec!["rust".into(), "async".into()]))
1566            .await
1567            .unwrap();
1568        store
1569            .store(sample_item("tagged two").with_tags(vec!["python".into()]))
1570            .await
1571            .unwrap();
1572        store
1573            .store(sample_item("tagged three").with_tags(vec!["rust".into(), "web".into()]))
1574            .await
1575            .unwrap();
1576
1577        let results = store
1578            .search_by_tags(&["rust".to_string()], 10)
1579            .await
1580            .unwrap();
1581        assert_eq!(results.len(), 2);
1582    }
1583
1584    #[tokio::test]
1585    async fn test_get_recent() {
1586        let (_dir, store) = setup().await;
1587        for i in 0..5 {
1588            let mut item = sample_item(&format!("item {}", i));
1589            item.timestamp = Utc::now() + chrono::Duration::seconds(i as i64);
1590            store.store(item).await.unwrap();
1591        }
1592
1593        let results = store.get_recent(3).await.unwrap();
1594        assert_eq!(results.len(), 3);
1595        // Most recent first
1596        assert!(results[0].timestamp >= results[1].timestamp);
1597        assert!(results[1].timestamp >= results[2].timestamp);
1598    }
1599
1600    #[tokio::test]
1601    async fn test_get_important() {
1602        let (_dir, store) = setup().await;
1603        store
1604            .store(sample_item("low").with_importance(0.1))
1605            .await
1606            .unwrap();
1607        store
1608            .store(sample_item("high").with_importance(0.9))
1609            .await
1610            .unwrap();
1611        store
1612            .store(sample_item("medium").with_importance(0.5))
1613            .await
1614            .unwrap();
1615
1616        let results = store.get_important(0.0, 2).await.unwrap();
1617        assert_eq!(results.len(), 2);
1618        assert!(results[0].importance >= results[1].importance);
1619        assert_eq!(results[0].content, "high");
1620    }
1621
1622    #[tokio::test]
1623    async fn test_delete() {
1624        let (_dir, store) = setup().await;
1625        let item = sample_item("to delete");
1626        let id = item.id.clone();
1627
1628        store.store(item).await.unwrap();
1629        assert_eq!(store.count().await.unwrap(), 1);
1630
1631        store.delete(&id).await.unwrap();
1632        assert_eq!(store.count().await.unwrap(), 0);
1633        assert!(store.retrieve(&id).await.unwrap().is_none());
1634    }
1635
1636    #[tokio::test]
1637    async fn test_delete_nonexistent() {
1638        let (_dir, store) = setup().await;
1639        // Should not error even if ID doesn't exist
1640        store.delete("nonexistent").await.unwrap();
1641    }
1642
1643    #[tokio::test]
1644    async fn test_clear() {
1645        let (_dir, store) = setup().await;
1646        for i in 0..5 {
1647            store
1648                .store(sample_item(&format!("item {}", i)))
1649                .await
1650                .unwrap();
1651        }
1652        assert_eq!(store.count().await.unwrap(), 5);
1653
1654        store.clear().await.unwrap();
1655        assert_eq!(store.count().await.unwrap(), 0);
1656    }
1657
1658    #[tokio::test]
1659    async fn test_count() {
1660        let (_dir, store) = setup().await;
1661        assert_eq!(store.count().await.unwrap(), 0);
1662
1663        store.store(sample_item("one")).await.unwrap();
1664        assert_eq!(store.count().await.unwrap(), 1);
1665
1666        store.store(sample_item("two")).await.unwrap();
1667        assert_eq!(store.count().await.unwrap(), 2);
1668    }
1669
1670    #[tokio::test]
1671    async fn test_persistence_across_instances() {
1672        let dir = TempDir::new().unwrap();
1673
1674        // Store with first instance
1675        {
1676            let store = FileMemoryStore::new(dir.path()).await.unwrap();
1677            store
1678                .store(sample_item("persistent data").with_tags(vec!["test".into()]))
1679                .await
1680                .unwrap();
1681        }
1682
1683        // Load with second instance
1684        {
1685            let store = FileMemoryStore::new(dir.path()).await.unwrap();
1686            assert_eq!(store.count().await.unwrap(), 1);
1687            let results = store.search("persistent", 10).await.unwrap();
1688            assert_eq!(results.len(), 1);
1689            assert_eq!(results[0].content, "persistent data");
1690        }
1691    }
1692
1693    #[tokio::test]
1694    async fn test_rebuild_index() {
1695        let dir = TempDir::new().unwrap();
1696
1697        // Store items
1698        {
1699            let store = FileMemoryStore::new(dir.path()).await.unwrap();
1700            store.store(sample_item("alpha")).await.unwrap();
1701            store.store(sample_item("beta")).await.unwrap();
1702        }
1703
1704        // Delete the index file to simulate corruption
1705        tokio::fs::remove_file(dir.path().join("index.json"))
1706            .await
1707            .unwrap();
1708
1709        // Rebuild
1710        {
1711            let store = FileMemoryStore::new(dir.path()).await.unwrap();
1712            // Index is empty after loading (file was deleted)
1713            assert_eq!(store.count().await.unwrap(), 0);
1714
1715            // Rebuild from item files
1716            store.rebuild_index().await.unwrap();
1717            assert_eq!(store.count().await.unwrap(), 2);
1718        }
1719    }
1720
1721    #[tokio::test]
1722    async fn test_path_traversal_prevention() {
1723        let (_dir, store) = setup().await;
1724        let mut item = sample_item("sneaky");
1725        item.id = "../../../etc/passwd".to_string();
1726
1727        store.store(item).await.unwrap();
1728
1729        // The ID should be sanitized — no path separators
1730        let results = store.search("sneaky", 10).await.unwrap();
1731        assert_eq!(results.len(), 1);
1732        assert!(!results[0].id.contains('/'));
1733        assert!(!results[0].id.contains(".."));
1734    }
1735
1736    #[tokio::test]
1737    async fn test_importance_threshold() {
1738        let (_dir, store) = setup().await;
1739        store
1740            .store(sample_item("low").with_importance(0.2))
1741            .await
1742            .unwrap();
1743        store
1744            .store(sample_item("high").with_importance(0.8))
1745            .await
1746            .unwrap();
1747
1748        let results = store.get_important(0.5, 10).await.unwrap();
1749        assert_eq!(results.len(), 1);
1750        assert_eq!(results[0].content, "high");
1751    }
1752}