Skip to main content

llm_agent_runtime/
memory.rs

1//! # Module: Memory
2//!
3//! ## Responsibility
4//! Provides episodic, semantic, and working memory stores for agents.
5//! Mirrors the public API of `tokio-agent-memory` and `tokio-memory`.
6//!
7//! ## Guarantees
8//! - Thread-safe: all stores wrap their state in `Arc<Mutex<_>>`
9//! - Bounded: WorkingMemory evicts the oldest entry when capacity is exceeded
10//! - Decaying: DecayPolicy reduces importance scores over time
11//! - Non-panicking: all operations return `Result`
12//! - Lock-poisoning resilient: a panicking thread does not permanently break a store
13//!
14//! ## NOT Responsible For
15//! - Cross-agent shared memory (see runtime.rs coordinator)
16//! - Persistence to disk or external store
17
18use crate::error::AgentRuntimeError;
19use chrono::{DateTime, Utc};
20use serde::{Deserialize, Serialize};
21use std::collections::{HashMap, VecDeque};
22use std::sync::{Arc, Mutex};
23use uuid::Uuid;
24
25// ── Lock-poisoning recovery ───────────────────────────────────────────────────
26
27/// Acquire a mutex guard, recovering from a poisoned mutex rather than
28/// propagating an error.  A panicking thread does not permanently break the
29/// store; we simply take ownership of the inner value and emit a warning.
30fn recover_lock<'a, T>(
31    result: std::sync::LockResult<std::sync::MutexGuard<'a, T>>,
32    ctx: &str,
33) -> std::sync::MutexGuard<'a, T>
34where
35    T: ?Sized,
36{
37    match result {
38        Ok(guard) => guard,
39        Err(poisoned) => {
40            tracing::warn!("mutex poisoned in {ctx}, recovering inner value");
41            poisoned.into_inner()
42        }
43    }
44}
45
46// ── Cosine similarity ─────────────────────────────────────────────────────────
47
48fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
49    if a.len() != b.len() || a.is_empty() {
50        return 0.0;
51    }
52    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
53    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
54    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
55    if norm_a == 0.0 || norm_b == 0.0 {
56        return 0.0;
57    }
58    (dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
59}
60
61// ── Newtype IDs ───────────────────────────────────────────────────────────────
62
63/// Stable identifier for an agent instance.
64#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
65pub struct AgentId(pub String);
66
67impl AgentId {
68    /// Create a new `AgentId` from any string-like value.
69    pub fn new(id: impl Into<String>) -> Self {
70        Self(id.into())
71    }
72
73    /// Generate a random `AgentId` backed by a UUID v4.
74    pub fn random() -> Self {
75        Self(Uuid::new_v4().to_string())
76    }
77}
78
79impl std::fmt::Display for AgentId {
80    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81        write!(f, "{}", self.0)
82    }
83}
84
85/// Stable identifier for a memory item.
86#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
87pub struct MemoryId(pub String);
88
89impl MemoryId {
90    /// Create a new `MemoryId` from any string-like value.
91    pub fn new(id: impl Into<String>) -> Self {
92        Self(id.into())
93    }
94
95    /// Generate a random `MemoryId` backed by a UUID v4.
96    pub fn random() -> Self {
97        Self(Uuid::new_v4().to_string())
98    }
99}
100
101impl std::fmt::Display for MemoryId {
102    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103        write!(f, "{}", self.0)
104    }
105}
106
107// ── MemoryItem ────────────────────────────────────────────────────────────────
108
109/// A single memory record stored for an agent.
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct MemoryItem {
112    /// Unique identifier for this memory.
113    pub id: MemoryId,
114    /// The agent this memory belongs to.
115    pub agent_id: AgentId,
116    /// Textual content of the memory.
117    pub content: String,
118    /// Importance score in `[0.0, 1.0]`. Higher = more important.
119    pub importance: f32,
120    /// UTC timestamp when this memory was recorded.
121    pub timestamp: DateTime<Utc>,
122    /// Searchable tags attached to this memory.
123    pub tags: Vec<String>,
124    /// Number of times this memory has been recalled. Updated in-place by `recall`.
125    #[serde(default)]
126    pub recall_count: u64,
127}
128
129impl MemoryItem {
130    /// Construct a new `MemoryItem` with the current timestamp and a random ID.
131    pub fn new(
132        agent_id: AgentId,
133        content: impl Into<String>,
134        importance: f32,
135        tags: Vec<String>,
136    ) -> Self {
137        Self {
138            id: MemoryId::random(),
139            agent_id,
140            content: content.into(),
141            importance: importance.clamp(0.0, 1.0),
142            timestamp: Utc::now(),
143            tags,
144            recall_count: 0,
145        }
146    }
147}
148
149// ── DecayPolicy ───────────────────────────────────────────────────────────────
150
151/// Governs how memory importance decays over time.
152#[derive(Debug, Clone)]
153pub struct DecayPolicy {
154    /// The half-life duration in hours. After this many hours, importance is halved.
155    half_life_hours: f64,
156}
157
158impl DecayPolicy {
159    /// Create an exponential decay policy with the given half-life in hours.
160    ///
161    /// # Arguments
162    /// * `half_life_hours` — time after which importance is halved; must be > 0
163    ///
164    /// # Returns
165    /// - `Ok(DecayPolicy)` — on success
166    /// - `Err(AgentRuntimeError::Memory)` — if `half_life_hours <= 0`
167    pub fn exponential(half_life_hours: f64) -> Result<Self, AgentRuntimeError> {
168        if half_life_hours <= 0.0 {
169            return Err(AgentRuntimeError::Memory(
170                "half_life_hours must be positive".into(),
171            ));
172        }
173        Ok(Self { half_life_hours })
174    }
175
176    /// Apply decay to an importance score based on elapsed time.
177    ///
178    /// # Arguments
179    /// * `importance` — original importance in `[0.0, 1.0]`
180    /// * `age_hours` — how many hours have passed since the memory was recorded
181    ///
182    /// # Returns
183    /// Decayed importance clamped to `[0.0, 1.0]`.
184    pub fn apply(&self, importance: f32, age_hours: f64) -> f32 {
185        let decay = (-age_hours * std::f64::consts::LN_2 / self.half_life_hours).exp();
186        (importance as f64 * decay).clamp(0.0, 1.0) as f32
187    }
188
189    /// Apply decay in-place to a mutable `MemoryItem`.
190    pub fn decay_item(&self, item: &mut MemoryItem) {
191        let age_hours = (Utc::now() - item.timestamp).num_seconds().max(0) as f64 / 3600.0;
192        item.importance = self.apply(item.importance, age_hours);
193    }
194}
195
196// ── RecallPolicy ──────────────────────────────────────────────────────────────
197
198/// Controls how memories are scored and ranked during recall.
199#[derive(Debug, Clone)]
200pub enum RecallPolicy {
201    /// Rank purely by importance score (default).
202    Importance,
203    /// Hybrid score: blends importance, recency, and recall frequency.
204    ///
205    /// `score = importance + recency_score * recency_weight + frequency_score * frequency_weight`
206    /// where `recency_score = exp(-age_hours / 24.0)` and
207    /// `frequency_score = recall_count / (max_recall_count + 1)` (normalized).
208    Hybrid {
209        /// Weight applied to the recency component of the hybrid score.
210        recency_weight: f32,
211        /// Weight applied to the recall-frequency component of the hybrid score.
212        frequency_weight: f32,
213    },
214}
215
216// ── Hybrid scoring helper ─────────────────────────────────────────────────────
217
218fn compute_hybrid_score(
219    item: &MemoryItem,
220    recency_weight: f32,
221    frequency_weight: f32,
222    max_recall: u64,
223    now: chrono::DateTime<Utc>,
224) -> f32 {
225    let age_hours = (now - item.timestamp).num_seconds().max(0) as f64 / 3600.0;
226    let recency_score = (-age_hours / 24.0).exp() as f32;
227    let frequency_score = item.recall_count as f32 / (max_recall as f32 + 1.0);
228    item.importance + recency_score * recency_weight + frequency_score * frequency_weight
229}
230
231// ── EpisodicStore ─────────────────────────────────────────────────────────────
232
233/// Stores episodic (event-based) memories for agents, ordered by insertion time.
234///
235/// ## Guarantees
236/// - Thread-safe via `Arc<Mutex<_>>`
237/// - Ordered: recall returns items in descending importance order
238/// - Bounded by optional capacity
239#[derive(Debug, Clone)]
240pub struct EpisodicStore {
241    inner: Arc<Mutex<EpisodicInner>>,
242}
243
244#[derive(Debug)]
245struct EpisodicInner {
246    items: Vec<MemoryItem>,
247    decay: Option<DecayPolicy>,
248    recall_policy: RecallPolicy,
249    /// Maximum items stored per agent. Oldest (lowest-importance) items evicted when exceeded.
250    per_agent_capacity: Option<usize>,
251}
252
253impl EpisodicStore {
254    /// Create a new unbounded episodic store without decay.
255    pub fn new() -> Self {
256        Self {
257            inner: Arc::new(Mutex::new(EpisodicInner {
258                items: Vec::new(),
259                decay: None,
260                recall_policy: RecallPolicy::Importance,
261                per_agent_capacity: None,
262            })),
263        }
264    }
265
266    /// Create a new episodic store with the given decay policy.
267    pub fn with_decay(policy: DecayPolicy) -> Self {
268        Self {
269            inner: Arc::new(Mutex::new(EpisodicInner {
270                items: Vec::new(),
271                decay: Some(policy),
272                recall_policy: RecallPolicy::Importance,
273                per_agent_capacity: None,
274            })),
275        }
276    }
277
278    /// Create a new episodic store with the given recall policy.
279    pub fn with_recall_policy(policy: RecallPolicy) -> Self {
280        Self {
281            inner: Arc::new(Mutex::new(EpisodicInner {
282                items: Vec::new(),
283                decay: None,
284                recall_policy: policy,
285                per_agent_capacity: None,
286            })),
287        }
288    }
289
290    /// Create a new episodic store with the given per-agent capacity limit.
291    ///
292    /// When an agent exceeds this capacity, the lowest-importance item for that
293    /// agent is evicted.
294    pub fn with_per_agent_capacity(capacity: usize) -> Self {
295        Self {
296            inner: Arc::new(Mutex::new(EpisodicInner {
297                items: Vec::new(),
298                decay: None,
299                recall_policy: RecallPolicy::Importance,
300                per_agent_capacity: Some(capacity),
301            })),
302        }
303    }
304
305    /// Record a new episode for the given agent.
306    ///
307    /// # Returns
308    /// The `MemoryId` of the newly created memory item.
309    #[tracing::instrument(skip(self))]
310    pub fn add_episode(
311        &self,
312        agent_id: AgentId,
313        content: impl Into<String> + std::fmt::Debug,
314        importance: f32,
315    ) -> Result<MemoryId, AgentRuntimeError> {
316        let item = MemoryItem::new(agent_id.clone(), content, importance, Vec::new());
317        let id = item.id.clone();
318        let mut inner = recover_lock(self.inner.lock(), "EpisodicStore::add_episode");
319        inner.items.push(item);
320        if let Some(cap) = inner.per_agent_capacity {
321            let agent_count = inner
322                .items
323                .iter()
324                .filter(|i| i.agent_id == agent_id)
325                .count();
326            if agent_count > cap {
327                if let Some(pos) = inner
328                    .items
329                    .iter()
330                    .enumerate()
331                    .filter(|(_, i)| i.agent_id == agent_id)
332                    .min_by(|(_, a), (_, b)| {
333                        a.importance
334                            .partial_cmp(&b.importance)
335                            .unwrap_or(std::cmp::Ordering::Equal)
336                    })
337                    .map(|(pos, _)| pos)
338                {
339                    inner.items.remove(pos);
340                }
341            }
342        }
343        Ok(id)
344    }
345
346    /// Add an episode with an explicit timestamp.
347    #[tracing::instrument(skip(self))]
348    pub fn add_episode_at(
349        &self,
350        agent_id: AgentId,
351        content: impl Into<String> + std::fmt::Debug,
352        importance: f32,
353        timestamp: chrono::DateTime<chrono::Utc>,
354    ) -> Result<MemoryId, AgentRuntimeError> {
355        let mut item = MemoryItem::new(agent_id.clone(), content, importance, Vec::new());
356        item.timestamp = timestamp;
357        let id = item.id.clone();
358        let mut inner = recover_lock(self.inner.lock(), "EpisodicStore::add_episode_at");
359        inner.items.push(item);
360        if let Some(cap) = inner.per_agent_capacity {
361            let agent_count = inner
362                .items
363                .iter()
364                .filter(|i| i.agent_id == agent_id)
365                .count();
366            if agent_count > cap {
367                if let Some(pos) = inner
368                    .items
369                    .iter()
370                    .enumerate()
371                    .filter(|(_, i)| i.agent_id == agent_id)
372                    .min_by(|(_, a), (_, b)| {
373                        a.importance
374                            .partial_cmp(&b.importance)
375                            .unwrap_or(std::cmp::Ordering::Equal)
376                    })
377                    .map(|(pos, _)| pos)
378                {
379                    inner.items.remove(pos);
380                }
381            }
382        }
383        Ok(id)
384    }
385
386    /// Recall up to `limit` memories for the given agent.
387    ///
388    /// Applies decay if configured, increments `recall_count` for each recalled
389    /// item in-place, then returns items sorted according to the configured
390    /// `RecallPolicy` (default: descending importance).
391    #[tracing::instrument(skip(self))]
392    pub fn recall(
393        &self,
394        agent_id: &AgentId,
395        limit: usize,
396    ) -> Result<Vec<MemoryItem>, AgentRuntimeError> {
397        let mut inner = recover_lock(self.inner.lock(), "EpisodicStore::recall");
398
399        // Apply decay in-place
400        let decay_clone: Option<DecayPolicy> = inner.decay.clone();
401        if let Some(policy) = decay_clone {
402            for item in inner.items.iter_mut() {
403                policy.decay_item(item);
404            }
405        }
406
407        // Collect IDs of items belonging to this agent
408        let agent_ids_to_update: Vec<MemoryId> = inner
409            .items
410            .iter()
411            .filter(|i| &i.agent_id == agent_id)
412            .map(|i| i.id.clone())
413            .collect();
414
415        // Increment recall_count in-place for all matching items
416        for item in inner.items.iter_mut() {
417            if agent_ids_to_update.contains(&item.id) {
418                item.recall_count += 1;
419            }
420        }
421
422        let mut items: Vec<MemoryItem> = inner
423            .items
424            .iter()
425            .filter(|i| &i.agent_id == agent_id)
426            .cloned()
427            .collect();
428
429        match inner.recall_policy {
430            RecallPolicy::Importance => {
431                items.sort_by(|a, b| {
432                    b.importance
433                        .partial_cmp(&a.importance)
434                        .unwrap_or(std::cmp::Ordering::Equal)
435                });
436            }
437            RecallPolicy::Hybrid {
438                recency_weight,
439                frequency_weight,
440            } => {
441                let max_recall = items
442                    .iter()
443                    .map(|i| i.recall_count)
444                    .max()
445                    .unwrap_or(1)
446                    .max(1);
447                let now = Utc::now();
448                items.sort_by(|a, b| {
449                    let score_a =
450                        compute_hybrid_score(a, recency_weight, frequency_weight, max_recall, now);
451                    let score_b =
452                        compute_hybrid_score(b, recency_weight, frequency_weight, max_recall, now);
453                    score_b
454                        .partial_cmp(&score_a)
455                        .unwrap_or(std::cmp::Ordering::Equal)
456                });
457            }
458        }
459
460        items.truncate(limit);
461        tracing::debug!("recalled {} items", items.len());
462        Ok(items)
463    }
464
465    /// Return the total number of stored episodes across all agents.
466    pub fn len(&self) -> Result<usize, AgentRuntimeError> {
467        let inner = recover_lock(self.inner.lock(), "EpisodicStore::len");
468        Ok(inner.items.len())
469    }
470
471    /// Return `true` if no episodes have been stored.
472    pub fn is_empty(&self) -> Result<bool, AgentRuntimeError> {
473        Ok(self.len()? == 0)
474    }
475
476    /// Bump the `recall_count` of every item whose content equals `content` by `amount`.
477    ///
478    /// This method exists to support integration tests that need to simulate prior recall
479    /// history without accessing private fields. It is not intended for production use.
480    #[doc(hidden)]
481    pub fn bump_recall_count_by_content(&self, content: &str, amount: u64) {
482        let mut inner = recover_lock(
483            self.inner.lock(),
484            "EpisodicStore::bump_recall_count_by_content",
485        );
486        for item in inner.items.iter_mut() {
487            if item.content == content {
488                item.recall_count = item.recall_count.saturating_add(amount);
489            }
490        }
491    }
492}
493
494impl Default for EpisodicStore {
495    fn default() -> Self {
496        Self::new()
497    }
498}
499
500// ── SemanticStore ─────────────────────────────────────────────────────────────
501
502/// Stores semantic (fact-based) knowledge as tagged key-value pairs.
503///
504/// ## Guarantees
505/// - Thread-safe via `Arc<Mutex<_>>`
506/// - Retrieval by tag intersection
507/// - Optional vector-based similarity search via stored embeddings
508#[derive(Debug, Clone)]
509pub struct SemanticStore {
510    inner: Arc<Mutex<Vec<SemanticEntry>>>,
511}
512
513#[derive(Debug, Clone)]
514struct SemanticEntry {
515    key: String,
516    value: String,
517    tags: Vec<String>,
518    embedding: Option<Vec<f32>>,
519}
520
521impl SemanticStore {
522    /// Create a new empty semantic store.
523    pub fn new() -> Self {
524        Self {
525            inner: Arc::new(Mutex::new(Vec::new())),
526        }
527    }
528
529    /// Store a key-value pair with associated tags.
530    #[tracing::instrument(skip(self))]
531    pub fn store(
532        &self,
533        key: impl Into<String> + std::fmt::Debug,
534        value: impl Into<String> + std::fmt::Debug,
535        tags: Vec<String>,
536    ) -> Result<(), AgentRuntimeError> {
537        let mut inner = recover_lock(self.inner.lock(), "SemanticStore::store");
538        inner.push(SemanticEntry {
539            key: key.into(),
540            value: value.into(),
541            tags,
542            embedding: None,
543        });
544        Ok(())
545    }
546
547    /// Store a key-value pair with an embedding vector for similarity search.
548    #[tracing::instrument(skip(self))]
549    pub fn store_with_embedding(
550        &self,
551        key: impl Into<String> + std::fmt::Debug,
552        value: impl Into<String> + std::fmt::Debug,
553        tags: Vec<String>,
554        embedding: Vec<f32>,
555    ) -> Result<(), AgentRuntimeError> {
556        let mut inner = recover_lock(self.inner.lock(), "SemanticStore::store_with_embedding");
557        inner.push(SemanticEntry {
558            key: key.into(),
559            value: value.into(),
560            tags,
561            embedding: Some(embedding),
562        });
563        Ok(())
564    }
565
566    /// Retrieve all entries that contain **all** of the given tags.
567    ///
568    /// If `tags` is empty, returns all entries.
569    #[tracing::instrument(skip(self))]
570    pub fn retrieve(&self, tags: &[&str]) -> Result<Vec<(String, String)>, AgentRuntimeError> {
571        let inner = recover_lock(self.inner.lock(), "SemanticStore::retrieve");
572
573        let results = inner
574            .iter()
575            .filter(|entry| {
576                tags.iter()
577                    .all(|t| entry.tags.iter().any(|et| et.as_str() == *t))
578            })
579            .map(|e| (e.key.clone(), e.value.clone()))
580            .collect();
581
582        Ok(results)
583    }
584
585    /// Retrieve top-k entries by cosine similarity to `query_embedding`.
586    ///
587    /// Only entries that were stored with an embedding (via [`store_with_embedding`])
588    /// are considered.  Returns `(key, value, similarity)` sorted by descending
589    /// similarity.
590    ///
591    /// [`store_with_embedding`]: SemanticStore::store_with_embedding
592    #[tracing::instrument(skip(self, query_embedding))]
593    pub fn retrieve_similar(
594        &self,
595        query_embedding: &[f32],
596        top_k: usize,
597    ) -> Result<Vec<(String, String, f32)>, AgentRuntimeError> {
598        let inner = recover_lock(self.inner.lock(), "SemanticStore::retrieve_similar");
599
600        let mut scored: Vec<(String, String, f32)> = inner
601            .iter()
602            .filter_map(|entry| {
603                entry.embedding.as_ref().map(|emb| {
604                    let sim = cosine_similarity(query_embedding, emb);
605                    (entry.key.clone(), entry.value.clone(), sim)
606                })
607            })
608            .collect();
609
610        scored.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
611        scored.truncate(top_k);
612        Ok(scored)
613    }
614
615    /// Return the total number of stored entries.
616    pub fn len(&self) -> Result<usize, AgentRuntimeError> {
617        let inner = recover_lock(self.inner.lock(), "SemanticStore::len");
618        Ok(inner.len())
619    }
620
621    /// Return `true` if no entries have been stored.
622    pub fn is_empty(&self) -> Result<bool, AgentRuntimeError> {
623        Ok(self.len()? == 0)
624    }
625}
626
627impl Default for SemanticStore {
628    fn default() -> Self {
629        Self::new()
630    }
631}
632
633// ── WorkingMemory ─────────────────────────────────────────────────────────────
634
635/// A bounded, key-value working memory for transient agent state.
636///
637/// When capacity is exceeded, the oldest entry (by insertion order) is evicted.
638///
639/// ## Guarantees
640/// - Thread-safe via `Arc<Mutex<_>>`
641/// - Bounded: never exceeds `capacity` entries
642/// - Deterministic eviction: LRU (oldest insertion first)
643#[derive(Debug, Clone)]
644pub struct WorkingMemory {
645    capacity: usize,
646    inner: Arc<Mutex<WorkingInner>>,
647}
648
649#[derive(Debug)]
650struct WorkingInner {
651    map: HashMap<String, String>,
652    order: VecDeque<String>,
653}
654
655impl WorkingMemory {
656    /// Create a new `WorkingMemory` with the given capacity.
657    ///
658    /// # Returns
659    /// - `Ok(WorkingMemory)` — on success
660    /// - `Err(AgentRuntimeError::Memory)` — if `capacity == 0`
661    pub fn new(capacity: usize) -> Result<Self, AgentRuntimeError> {
662        if capacity == 0 {
663            return Err(AgentRuntimeError::Memory(
664                "WorkingMemory capacity must be > 0".into(),
665            ));
666        }
667        Ok(Self {
668            capacity,
669            inner: Arc::new(Mutex::new(WorkingInner {
670                map: HashMap::new(),
671                order: VecDeque::new(),
672            })),
673        })
674    }
675
676    /// Insert or update a key-value pair, evicting the oldest entry if over capacity.
677    #[tracing::instrument(skip(self))]
678    pub fn set(
679        &self,
680        key: impl Into<String> + std::fmt::Debug,
681        value: impl Into<String> + std::fmt::Debug,
682    ) -> Result<(), AgentRuntimeError> {
683        let key = key.into();
684        let value = value.into();
685        let mut inner = recover_lock(self.inner.lock(), "WorkingMemory::set");
686
687        // Remove existing key from order tracking if present
688        if inner.map.contains_key(&key) {
689            inner.order.retain(|k| k != &key);
690        } else if inner.map.len() >= self.capacity {
691            // Evict oldest
692            if let Some(oldest) = inner.order.pop_front() {
693                inner.map.remove(&oldest);
694            }
695        }
696
697        inner.order.push_back(key.clone());
698        inner.map.insert(key, value);
699        Ok(())
700    }
701
702    /// Retrieve a value by key.
703    ///
704    /// # Returns
705    /// - `Some(value)` — if the key exists
706    /// - `None` — if not found
707    #[tracing::instrument(skip(self))]
708    pub fn get(&self, key: &str) -> Result<Option<String>, AgentRuntimeError> {
709        let inner = recover_lock(self.inner.lock(), "WorkingMemory::get");
710        Ok(inner.map.get(key).cloned())
711    }
712
713    /// Remove all entries from working memory.
714    pub fn clear(&self) -> Result<(), AgentRuntimeError> {
715        let mut inner = recover_lock(self.inner.lock(), "WorkingMemory::clear");
716        inner.map.clear();
717        inner.order.clear();
718        Ok(())
719    }
720
721    /// Return the current number of entries.
722    pub fn len(&self) -> Result<usize, AgentRuntimeError> {
723        let inner = recover_lock(self.inner.lock(), "WorkingMemory::len");
724        Ok(inner.map.len())
725    }
726
727    /// Return `true` if no entries are stored.
728    pub fn is_empty(&self) -> Result<bool, AgentRuntimeError> {
729        Ok(self.len()? == 0)
730    }
731
732    /// Return all key-value pairs in insertion order.
733    pub fn entries(&self) -> Result<Vec<(String, String)>, AgentRuntimeError> {
734        let inner = recover_lock(self.inner.lock(), "WorkingMemory::entries");
735        let entries = inner
736            .order
737            .iter()
738            .filter_map(|k| inner.map.get(k).map(|v| (k.clone(), v.clone())))
739            .collect();
740        Ok(entries)
741    }
742}
743
744// ── Tests ─────────────────────────────────────────────────────────────────────
745
746#[cfg(test)]
747mod tests {
748    use super::*;
749
750    // ── AgentId / MemoryId ────────────────────────────────────────────────────
751
752    #[test]
753    fn test_agent_id_new_stores_string() {
754        let id = AgentId::new("agent-1");
755        assert_eq!(id.0, "agent-1");
756    }
757
758    #[test]
759    fn test_agent_id_random_is_unique() {
760        let a = AgentId::random();
761        let b = AgentId::random();
762        assert_ne!(a, b);
763    }
764
765    #[test]
766    fn test_memory_id_new_stores_string() {
767        let id = MemoryId::new("mem-1");
768        assert_eq!(id.0, "mem-1");
769    }
770
771    #[test]
772    fn test_memory_id_random_is_unique() {
773        let a = MemoryId::random();
774        let b = MemoryId::random();
775        assert_ne!(a, b);
776    }
777
778    // ── MemoryItem ────────────────────────────────────────────────────────────
779
780    #[test]
781    fn test_memory_item_new_clamps_importance_above_one() {
782        let item = MemoryItem::new(AgentId::new("a"), "test", 1.5, vec![]);
783        assert_eq!(item.importance, 1.0);
784    }
785
786    #[test]
787    fn test_memory_item_new_clamps_importance_below_zero() {
788        let item = MemoryItem::new(AgentId::new("a"), "test", -0.5, vec![]);
789        assert_eq!(item.importance, 0.0);
790    }
791
792    #[test]
793    fn test_memory_item_new_preserves_valid_importance() {
794        let item = MemoryItem::new(AgentId::new("a"), "test", 0.7, vec![]);
795        assert!((item.importance - 0.7).abs() < 1e-6);
796    }
797
798    // ── DecayPolicy ───────────────────────────────────────────────────────────
799
800    #[test]
801    fn test_decay_policy_rejects_zero_half_life() {
802        assert!(DecayPolicy::exponential(0.0).is_err());
803    }
804
805    #[test]
806    fn test_decay_policy_rejects_negative_half_life() {
807        assert!(DecayPolicy::exponential(-1.0).is_err());
808    }
809
810    #[test]
811    fn test_decay_policy_no_decay_at_age_zero() {
812        let p = DecayPolicy::exponential(24.0).unwrap();
813        let decayed = p.apply(1.0, 0.0);
814        assert!((decayed - 1.0).abs() < 1e-5);
815    }
816
817    #[test]
818    fn test_decay_policy_half_importance_at_half_life() {
819        let p = DecayPolicy::exponential(24.0).unwrap();
820        let decayed = p.apply(1.0, 24.0);
821        assert!((decayed - 0.5).abs() < 1e-5);
822    }
823
824    #[test]
825    fn test_decay_policy_quarter_importance_at_two_half_lives() {
826        let p = DecayPolicy::exponential(24.0).unwrap();
827        let decayed = p.apply(1.0, 48.0);
828        assert!((decayed - 0.25).abs() < 1e-5);
829    }
830
831    #[test]
832    fn test_decay_policy_result_is_clamped_to_zero_one() {
833        let p = DecayPolicy::exponential(1.0).unwrap();
834        let decayed = p.apply(0.0, 1000.0);
835        assert!(decayed >= 0.0 && decayed <= 1.0);
836    }
837
838    // ── EpisodicStore ─────────────────────────────────────────────────────────
839
840    #[test]
841    fn test_episodic_store_add_episode_returns_id() {
842        let store = EpisodicStore::new();
843        let id = store.add_episode(AgentId::new("a"), "event", 0.8).unwrap();
844        assert!(!id.0.is_empty());
845    }
846
847    #[test]
848    fn test_episodic_store_recall_returns_stored_item() {
849        let store = EpisodicStore::new();
850        let agent = AgentId::new("agent-1");
851        store
852            .add_episode(agent.clone(), "hello world", 0.9)
853            .unwrap();
854        let items = store.recall(&agent, 10).unwrap();
855        assert_eq!(items.len(), 1);
856        assert_eq!(items[0].content, "hello world");
857    }
858
859    #[test]
860    fn test_episodic_store_recall_filters_by_agent() {
861        let store = EpisodicStore::new();
862        let a = AgentId::new("agent-a");
863        let b = AgentId::new("agent-b");
864        store.add_episode(a.clone(), "for a", 0.5).unwrap();
865        store.add_episode(b.clone(), "for b", 0.5).unwrap();
866        let items = store.recall(&a, 10).unwrap();
867        assert_eq!(items.len(), 1);
868        assert_eq!(items[0].content, "for a");
869    }
870
871    #[test]
872    fn test_episodic_store_recall_sorted_by_descending_importance() {
873        let store = EpisodicStore::new();
874        let agent = AgentId::new("agent-1");
875        store.add_episode(agent.clone(), "low", 0.1).unwrap();
876        store.add_episode(agent.clone(), "high", 0.9).unwrap();
877        store.add_episode(agent.clone(), "mid", 0.5).unwrap();
878        let items = store.recall(&agent, 10).unwrap();
879        assert_eq!(items[0].content, "high");
880        assert_eq!(items[1].content, "mid");
881        assert_eq!(items[2].content, "low");
882    }
883
884    #[test]
885    fn test_episodic_store_recall_respects_limit() {
886        let store = EpisodicStore::new();
887        let agent = AgentId::new("agent-1");
888        for i in 0..5 {
889            store
890                .add_episode(agent.clone(), format!("item {i}"), 0.5)
891                .unwrap();
892        }
893        let items = store.recall(&agent, 3).unwrap();
894        assert_eq!(items.len(), 3);
895    }
896
897    #[test]
898    fn test_episodic_store_len_tracks_insertions() {
899        let store = EpisodicStore::new();
900        let agent = AgentId::new("a");
901        store.add_episode(agent.clone(), "a", 0.5).unwrap();
902        store.add_episode(agent.clone(), "b", 0.5).unwrap();
903        assert_eq!(store.len().unwrap(), 2);
904    }
905
906    #[test]
907    fn test_episodic_store_is_empty_initially() {
908        let store = EpisodicStore::new();
909        assert!(store.is_empty().unwrap());
910    }
911
912    #[test]
913    fn test_episodic_store_with_decay_reduces_importance() {
914        let policy = DecayPolicy::exponential(0.001).unwrap(); // very fast decay
915        let store = EpisodicStore::with_decay(policy);
916        let agent = AgentId::new("a");
917
918        // Manually insert an old item by directly manipulating timestamps
919        {
920            let mut inner = store.inner.lock().unwrap();
921            let mut item = MemoryItem::new(agent.clone(), "old event", 1.0, vec![]);
922            // Set the timestamp to 1 hour ago
923            item.timestamp = Utc::now() - chrono::Duration::hours(1);
924            inner.items.push(item);
925        }
926
927        let items = store.recall(&agent, 10).unwrap();
928        // With half_life=0.001h and age=1h, importance should be near 0
929        assert_eq!(items.len(), 1);
930        assert!(
931            items[0].importance < 0.01,
932            "expected near-zero importance, got {}",
933            items[0].importance
934        );
935    }
936
937    // ── Item 10: RecallPolicy / per-agent capacity tests ──────────────────────
938
939    #[test]
940    fn test_recall_increments_recall_count() {
941        let store = EpisodicStore::new();
942        let agent = AgentId::new("agent-rc");
943        store.add_episode(agent.clone(), "memory", 0.5).unwrap();
944
945        // First recall — count becomes 1
946        let items = store.recall(&agent, 10).unwrap();
947        assert_eq!(items[0].recall_count, 1);
948
949        // Second recall — count becomes 2
950        let items = store.recall(&agent, 10).unwrap();
951        assert_eq!(items[0].recall_count, 2);
952    }
953
954    #[test]
955    fn test_hybrid_recall_policy_prefers_recently_used() {
956        // "old_frequent": added 48 h ago with importance 0.5, recall_count bumped
957        // manually to simulate frequent use.
958        // "new_never": added just now with importance 0.5, never recalled.
959        // With a large frequency_weight the frequently-recalled item should rank higher.
960        let store = EpisodicStore::with_recall_policy(RecallPolicy::Hybrid {
961            recency_weight: 0.1,
962            frequency_weight: 2.0,
963        });
964        let agent = AgentId::new("agent-hybrid");
965
966        let old_ts = Utc::now() - chrono::Duration::hours(48);
967        store
968            .add_episode_at(agent.clone(), "old_frequent", 0.5, old_ts)
969            .unwrap();
970        store.add_episode(agent.clone(), "new_never", 0.5).unwrap();
971
972        // Simulate many prior recalls of "old_frequent" by manually bumping its count
973        {
974            let mut inner = store.inner.lock().unwrap();
975            for item in inner.items.iter_mut() {
976                if item.content == "old_frequent" {
977                    item.recall_count = 100;
978                }
979            }
980        }
981
982        let items = store.recall(&agent, 10).unwrap();
983        assert_eq!(items.len(), 2);
984        assert_eq!(
985            items[0].content, "old_frequent",
986            "hybrid policy should rank the frequently-recalled item first"
987        );
988    }
989
990    #[test]
991    fn test_per_agent_capacity_evicts_lowest_importance() {
992        let store = EpisodicStore::with_per_agent_capacity(2);
993        let agent = AgentId::new("agent-cap");
994
995        store.add_episode(agent.clone(), "mid", 0.5).unwrap();
996        store.add_episode(agent.clone(), "high", 0.9).unwrap();
997        // Adding "low" (0.1) should trigger eviction of the lowest-importance item
998        store.add_episode(agent.clone(), "low", 0.1).unwrap();
999
1000        assert_eq!(
1001            store.len().unwrap(),
1002            2,
1003            "store should hold exactly 2 items after eviction"
1004        );
1005
1006        let items = store.recall(&agent, 10).unwrap();
1007        let contents: Vec<&str> = items.iter().map(|i| i.content.as_str()).collect();
1008        assert!(
1009            !contents.contains(&"low"),
1010            "the lowest-importance item should have been evicted; remaining: {:?}",
1011            contents
1012        );
1013    }
1014
1015    // ── SemanticStore ─────────────────────────────────────────────────────────
1016
1017    #[test]
1018    fn test_semantic_store_store_and_retrieve_all() {
1019        let store = SemanticStore::new();
1020        store.store("key1", "value1", vec!["tag-a".into()]).unwrap();
1021        store.store("key2", "value2", vec!["tag-b".into()]).unwrap();
1022        let results = store.retrieve(&[]).unwrap();
1023        assert_eq!(results.len(), 2);
1024    }
1025
1026    #[test]
1027    fn test_semantic_store_retrieve_filters_by_tag() {
1028        let store = SemanticStore::new();
1029        store
1030            .store("k1", "v1", vec!["rust".into(), "async".into()])
1031            .unwrap();
1032        store.store("k2", "v2", vec!["rust".into()]).unwrap();
1033        let results = store.retrieve(&["async"]).unwrap();
1034        assert_eq!(results.len(), 1);
1035        assert_eq!(results[0].0, "k1");
1036    }
1037
1038    #[test]
1039    fn test_semantic_store_retrieve_requires_all_tags() {
1040        let store = SemanticStore::new();
1041        store
1042            .store("k1", "v1", vec!["a".into(), "b".into()])
1043            .unwrap();
1044        store.store("k2", "v2", vec!["a".into()]).unwrap();
1045        let results = store.retrieve(&["a", "b"]).unwrap();
1046        assert_eq!(results.len(), 1);
1047    }
1048
1049    #[test]
1050    fn test_semantic_store_is_empty_initially() {
1051        let store = SemanticStore::new();
1052        assert!(store.is_empty().unwrap());
1053    }
1054
1055    #[test]
1056    fn test_semantic_store_len_tracks_insertions() {
1057        let store = SemanticStore::new();
1058        store.store("k", "v", vec![]).unwrap();
1059        assert_eq!(store.len().unwrap(), 1);
1060    }
1061
1062    #[test]
1063    fn test_semantic_store_retrieve_similar_returns_closest() {
1064        let store = SemanticStore::new();
1065        // "close" is in the same direction as the query
1066        store
1067            .store_with_embedding("close", "close value", vec![], vec![1.0, 0.0, 0.0])
1068            .unwrap();
1069        // "far" is orthogonal to the query
1070        store
1071            .store_with_embedding("far", "far value", vec![], vec![0.0, 1.0, 0.0])
1072            .unwrap();
1073
1074        let query = vec![1.0, 0.0, 0.0];
1075        let results = store.retrieve_similar(&query, 2).unwrap();
1076        assert_eq!(results.len(), 2);
1077        // The closest result should be "close"
1078        assert_eq!(results[0].0, "close");
1079        assert!(
1080            (results[0].2 - 1.0).abs() < 1e-5,
1081            "expected similarity ~1.0, got {}",
1082            results[0].2
1083        );
1084        // The far result should have similarity 0.0
1085        assert!(
1086            (results[1].2).abs() < 1e-5,
1087            "expected similarity ~0.0, got {}",
1088            results[1].2
1089        );
1090    }
1091
1092    #[test]
1093    fn test_semantic_store_retrieve_similar_ignores_unembedded_entries() {
1094        let store = SemanticStore::new();
1095        // This entry has no embedding — must not appear in similarity results
1096        store.store("no-emb", "no embedding value", vec![]).unwrap();
1097        // This entry has an embedding
1098        store
1099            .store_with_embedding("with-emb", "with embedding value", vec![], vec![1.0, 0.0])
1100            .unwrap();
1101
1102        let query = vec![1.0, 0.0];
1103        let results = store.retrieve_similar(&query, 10).unwrap();
1104        assert_eq!(results.len(), 1, "only the embedded entry should appear");
1105        assert_eq!(results[0].0, "with-emb");
1106    }
1107
1108    #[test]
1109    fn test_cosine_similarity_orthogonal_vectors_return_zero() {
1110        // Exercise cosine_similarity through retrieve_similar with orthogonal vectors
1111        let store = SemanticStore::new();
1112        store
1113            .store_with_embedding("a", "va", vec![], vec![1.0, 0.0])
1114            .unwrap();
1115        store
1116            .store_with_embedding("b", "vb", vec![], vec![0.0, 1.0])
1117            .unwrap();
1118
1119        // query is along [1, 0]; "b" is orthogonal → similarity should be 0
1120        let query = vec![1.0, 0.0];
1121        let results = store.retrieve_similar(&query, 2).unwrap();
1122        assert_eq!(results.len(), 2);
1123        let b_result = results.iter().find(|(k, _, _)| k == "b").unwrap();
1124        assert!(
1125            b_result.2.abs() < 1e-5,
1126            "expected cosine similarity 0.0 for orthogonal vectors, got {}",
1127            b_result.2
1128        );
1129    }
1130
1131    // ── WorkingMemory ─────────────────────────────────────────────────────────
1132
1133    #[test]
1134    fn test_working_memory_new_rejects_zero_capacity() {
1135        assert!(WorkingMemory::new(0).is_err());
1136    }
1137
1138    #[test]
1139    fn test_working_memory_set_and_get() {
1140        let wm = WorkingMemory::new(10).unwrap();
1141        wm.set("foo", "bar").unwrap();
1142        let val = wm.get("foo").unwrap();
1143        assert_eq!(val, Some("bar".into()));
1144    }
1145
1146    #[test]
1147    fn test_working_memory_get_missing_key_returns_none() {
1148        let wm = WorkingMemory::new(10).unwrap();
1149        assert_eq!(wm.get("missing").unwrap(), None);
1150    }
1151
1152    #[test]
1153    fn test_working_memory_bounded_evicts_oldest() {
1154        let wm = WorkingMemory::new(3).unwrap();
1155        wm.set("k1", "v1").unwrap();
1156        wm.set("k2", "v2").unwrap();
1157        wm.set("k3", "v3").unwrap();
1158        wm.set("k4", "v4").unwrap(); // k1 should be evicted
1159        assert_eq!(wm.get("k1").unwrap(), None);
1160        assert_eq!(wm.get("k4").unwrap(), Some("v4".into()));
1161    }
1162
1163    #[test]
1164    fn test_working_memory_update_existing_key_no_eviction() {
1165        let wm = WorkingMemory::new(2).unwrap();
1166        wm.set("k1", "v1").unwrap();
1167        wm.set("k2", "v2").unwrap();
1168        wm.set("k1", "v1-updated").unwrap(); // update, not eviction
1169        assert_eq!(wm.len().unwrap(), 2);
1170        assert_eq!(wm.get("k1").unwrap(), Some("v1-updated".into()));
1171        assert_eq!(wm.get("k2").unwrap(), Some("v2".into()));
1172    }
1173
1174    #[test]
1175    fn test_working_memory_clear_removes_all() {
1176        let wm = WorkingMemory::new(10).unwrap();
1177        wm.set("a", "1").unwrap();
1178        wm.set("b", "2").unwrap();
1179        wm.clear().unwrap();
1180        assert!(wm.is_empty().unwrap());
1181    }
1182
1183    #[test]
1184    fn test_working_memory_is_empty_initially() {
1185        let wm = WorkingMemory::new(5).unwrap();
1186        assert!(wm.is_empty().unwrap());
1187    }
1188
1189    #[test]
1190    fn test_working_memory_len_tracks_entries() {
1191        let wm = WorkingMemory::new(10).unwrap();
1192        wm.set("a", "1").unwrap();
1193        wm.set("b", "2").unwrap();
1194        assert_eq!(wm.len().unwrap(), 2);
1195    }
1196
1197    #[test]
1198    fn test_working_memory_capacity_never_exceeded() {
1199        let cap = 5usize;
1200        let wm = WorkingMemory::new(cap).unwrap();
1201        for i in 0..20 {
1202            wm.set(format!("key-{i}"), format!("val-{i}")).unwrap();
1203            assert!(wm.len().unwrap() <= cap);
1204        }
1205    }
1206}