Skip to main content

llm_agent_runtime/
memory.rs

1//! # Module: Memory
2//!
3//! ## Responsibility
4//! Provides episodic, semantic, and working memory stores for agents.
5//! Mirrors the public API of `tokio-agent-memory` and `tokio-memory`.
6//!
7//! ## Guarantees
8//! - Thread-safe: all stores wrap their state in `Arc<Mutex<_>>`
9//! - Bounded: WorkingMemory evicts the oldest entry when capacity is exceeded
10//! - Decaying: DecayPolicy reduces importance scores over time
11//! - Non-panicking: all operations return `Result`
12//! - Lock-poisoning resilient: a panicking thread does not permanently break a store
13//!
14//! ## NOT Responsible For
15//! - Cross-agent shared memory (see runtime.rs coordinator)
16//! - Persistence to disk or external store
17
18use crate::error::AgentRuntimeError;
19use chrono::{DateTime, Utc};
20use serde::{Deserialize, Serialize};
21use std::collections::{HashMap, VecDeque};
22use std::sync::{Arc, Mutex};
23use uuid::Uuid;
24
25// ── Lock-poisoning recovery ───────────────────────────────────────────────────
26
27/// Acquire a mutex guard, recovering from a poisoned mutex rather than
28/// propagating an error.  A panicking thread does not permanently break the
29/// store; we simply take ownership of the inner value and emit a warning.
30fn recover_lock<'a, T>(
31    result: std::sync::LockResult<std::sync::MutexGuard<'a, T>>,
32    ctx: &str,
33) -> std::sync::MutexGuard<'a, T>
34where
35    T: ?Sized,
36{
37    match result {
38        Ok(guard) => guard,
39        Err(poisoned) => {
40            tracing::warn!("mutex poisoned in {ctx}, recovering inner value");
41            poisoned.into_inner()
42        }
43    }
44}
45
46// ── Cosine similarity ─────────────────────────────────────────────────────────
47
48fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
49    if a.len() != b.len() || a.is_empty() {
50        return 0.0;
51    }
52    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
53    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
54    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
55    if norm_a == 0.0 || norm_b == 0.0 {
56        return 0.0;
57    }
58    (dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
59}
60
61// ── Newtype IDs ───────────────────────────────────────────────────────────────
62
63/// Stable identifier for an agent instance.
64#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
65pub struct AgentId(pub String);
66
67impl AgentId {
68    /// Create a new `AgentId` from any string-like value.
69    pub fn new(id: impl Into<String>) -> Self {
70        Self(id.into())
71    }
72
73    /// Generate a random `AgentId` backed by a UUID v4.
74    pub fn random() -> Self {
75        Self(Uuid::new_v4().to_string())
76    }
77}
78
79impl std::fmt::Display for AgentId {
80    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81        write!(f, "{}", self.0)
82    }
83}
84
85/// Stable identifier for a memory item.
86#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
87pub struct MemoryId(pub String);
88
89impl MemoryId {
90    /// Create a new `MemoryId` from any string-like value.
91    pub fn new(id: impl Into<String>) -> Self {
92        Self(id.into())
93    }
94
95    /// Generate a random `MemoryId` backed by a UUID v4.
96    pub fn random() -> Self {
97        Self(Uuid::new_v4().to_string())
98    }
99}
100
101impl std::fmt::Display for MemoryId {
102    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103        write!(f, "{}", self.0)
104    }
105}
106
107// ── MemoryItem ────────────────────────────────────────────────────────────────
108
109/// A single memory record stored for an agent.
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct MemoryItem {
112    /// Unique identifier for this memory.
113    pub id: MemoryId,
114    /// The agent this memory belongs to.
115    pub agent_id: AgentId,
116    /// Textual content of the memory.
117    pub content: String,
118    /// Importance score in `[0.0, 1.0]`. Higher = more important.
119    pub importance: f32,
120    /// UTC timestamp when this memory was recorded.
121    pub timestamp: DateTime<Utc>,
122    /// Searchable tags attached to this memory.
123    pub tags: Vec<String>,
124    /// Number of times this memory has been recalled. Updated in-place by `recall`.
125    #[serde(default)]
126    pub recall_count: u64,
127}
128
129impl MemoryItem {
130    /// Construct a new `MemoryItem` with the current timestamp and a random ID.
131    pub fn new(
132        agent_id: AgentId,
133        content: impl Into<String>,
134        importance: f32,
135        tags: Vec<String>,
136    ) -> Self {
137        Self {
138            id: MemoryId::random(),
139            agent_id,
140            content: content.into(),
141            importance: importance.clamp(0.0, 1.0),
142            timestamp: Utc::now(),
143            tags,
144            recall_count: 0,
145        }
146    }
147}
148
149// ── DecayPolicy ───────────────────────────────────────────────────────────────
150
151/// Governs how memory importance decays over time.
152#[derive(Debug, Clone)]
153pub struct DecayPolicy {
154    /// The half-life duration in hours. After this many hours, importance is halved.
155    half_life_hours: f64,
156}
157
158impl DecayPolicy {
159    /// Create an exponential decay policy with the given half-life in hours.
160    ///
161    /// # Arguments
162    /// * `half_life_hours` — time after which importance is halved; must be > 0
163    ///
164    /// # Returns
165    /// - `Ok(DecayPolicy)` — on success
166    /// - `Err(AgentRuntimeError::Memory)` — if `half_life_hours <= 0`
167    pub fn exponential(half_life_hours: f64) -> Result<Self, AgentRuntimeError> {
168        if half_life_hours <= 0.0 {
169            return Err(AgentRuntimeError::Memory(
170                "half_life_hours must be positive".into(),
171            ));
172        }
173        Ok(Self { half_life_hours })
174    }
175
176    /// Apply decay to an importance score based on elapsed time.
177    ///
178    /// # Arguments
179    /// * `importance` — original importance in `[0.0, 1.0]`
180    /// * `age_hours` — how many hours have passed since the memory was recorded
181    ///
182    /// # Returns
183    /// Decayed importance clamped to `[0.0, 1.0]`.
184    pub fn apply(&self, importance: f32, age_hours: f64) -> f32 {
185        let decay = (-age_hours * std::f64::consts::LN_2 / self.half_life_hours).exp();
186        (importance as f64 * decay).clamp(0.0, 1.0) as f32
187    }
188
189    /// Apply decay in-place to a mutable `MemoryItem`.
190    pub fn decay_item(&self, item: &mut MemoryItem) {
191        let age_hours = (Utc::now() - item.timestamp).num_seconds().max(0) as f64 / 3600.0;
192        item.importance = self.apply(item.importance, age_hours);
193    }
194}
195
196// ── RecallPolicy ──────────────────────────────────────────────────────────────
197
198/// Controls how memories are scored and ranked during recall.
199#[derive(Debug, Clone)]
200pub enum RecallPolicy {
201    /// Rank purely by importance score (default).
202    Importance,
203    /// Hybrid score: blends importance, recency, and recall frequency.
204    ///
205    /// `score = importance + recency_score * recency_weight + frequency_score * frequency_weight`
206    /// where `recency_score = exp(-age_hours / 24.0)` and
207    /// `frequency_score = recall_count / (max_recall_count + 1)` (normalized).
208    Hybrid {
209        /// Weight applied to the recency component of the hybrid score.
210        recency_weight: f32,
211        /// Weight applied to the recall-frequency component of the hybrid score.
212        frequency_weight: f32,
213    },
214}
215
216// ── Hybrid scoring helper ─────────────────────────────────────────────────────
217
218fn compute_hybrid_score(
219    item: &MemoryItem,
220    recency_weight: f32,
221    frequency_weight: f32,
222    max_recall: u64,
223    now: chrono::DateTime<Utc>,
224) -> f32 {
225    let age_hours = (now - item.timestamp).num_seconds().max(0) as f64 / 3600.0;
226    let recency_score = (-age_hours / 24.0).exp() as f32;
227    let frequency_score = item.recall_count as f32 / (max_recall as f32 + 1.0);
228    item.importance + recency_score * recency_weight + frequency_score * frequency_weight
229}
230
231// ── EpisodicStore ─────────────────────────────────────────────────────────────
232
233/// Stores episodic (event-based) memories for agents, ordered by insertion time.
234///
235/// ## Guarantees
236/// - Thread-safe via `Arc<Mutex<_>>`
237/// - Ordered: recall returns items in descending importance order
238/// - Bounded by optional capacity
239#[derive(Debug, Clone)]
240pub struct EpisodicStore {
241    inner: Arc<Mutex<EpisodicInner>>,
242}
243
244#[derive(Debug)]
245struct EpisodicInner {
246    items: Vec<MemoryItem>,
247    decay: Option<DecayPolicy>,
248    recall_policy: RecallPolicy,
249    /// Maximum items stored per agent. Oldest (lowest-importance) items evicted when exceeded.
250    per_agent_capacity: Option<usize>,
251}
252
253impl EpisodicStore {
254    /// Create a new unbounded episodic store without decay.
255    pub fn new() -> Self {
256        Self {
257            inner: Arc::new(Mutex::new(EpisodicInner {
258                items: Vec::new(),
259                decay: None,
260                recall_policy: RecallPolicy::Importance,
261                per_agent_capacity: None,
262            })),
263        }
264    }
265
266    /// Create a new episodic store with the given decay policy.
267    pub fn with_decay(policy: DecayPolicy) -> Self {
268        Self {
269            inner: Arc::new(Mutex::new(EpisodicInner {
270                items: Vec::new(),
271                decay: Some(policy),
272                recall_policy: RecallPolicy::Importance,
273                per_agent_capacity: None,
274            })),
275        }
276    }
277
278    /// Create a new episodic store with the given recall policy.
279    pub fn with_recall_policy(policy: RecallPolicy) -> Self {
280        Self {
281            inner: Arc::new(Mutex::new(EpisodicInner {
282                items: Vec::new(),
283                decay: None,
284                recall_policy: policy,
285                per_agent_capacity: None,
286            })),
287        }
288    }
289
290    /// Create a new episodic store with the given per-agent capacity limit.
291    ///
292    /// When an agent exceeds this capacity, the lowest-importance item for that
293    /// agent is evicted.
294    pub fn with_per_agent_capacity(capacity: usize) -> Self {
295        Self {
296            inner: Arc::new(Mutex::new(EpisodicInner {
297                items: Vec::new(),
298                decay: None,
299                recall_policy: RecallPolicy::Importance,
300                per_agent_capacity: Some(capacity),
301            })),
302        }
303    }
304
305    /// Record a new episode for the given agent.
306    ///
307    /// # Returns
308    /// The `MemoryId` of the newly created memory item.
309    #[tracing::instrument(skip(self))]
310    pub fn add_episode(
311        &self,
312        agent_id: AgentId,
313        content: impl Into<String> + std::fmt::Debug,
314        importance: f32,
315    ) -> Result<MemoryId, AgentRuntimeError> {
316        let item = MemoryItem::new(agent_id.clone(), content, importance, Vec::new());
317        let id = item.id.clone();
318        let mut inner = recover_lock(self.inner.lock(), "EpisodicStore::add_episode");
319        inner.items.push(item);
320        if let Some(cap) = inner.per_agent_capacity {
321            let agent_count = inner.items.iter().filter(|i| i.agent_id == agent_id).count();
322            if agent_count > cap {
323                if let Some(pos) = inner
324                    .items
325                    .iter()
326                    .enumerate()
327                    .filter(|(_, i)| i.agent_id == agent_id)
328                    .min_by(|(_, a), (_, b)| {
329                        a.importance
330                            .partial_cmp(&b.importance)
331                            .unwrap_or(std::cmp::Ordering::Equal)
332                    })
333                    .map(|(pos, _)| pos)
334                {
335                    inner.items.remove(pos);
336                }
337            }
338        }
339        Ok(id)
340    }
341
342    /// Add an episode with an explicit timestamp.
343    #[tracing::instrument(skip(self))]
344    pub fn add_episode_at(
345        &self,
346        agent_id: AgentId,
347        content: impl Into<String> + std::fmt::Debug,
348        importance: f32,
349        timestamp: chrono::DateTime<chrono::Utc>,
350    ) -> Result<MemoryId, AgentRuntimeError> {
351        let mut item = MemoryItem::new(agent_id.clone(), content, importance, Vec::new());
352        item.timestamp = timestamp;
353        let id = item.id.clone();
354        let mut inner = recover_lock(self.inner.lock(), "EpisodicStore::add_episode_at");
355        inner.items.push(item);
356        if let Some(cap) = inner.per_agent_capacity {
357            let agent_count = inner.items.iter().filter(|i| i.agent_id == agent_id).count();
358            if agent_count > cap {
359                if let Some(pos) = inner
360                    .items
361                    .iter()
362                    .enumerate()
363                    .filter(|(_, i)| i.agent_id == agent_id)
364                    .min_by(|(_, a), (_, b)| {
365                        a.importance
366                            .partial_cmp(&b.importance)
367                            .unwrap_or(std::cmp::Ordering::Equal)
368                    })
369                    .map(|(pos, _)| pos)
370                {
371                    inner.items.remove(pos);
372                }
373            }
374        }
375        Ok(id)
376    }
377
378    /// Recall up to `limit` memories for the given agent.
379    ///
380    /// Applies decay if configured, increments `recall_count` for each recalled
381    /// item in-place, then returns items sorted according to the configured
382    /// `RecallPolicy` (default: descending importance).
383    #[tracing::instrument(skip(self))]
384    pub fn recall(
385        &self,
386        agent_id: &AgentId,
387        limit: usize,
388    ) -> Result<Vec<MemoryItem>, AgentRuntimeError> {
389        let mut inner = recover_lock(self.inner.lock(), "EpisodicStore::recall");
390
391        // Apply decay in-place
392        let decay_clone: Option<DecayPolicy> = inner.decay.clone();
393        if let Some(policy) = decay_clone {
394            for item in inner.items.iter_mut() {
395                policy.decay_item(item);
396            }
397        }
398
399        // Collect IDs of items belonging to this agent
400        let agent_ids_to_update: Vec<MemoryId> = inner
401            .items
402            .iter()
403            .filter(|i| &i.agent_id == agent_id)
404            .map(|i| i.id.clone())
405            .collect();
406
407        // Increment recall_count in-place for all matching items
408        for item in inner.items.iter_mut() {
409            if agent_ids_to_update.contains(&item.id) {
410                item.recall_count += 1;
411            }
412        }
413
414        let mut items: Vec<MemoryItem> = inner
415            .items
416            .iter()
417            .filter(|i| &i.agent_id == agent_id)
418            .cloned()
419            .collect();
420
421        match inner.recall_policy {
422            RecallPolicy::Importance => {
423                items.sort_by(|a, b| {
424                    b.importance
425                        .partial_cmp(&a.importance)
426                        .unwrap_or(std::cmp::Ordering::Equal)
427                });
428            }
429            RecallPolicy::Hybrid {
430                recency_weight,
431                frequency_weight,
432            } => {
433                let max_recall = items.iter().map(|i| i.recall_count).max().unwrap_or(1).max(1);
434                let now = Utc::now();
435                items.sort_by(|a, b| {
436                    let score_a = compute_hybrid_score(
437                        a,
438                        recency_weight,
439                        frequency_weight,
440                        max_recall,
441                        now,
442                    );
443                    let score_b = compute_hybrid_score(
444                        b,
445                        recency_weight,
446                        frequency_weight,
447                        max_recall,
448                        now,
449                    );
450                    score_b
451                        .partial_cmp(&score_a)
452                        .unwrap_or(std::cmp::Ordering::Equal)
453                });
454            }
455        }
456
457        items.truncate(limit);
458        tracing::debug!("recalled {} items", items.len());
459        Ok(items)
460    }
461
462    /// Return the total number of stored episodes across all agents.
463    pub fn len(&self) -> Result<usize, AgentRuntimeError> {
464        let inner = recover_lock(self.inner.lock(), "EpisodicStore::len");
465        Ok(inner.items.len())
466    }
467
468    /// Return `true` if no episodes have been stored.
469    pub fn is_empty(&self) -> Result<bool, AgentRuntimeError> {
470        Ok(self.len()? == 0)
471    }
472
473    /// Bump the `recall_count` of every item whose content equals `content` by `amount`.
474    ///
475    /// This method exists to support integration tests that need to simulate prior recall
476    /// history without accessing private fields. It is not intended for production use.
477    #[doc(hidden)]
478    pub fn bump_recall_count_by_content(&self, content: &str, amount: u64) {
479        let mut inner = recover_lock(self.inner.lock(), "EpisodicStore::bump_recall_count_by_content");
480        for item in inner.items.iter_mut() {
481            if item.content == content {
482                item.recall_count = item.recall_count.saturating_add(amount);
483            }
484        }
485    }
486}
487
488impl Default for EpisodicStore {
489    fn default() -> Self {
490        Self::new()
491    }
492}
493
494// ── SemanticStore ─────────────────────────────────────────────────────────────
495
496/// Stores semantic (fact-based) knowledge as tagged key-value pairs.
497///
498/// ## Guarantees
499/// - Thread-safe via `Arc<Mutex<_>>`
500/// - Retrieval by tag intersection
501/// - Optional vector-based similarity search via stored embeddings
502#[derive(Debug, Clone)]
503pub struct SemanticStore {
504    inner: Arc<Mutex<Vec<SemanticEntry>>>,
505}
506
507#[derive(Debug, Clone)]
508struct SemanticEntry {
509    key: String,
510    value: String,
511    tags: Vec<String>,
512    embedding: Option<Vec<f32>>,
513}
514
515impl SemanticStore {
516    /// Create a new empty semantic store.
517    pub fn new() -> Self {
518        Self {
519            inner: Arc::new(Mutex::new(Vec::new())),
520        }
521    }
522
523    /// Store a key-value pair with associated tags.
524    #[tracing::instrument(skip(self))]
525    pub fn store(
526        &self,
527        key: impl Into<String> + std::fmt::Debug,
528        value: impl Into<String> + std::fmt::Debug,
529        tags: Vec<String>,
530    ) -> Result<(), AgentRuntimeError> {
531        let mut inner = recover_lock(self.inner.lock(), "SemanticStore::store");
532        inner.push(SemanticEntry {
533            key: key.into(),
534            value: value.into(),
535            tags,
536            embedding: None,
537        });
538        Ok(())
539    }
540
541    /// Store a key-value pair with an embedding vector for similarity search.
542    #[tracing::instrument(skip(self))]
543    pub fn store_with_embedding(
544        &self,
545        key: impl Into<String> + std::fmt::Debug,
546        value: impl Into<String> + std::fmt::Debug,
547        tags: Vec<String>,
548        embedding: Vec<f32>,
549    ) -> Result<(), AgentRuntimeError> {
550        let mut inner = recover_lock(self.inner.lock(), "SemanticStore::store_with_embedding");
551        inner.push(SemanticEntry {
552            key: key.into(),
553            value: value.into(),
554            tags,
555            embedding: Some(embedding),
556        });
557        Ok(())
558    }
559
560    /// Retrieve all entries that contain **all** of the given tags.
561    ///
562    /// If `tags` is empty, returns all entries.
563    #[tracing::instrument(skip(self))]
564    pub fn retrieve(&self, tags: &[&str]) -> Result<Vec<(String, String)>, AgentRuntimeError> {
565        let inner = recover_lock(self.inner.lock(), "SemanticStore::retrieve");
566
567        let results = inner
568            .iter()
569            .filter(|entry| {
570                tags.iter()
571                    .all(|t| entry.tags.iter().any(|et| et.as_str() == *t))
572            })
573            .map(|e| (e.key.clone(), e.value.clone()))
574            .collect();
575
576        Ok(results)
577    }
578
579    /// Retrieve top-k entries by cosine similarity to `query_embedding`.
580    ///
581    /// Only entries that were stored with an embedding (via [`store_with_embedding`])
582    /// are considered.  Returns `(key, value, similarity)` sorted by descending
583    /// similarity.
584    ///
585    /// [`store_with_embedding`]: SemanticStore::store_with_embedding
586    #[tracing::instrument(skip(self, query_embedding))]
587    pub fn retrieve_similar(
588        &self,
589        query_embedding: &[f32],
590        top_k: usize,
591    ) -> Result<Vec<(String, String, f32)>, AgentRuntimeError> {
592        let inner = recover_lock(self.inner.lock(), "SemanticStore::retrieve_similar");
593
594        let mut scored: Vec<(String, String, f32)> = inner
595            .iter()
596            .filter_map(|entry| {
597                entry.embedding.as_ref().map(|emb| {
598                    let sim = cosine_similarity(query_embedding, emb);
599                    (entry.key.clone(), entry.value.clone(), sim)
600                })
601            })
602            .collect();
603
604        scored.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
605        scored.truncate(top_k);
606        Ok(scored)
607    }
608
609    /// Return the total number of stored entries.
610    pub fn len(&self) -> Result<usize, AgentRuntimeError> {
611        let inner = recover_lock(self.inner.lock(), "SemanticStore::len");
612        Ok(inner.len())
613    }
614
615    /// Return `true` if no entries have been stored.
616    pub fn is_empty(&self) -> Result<bool, AgentRuntimeError> {
617        Ok(self.len()? == 0)
618    }
619}
620
621impl Default for SemanticStore {
622    fn default() -> Self {
623        Self::new()
624    }
625}
626
627// ── WorkingMemory ─────────────────────────────────────────────────────────────
628
629/// A bounded, key-value working memory for transient agent state.
630///
631/// When capacity is exceeded, the oldest entry (by insertion order) is evicted.
632///
633/// ## Guarantees
634/// - Thread-safe via `Arc<Mutex<_>>`
635/// - Bounded: never exceeds `capacity` entries
636/// - Deterministic eviction: LRU (oldest insertion first)
637#[derive(Debug, Clone)]
638pub struct WorkingMemory {
639    capacity: usize,
640    inner: Arc<Mutex<WorkingInner>>,
641}
642
643#[derive(Debug)]
644struct WorkingInner {
645    map: HashMap<String, String>,
646    order: VecDeque<String>,
647}
648
649impl WorkingMemory {
650    /// Create a new `WorkingMemory` with the given capacity.
651    ///
652    /// # Returns
653    /// - `Ok(WorkingMemory)` — on success
654    /// - `Err(AgentRuntimeError::Memory)` — if `capacity == 0`
655    pub fn new(capacity: usize) -> Result<Self, AgentRuntimeError> {
656        if capacity == 0 {
657            return Err(AgentRuntimeError::Memory(
658                "WorkingMemory capacity must be > 0".into(),
659            ));
660        }
661        Ok(Self {
662            capacity,
663            inner: Arc::new(Mutex::new(WorkingInner {
664                map: HashMap::new(),
665                order: VecDeque::new(),
666            })),
667        })
668    }
669
670    /// Insert or update a key-value pair, evicting the oldest entry if over capacity.
671    #[tracing::instrument(skip(self))]
672    pub fn set(
673        &self,
674        key: impl Into<String> + std::fmt::Debug,
675        value: impl Into<String> + std::fmt::Debug,
676    ) -> Result<(), AgentRuntimeError> {
677        let key = key.into();
678        let value = value.into();
679        let mut inner = recover_lock(self.inner.lock(), "WorkingMemory::set");
680
681        // Remove existing key from order tracking if present
682        if inner.map.contains_key(&key) {
683            inner.order.retain(|k| k != &key);
684        } else if inner.map.len() >= self.capacity {
685            // Evict oldest
686            if let Some(oldest) = inner.order.pop_front() {
687                inner.map.remove(&oldest);
688            }
689        }
690
691        inner.order.push_back(key.clone());
692        inner.map.insert(key, value);
693        Ok(())
694    }
695
696    /// Retrieve a value by key.
697    ///
698    /// # Returns
699    /// - `Some(value)` — if the key exists
700    /// - `None` — if not found
701    #[tracing::instrument(skip(self))]
702    pub fn get(&self, key: &str) -> Result<Option<String>, AgentRuntimeError> {
703        let inner = recover_lock(self.inner.lock(), "WorkingMemory::get");
704        Ok(inner.map.get(key).cloned())
705    }
706
707    /// Remove all entries from working memory.
708    pub fn clear(&self) -> Result<(), AgentRuntimeError> {
709        let mut inner = recover_lock(self.inner.lock(), "WorkingMemory::clear");
710        inner.map.clear();
711        inner.order.clear();
712        Ok(())
713    }
714
715    /// Return the current number of entries.
716    pub fn len(&self) -> Result<usize, AgentRuntimeError> {
717        let inner = recover_lock(self.inner.lock(), "WorkingMemory::len");
718        Ok(inner.map.len())
719    }
720
721    /// Return `true` if no entries are stored.
722    pub fn is_empty(&self) -> Result<bool, AgentRuntimeError> {
723        Ok(self.len()? == 0)
724    }
725
726    /// Return all key-value pairs in insertion order.
727    pub fn entries(&self) -> Result<Vec<(String, String)>, AgentRuntimeError> {
728        let inner = recover_lock(self.inner.lock(), "WorkingMemory::entries");
729        let entries = inner
730            .order
731            .iter()
732            .filter_map(|k| inner.map.get(k).map(|v| (k.clone(), v.clone())))
733            .collect();
734        Ok(entries)
735    }
736}
737
738// ── Tests ─────────────────────────────────────────────────────────────────────
739
740#[cfg(test)]
741mod tests {
742    use super::*;
743
744    // ── AgentId / MemoryId ────────────────────────────────────────────────────
745
746    #[test]
747    fn test_agent_id_new_stores_string() {
748        let id = AgentId::new("agent-1");
749        assert_eq!(id.0, "agent-1");
750    }
751
752    #[test]
753    fn test_agent_id_random_is_unique() {
754        let a = AgentId::random();
755        let b = AgentId::random();
756        assert_ne!(a, b);
757    }
758
759    #[test]
760    fn test_memory_id_new_stores_string() {
761        let id = MemoryId::new("mem-1");
762        assert_eq!(id.0, "mem-1");
763    }
764
765    #[test]
766    fn test_memory_id_random_is_unique() {
767        let a = MemoryId::random();
768        let b = MemoryId::random();
769        assert_ne!(a, b);
770    }
771
772    // ── MemoryItem ────────────────────────────────────────────────────────────
773
774    #[test]
775    fn test_memory_item_new_clamps_importance_above_one() {
776        let item = MemoryItem::new(AgentId::new("a"), "test", 1.5, vec![]);
777        assert_eq!(item.importance, 1.0);
778    }
779
780    #[test]
781    fn test_memory_item_new_clamps_importance_below_zero() {
782        let item = MemoryItem::new(AgentId::new("a"), "test", -0.5, vec![]);
783        assert_eq!(item.importance, 0.0);
784    }
785
786    #[test]
787    fn test_memory_item_new_preserves_valid_importance() {
788        let item = MemoryItem::new(AgentId::new("a"), "test", 0.7, vec![]);
789        assert!((item.importance - 0.7).abs() < 1e-6);
790    }
791
792    // ── DecayPolicy ───────────────────────────────────────────────────────────
793
794    #[test]
795    fn test_decay_policy_rejects_zero_half_life() {
796        assert!(DecayPolicy::exponential(0.0).is_err());
797    }
798
799    #[test]
800    fn test_decay_policy_rejects_negative_half_life() {
801        assert!(DecayPolicy::exponential(-1.0).is_err());
802    }
803
804    #[test]
805    fn test_decay_policy_no_decay_at_age_zero() {
806        let p = DecayPolicy::exponential(24.0).unwrap();
807        let decayed = p.apply(1.0, 0.0);
808        assert!((decayed - 1.0).abs() < 1e-5);
809    }
810
811    #[test]
812    fn test_decay_policy_half_importance_at_half_life() {
813        let p = DecayPolicy::exponential(24.0).unwrap();
814        let decayed = p.apply(1.0, 24.0);
815        assert!((decayed - 0.5).abs() < 1e-5);
816    }
817
818    #[test]
819    fn test_decay_policy_quarter_importance_at_two_half_lives() {
820        let p = DecayPolicy::exponential(24.0).unwrap();
821        let decayed = p.apply(1.0, 48.0);
822        assert!((decayed - 0.25).abs() < 1e-5);
823    }
824
825    #[test]
826    fn test_decay_policy_result_is_clamped_to_zero_one() {
827        let p = DecayPolicy::exponential(1.0).unwrap();
828        let decayed = p.apply(0.0, 1000.0);
829        assert!(decayed >= 0.0 && decayed <= 1.0);
830    }
831
832    // ── EpisodicStore ─────────────────────────────────────────────────────────
833
834    #[test]
835    fn test_episodic_store_add_episode_returns_id() {
836        let store = EpisodicStore::new();
837        let id = store.add_episode(AgentId::new("a"), "event", 0.8).unwrap();
838        assert!(!id.0.is_empty());
839    }
840
841    #[test]
842    fn test_episodic_store_recall_returns_stored_item() {
843        let store = EpisodicStore::new();
844        let agent = AgentId::new("agent-1");
845        store
846            .add_episode(agent.clone(), "hello world", 0.9)
847            .unwrap();
848        let items = store.recall(&agent, 10).unwrap();
849        assert_eq!(items.len(), 1);
850        assert_eq!(items[0].content, "hello world");
851    }
852
853    #[test]
854    fn test_episodic_store_recall_filters_by_agent() {
855        let store = EpisodicStore::new();
856        let a = AgentId::new("agent-a");
857        let b = AgentId::new("agent-b");
858        store.add_episode(a.clone(), "for a", 0.5).unwrap();
859        store.add_episode(b.clone(), "for b", 0.5).unwrap();
860        let items = store.recall(&a, 10).unwrap();
861        assert_eq!(items.len(), 1);
862        assert_eq!(items[0].content, "for a");
863    }
864
865    #[test]
866    fn test_episodic_store_recall_sorted_by_descending_importance() {
867        let store = EpisodicStore::new();
868        let agent = AgentId::new("agent-1");
869        store.add_episode(agent.clone(), "low", 0.1).unwrap();
870        store.add_episode(agent.clone(), "high", 0.9).unwrap();
871        store.add_episode(agent.clone(), "mid", 0.5).unwrap();
872        let items = store.recall(&agent, 10).unwrap();
873        assert_eq!(items[0].content, "high");
874        assert_eq!(items[1].content, "mid");
875        assert_eq!(items[2].content, "low");
876    }
877
878    #[test]
879    fn test_episodic_store_recall_respects_limit() {
880        let store = EpisodicStore::new();
881        let agent = AgentId::new("agent-1");
882        for i in 0..5 {
883            store
884                .add_episode(agent.clone(), format!("item {i}"), 0.5)
885                .unwrap();
886        }
887        let items = store.recall(&agent, 3).unwrap();
888        assert_eq!(items.len(), 3);
889    }
890
891    #[test]
892    fn test_episodic_store_len_tracks_insertions() {
893        let store = EpisodicStore::new();
894        let agent = AgentId::new("a");
895        store.add_episode(agent.clone(), "a", 0.5).unwrap();
896        store.add_episode(agent.clone(), "b", 0.5).unwrap();
897        assert_eq!(store.len().unwrap(), 2);
898    }
899
900    #[test]
901    fn test_episodic_store_is_empty_initially() {
902        let store = EpisodicStore::new();
903        assert!(store.is_empty().unwrap());
904    }
905
906    #[test]
907    fn test_episodic_store_with_decay_reduces_importance() {
908        let policy = DecayPolicy::exponential(0.001).unwrap(); // very fast decay
909        let store = EpisodicStore::with_decay(policy);
910        let agent = AgentId::new("a");
911
912        // Manually insert an old item by directly manipulating timestamps
913        {
914            let mut inner = store.inner.lock().unwrap();
915            let mut item = MemoryItem::new(agent.clone(), "old event", 1.0, vec![]);
916            // Set the timestamp to 1 hour ago
917            item.timestamp = Utc::now() - chrono::Duration::hours(1);
918            inner.items.push(item);
919        }
920
921        let items = store.recall(&agent, 10).unwrap();
922        // With half_life=0.001h and age=1h, importance should be near 0
923        assert_eq!(items.len(), 1);
924        assert!(
925            items[0].importance < 0.01,
926            "expected near-zero importance, got {}",
927            items[0].importance
928        );
929    }
930
931    // ── Item 10: RecallPolicy / per-agent capacity tests ──────────────────────
932
933    #[test]
934    fn test_recall_increments_recall_count() {
935        let store = EpisodicStore::new();
936        let agent = AgentId::new("agent-rc");
937        store.add_episode(agent.clone(), "memory", 0.5).unwrap();
938
939        // First recall — count becomes 1
940        let items = store.recall(&agent, 10).unwrap();
941        assert_eq!(items[0].recall_count, 1);
942
943        // Second recall — count becomes 2
944        let items = store.recall(&agent, 10).unwrap();
945        assert_eq!(items[0].recall_count, 2);
946    }
947
948    #[test]
949    fn test_hybrid_recall_policy_prefers_recently_used() {
950        // "old_frequent": added 48 h ago with importance 0.5, recall_count bumped
951        // manually to simulate frequent use.
952        // "new_never": added just now with importance 0.5, never recalled.
953        // With a large frequency_weight the frequently-recalled item should rank higher.
954        let store = EpisodicStore::with_recall_policy(RecallPolicy::Hybrid {
955            recency_weight: 0.1,
956            frequency_weight: 2.0,
957        });
958        let agent = AgentId::new("agent-hybrid");
959
960        let old_ts = Utc::now() - chrono::Duration::hours(48);
961        store
962            .add_episode_at(agent.clone(), "old_frequent", 0.5, old_ts)
963            .unwrap();
964        store
965            .add_episode(agent.clone(), "new_never", 0.5)
966            .unwrap();
967
968        // Simulate many prior recalls of "old_frequent" by manually bumping its count
969        {
970            let mut inner = store.inner.lock().unwrap();
971            for item in inner.items.iter_mut() {
972                if item.content == "old_frequent" {
973                    item.recall_count = 100;
974                }
975            }
976        }
977
978        let items = store.recall(&agent, 10).unwrap();
979        assert_eq!(items.len(), 2);
980        assert_eq!(
981            items[0].content, "old_frequent",
982            "hybrid policy should rank the frequently-recalled item first"
983        );
984    }
985
986    #[test]
987    fn test_per_agent_capacity_evicts_lowest_importance() {
988        let store = EpisodicStore::with_per_agent_capacity(2);
989        let agent = AgentId::new("agent-cap");
990
991        store.add_episode(agent.clone(), "mid", 0.5).unwrap();
992        store.add_episode(agent.clone(), "high", 0.9).unwrap();
993        // Adding "low" (0.1) should trigger eviction of the lowest-importance item
994        store.add_episode(agent.clone(), "low", 0.1).unwrap();
995
996        assert_eq!(
997            store.len().unwrap(),
998            2,
999            "store should hold exactly 2 items after eviction"
1000        );
1001
1002        let items = store.recall(&agent, 10).unwrap();
1003        let contents: Vec<&str> = items.iter().map(|i| i.content.as_str()).collect();
1004        assert!(
1005            !contents.contains(&"low"),
1006            "the lowest-importance item should have been evicted; remaining: {:?}",
1007            contents
1008        );
1009    }
1010
1011    // ── SemanticStore ─────────────────────────────────────────────────────────
1012
1013    #[test]
1014    fn test_semantic_store_store_and_retrieve_all() {
1015        let store = SemanticStore::new();
1016        store.store("key1", "value1", vec!["tag-a".into()]).unwrap();
1017        store.store("key2", "value2", vec!["tag-b".into()]).unwrap();
1018        let results = store.retrieve(&[]).unwrap();
1019        assert_eq!(results.len(), 2);
1020    }
1021
1022    #[test]
1023    fn test_semantic_store_retrieve_filters_by_tag() {
1024        let store = SemanticStore::new();
1025        store
1026            .store("k1", "v1", vec!["rust".into(), "async".into()])
1027            .unwrap();
1028        store.store("k2", "v2", vec!["rust".into()]).unwrap();
1029        let results = store.retrieve(&["async"]).unwrap();
1030        assert_eq!(results.len(), 1);
1031        assert_eq!(results[0].0, "k1");
1032    }
1033
1034    #[test]
1035    fn test_semantic_store_retrieve_requires_all_tags() {
1036        let store = SemanticStore::new();
1037        store
1038            .store("k1", "v1", vec!["a".into(), "b".into()])
1039            .unwrap();
1040        store.store("k2", "v2", vec!["a".into()]).unwrap();
1041        let results = store.retrieve(&["a", "b"]).unwrap();
1042        assert_eq!(results.len(), 1);
1043    }
1044
1045    #[test]
1046    fn test_semantic_store_is_empty_initially() {
1047        let store = SemanticStore::new();
1048        assert!(store.is_empty().unwrap());
1049    }
1050
1051    #[test]
1052    fn test_semantic_store_len_tracks_insertions() {
1053        let store = SemanticStore::new();
1054        store.store("k", "v", vec![]).unwrap();
1055        assert_eq!(store.len().unwrap(), 1);
1056    }
1057
1058    #[test]
1059    fn test_semantic_store_retrieve_similar_returns_closest() {
1060        let store = SemanticStore::new();
1061        // "close" is in the same direction as the query
1062        store
1063            .store_with_embedding("close", "close value", vec![], vec![1.0, 0.0, 0.0])
1064            .unwrap();
1065        // "far" is orthogonal to the query
1066        store
1067            .store_with_embedding("far", "far value", vec![], vec![0.0, 1.0, 0.0])
1068            .unwrap();
1069
1070        let query = vec![1.0, 0.0, 0.0];
1071        let results = store.retrieve_similar(&query, 2).unwrap();
1072        assert_eq!(results.len(), 2);
1073        // The closest result should be "close"
1074        assert_eq!(results[0].0, "close");
1075        assert!((results[0].2 - 1.0).abs() < 1e-5, "expected similarity ~1.0, got {}", results[0].2);
1076        // The far result should have similarity 0.0
1077        assert!((results[1].2).abs() < 1e-5, "expected similarity ~0.0, got {}", results[1].2);
1078    }
1079
1080    #[test]
1081    fn test_semantic_store_retrieve_similar_ignores_unembedded_entries() {
1082        let store = SemanticStore::new();
1083        // This entry has no embedding — must not appear in similarity results
1084        store.store("no-emb", "no embedding value", vec![]).unwrap();
1085        // This entry has an embedding
1086        store
1087            .store_with_embedding("with-emb", "with embedding value", vec![], vec![1.0, 0.0])
1088            .unwrap();
1089
1090        let query = vec![1.0, 0.0];
1091        let results = store.retrieve_similar(&query, 10).unwrap();
1092        assert_eq!(results.len(), 1, "only the embedded entry should appear");
1093        assert_eq!(results[0].0, "with-emb");
1094    }
1095
1096    #[test]
1097    fn test_cosine_similarity_orthogonal_vectors_return_zero() {
1098        // Exercise cosine_similarity through retrieve_similar with orthogonal vectors
1099        let store = SemanticStore::new();
1100        store
1101            .store_with_embedding("a", "va", vec![], vec![1.0, 0.0])
1102            .unwrap();
1103        store
1104            .store_with_embedding("b", "vb", vec![], vec![0.0, 1.0])
1105            .unwrap();
1106
1107        // query is along [1, 0]; "b" is orthogonal → similarity should be 0
1108        let query = vec![1.0, 0.0];
1109        let results = store.retrieve_similar(&query, 2).unwrap();
1110        assert_eq!(results.len(), 2);
1111        let b_result = results.iter().find(|(k, _, _)| k == "b").unwrap();
1112        assert!(
1113            b_result.2.abs() < 1e-5,
1114            "expected cosine similarity 0.0 for orthogonal vectors, got {}",
1115            b_result.2
1116        );
1117    }
1118
1119    // ── WorkingMemory ─────────────────────────────────────────────────────────
1120
1121    #[test]
1122    fn test_working_memory_new_rejects_zero_capacity() {
1123        assert!(WorkingMemory::new(0).is_err());
1124    }
1125
1126    #[test]
1127    fn test_working_memory_set_and_get() {
1128        let wm = WorkingMemory::new(10).unwrap();
1129        wm.set("foo", "bar").unwrap();
1130        let val = wm.get("foo").unwrap();
1131        assert_eq!(val, Some("bar".into()));
1132    }
1133
1134    #[test]
1135    fn test_working_memory_get_missing_key_returns_none() {
1136        let wm = WorkingMemory::new(10).unwrap();
1137        assert_eq!(wm.get("missing").unwrap(), None);
1138    }
1139
1140    #[test]
1141    fn test_working_memory_bounded_evicts_oldest() {
1142        let wm = WorkingMemory::new(3).unwrap();
1143        wm.set("k1", "v1").unwrap();
1144        wm.set("k2", "v2").unwrap();
1145        wm.set("k3", "v3").unwrap();
1146        wm.set("k4", "v4").unwrap(); // k1 should be evicted
1147        assert_eq!(wm.get("k1").unwrap(), None);
1148        assert_eq!(wm.get("k4").unwrap(), Some("v4".into()));
1149    }
1150
1151    #[test]
1152    fn test_working_memory_update_existing_key_no_eviction() {
1153        let wm = WorkingMemory::new(2).unwrap();
1154        wm.set("k1", "v1").unwrap();
1155        wm.set("k2", "v2").unwrap();
1156        wm.set("k1", "v1-updated").unwrap(); // update, not eviction
1157        assert_eq!(wm.len().unwrap(), 2);
1158        assert_eq!(wm.get("k1").unwrap(), Some("v1-updated".into()));
1159        assert_eq!(wm.get("k2").unwrap(), Some("v2".into()));
1160    }
1161
1162    #[test]
1163    fn test_working_memory_clear_removes_all() {
1164        let wm = WorkingMemory::new(10).unwrap();
1165        wm.set("a", "1").unwrap();
1166        wm.set("b", "2").unwrap();
1167        wm.clear().unwrap();
1168        assert!(wm.is_empty().unwrap());
1169    }
1170
1171    #[test]
1172    fn test_working_memory_is_empty_initially() {
1173        let wm = WorkingMemory::new(5).unwrap();
1174        assert!(wm.is_empty().unwrap());
1175    }
1176
1177    #[test]
1178    fn test_working_memory_len_tracks_entries() {
1179        let wm = WorkingMemory::new(10).unwrap();
1180        wm.set("a", "1").unwrap();
1181        wm.set("b", "2").unwrap();
1182        assert_eq!(wm.len().unwrap(), 2);
1183    }
1184
1185    #[test]
1186    fn test_working_memory_capacity_never_exceeded() {
1187        let cap = 5usize;
1188        let wm = WorkingMemory::new(cap).unwrap();
1189        for i in 0..20 {
1190            wm.set(format!("key-{i}"), format!("val-{i}")).unwrap();
1191            assert!(wm.len().unwrap() <= cap);
1192        }
1193    }
1194}