1use crate::error::AgentRuntimeError;
19use chrono::{DateTime, Utc};
20use serde::{Deserialize, Serialize};
21use std::collections::{HashMap, VecDeque};
22use std::sync::{Arc, Mutex};
23use uuid::Uuid;
24
25fn recover_lock<'a, T>(
31 result: std::sync::LockResult<std::sync::MutexGuard<'a, T>>,
32 ctx: &str,
33) -> std::sync::MutexGuard<'a, T>
34where
35 T: ?Sized,
36{
37 match result {
38 Ok(guard) => guard,
39 Err(poisoned) => {
40 tracing::warn!("mutex poisoned in {ctx}, recovering inner value");
41 poisoned.into_inner()
42 }
43 }
44}
45
46fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
49 if a.len() != b.len() || a.is_empty() {
50 return 0.0;
51 }
52 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
53 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
54 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
55 if norm_a == 0.0 || norm_b == 0.0 {
56 return 0.0;
57 }
58 (dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
59}
60
61#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
65pub struct AgentId(pub String);
66
67impl AgentId {
68 pub fn new(id: impl Into<String>) -> Self {
70 Self(id.into())
71 }
72
73 pub fn random() -> Self {
75 Self(Uuid::new_v4().to_string())
76 }
77}
78
79impl std::fmt::Display for AgentId {
80 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81 write!(f, "{}", self.0)
82 }
83}
84
85#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
87pub struct MemoryId(pub String);
88
89impl MemoryId {
90 pub fn new(id: impl Into<String>) -> Self {
92 Self(id.into())
93 }
94
95 pub fn random() -> Self {
97 Self(Uuid::new_v4().to_string())
98 }
99}
100
101impl std::fmt::Display for MemoryId {
102 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103 write!(f, "{}", self.0)
104 }
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct MemoryItem {
112 pub id: MemoryId,
114 pub agent_id: AgentId,
116 pub content: String,
118 pub importance: f32,
120 pub timestamp: DateTime<Utc>,
122 pub tags: Vec<String>,
124 #[serde(default)]
126 pub recall_count: u64,
127}
128
129impl MemoryItem {
130 pub fn new(
132 agent_id: AgentId,
133 content: impl Into<String>,
134 importance: f32,
135 tags: Vec<String>,
136 ) -> Self {
137 Self {
138 id: MemoryId::random(),
139 agent_id,
140 content: content.into(),
141 importance: importance.clamp(0.0, 1.0),
142 timestamp: Utc::now(),
143 tags,
144 recall_count: 0,
145 }
146 }
147}
148
149#[derive(Debug, Clone)]
153pub struct DecayPolicy {
154 half_life_hours: f64,
156}
157
158impl DecayPolicy {
159 pub fn exponential(half_life_hours: f64) -> Result<Self, AgentRuntimeError> {
168 if half_life_hours <= 0.0 {
169 return Err(AgentRuntimeError::Memory(
170 "half_life_hours must be positive".into(),
171 ));
172 }
173 Ok(Self { half_life_hours })
174 }
175
176 pub fn apply(&self, importance: f32, age_hours: f64) -> f32 {
185 let decay = (-age_hours * std::f64::consts::LN_2 / self.half_life_hours).exp();
186 (importance as f64 * decay).clamp(0.0, 1.0) as f32
187 }
188
189 pub fn decay_item(&self, item: &mut MemoryItem) {
191 let age_hours = (Utc::now() - item.timestamp).num_seconds().max(0) as f64 / 3600.0;
192 item.importance = self.apply(item.importance, age_hours);
193 }
194}
195
196#[derive(Debug, Clone)]
200pub enum RecallPolicy {
201 Importance,
203 Hybrid {
209 recency_weight: f32,
211 frequency_weight: f32,
213 },
214}
215
216fn compute_hybrid_score(
219 item: &MemoryItem,
220 recency_weight: f32,
221 frequency_weight: f32,
222 max_recall: u64,
223 now: chrono::DateTime<Utc>,
224) -> f32 {
225 let age_hours = (now - item.timestamp).num_seconds().max(0) as f64 / 3600.0;
226 let recency_score = (-age_hours / 24.0).exp() as f32;
227 let frequency_score = item.recall_count as f32 / (max_recall as f32 + 1.0);
228 item.importance + recency_score * recency_weight + frequency_score * frequency_weight
229}
230
231#[derive(Debug, Clone)]
240pub struct EpisodicStore {
241 inner: Arc<Mutex<EpisodicInner>>,
242}
243
244#[derive(Debug)]
245struct EpisodicInner {
246 items: Vec<MemoryItem>,
247 decay: Option<DecayPolicy>,
248 recall_policy: RecallPolicy,
249 per_agent_capacity: Option<usize>,
251}
252
253impl EpisodicStore {
254 pub fn new() -> Self {
256 Self {
257 inner: Arc::new(Mutex::new(EpisodicInner {
258 items: Vec::new(),
259 decay: None,
260 recall_policy: RecallPolicy::Importance,
261 per_agent_capacity: None,
262 })),
263 }
264 }
265
266 pub fn with_decay(policy: DecayPolicy) -> Self {
268 Self {
269 inner: Arc::new(Mutex::new(EpisodicInner {
270 items: Vec::new(),
271 decay: Some(policy),
272 recall_policy: RecallPolicy::Importance,
273 per_agent_capacity: None,
274 })),
275 }
276 }
277
278 pub fn with_recall_policy(policy: RecallPolicy) -> Self {
280 Self {
281 inner: Arc::new(Mutex::new(EpisodicInner {
282 items: Vec::new(),
283 decay: None,
284 recall_policy: policy,
285 per_agent_capacity: None,
286 })),
287 }
288 }
289
290 pub fn with_per_agent_capacity(capacity: usize) -> Self {
295 Self {
296 inner: Arc::new(Mutex::new(EpisodicInner {
297 items: Vec::new(),
298 decay: None,
299 recall_policy: RecallPolicy::Importance,
300 per_agent_capacity: Some(capacity),
301 })),
302 }
303 }
304
305 #[tracing::instrument(skip(self))]
310 pub fn add_episode(
311 &self,
312 agent_id: AgentId,
313 content: impl Into<String> + std::fmt::Debug,
314 importance: f32,
315 ) -> Result<MemoryId, AgentRuntimeError> {
316 let item = MemoryItem::new(agent_id.clone(), content, importance, Vec::new());
317 let id = item.id.clone();
318 let mut inner = recover_lock(self.inner.lock(), "EpisodicStore::add_episode");
319 inner.items.push(item);
320 if let Some(cap) = inner.per_agent_capacity {
321 let agent_count = inner.items.iter().filter(|i| i.agent_id == agent_id).count();
322 if agent_count > cap {
323 if let Some(pos) = inner
324 .items
325 .iter()
326 .enumerate()
327 .filter(|(_, i)| i.agent_id == agent_id)
328 .min_by(|(_, a), (_, b)| {
329 a.importance
330 .partial_cmp(&b.importance)
331 .unwrap_or(std::cmp::Ordering::Equal)
332 })
333 .map(|(pos, _)| pos)
334 {
335 inner.items.remove(pos);
336 }
337 }
338 }
339 Ok(id)
340 }
341
342 #[tracing::instrument(skip(self))]
344 pub fn add_episode_at(
345 &self,
346 agent_id: AgentId,
347 content: impl Into<String> + std::fmt::Debug,
348 importance: f32,
349 timestamp: chrono::DateTime<chrono::Utc>,
350 ) -> Result<MemoryId, AgentRuntimeError> {
351 let mut item = MemoryItem::new(agent_id.clone(), content, importance, Vec::new());
352 item.timestamp = timestamp;
353 let id = item.id.clone();
354 let mut inner = recover_lock(self.inner.lock(), "EpisodicStore::add_episode_at");
355 inner.items.push(item);
356 if let Some(cap) = inner.per_agent_capacity {
357 let agent_count = inner.items.iter().filter(|i| i.agent_id == agent_id).count();
358 if agent_count > cap {
359 if let Some(pos) = inner
360 .items
361 .iter()
362 .enumerate()
363 .filter(|(_, i)| i.agent_id == agent_id)
364 .min_by(|(_, a), (_, b)| {
365 a.importance
366 .partial_cmp(&b.importance)
367 .unwrap_or(std::cmp::Ordering::Equal)
368 })
369 .map(|(pos, _)| pos)
370 {
371 inner.items.remove(pos);
372 }
373 }
374 }
375 Ok(id)
376 }
377
378 #[tracing::instrument(skip(self))]
384 pub fn recall(
385 &self,
386 agent_id: &AgentId,
387 limit: usize,
388 ) -> Result<Vec<MemoryItem>, AgentRuntimeError> {
389 let mut inner = recover_lock(self.inner.lock(), "EpisodicStore::recall");
390
391 let decay_clone: Option<DecayPolicy> = inner.decay.clone();
393 if let Some(policy) = decay_clone {
394 for item in inner.items.iter_mut() {
395 policy.decay_item(item);
396 }
397 }
398
399 let agent_ids_to_update: Vec<MemoryId> = inner
401 .items
402 .iter()
403 .filter(|i| &i.agent_id == agent_id)
404 .map(|i| i.id.clone())
405 .collect();
406
407 for item in inner.items.iter_mut() {
409 if agent_ids_to_update.contains(&item.id) {
410 item.recall_count += 1;
411 }
412 }
413
414 let mut items: Vec<MemoryItem> = inner
415 .items
416 .iter()
417 .filter(|i| &i.agent_id == agent_id)
418 .cloned()
419 .collect();
420
421 match inner.recall_policy {
422 RecallPolicy::Importance => {
423 items.sort_by(|a, b| {
424 b.importance
425 .partial_cmp(&a.importance)
426 .unwrap_or(std::cmp::Ordering::Equal)
427 });
428 }
429 RecallPolicy::Hybrid {
430 recency_weight,
431 frequency_weight,
432 } => {
433 let max_recall = items.iter().map(|i| i.recall_count).max().unwrap_or(1).max(1);
434 let now = Utc::now();
435 items.sort_by(|a, b| {
436 let score_a = compute_hybrid_score(
437 a,
438 recency_weight,
439 frequency_weight,
440 max_recall,
441 now,
442 );
443 let score_b = compute_hybrid_score(
444 b,
445 recency_weight,
446 frequency_weight,
447 max_recall,
448 now,
449 );
450 score_b
451 .partial_cmp(&score_a)
452 .unwrap_or(std::cmp::Ordering::Equal)
453 });
454 }
455 }
456
457 items.truncate(limit);
458 tracing::debug!("recalled {} items", items.len());
459 Ok(items)
460 }
461
462 pub fn len(&self) -> Result<usize, AgentRuntimeError> {
464 let inner = recover_lock(self.inner.lock(), "EpisodicStore::len");
465 Ok(inner.items.len())
466 }
467
468 pub fn is_empty(&self) -> Result<bool, AgentRuntimeError> {
470 Ok(self.len()? == 0)
471 }
472
473 #[doc(hidden)]
478 pub fn bump_recall_count_by_content(&self, content: &str, amount: u64) {
479 let mut inner = recover_lock(self.inner.lock(), "EpisodicStore::bump_recall_count_by_content");
480 for item in inner.items.iter_mut() {
481 if item.content == content {
482 item.recall_count = item.recall_count.saturating_add(amount);
483 }
484 }
485 }
486}
487
488impl Default for EpisodicStore {
489 fn default() -> Self {
490 Self::new()
491 }
492}
493
494#[derive(Debug, Clone)]
503pub struct SemanticStore {
504 inner: Arc<Mutex<Vec<SemanticEntry>>>,
505}
506
507#[derive(Debug, Clone)]
508struct SemanticEntry {
509 key: String,
510 value: String,
511 tags: Vec<String>,
512 embedding: Option<Vec<f32>>,
513}
514
515impl SemanticStore {
516 pub fn new() -> Self {
518 Self {
519 inner: Arc::new(Mutex::new(Vec::new())),
520 }
521 }
522
523 #[tracing::instrument(skip(self))]
525 pub fn store(
526 &self,
527 key: impl Into<String> + std::fmt::Debug,
528 value: impl Into<String> + std::fmt::Debug,
529 tags: Vec<String>,
530 ) -> Result<(), AgentRuntimeError> {
531 let mut inner = recover_lock(self.inner.lock(), "SemanticStore::store");
532 inner.push(SemanticEntry {
533 key: key.into(),
534 value: value.into(),
535 tags,
536 embedding: None,
537 });
538 Ok(())
539 }
540
541 #[tracing::instrument(skip(self))]
543 pub fn store_with_embedding(
544 &self,
545 key: impl Into<String> + std::fmt::Debug,
546 value: impl Into<String> + std::fmt::Debug,
547 tags: Vec<String>,
548 embedding: Vec<f32>,
549 ) -> Result<(), AgentRuntimeError> {
550 let mut inner = recover_lock(self.inner.lock(), "SemanticStore::store_with_embedding");
551 inner.push(SemanticEntry {
552 key: key.into(),
553 value: value.into(),
554 tags,
555 embedding: Some(embedding),
556 });
557 Ok(())
558 }
559
560 #[tracing::instrument(skip(self))]
564 pub fn retrieve(&self, tags: &[&str]) -> Result<Vec<(String, String)>, AgentRuntimeError> {
565 let inner = recover_lock(self.inner.lock(), "SemanticStore::retrieve");
566
567 let results = inner
568 .iter()
569 .filter(|entry| {
570 tags.iter()
571 .all(|t| entry.tags.iter().any(|et| et.as_str() == *t))
572 })
573 .map(|e| (e.key.clone(), e.value.clone()))
574 .collect();
575
576 Ok(results)
577 }
578
579 #[tracing::instrument(skip(self, query_embedding))]
587 pub fn retrieve_similar(
588 &self,
589 query_embedding: &[f32],
590 top_k: usize,
591 ) -> Result<Vec<(String, String, f32)>, AgentRuntimeError> {
592 let inner = recover_lock(self.inner.lock(), "SemanticStore::retrieve_similar");
593
594 let mut scored: Vec<(String, String, f32)> = inner
595 .iter()
596 .filter_map(|entry| {
597 entry.embedding.as_ref().map(|emb| {
598 let sim = cosine_similarity(query_embedding, emb);
599 (entry.key.clone(), entry.value.clone(), sim)
600 })
601 })
602 .collect();
603
604 scored.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
605 scored.truncate(top_k);
606 Ok(scored)
607 }
608
609 pub fn len(&self) -> Result<usize, AgentRuntimeError> {
611 let inner = recover_lock(self.inner.lock(), "SemanticStore::len");
612 Ok(inner.len())
613 }
614
615 pub fn is_empty(&self) -> Result<bool, AgentRuntimeError> {
617 Ok(self.len()? == 0)
618 }
619}
620
621impl Default for SemanticStore {
622 fn default() -> Self {
623 Self::new()
624 }
625}
626
627#[derive(Debug, Clone)]
638pub struct WorkingMemory {
639 capacity: usize,
640 inner: Arc<Mutex<WorkingInner>>,
641}
642
643#[derive(Debug)]
644struct WorkingInner {
645 map: HashMap<String, String>,
646 order: VecDeque<String>,
647}
648
649impl WorkingMemory {
650 pub fn new(capacity: usize) -> Result<Self, AgentRuntimeError> {
656 if capacity == 0 {
657 return Err(AgentRuntimeError::Memory(
658 "WorkingMemory capacity must be > 0".into(),
659 ));
660 }
661 Ok(Self {
662 capacity,
663 inner: Arc::new(Mutex::new(WorkingInner {
664 map: HashMap::new(),
665 order: VecDeque::new(),
666 })),
667 })
668 }
669
670 #[tracing::instrument(skip(self))]
672 pub fn set(
673 &self,
674 key: impl Into<String> + std::fmt::Debug,
675 value: impl Into<String> + std::fmt::Debug,
676 ) -> Result<(), AgentRuntimeError> {
677 let key = key.into();
678 let value = value.into();
679 let mut inner = recover_lock(self.inner.lock(), "WorkingMemory::set");
680
681 if inner.map.contains_key(&key) {
683 inner.order.retain(|k| k != &key);
684 } else if inner.map.len() >= self.capacity {
685 if let Some(oldest) = inner.order.pop_front() {
687 inner.map.remove(&oldest);
688 }
689 }
690
691 inner.order.push_back(key.clone());
692 inner.map.insert(key, value);
693 Ok(())
694 }
695
696 #[tracing::instrument(skip(self))]
702 pub fn get(&self, key: &str) -> Result<Option<String>, AgentRuntimeError> {
703 let inner = recover_lock(self.inner.lock(), "WorkingMemory::get");
704 Ok(inner.map.get(key).cloned())
705 }
706
707 pub fn clear(&self) -> Result<(), AgentRuntimeError> {
709 let mut inner = recover_lock(self.inner.lock(), "WorkingMemory::clear");
710 inner.map.clear();
711 inner.order.clear();
712 Ok(())
713 }
714
715 pub fn len(&self) -> Result<usize, AgentRuntimeError> {
717 let inner = recover_lock(self.inner.lock(), "WorkingMemory::len");
718 Ok(inner.map.len())
719 }
720
721 pub fn is_empty(&self) -> Result<bool, AgentRuntimeError> {
723 Ok(self.len()? == 0)
724 }
725
726 pub fn entries(&self) -> Result<Vec<(String, String)>, AgentRuntimeError> {
728 let inner = recover_lock(self.inner.lock(), "WorkingMemory::entries");
729 let entries = inner
730 .order
731 .iter()
732 .filter_map(|k| inner.map.get(k).map(|v| (k.clone(), v.clone())))
733 .collect();
734 Ok(entries)
735 }
736}
737
738#[cfg(test)]
741mod tests {
742 use super::*;
743
744 #[test]
747 fn test_agent_id_new_stores_string() {
748 let id = AgentId::new("agent-1");
749 assert_eq!(id.0, "agent-1");
750 }
751
752 #[test]
753 fn test_agent_id_random_is_unique() {
754 let a = AgentId::random();
755 let b = AgentId::random();
756 assert_ne!(a, b);
757 }
758
759 #[test]
760 fn test_memory_id_new_stores_string() {
761 let id = MemoryId::new("mem-1");
762 assert_eq!(id.0, "mem-1");
763 }
764
765 #[test]
766 fn test_memory_id_random_is_unique() {
767 let a = MemoryId::random();
768 let b = MemoryId::random();
769 assert_ne!(a, b);
770 }
771
772 #[test]
775 fn test_memory_item_new_clamps_importance_above_one() {
776 let item = MemoryItem::new(AgentId::new("a"), "test", 1.5, vec![]);
777 assert_eq!(item.importance, 1.0);
778 }
779
780 #[test]
781 fn test_memory_item_new_clamps_importance_below_zero() {
782 let item = MemoryItem::new(AgentId::new("a"), "test", -0.5, vec![]);
783 assert_eq!(item.importance, 0.0);
784 }
785
786 #[test]
787 fn test_memory_item_new_preserves_valid_importance() {
788 let item = MemoryItem::new(AgentId::new("a"), "test", 0.7, vec![]);
789 assert!((item.importance - 0.7).abs() < 1e-6);
790 }
791
792 #[test]
795 fn test_decay_policy_rejects_zero_half_life() {
796 assert!(DecayPolicy::exponential(0.0).is_err());
797 }
798
799 #[test]
800 fn test_decay_policy_rejects_negative_half_life() {
801 assert!(DecayPolicy::exponential(-1.0).is_err());
802 }
803
804 #[test]
805 fn test_decay_policy_no_decay_at_age_zero() {
806 let p = DecayPolicy::exponential(24.0).unwrap();
807 let decayed = p.apply(1.0, 0.0);
808 assert!((decayed - 1.0).abs() < 1e-5);
809 }
810
811 #[test]
812 fn test_decay_policy_half_importance_at_half_life() {
813 let p = DecayPolicy::exponential(24.0).unwrap();
814 let decayed = p.apply(1.0, 24.0);
815 assert!((decayed - 0.5).abs() < 1e-5);
816 }
817
818 #[test]
819 fn test_decay_policy_quarter_importance_at_two_half_lives() {
820 let p = DecayPolicy::exponential(24.0).unwrap();
821 let decayed = p.apply(1.0, 48.0);
822 assert!((decayed - 0.25).abs() < 1e-5);
823 }
824
825 #[test]
826 fn test_decay_policy_result_is_clamped_to_zero_one() {
827 let p = DecayPolicy::exponential(1.0).unwrap();
828 let decayed = p.apply(0.0, 1000.0);
829 assert!(decayed >= 0.0 && decayed <= 1.0);
830 }
831
832 #[test]
835 fn test_episodic_store_add_episode_returns_id() {
836 let store = EpisodicStore::new();
837 let id = store.add_episode(AgentId::new("a"), "event", 0.8).unwrap();
838 assert!(!id.0.is_empty());
839 }
840
841 #[test]
842 fn test_episodic_store_recall_returns_stored_item() {
843 let store = EpisodicStore::new();
844 let agent = AgentId::new("agent-1");
845 store
846 .add_episode(agent.clone(), "hello world", 0.9)
847 .unwrap();
848 let items = store.recall(&agent, 10).unwrap();
849 assert_eq!(items.len(), 1);
850 assert_eq!(items[0].content, "hello world");
851 }
852
853 #[test]
854 fn test_episodic_store_recall_filters_by_agent() {
855 let store = EpisodicStore::new();
856 let a = AgentId::new("agent-a");
857 let b = AgentId::new("agent-b");
858 store.add_episode(a.clone(), "for a", 0.5).unwrap();
859 store.add_episode(b.clone(), "for b", 0.5).unwrap();
860 let items = store.recall(&a, 10).unwrap();
861 assert_eq!(items.len(), 1);
862 assert_eq!(items[0].content, "for a");
863 }
864
865 #[test]
866 fn test_episodic_store_recall_sorted_by_descending_importance() {
867 let store = EpisodicStore::new();
868 let agent = AgentId::new("agent-1");
869 store.add_episode(agent.clone(), "low", 0.1).unwrap();
870 store.add_episode(agent.clone(), "high", 0.9).unwrap();
871 store.add_episode(agent.clone(), "mid", 0.5).unwrap();
872 let items = store.recall(&agent, 10).unwrap();
873 assert_eq!(items[0].content, "high");
874 assert_eq!(items[1].content, "mid");
875 assert_eq!(items[2].content, "low");
876 }
877
878 #[test]
879 fn test_episodic_store_recall_respects_limit() {
880 let store = EpisodicStore::new();
881 let agent = AgentId::new("agent-1");
882 for i in 0..5 {
883 store
884 .add_episode(agent.clone(), format!("item {i}"), 0.5)
885 .unwrap();
886 }
887 let items = store.recall(&agent, 3).unwrap();
888 assert_eq!(items.len(), 3);
889 }
890
891 #[test]
892 fn test_episodic_store_len_tracks_insertions() {
893 let store = EpisodicStore::new();
894 let agent = AgentId::new("a");
895 store.add_episode(agent.clone(), "a", 0.5).unwrap();
896 store.add_episode(agent.clone(), "b", 0.5).unwrap();
897 assert_eq!(store.len().unwrap(), 2);
898 }
899
900 #[test]
901 fn test_episodic_store_is_empty_initially() {
902 let store = EpisodicStore::new();
903 assert!(store.is_empty().unwrap());
904 }
905
906 #[test]
907 fn test_episodic_store_with_decay_reduces_importance() {
908 let policy = DecayPolicy::exponential(0.001).unwrap(); let store = EpisodicStore::with_decay(policy);
910 let agent = AgentId::new("a");
911
912 {
914 let mut inner = store.inner.lock().unwrap();
915 let mut item = MemoryItem::new(agent.clone(), "old event", 1.0, vec![]);
916 item.timestamp = Utc::now() - chrono::Duration::hours(1);
918 inner.items.push(item);
919 }
920
921 let items = store.recall(&agent, 10).unwrap();
922 assert_eq!(items.len(), 1);
924 assert!(
925 items[0].importance < 0.01,
926 "expected near-zero importance, got {}",
927 items[0].importance
928 );
929 }
930
931 #[test]
934 fn test_recall_increments_recall_count() {
935 let store = EpisodicStore::new();
936 let agent = AgentId::new("agent-rc");
937 store.add_episode(agent.clone(), "memory", 0.5).unwrap();
938
939 let items = store.recall(&agent, 10).unwrap();
941 assert_eq!(items[0].recall_count, 1);
942
943 let items = store.recall(&agent, 10).unwrap();
945 assert_eq!(items[0].recall_count, 2);
946 }
947
948 #[test]
949 fn test_hybrid_recall_policy_prefers_recently_used() {
950 let store = EpisodicStore::with_recall_policy(RecallPolicy::Hybrid {
955 recency_weight: 0.1,
956 frequency_weight: 2.0,
957 });
958 let agent = AgentId::new("agent-hybrid");
959
960 let old_ts = Utc::now() - chrono::Duration::hours(48);
961 store
962 .add_episode_at(agent.clone(), "old_frequent", 0.5, old_ts)
963 .unwrap();
964 store
965 .add_episode(agent.clone(), "new_never", 0.5)
966 .unwrap();
967
968 {
970 let mut inner = store.inner.lock().unwrap();
971 for item in inner.items.iter_mut() {
972 if item.content == "old_frequent" {
973 item.recall_count = 100;
974 }
975 }
976 }
977
978 let items = store.recall(&agent, 10).unwrap();
979 assert_eq!(items.len(), 2);
980 assert_eq!(
981 items[0].content, "old_frequent",
982 "hybrid policy should rank the frequently-recalled item first"
983 );
984 }
985
986 #[test]
987 fn test_per_agent_capacity_evicts_lowest_importance() {
988 let store = EpisodicStore::with_per_agent_capacity(2);
989 let agent = AgentId::new("agent-cap");
990
991 store.add_episode(agent.clone(), "mid", 0.5).unwrap();
992 store.add_episode(agent.clone(), "high", 0.9).unwrap();
993 store.add_episode(agent.clone(), "low", 0.1).unwrap();
995
996 assert_eq!(
997 store.len().unwrap(),
998 2,
999 "store should hold exactly 2 items after eviction"
1000 );
1001
1002 let items = store.recall(&agent, 10).unwrap();
1003 let contents: Vec<&str> = items.iter().map(|i| i.content.as_str()).collect();
1004 assert!(
1005 !contents.contains(&"low"),
1006 "the lowest-importance item should have been evicted; remaining: {:?}",
1007 contents
1008 );
1009 }
1010
1011 #[test]
1014 fn test_semantic_store_store_and_retrieve_all() {
1015 let store = SemanticStore::new();
1016 store.store("key1", "value1", vec!["tag-a".into()]).unwrap();
1017 store.store("key2", "value2", vec!["tag-b".into()]).unwrap();
1018 let results = store.retrieve(&[]).unwrap();
1019 assert_eq!(results.len(), 2);
1020 }
1021
1022 #[test]
1023 fn test_semantic_store_retrieve_filters_by_tag() {
1024 let store = SemanticStore::new();
1025 store
1026 .store("k1", "v1", vec!["rust".into(), "async".into()])
1027 .unwrap();
1028 store.store("k2", "v2", vec!["rust".into()]).unwrap();
1029 let results = store.retrieve(&["async"]).unwrap();
1030 assert_eq!(results.len(), 1);
1031 assert_eq!(results[0].0, "k1");
1032 }
1033
1034 #[test]
1035 fn test_semantic_store_retrieve_requires_all_tags() {
1036 let store = SemanticStore::new();
1037 store
1038 .store("k1", "v1", vec!["a".into(), "b".into()])
1039 .unwrap();
1040 store.store("k2", "v2", vec!["a".into()]).unwrap();
1041 let results = store.retrieve(&["a", "b"]).unwrap();
1042 assert_eq!(results.len(), 1);
1043 }
1044
1045 #[test]
1046 fn test_semantic_store_is_empty_initially() {
1047 let store = SemanticStore::new();
1048 assert!(store.is_empty().unwrap());
1049 }
1050
1051 #[test]
1052 fn test_semantic_store_len_tracks_insertions() {
1053 let store = SemanticStore::new();
1054 store.store("k", "v", vec![]).unwrap();
1055 assert_eq!(store.len().unwrap(), 1);
1056 }
1057
1058 #[test]
1059 fn test_semantic_store_retrieve_similar_returns_closest() {
1060 let store = SemanticStore::new();
1061 store
1063 .store_with_embedding("close", "close value", vec![], vec![1.0, 0.0, 0.0])
1064 .unwrap();
1065 store
1067 .store_with_embedding("far", "far value", vec![], vec![0.0, 1.0, 0.0])
1068 .unwrap();
1069
1070 let query = vec![1.0, 0.0, 0.0];
1071 let results = store.retrieve_similar(&query, 2).unwrap();
1072 assert_eq!(results.len(), 2);
1073 assert_eq!(results[0].0, "close");
1075 assert!((results[0].2 - 1.0).abs() < 1e-5, "expected similarity ~1.0, got {}", results[0].2);
1076 assert!((results[1].2).abs() < 1e-5, "expected similarity ~0.0, got {}", results[1].2);
1078 }
1079
1080 #[test]
1081 fn test_semantic_store_retrieve_similar_ignores_unembedded_entries() {
1082 let store = SemanticStore::new();
1083 store.store("no-emb", "no embedding value", vec![]).unwrap();
1085 store
1087 .store_with_embedding("with-emb", "with embedding value", vec![], vec![1.0, 0.0])
1088 .unwrap();
1089
1090 let query = vec![1.0, 0.0];
1091 let results = store.retrieve_similar(&query, 10).unwrap();
1092 assert_eq!(results.len(), 1, "only the embedded entry should appear");
1093 assert_eq!(results[0].0, "with-emb");
1094 }
1095
1096 #[test]
1097 fn test_cosine_similarity_orthogonal_vectors_return_zero() {
1098 let store = SemanticStore::new();
1100 store
1101 .store_with_embedding("a", "va", vec![], vec![1.0, 0.0])
1102 .unwrap();
1103 store
1104 .store_with_embedding("b", "vb", vec![], vec![0.0, 1.0])
1105 .unwrap();
1106
1107 let query = vec![1.0, 0.0];
1109 let results = store.retrieve_similar(&query, 2).unwrap();
1110 assert_eq!(results.len(), 2);
1111 let b_result = results.iter().find(|(k, _, _)| k == "b").unwrap();
1112 assert!(
1113 b_result.2.abs() < 1e-5,
1114 "expected cosine similarity 0.0 for orthogonal vectors, got {}",
1115 b_result.2
1116 );
1117 }
1118
1119 #[test]
1122 fn test_working_memory_new_rejects_zero_capacity() {
1123 assert!(WorkingMemory::new(0).is_err());
1124 }
1125
1126 #[test]
1127 fn test_working_memory_set_and_get() {
1128 let wm = WorkingMemory::new(10).unwrap();
1129 wm.set("foo", "bar").unwrap();
1130 let val = wm.get("foo").unwrap();
1131 assert_eq!(val, Some("bar".into()));
1132 }
1133
1134 #[test]
1135 fn test_working_memory_get_missing_key_returns_none() {
1136 let wm = WorkingMemory::new(10).unwrap();
1137 assert_eq!(wm.get("missing").unwrap(), None);
1138 }
1139
1140 #[test]
1141 fn test_working_memory_bounded_evicts_oldest() {
1142 let wm = WorkingMemory::new(3).unwrap();
1143 wm.set("k1", "v1").unwrap();
1144 wm.set("k2", "v2").unwrap();
1145 wm.set("k3", "v3").unwrap();
1146 wm.set("k4", "v4").unwrap(); assert_eq!(wm.get("k1").unwrap(), None);
1148 assert_eq!(wm.get("k4").unwrap(), Some("v4".into()));
1149 }
1150
1151 #[test]
1152 fn test_working_memory_update_existing_key_no_eviction() {
1153 let wm = WorkingMemory::new(2).unwrap();
1154 wm.set("k1", "v1").unwrap();
1155 wm.set("k2", "v2").unwrap();
1156 wm.set("k1", "v1-updated").unwrap(); assert_eq!(wm.len().unwrap(), 2);
1158 assert_eq!(wm.get("k1").unwrap(), Some("v1-updated".into()));
1159 assert_eq!(wm.get("k2").unwrap(), Some("v2".into()));
1160 }
1161
1162 #[test]
1163 fn test_working_memory_clear_removes_all() {
1164 let wm = WorkingMemory::new(10).unwrap();
1165 wm.set("a", "1").unwrap();
1166 wm.set("b", "2").unwrap();
1167 wm.clear().unwrap();
1168 assert!(wm.is_empty().unwrap());
1169 }
1170
1171 #[test]
1172 fn test_working_memory_is_empty_initially() {
1173 let wm = WorkingMemory::new(5).unwrap();
1174 assert!(wm.is_empty().unwrap());
1175 }
1176
1177 #[test]
1178 fn test_working_memory_len_tracks_entries() {
1179 let wm = WorkingMemory::new(10).unwrap();
1180 wm.set("a", "1").unwrap();
1181 wm.set("b", "2").unwrap();
1182 assert_eq!(wm.len().unwrap(), 2);
1183 }
1184
1185 #[test]
1186 fn test_working_memory_capacity_never_exceeded() {
1187 let cap = 5usize;
1188 let wm = WorkingMemory::new(cap).unwrap();
1189 for i in 0..20 {
1190 wm.set(format!("key-{i}"), format!("val-{i}")).unwrap();
1191 assert!(wm.len().unwrap() <= cap);
1192 }
1193 }
1194}