1use crate::error::AgentRuntimeError;
19use chrono::{DateTime, Utc};
20use serde::{Deserialize, Serialize};
21use std::collections::{HashMap, VecDeque};
22use std::sync::{Arc, Mutex};
23use uuid::Uuid;
24
25fn recover_lock<'a, T>(
31 result: std::sync::LockResult<std::sync::MutexGuard<'a, T>>,
32 ctx: &str,
33) -> std::sync::MutexGuard<'a, T>
34where
35 T: ?Sized,
36{
37 match result {
38 Ok(guard) => guard,
39 Err(poisoned) => {
40 tracing::warn!("mutex poisoned in {ctx}, recovering inner value");
41 poisoned.into_inner()
42 }
43 }
44}
45
46fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
49 if a.len() != b.len() || a.is_empty() {
50 return 0.0;
51 }
52 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
53 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
54 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
55 if norm_a == 0.0 || norm_b == 0.0 {
56 return 0.0;
57 }
58 (dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
59}
60
61#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
65pub struct AgentId(pub String);
66
67impl AgentId {
68 pub fn new(id: impl Into<String>) -> Self {
70 Self(id.into())
71 }
72
73 pub fn random() -> Self {
75 Self(Uuid::new_v4().to_string())
76 }
77}
78
79impl std::fmt::Display for AgentId {
80 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81 write!(f, "{}", self.0)
82 }
83}
84
85#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
87pub struct MemoryId(pub String);
88
89impl MemoryId {
90 pub fn new(id: impl Into<String>) -> Self {
92 Self(id.into())
93 }
94
95 pub fn random() -> Self {
97 Self(Uuid::new_v4().to_string())
98 }
99}
100
101impl std::fmt::Display for MemoryId {
102 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103 write!(f, "{}", self.0)
104 }
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct MemoryItem {
112 pub id: MemoryId,
114 pub agent_id: AgentId,
116 pub content: String,
118 pub importance: f32,
120 pub timestamp: DateTime<Utc>,
122 pub tags: Vec<String>,
124 #[serde(default)]
126 pub recall_count: u64,
127}
128
129impl MemoryItem {
130 pub fn new(
132 agent_id: AgentId,
133 content: impl Into<String>,
134 importance: f32,
135 tags: Vec<String>,
136 ) -> Self {
137 Self {
138 id: MemoryId::random(),
139 agent_id,
140 content: content.into(),
141 importance: importance.clamp(0.0, 1.0),
142 timestamp: Utc::now(),
143 tags,
144 recall_count: 0,
145 }
146 }
147}
148
149#[derive(Debug, Clone)]
153pub struct DecayPolicy {
154 half_life_hours: f64,
156}
157
158impl DecayPolicy {
159 pub fn exponential(half_life_hours: f64) -> Result<Self, AgentRuntimeError> {
168 if half_life_hours <= 0.0 {
169 return Err(AgentRuntimeError::Memory(
170 "half_life_hours must be positive".into(),
171 ));
172 }
173 Ok(Self { half_life_hours })
174 }
175
176 pub fn apply(&self, importance: f32, age_hours: f64) -> f32 {
185 let decay = (-age_hours * std::f64::consts::LN_2 / self.half_life_hours).exp();
186 (importance as f64 * decay).clamp(0.0, 1.0) as f32
187 }
188
189 pub fn decay_item(&self, item: &mut MemoryItem) {
191 let age_hours = (Utc::now() - item.timestamp).num_seconds().max(0) as f64 / 3600.0;
192 item.importance = self.apply(item.importance, age_hours);
193 }
194}
195
196#[derive(Debug, Clone)]
200pub enum RecallPolicy {
201 Importance,
203 Hybrid {
209 recency_weight: f32,
211 frequency_weight: f32,
213 },
214}
215
216fn compute_hybrid_score(
219 item: &MemoryItem,
220 recency_weight: f32,
221 frequency_weight: f32,
222 max_recall: u64,
223 now: chrono::DateTime<Utc>,
224) -> f32 {
225 let age_hours = (now - item.timestamp).num_seconds().max(0) as f64 / 3600.0;
226 let recency_score = (-age_hours / 24.0).exp() as f32;
227 let frequency_score = item.recall_count as f32 / (max_recall as f32 + 1.0);
228 item.importance + recency_score * recency_weight + frequency_score * frequency_weight
229}
230
231#[derive(Debug, Clone)]
240pub struct EpisodicStore {
241 inner: Arc<Mutex<EpisodicInner>>,
242}
243
244#[derive(Debug)]
245struct EpisodicInner {
246 items: Vec<MemoryItem>,
247 decay: Option<DecayPolicy>,
248 recall_policy: RecallPolicy,
249 per_agent_capacity: Option<usize>,
251}
252
253impl EpisodicStore {
254 pub fn new() -> Self {
256 Self {
257 inner: Arc::new(Mutex::new(EpisodicInner {
258 items: Vec::new(),
259 decay: None,
260 recall_policy: RecallPolicy::Importance,
261 per_agent_capacity: None,
262 })),
263 }
264 }
265
266 pub fn with_decay(policy: DecayPolicy) -> Self {
268 Self {
269 inner: Arc::new(Mutex::new(EpisodicInner {
270 items: Vec::new(),
271 decay: Some(policy),
272 recall_policy: RecallPolicy::Importance,
273 per_agent_capacity: None,
274 })),
275 }
276 }
277
278 pub fn with_recall_policy(policy: RecallPolicy) -> Self {
280 Self {
281 inner: Arc::new(Mutex::new(EpisodicInner {
282 items: Vec::new(),
283 decay: None,
284 recall_policy: policy,
285 per_agent_capacity: None,
286 })),
287 }
288 }
289
290 pub fn with_per_agent_capacity(capacity: usize) -> Self {
295 Self {
296 inner: Arc::new(Mutex::new(EpisodicInner {
297 items: Vec::new(),
298 decay: None,
299 recall_policy: RecallPolicy::Importance,
300 per_agent_capacity: Some(capacity),
301 })),
302 }
303 }
304
305 #[tracing::instrument(skip(self))]
310 pub fn add_episode(
311 &self,
312 agent_id: AgentId,
313 content: impl Into<String> + std::fmt::Debug,
314 importance: f32,
315 ) -> Result<MemoryId, AgentRuntimeError> {
316 let item = MemoryItem::new(agent_id.clone(), content, importance, Vec::new());
317 let id = item.id.clone();
318 let mut inner = recover_lock(self.inner.lock(), "EpisodicStore::add_episode");
319 inner.items.push(item);
320 if let Some(cap) = inner.per_agent_capacity {
321 let agent_count = inner
322 .items
323 .iter()
324 .filter(|i| i.agent_id == agent_id)
325 .count();
326 if agent_count > cap {
327 if let Some(pos) = inner
328 .items
329 .iter()
330 .enumerate()
331 .filter(|(_, i)| i.agent_id == agent_id)
332 .min_by(|(_, a), (_, b)| {
333 a.importance
334 .partial_cmp(&b.importance)
335 .unwrap_or(std::cmp::Ordering::Equal)
336 })
337 .map(|(pos, _)| pos)
338 {
339 inner.items.remove(pos);
340 }
341 }
342 }
343 Ok(id)
344 }
345
346 #[tracing::instrument(skip(self))]
348 pub fn add_episode_at(
349 &self,
350 agent_id: AgentId,
351 content: impl Into<String> + std::fmt::Debug,
352 importance: f32,
353 timestamp: chrono::DateTime<chrono::Utc>,
354 ) -> Result<MemoryId, AgentRuntimeError> {
355 let mut item = MemoryItem::new(agent_id.clone(), content, importance, Vec::new());
356 item.timestamp = timestamp;
357 let id = item.id.clone();
358 let mut inner = recover_lock(self.inner.lock(), "EpisodicStore::add_episode_at");
359 inner.items.push(item);
360 if let Some(cap) = inner.per_agent_capacity {
361 let agent_count = inner
362 .items
363 .iter()
364 .filter(|i| i.agent_id == agent_id)
365 .count();
366 if agent_count > cap {
367 if let Some(pos) = inner
368 .items
369 .iter()
370 .enumerate()
371 .filter(|(_, i)| i.agent_id == agent_id)
372 .min_by(|(_, a), (_, b)| {
373 a.importance
374 .partial_cmp(&b.importance)
375 .unwrap_or(std::cmp::Ordering::Equal)
376 })
377 .map(|(pos, _)| pos)
378 {
379 inner.items.remove(pos);
380 }
381 }
382 }
383 Ok(id)
384 }
385
386 #[tracing::instrument(skip(self))]
392 pub fn recall(
393 &self,
394 agent_id: &AgentId,
395 limit: usize,
396 ) -> Result<Vec<MemoryItem>, AgentRuntimeError> {
397 let mut inner = recover_lock(self.inner.lock(), "EpisodicStore::recall");
398
399 let decay_clone: Option<DecayPolicy> = inner.decay.clone();
401 if let Some(policy) = decay_clone {
402 for item in inner.items.iter_mut() {
403 policy.decay_item(item);
404 }
405 }
406
407 let agent_ids_to_update: Vec<MemoryId> = inner
409 .items
410 .iter()
411 .filter(|i| &i.agent_id == agent_id)
412 .map(|i| i.id.clone())
413 .collect();
414
415 for item in inner.items.iter_mut() {
417 if agent_ids_to_update.contains(&item.id) {
418 item.recall_count += 1;
419 }
420 }
421
422 let mut items: Vec<MemoryItem> = inner
423 .items
424 .iter()
425 .filter(|i| &i.agent_id == agent_id)
426 .cloned()
427 .collect();
428
429 match inner.recall_policy {
430 RecallPolicy::Importance => {
431 items.sort_by(|a, b| {
432 b.importance
433 .partial_cmp(&a.importance)
434 .unwrap_or(std::cmp::Ordering::Equal)
435 });
436 }
437 RecallPolicy::Hybrid {
438 recency_weight,
439 frequency_weight,
440 } => {
441 let max_recall = items
442 .iter()
443 .map(|i| i.recall_count)
444 .max()
445 .unwrap_or(1)
446 .max(1);
447 let now = Utc::now();
448 items.sort_by(|a, b| {
449 let score_a =
450 compute_hybrid_score(a, recency_weight, frequency_weight, max_recall, now);
451 let score_b =
452 compute_hybrid_score(b, recency_weight, frequency_weight, max_recall, now);
453 score_b
454 .partial_cmp(&score_a)
455 .unwrap_or(std::cmp::Ordering::Equal)
456 });
457 }
458 }
459
460 items.truncate(limit);
461 tracing::debug!("recalled {} items", items.len());
462 Ok(items)
463 }
464
465 pub fn len(&self) -> Result<usize, AgentRuntimeError> {
467 let inner = recover_lock(self.inner.lock(), "EpisodicStore::len");
468 Ok(inner.items.len())
469 }
470
471 pub fn is_empty(&self) -> Result<bool, AgentRuntimeError> {
473 Ok(self.len()? == 0)
474 }
475
476 #[doc(hidden)]
481 pub fn bump_recall_count_by_content(&self, content: &str, amount: u64) {
482 let mut inner = recover_lock(
483 self.inner.lock(),
484 "EpisodicStore::bump_recall_count_by_content",
485 );
486 for item in inner.items.iter_mut() {
487 if item.content == content {
488 item.recall_count = item.recall_count.saturating_add(amount);
489 }
490 }
491 }
492}
493
494impl Default for EpisodicStore {
495 fn default() -> Self {
496 Self::new()
497 }
498}
499
500#[derive(Debug, Clone)]
509pub struct SemanticStore {
510 inner: Arc<Mutex<Vec<SemanticEntry>>>,
511}
512
513#[derive(Debug, Clone)]
514struct SemanticEntry {
515 key: String,
516 value: String,
517 tags: Vec<String>,
518 embedding: Option<Vec<f32>>,
519}
520
521impl SemanticStore {
522 pub fn new() -> Self {
524 Self {
525 inner: Arc::new(Mutex::new(Vec::new())),
526 }
527 }
528
529 #[tracing::instrument(skip(self))]
531 pub fn store(
532 &self,
533 key: impl Into<String> + std::fmt::Debug,
534 value: impl Into<String> + std::fmt::Debug,
535 tags: Vec<String>,
536 ) -> Result<(), AgentRuntimeError> {
537 let mut inner = recover_lock(self.inner.lock(), "SemanticStore::store");
538 inner.push(SemanticEntry {
539 key: key.into(),
540 value: value.into(),
541 tags,
542 embedding: None,
543 });
544 Ok(())
545 }
546
547 #[tracing::instrument(skip(self))]
549 pub fn store_with_embedding(
550 &self,
551 key: impl Into<String> + std::fmt::Debug,
552 value: impl Into<String> + std::fmt::Debug,
553 tags: Vec<String>,
554 embedding: Vec<f32>,
555 ) -> Result<(), AgentRuntimeError> {
556 let mut inner = recover_lock(self.inner.lock(), "SemanticStore::store_with_embedding");
557 inner.push(SemanticEntry {
558 key: key.into(),
559 value: value.into(),
560 tags,
561 embedding: Some(embedding),
562 });
563 Ok(())
564 }
565
566 #[tracing::instrument(skip(self))]
570 pub fn retrieve(&self, tags: &[&str]) -> Result<Vec<(String, String)>, AgentRuntimeError> {
571 let inner = recover_lock(self.inner.lock(), "SemanticStore::retrieve");
572
573 let results = inner
574 .iter()
575 .filter(|entry| {
576 tags.iter()
577 .all(|t| entry.tags.iter().any(|et| et.as_str() == *t))
578 })
579 .map(|e| (e.key.clone(), e.value.clone()))
580 .collect();
581
582 Ok(results)
583 }
584
585 #[tracing::instrument(skip(self, query_embedding))]
593 pub fn retrieve_similar(
594 &self,
595 query_embedding: &[f32],
596 top_k: usize,
597 ) -> Result<Vec<(String, String, f32)>, AgentRuntimeError> {
598 let inner = recover_lock(self.inner.lock(), "SemanticStore::retrieve_similar");
599
600 let mut scored: Vec<(String, String, f32)> = inner
601 .iter()
602 .filter_map(|entry| {
603 entry.embedding.as_ref().map(|emb| {
604 let sim = cosine_similarity(query_embedding, emb);
605 (entry.key.clone(), entry.value.clone(), sim)
606 })
607 })
608 .collect();
609
610 scored.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
611 scored.truncate(top_k);
612 Ok(scored)
613 }
614
615 pub fn len(&self) -> Result<usize, AgentRuntimeError> {
617 let inner = recover_lock(self.inner.lock(), "SemanticStore::len");
618 Ok(inner.len())
619 }
620
621 pub fn is_empty(&self) -> Result<bool, AgentRuntimeError> {
623 Ok(self.len()? == 0)
624 }
625}
626
627impl Default for SemanticStore {
628 fn default() -> Self {
629 Self::new()
630 }
631}
632
633#[derive(Debug, Clone)]
644pub struct WorkingMemory {
645 capacity: usize,
646 inner: Arc<Mutex<WorkingInner>>,
647}
648
649#[derive(Debug)]
650struct WorkingInner {
651 map: HashMap<String, String>,
652 order: VecDeque<String>,
653}
654
655impl WorkingMemory {
656 pub fn new(capacity: usize) -> Result<Self, AgentRuntimeError> {
662 if capacity == 0 {
663 return Err(AgentRuntimeError::Memory(
664 "WorkingMemory capacity must be > 0".into(),
665 ));
666 }
667 Ok(Self {
668 capacity,
669 inner: Arc::new(Mutex::new(WorkingInner {
670 map: HashMap::new(),
671 order: VecDeque::new(),
672 })),
673 })
674 }
675
676 #[tracing::instrument(skip(self))]
678 pub fn set(
679 &self,
680 key: impl Into<String> + std::fmt::Debug,
681 value: impl Into<String> + std::fmt::Debug,
682 ) -> Result<(), AgentRuntimeError> {
683 let key = key.into();
684 let value = value.into();
685 let mut inner = recover_lock(self.inner.lock(), "WorkingMemory::set");
686
687 if inner.map.contains_key(&key) {
689 inner.order.retain(|k| k != &key);
690 } else if inner.map.len() >= self.capacity {
691 if let Some(oldest) = inner.order.pop_front() {
693 inner.map.remove(&oldest);
694 }
695 }
696
697 inner.order.push_back(key.clone());
698 inner.map.insert(key, value);
699 Ok(())
700 }
701
702 #[tracing::instrument(skip(self))]
708 pub fn get(&self, key: &str) -> Result<Option<String>, AgentRuntimeError> {
709 let inner = recover_lock(self.inner.lock(), "WorkingMemory::get");
710 Ok(inner.map.get(key).cloned())
711 }
712
713 pub fn clear(&self) -> Result<(), AgentRuntimeError> {
715 let mut inner = recover_lock(self.inner.lock(), "WorkingMemory::clear");
716 inner.map.clear();
717 inner.order.clear();
718 Ok(())
719 }
720
721 pub fn len(&self) -> Result<usize, AgentRuntimeError> {
723 let inner = recover_lock(self.inner.lock(), "WorkingMemory::len");
724 Ok(inner.map.len())
725 }
726
727 pub fn is_empty(&self) -> Result<bool, AgentRuntimeError> {
729 Ok(self.len()? == 0)
730 }
731
732 pub fn entries(&self) -> Result<Vec<(String, String)>, AgentRuntimeError> {
734 let inner = recover_lock(self.inner.lock(), "WorkingMemory::entries");
735 let entries = inner
736 .order
737 .iter()
738 .filter_map(|k| inner.map.get(k).map(|v| (k.clone(), v.clone())))
739 .collect();
740 Ok(entries)
741 }
742}
743
744#[cfg(test)]
747mod tests {
748 use super::*;
749
750 #[test]
753 fn test_agent_id_new_stores_string() {
754 let id = AgentId::new("agent-1");
755 assert_eq!(id.0, "agent-1");
756 }
757
758 #[test]
759 fn test_agent_id_random_is_unique() {
760 let a = AgentId::random();
761 let b = AgentId::random();
762 assert_ne!(a, b);
763 }
764
765 #[test]
766 fn test_memory_id_new_stores_string() {
767 let id = MemoryId::new("mem-1");
768 assert_eq!(id.0, "mem-1");
769 }
770
771 #[test]
772 fn test_memory_id_random_is_unique() {
773 let a = MemoryId::random();
774 let b = MemoryId::random();
775 assert_ne!(a, b);
776 }
777
778 #[test]
781 fn test_memory_item_new_clamps_importance_above_one() {
782 let item = MemoryItem::new(AgentId::new("a"), "test", 1.5, vec![]);
783 assert_eq!(item.importance, 1.0);
784 }
785
786 #[test]
787 fn test_memory_item_new_clamps_importance_below_zero() {
788 let item = MemoryItem::new(AgentId::new("a"), "test", -0.5, vec![]);
789 assert_eq!(item.importance, 0.0);
790 }
791
792 #[test]
793 fn test_memory_item_new_preserves_valid_importance() {
794 let item = MemoryItem::new(AgentId::new("a"), "test", 0.7, vec![]);
795 assert!((item.importance - 0.7).abs() < 1e-6);
796 }
797
798 #[test]
801 fn test_decay_policy_rejects_zero_half_life() {
802 assert!(DecayPolicy::exponential(0.0).is_err());
803 }
804
805 #[test]
806 fn test_decay_policy_rejects_negative_half_life() {
807 assert!(DecayPolicy::exponential(-1.0).is_err());
808 }
809
810 #[test]
811 fn test_decay_policy_no_decay_at_age_zero() {
812 let p = DecayPolicy::exponential(24.0).unwrap();
813 let decayed = p.apply(1.0, 0.0);
814 assert!((decayed - 1.0).abs() < 1e-5);
815 }
816
817 #[test]
818 fn test_decay_policy_half_importance_at_half_life() {
819 let p = DecayPolicy::exponential(24.0).unwrap();
820 let decayed = p.apply(1.0, 24.0);
821 assert!((decayed - 0.5).abs() < 1e-5);
822 }
823
824 #[test]
825 fn test_decay_policy_quarter_importance_at_two_half_lives() {
826 let p = DecayPolicy::exponential(24.0).unwrap();
827 let decayed = p.apply(1.0, 48.0);
828 assert!((decayed - 0.25).abs() < 1e-5);
829 }
830
831 #[test]
832 fn test_decay_policy_result_is_clamped_to_zero_one() {
833 let p = DecayPolicy::exponential(1.0).unwrap();
834 let decayed = p.apply(0.0, 1000.0);
835 assert!(decayed >= 0.0 && decayed <= 1.0);
836 }
837
838 #[test]
841 fn test_episodic_store_add_episode_returns_id() {
842 let store = EpisodicStore::new();
843 let id = store.add_episode(AgentId::new("a"), "event", 0.8).unwrap();
844 assert!(!id.0.is_empty());
845 }
846
847 #[test]
848 fn test_episodic_store_recall_returns_stored_item() {
849 let store = EpisodicStore::new();
850 let agent = AgentId::new("agent-1");
851 store
852 .add_episode(agent.clone(), "hello world", 0.9)
853 .unwrap();
854 let items = store.recall(&agent, 10).unwrap();
855 assert_eq!(items.len(), 1);
856 assert_eq!(items[0].content, "hello world");
857 }
858
859 #[test]
860 fn test_episodic_store_recall_filters_by_agent() {
861 let store = EpisodicStore::new();
862 let a = AgentId::new("agent-a");
863 let b = AgentId::new("agent-b");
864 store.add_episode(a.clone(), "for a", 0.5).unwrap();
865 store.add_episode(b.clone(), "for b", 0.5).unwrap();
866 let items = store.recall(&a, 10).unwrap();
867 assert_eq!(items.len(), 1);
868 assert_eq!(items[0].content, "for a");
869 }
870
871 #[test]
872 fn test_episodic_store_recall_sorted_by_descending_importance() {
873 let store = EpisodicStore::new();
874 let agent = AgentId::new("agent-1");
875 store.add_episode(agent.clone(), "low", 0.1).unwrap();
876 store.add_episode(agent.clone(), "high", 0.9).unwrap();
877 store.add_episode(agent.clone(), "mid", 0.5).unwrap();
878 let items = store.recall(&agent, 10).unwrap();
879 assert_eq!(items[0].content, "high");
880 assert_eq!(items[1].content, "mid");
881 assert_eq!(items[2].content, "low");
882 }
883
884 #[test]
885 fn test_episodic_store_recall_respects_limit() {
886 let store = EpisodicStore::new();
887 let agent = AgentId::new("agent-1");
888 for i in 0..5 {
889 store
890 .add_episode(agent.clone(), format!("item {i}"), 0.5)
891 .unwrap();
892 }
893 let items = store.recall(&agent, 3).unwrap();
894 assert_eq!(items.len(), 3);
895 }
896
897 #[test]
898 fn test_episodic_store_len_tracks_insertions() {
899 let store = EpisodicStore::new();
900 let agent = AgentId::new("a");
901 store.add_episode(agent.clone(), "a", 0.5).unwrap();
902 store.add_episode(agent.clone(), "b", 0.5).unwrap();
903 assert_eq!(store.len().unwrap(), 2);
904 }
905
906 #[test]
907 fn test_episodic_store_is_empty_initially() {
908 let store = EpisodicStore::new();
909 assert!(store.is_empty().unwrap());
910 }
911
912 #[test]
913 fn test_episodic_store_with_decay_reduces_importance() {
914 let policy = DecayPolicy::exponential(0.001).unwrap(); let store = EpisodicStore::with_decay(policy);
916 let agent = AgentId::new("a");
917
918 {
920 let mut inner = store.inner.lock().unwrap();
921 let mut item = MemoryItem::new(agent.clone(), "old event", 1.0, vec![]);
922 item.timestamp = Utc::now() - chrono::Duration::hours(1);
924 inner.items.push(item);
925 }
926
927 let items = store.recall(&agent, 10).unwrap();
928 assert_eq!(items.len(), 1);
930 assert!(
931 items[0].importance < 0.01,
932 "expected near-zero importance, got {}",
933 items[0].importance
934 );
935 }
936
937 #[test]
940 fn test_recall_increments_recall_count() {
941 let store = EpisodicStore::new();
942 let agent = AgentId::new("agent-rc");
943 store.add_episode(agent.clone(), "memory", 0.5).unwrap();
944
945 let items = store.recall(&agent, 10).unwrap();
947 assert_eq!(items[0].recall_count, 1);
948
949 let items = store.recall(&agent, 10).unwrap();
951 assert_eq!(items[0].recall_count, 2);
952 }
953
954 #[test]
955 fn test_hybrid_recall_policy_prefers_recently_used() {
956 let store = EpisodicStore::with_recall_policy(RecallPolicy::Hybrid {
961 recency_weight: 0.1,
962 frequency_weight: 2.0,
963 });
964 let agent = AgentId::new("agent-hybrid");
965
966 let old_ts = Utc::now() - chrono::Duration::hours(48);
967 store
968 .add_episode_at(agent.clone(), "old_frequent", 0.5, old_ts)
969 .unwrap();
970 store.add_episode(agent.clone(), "new_never", 0.5).unwrap();
971
972 {
974 let mut inner = store.inner.lock().unwrap();
975 for item in inner.items.iter_mut() {
976 if item.content == "old_frequent" {
977 item.recall_count = 100;
978 }
979 }
980 }
981
982 let items = store.recall(&agent, 10).unwrap();
983 assert_eq!(items.len(), 2);
984 assert_eq!(
985 items[0].content, "old_frequent",
986 "hybrid policy should rank the frequently-recalled item first"
987 );
988 }
989
990 #[test]
991 fn test_per_agent_capacity_evicts_lowest_importance() {
992 let store = EpisodicStore::with_per_agent_capacity(2);
993 let agent = AgentId::new("agent-cap");
994
995 store.add_episode(agent.clone(), "mid", 0.5).unwrap();
996 store.add_episode(agent.clone(), "high", 0.9).unwrap();
997 store.add_episode(agent.clone(), "low", 0.1).unwrap();
999
1000 assert_eq!(
1001 store.len().unwrap(),
1002 2,
1003 "store should hold exactly 2 items after eviction"
1004 );
1005
1006 let items = store.recall(&agent, 10).unwrap();
1007 let contents: Vec<&str> = items.iter().map(|i| i.content.as_str()).collect();
1008 assert!(
1009 !contents.contains(&"low"),
1010 "the lowest-importance item should have been evicted; remaining: {:?}",
1011 contents
1012 );
1013 }
1014
1015 #[test]
1018 fn test_semantic_store_store_and_retrieve_all() {
1019 let store = SemanticStore::new();
1020 store.store("key1", "value1", vec!["tag-a".into()]).unwrap();
1021 store.store("key2", "value2", vec!["tag-b".into()]).unwrap();
1022 let results = store.retrieve(&[]).unwrap();
1023 assert_eq!(results.len(), 2);
1024 }
1025
1026 #[test]
1027 fn test_semantic_store_retrieve_filters_by_tag() {
1028 let store = SemanticStore::new();
1029 store
1030 .store("k1", "v1", vec!["rust".into(), "async".into()])
1031 .unwrap();
1032 store.store("k2", "v2", vec!["rust".into()]).unwrap();
1033 let results = store.retrieve(&["async"]).unwrap();
1034 assert_eq!(results.len(), 1);
1035 assert_eq!(results[0].0, "k1");
1036 }
1037
1038 #[test]
1039 fn test_semantic_store_retrieve_requires_all_tags() {
1040 let store = SemanticStore::new();
1041 store
1042 .store("k1", "v1", vec!["a".into(), "b".into()])
1043 .unwrap();
1044 store.store("k2", "v2", vec!["a".into()]).unwrap();
1045 let results = store.retrieve(&["a", "b"]).unwrap();
1046 assert_eq!(results.len(), 1);
1047 }
1048
1049 #[test]
1050 fn test_semantic_store_is_empty_initially() {
1051 let store = SemanticStore::new();
1052 assert!(store.is_empty().unwrap());
1053 }
1054
1055 #[test]
1056 fn test_semantic_store_len_tracks_insertions() {
1057 let store = SemanticStore::new();
1058 store.store("k", "v", vec![]).unwrap();
1059 assert_eq!(store.len().unwrap(), 1);
1060 }
1061
1062 #[test]
1063 fn test_semantic_store_retrieve_similar_returns_closest() {
1064 let store = SemanticStore::new();
1065 store
1067 .store_with_embedding("close", "close value", vec![], vec![1.0, 0.0, 0.0])
1068 .unwrap();
1069 store
1071 .store_with_embedding("far", "far value", vec![], vec![0.0, 1.0, 0.0])
1072 .unwrap();
1073
1074 let query = vec![1.0, 0.0, 0.0];
1075 let results = store.retrieve_similar(&query, 2).unwrap();
1076 assert_eq!(results.len(), 2);
1077 assert_eq!(results[0].0, "close");
1079 assert!(
1080 (results[0].2 - 1.0).abs() < 1e-5,
1081 "expected similarity ~1.0, got {}",
1082 results[0].2
1083 );
1084 assert!(
1086 (results[1].2).abs() < 1e-5,
1087 "expected similarity ~0.0, got {}",
1088 results[1].2
1089 );
1090 }
1091
1092 #[test]
1093 fn test_semantic_store_retrieve_similar_ignores_unembedded_entries() {
1094 let store = SemanticStore::new();
1095 store.store("no-emb", "no embedding value", vec![]).unwrap();
1097 store
1099 .store_with_embedding("with-emb", "with embedding value", vec![], vec![1.0, 0.0])
1100 .unwrap();
1101
1102 let query = vec![1.0, 0.0];
1103 let results = store.retrieve_similar(&query, 10).unwrap();
1104 assert_eq!(results.len(), 1, "only the embedded entry should appear");
1105 assert_eq!(results[0].0, "with-emb");
1106 }
1107
1108 #[test]
1109 fn test_cosine_similarity_orthogonal_vectors_return_zero() {
1110 let store = SemanticStore::new();
1112 store
1113 .store_with_embedding("a", "va", vec![], vec![1.0, 0.0])
1114 .unwrap();
1115 store
1116 .store_with_embedding("b", "vb", vec![], vec![0.0, 1.0])
1117 .unwrap();
1118
1119 let query = vec![1.0, 0.0];
1121 let results = store.retrieve_similar(&query, 2).unwrap();
1122 assert_eq!(results.len(), 2);
1123 let b_result = results.iter().find(|(k, _, _)| k == "b").unwrap();
1124 assert!(
1125 b_result.2.abs() < 1e-5,
1126 "expected cosine similarity 0.0 for orthogonal vectors, got {}",
1127 b_result.2
1128 );
1129 }
1130
1131 #[test]
1134 fn test_working_memory_new_rejects_zero_capacity() {
1135 assert!(WorkingMemory::new(0).is_err());
1136 }
1137
1138 #[test]
1139 fn test_working_memory_set_and_get() {
1140 let wm = WorkingMemory::new(10).unwrap();
1141 wm.set("foo", "bar").unwrap();
1142 let val = wm.get("foo").unwrap();
1143 assert_eq!(val, Some("bar".into()));
1144 }
1145
1146 #[test]
1147 fn test_working_memory_get_missing_key_returns_none() {
1148 let wm = WorkingMemory::new(10).unwrap();
1149 assert_eq!(wm.get("missing").unwrap(), None);
1150 }
1151
1152 #[test]
1153 fn test_working_memory_bounded_evicts_oldest() {
1154 let wm = WorkingMemory::new(3).unwrap();
1155 wm.set("k1", "v1").unwrap();
1156 wm.set("k2", "v2").unwrap();
1157 wm.set("k3", "v3").unwrap();
1158 wm.set("k4", "v4").unwrap(); assert_eq!(wm.get("k1").unwrap(), None);
1160 assert_eq!(wm.get("k4").unwrap(), Some("v4".into()));
1161 }
1162
1163 #[test]
1164 fn test_working_memory_update_existing_key_no_eviction() {
1165 let wm = WorkingMemory::new(2).unwrap();
1166 wm.set("k1", "v1").unwrap();
1167 wm.set("k2", "v2").unwrap();
1168 wm.set("k1", "v1-updated").unwrap(); assert_eq!(wm.len().unwrap(), 2);
1170 assert_eq!(wm.get("k1").unwrap(), Some("v1-updated".into()));
1171 assert_eq!(wm.get("k2").unwrap(), Some("v2".into()));
1172 }
1173
1174 #[test]
1175 fn test_working_memory_clear_removes_all() {
1176 let wm = WorkingMemory::new(10).unwrap();
1177 wm.set("a", "1").unwrap();
1178 wm.set("b", "2").unwrap();
1179 wm.clear().unwrap();
1180 assert!(wm.is_empty().unwrap());
1181 }
1182
1183 #[test]
1184 fn test_working_memory_is_empty_initially() {
1185 let wm = WorkingMemory::new(5).unwrap();
1186 assert!(wm.is_empty().unwrap());
1187 }
1188
1189 #[test]
1190 fn test_working_memory_len_tracks_entries() {
1191 let wm = WorkingMemory::new(10).unwrap();
1192 wm.set("a", "1").unwrap();
1193 wm.set("b", "2").unwrap();
1194 assert_eq!(wm.len().unwrap(), 2);
1195 }
1196
1197 #[test]
1198 fn test_working_memory_capacity_never_exceeded() {
1199 let cap = 5usize;
1200 let wm = WorkingMemory::new(cap).unwrap();
1201 for i in 0..20 {
1202 wm.set(format!("key-{i}"), format!("val-{i}")).unwrap();
1203 assert!(wm.len().unwrap() <= cap);
1204 }
1205 }
1206}