1use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, VecDeque};
9use std::sync::Arc;
10use tokio::sync::RwLock;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
18#[serde(rename_all = "camelCase")]
19pub struct RelevanceConfig {
20 #[serde(default = "RelevanceConfig::default_decay_days")]
22 pub decay_days: f32,
23 #[serde(default = "RelevanceConfig::default_importance_weight")]
25 pub importance_weight: f32,
26 #[serde(default = "RelevanceConfig::default_recency_weight")]
28 pub recency_weight: f32,
29}
30
31impl RelevanceConfig {
32 fn default_decay_days() -> f32 {
33 30.0
34 }
35 fn default_importance_weight() -> f32 {
36 0.7
37 }
38 fn default_recency_weight() -> f32 {
39 0.3
40 }
41}
42
43impl Default for RelevanceConfig {
44 fn default() -> Self {
45 Self {
46 decay_days: 30.0,
47 importance_weight: 0.7,
48 recency_weight: 0.3,
49 }
50 }
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55#[serde(rename_all = "camelCase")]
56pub struct MemoryConfig {
57 #[serde(default)]
59 pub relevance: RelevanceConfig,
60 #[serde(default = "MemoryConfig::default_max_short_term")]
62 pub max_short_term: usize,
63 #[serde(default = "MemoryConfig::default_max_working")]
65 pub max_working: usize,
66}
67
68impl MemoryConfig {
69 fn default_max_short_term() -> usize {
70 100
71 }
72 fn default_max_working() -> usize {
73 10
74 }
75}
76
77impl Default for MemoryConfig {
78 fn default() -> Self {
79 Self {
80 relevance: RelevanceConfig::default(),
81 max_short_term: 100,
82 max_working: 10,
83 }
84 }
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct MemoryItem {
94 pub id: String,
96 pub content: String,
98 pub timestamp: DateTime<Utc>,
100 pub importance: f32,
102 pub tags: Vec<String>,
104 pub memory_type: MemoryType,
106 pub metadata: HashMap<String, String>,
108 pub access_count: u32,
110 pub last_accessed: Option<DateTime<Utc>>,
112 #[serde(skip)]
114 pub content_lower: String,
115}
116
117impl MemoryItem {
118 pub fn new(content: impl Into<String>) -> Self {
120 let content = content.into();
121 let content_lower = content.to_lowercase();
122 Self {
123 id: uuid::Uuid::new_v4().to_string(),
124 content,
125 timestamp: Utc::now(),
126 importance: 0.5,
127 tags: Vec::new(),
128 memory_type: MemoryType::Episodic,
129 metadata: HashMap::new(),
130 access_count: 0,
131 last_accessed: None,
132 content_lower,
133 }
134 }
135
136 pub fn with_importance(mut self, importance: f32) -> Self {
138 self.importance = importance.clamp(0.0, 1.0);
139 self
140 }
141
142 pub fn with_tags(mut self, tags: Vec<String>) -> Self {
144 self.tags = tags;
145 self
146 }
147
148 pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
150 self.tags.push(tag.into());
151 self
152 }
153
154 pub fn with_type(mut self, memory_type: MemoryType) -> Self {
156 self.memory_type = memory_type;
157 self
158 }
159
160 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
162 self.metadata.insert(key.into(), value.into());
163 self
164 }
165
166 pub fn record_access(&mut self) {
168 self.access_count += 1;
169 self.last_accessed = Some(Utc::now());
170 }
171
172 pub fn relevance_score_at(&self, now: DateTime<Utc>) -> f32 {
176 let age_seconds = (now - self.timestamp).num_seconds() as f32;
177 let age_days = age_seconds / 86400.0;
178
179 let decay = (-age_days / 30.0).exp(); self.importance * 0.7 + decay * 0.3
184 }
185
186 pub fn relevance_score(&self) -> f32 {
188 self.relevance_score_at(Utc::now())
189 }
190}
191
192#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
194#[serde(rename_all = "snake_case")]
195pub enum MemoryType {
196 Episodic,
198 Semantic,
200 Procedural,
202 Working,
204}
205
206#[async_trait::async_trait]
212pub trait MemoryStore: Send + Sync {
213 async fn store(&self, item: MemoryItem) -> anyhow::Result<()>;
215
216 async fn retrieve(&self, id: &str) -> anyhow::Result<Option<MemoryItem>>;
218
219 async fn search(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryItem>>;
221
222 async fn search_by_tags(
224 &self,
225 tags: &[String],
226 limit: usize,
227 ) -> anyhow::Result<Vec<MemoryItem>>;
228
229 async fn get_recent(&self, limit: usize) -> anyhow::Result<Vec<MemoryItem>>;
231
232 async fn get_important(&self, threshold: f32, limit: usize) -> anyhow::Result<Vec<MemoryItem>>;
234
235 async fn delete(&self, id: &str) -> anyhow::Result<()>;
237
238 async fn clear(&self) -> anyhow::Result<()>;
240
241 async fn count(&self) -> anyhow::Result<usize>;
243}
244
245fn sort_by_relevance(items: &mut [MemoryItem]) {
251 let now = Utc::now();
252 items.sort_by(|a, b| {
253 b.relevance_score_at(now)
254 .partial_cmp(&a.relevance_score_at(now))
255 .unwrap_or(std::cmp::Ordering::Equal)
256 });
257}
258
259#[derive(Clone)]
265pub struct AgentMemory {
266 store: Arc<dyn MemoryStore>,
268 short_term: Arc<RwLock<VecDeque<MemoryItem>>>,
270 working: Arc<RwLock<Vec<MemoryItem>>>,
272 max_short_term: usize,
274 max_working: usize,
276 relevance_config: RelevanceConfig,
278}
279
280impl std::fmt::Debug for AgentMemory {
281 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
282 f.debug_struct("AgentMemory")
283 .field("max_short_term", &self.max_short_term)
284 .field("max_working", &self.max_working)
285 .finish()
286 }
287}
288
289impl AgentMemory {
290 pub fn new(store: Arc<dyn MemoryStore>) -> Self {
292 Self::with_config(store, MemoryConfig::default())
293 }
294
295 pub fn with_config(store: Arc<dyn MemoryStore>, config: MemoryConfig) -> Self {
297 Self {
298 store,
299 short_term: Arc::new(RwLock::new(VecDeque::new())),
300 working: Arc::new(RwLock::new(Vec::new())),
301 max_short_term: config.max_short_term,
302 max_working: config.max_working,
303 relevance_config: config.relevance,
304 }
305 }
306
307 fn score(&self, item: &MemoryItem, now: DateTime<Utc>) -> f32 {
309 let age_seconds = (now - item.timestamp).num_seconds() as f32;
310 let age_days = age_seconds / 86400.0;
311 let decay = (-age_days / self.relevance_config.decay_days).exp();
312 item.importance * self.relevance_config.importance_weight
313 + decay * self.relevance_config.recency_weight
314 }
315
316 pub async fn remember(&self, item: MemoryItem) -> anyhow::Result<()> {
318 self.store.store(item.clone()).await?;
320
321 let mut short_term = self.short_term.write().await;
323 short_term.push_back(item);
324
325 if short_term.len() > self.max_short_term {
327 short_term.pop_front();
328 }
329
330 Ok(())
331 }
332
333 pub async fn remember_success(
335 &self,
336 prompt: &str,
337 tools_used: &[String],
338 result: &str,
339 ) -> anyhow::Result<()> {
340 let content = format!(
341 "Success: {}\nTools: {}\nResult: {}",
342 prompt,
343 tools_used.join(", "),
344 result
345 );
346
347 let item = MemoryItem::new(content)
348 .with_importance(0.8)
349 .with_tag("success")
350 .with_tag("pattern")
351 .with_type(MemoryType::Procedural)
352 .with_metadata("prompt", prompt)
353 .with_metadata("tools", tools_used.join(","));
354
355 self.remember(item).await
356 }
357
358 pub async fn remember_failure(
360 &self,
361 prompt: &str,
362 error: &str,
363 attempted_tools: &[String],
364 ) -> anyhow::Result<()> {
365 let content = format!(
366 "Failure: {}\nError: {}\nAttempted tools: {}",
367 prompt,
368 error,
369 attempted_tools.join(", ")
370 );
371
372 let item = MemoryItem::new(content)
373 .with_importance(0.9) .with_tag("failure")
375 .with_tag("avoid")
376 .with_type(MemoryType::Episodic)
377 .with_metadata("prompt", prompt)
378 .with_metadata("error", error);
379
380 self.remember(item).await
381 }
382
383 pub async fn recall_similar(
385 &self,
386 prompt: &str,
387 limit: usize,
388 ) -> anyhow::Result<Vec<MemoryItem>> {
389 self.store.search(prompt, limit).await
390 }
391
392 pub async fn recall_by_tags(
394 &self,
395 tags: &[String],
396 limit: usize,
397 ) -> anyhow::Result<Vec<MemoryItem>> {
398 self.store.search_by_tags(tags, limit).await
399 }
400
401 pub async fn get_recent(&self, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
403 self.store.get_recent(limit).await
404 }
405
406 pub async fn add_to_working(&self, item: MemoryItem) -> anyhow::Result<()> {
408 let mut working = self.working.write().await;
409 working.push(item);
410
411 if working.len() > self.max_working {
413 let now = Utc::now();
414 working.sort_by(|a, b| {
415 self.score(b, now)
416 .partial_cmp(&self.score(a, now))
417 .unwrap_or(std::cmp::Ordering::Equal)
418 });
419 working.truncate(self.max_working);
420 }
421
422 Ok(())
423 }
424
425 pub async fn get_working(&self) -> Vec<MemoryItem> {
427 self.working.read().await.clone()
428 }
429
430 pub async fn clear_working(&self) {
432 self.working.write().await.clear();
433 }
434
435 pub async fn get_short_term(&self) -> Vec<MemoryItem> {
437 self.short_term.read().await.iter().cloned().collect()
438 }
439
440 pub async fn clear_short_term(&self) {
442 self.short_term.write().await.clear();
443 }
444
445 pub async fn stats(&self) -> anyhow::Result<MemoryStats> {
447 let long_term_count = self.store.count().await?;
448 let short_term_count = self.short_term.read().await.len();
449 let working_count = self.working.read().await.len();
450
451 Ok(MemoryStats {
452 long_term_count,
453 short_term_count,
454 working_count,
455 })
456 }
457
458 pub fn store(&self) -> &Arc<dyn MemoryStore> {
460 &self.store
461 }
462
463 pub async fn working_count(&self) -> usize {
465 self.working.read().await.len()
466 }
467
468 pub async fn short_term_count(&self) -> usize {
470 self.short_term.read().await.len()
471 }
472}
473
474#[derive(Debug, Clone, Serialize, Deserialize)]
476pub struct MemoryStats {
477 pub long_term_count: usize,
479 pub short_term_count: usize,
481 pub working_count: usize,
483}
484
485pub struct MemoryContextProvider {
494 memory: AgentMemory,
495}
496
497impl MemoryContextProvider {
498 pub fn new(memory: AgentMemory) -> Self {
500 Self { memory }
501 }
502}
503
504#[async_trait::async_trait]
505impl crate::context::ContextProvider for MemoryContextProvider {
506 fn name(&self) -> &str {
507 "memory"
508 }
509
510 async fn query(
511 &self,
512 query: &crate::context::ContextQuery,
513 ) -> anyhow::Result<crate::context::ContextResult> {
514 let limit = query.max_results.min(5);
515 let items = self.memory.recall_similar(&query.query, limit).await?;
516
517 let mut result = crate::context::ContextResult::new("memory");
518 for item in items {
519 let relevance = item.relevance_score();
520 let token_count = item.content.len() / 4; let context_item = crate::context::ContextItem::new(
522 &item.id,
523 crate::context::ContextType::Memory,
524 &item.content,
525 )
526 .with_relevance(relevance)
527 .with_token_count(token_count)
528 .with_source("memory");
529 result.add_item(context_item);
530 }
531
532 Ok(result)
533 }
534
535 async fn on_turn_complete(
536 &self,
537 _session_id: &str,
538 prompt: &str,
539 response: &str,
540 ) -> anyhow::Result<()> {
541 self.memory.remember_success(prompt, &[], response).await
543 }
544}
545
546#[cfg(test)]
551mod tests {
552 use super::*;
553
554 struct TestMemoryStore {
556 items: std::sync::Mutex<Vec<MemoryItem>>,
557 }
558
559 impl TestMemoryStore {
560 fn new() -> Self {
561 Self { items: std::sync::Mutex::new(Vec::new()) }
562 }
563 }
564
565 #[async_trait::async_trait]
566 impl MemoryStore for TestMemoryStore {
567 async fn store(&self, item: MemoryItem) -> anyhow::Result<()> {
568 self.items.lock().unwrap().push(item);
569 Ok(())
570 }
571 async fn retrieve(&self, id: &str) -> anyhow::Result<Option<MemoryItem>> {
572 Ok(self.items.lock().unwrap().iter().find(|i| i.id == id).cloned())
573 }
574 async fn search(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
575 let items = self.items.lock().unwrap();
576 let query_lower = query.to_lowercase();
577 Ok(items.iter().filter(|i| i.content.to_lowercase().contains(&query_lower)).take(limit).cloned().collect())
578 }
579 async fn search_by_tags(&self, tags: &[String], limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
580 let items = self.items.lock().unwrap();
581 Ok(items.iter().filter(|i| tags.iter().any(|t| i.tags.contains(t))).take(limit).cloned().collect())
582 }
583 async fn get_recent(&self, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
584 let items = self.items.lock().unwrap();
585 let mut sorted: Vec<_> = items.iter().cloned().collect();
586 sorted.sort_by(|a, b| b.timestamp.cmp(&a.timestamp));
587 sorted.truncate(limit);
588 Ok(sorted)
589 }
590 async fn get_important(&self, threshold: f32, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
591 let items = self.items.lock().unwrap();
592 Ok(items.iter().filter(|i| i.importance >= threshold).take(limit).cloned().collect())
593 }
594 async fn delete(&self, id: &str) -> anyhow::Result<()> {
595 self.items.lock().unwrap().retain(|i| i.id != id);
596 Ok(())
597 }
598 async fn clear(&self) -> anyhow::Result<()> {
599 self.items.lock().unwrap().clear();
600 Ok(())
601 }
602 async fn count(&self) -> anyhow::Result<usize> {
603 Ok(self.items.lock().unwrap().len())
604 }
605 }
606
607
608 #[test]
609 fn test_memory_item_creation() {
610 let item = MemoryItem::new("Test memory")
611 .with_importance(0.8)
612 .with_tag("test")
613 .with_type(MemoryType::Semantic);
614
615 assert_eq!(item.content, "Test memory");
616 assert_eq!(item.importance, 0.8);
617 assert_eq!(item.tags, vec!["test"]);
618 assert_eq!(item.memory_type, MemoryType::Semantic);
619 }
620
621 #[test]
622 fn test_memory_item_relevance() {
623 let item = MemoryItem::new("Test").with_importance(0.9);
624 let score = item.relevance_score();
625
626 assert!(score > 0.6);
628 }
629
630 #[test]
631 fn test_relevance_config_defaults() {
632 let config = RelevanceConfig::default();
633 assert_eq!(config.decay_days, 30.0);
634 assert_eq!(config.importance_weight, 0.7);
635 assert_eq!(config.recency_weight, 0.3);
636 }
637
638 #[test]
639 fn test_memory_config_defaults() {
640 let config = MemoryConfig::default();
641 assert_eq!(config.max_short_term, 100);
642 assert_eq!(config.max_working, 10);
643 assert_eq!(config.relevance.decay_days, 30.0);
644 }
645
646 #[test]
647 fn test_memory_config_serde_roundtrip() {
648 let config = MemoryConfig::default();
649 let json = serde_json::to_string(&config).unwrap();
650 let parsed: MemoryConfig = serde_json::from_str(&json).unwrap();
651 assert_eq!(parsed.max_short_term, config.max_short_term);
652 assert_eq!(parsed.max_working, config.max_working);
653 assert_eq!(parsed.relevance.decay_days, config.relevance.decay_days);
654 }
655
656 #[test]
657 fn test_agent_memory_with_config() {
658 let config = MemoryConfig {
659 relevance: RelevanceConfig {
660 decay_days: 7.0,
661 importance_weight: 0.5,
662 recency_weight: 0.5,
663 },
664 max_short_term: 50,
665 max_working: 5,
666 };
667 let memory = AgentMemory::with_config(Arc::new(TestMemoryStore::new()), config);
668 assert_eq!(memory.max_short_term, 50);
669 assert_eq!(memory.max_working, 5);
670 assert_eq!(memory.relevance_config.decay_days, 7.0);
671 }
672
673 #[test]
674 fn test_agent_memory_score_uses_config() {
675 let config = MemoryConfig {
676 relevance: RelevanceConfig {
677 decay_days: 7.0,
678 importance_weight: 0.9,
679 recency_weight: 0.1,
680 },
681 ..Default::default()
682 };
683 let memory = AgentMemory::with_config(Arc::new(TestMemoryStore::new()), config);
684
685 let item = MemoryItem::new("Test").with_importance(1.0);
686 let now = Utc::now();
687 let score = memory.score(&item, now);
688
689 assert!(score > 0.95, "Score was {}", score);
692 }
693
694 #[tokio::test]
695 async fn test_in_memory_store() {
696 let store = TestMemoryStore::new();
697
698 let item = MemoryItem::new("Test memory").with_tag("test");
699 store.store(item.clone()).await.unwrap();
700
701 let retrieved = store.retrieve(&item.id).await.unwrap();
702 assert!(retrieved.is_some());
703 assert_eq!(retrieved.unwrap().content, "Test memory");
704 }
705
706 #[tokio::test]
707 async fn test_memory_search() {
708 let store = TestMemoryStore::new();
709
710 store
711 .store(MemoryItem::new("How to create a file").with_tag("file"))
712 .await
713 .unwrap();
714 store
715 .store(MemoryItem::new("How to delete a file").with_tag("file"))
716 .await
717 .unwrap();
718 store
719 .store(MemoryItem::new("How to create a directory").with_tag("dir"))
720 .await
721 .unwrap();
722
723 let results = store.search("create", 10).await.unwrap();
724 assert_eq!(results.len(), 2);
725
726 let results = store
727 .search_by_tags(&["file".to_string()], 10)
728 .await
729 .unwrap();
730 assert_eq!(results.len(), 2);
731 }
732
733 #[tokio::test]
734 async fn test_agent_memory() {
735 let memory = AgentMemory::new(Arc::new(TestMemoryStore::new()));
736
737 memory
739 .remember_success("Create a file", &["write".to_string()], "File created")
740 .await
741 .unwrap();
742
743 memory
745 .remember_failure("Delete file", "Permission denied", &["bash".to_string()])
746 .await
747 .unwrap();
748
749 let results = memory.recall_similar("create", 10).await.unwrap();
751 assert!(!results.is_empty());
752
753 let stats = memory.stats().await.unwrap();
754 assert_eq!(stats.long_term_count, 2);
755 }
756
757 #[tokio::test]
758 async fn test_working_memory() {
759 let memory = AgentMemory::new(Arc::new(TestMemoryStore::new()));
760
761 let item = MemoryItem::new("Active task").with_type(MemoryType::Working);
762 memory.add_to_working(item).await.unwrap();
763
764 let working = memory.get_working().await;
765 assert_eq!(working.len(), 1);
766
767 memory.clear_working().await;
768 let working = memory.get_working().await;
769 assert_eq!(working.len(), 0);
770 }
771}
772
773#[cfg(test)]
774mod extra_memory_tests {
775 use super::*;
776
777 struct TestMemoryStore {
779 items: std::sync::Mutex<Vec<MemoryItem>>,
780 }
781
782 impl TestMemoryStore {
783 fn new() -> Self {
784 Self { items: std::sync::Mutex::new(Vec::new()) }
785 }
786 }
787
788 #[async_trait::async_trait]
789 impl MemoryStore for TestMemoryStore {
790 async fn store(&self, item: MemoryItem) -> anyhow::Result<()> {
791 self.items.lock().unwrap().push(item);
792 Ok(())
793 }
794 async fn retrieve(&self, id: &str) -> anyhow::Result<Option<MemoryItem>> {
795 Ok(self.items.lock().unwrap().iter().find(|i| i.id == id).cloned())
796 }
797 async fn search(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
798 let items = self.items.lock().unwrap();
799 let query_lower = query.to_lowercase();
800 Ok(items.iter().filter(|i| i.content.to_lowercase().contains(&query_lower)).take(limit).cloned().collect())
801 }
802 async fn search_by_tags(&self, tags: &[String], limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
803 let items = self.items.lock().unwrap();
804 Ok(items.iter().filter(|i| tags.iter().any(|t| i.tags.contains(t))).take(limit).cloned().collect())
805 }
806 async fn get_recent(&self, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
807 let items = self.items.lock().unwrap();
808 let mut sorted: Vec<_> = items.iter().cloned().collect();
809 sorted.sort_by(|a, b| b.timestamp.cmp(&a.timestamp));
810 sorted.truncate(limit);
811 Ok(sorted)
812 }
813 async fn get_important(&self, threshold: f32, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
814 let items = self.items.lock().unwrap();
815 Ok(items.iter().filter(|i| i.importance >= threshold).take(limit).cloned().collect())
816 }
817 async fn delete(&self, id: &str) -> anyhow::Result<()> {
818 self.items.lock().unwrap().retain(|i| i.id != id);
819 Ok(())
820 }
821 async fn clear(&self) -> anyhow::Result<()> {
822 self.items.lock().unwrap().clear();
823 Ok(())
824 }
825 async fn count(&self) -> anyhow::Result<usize> {
826 Ok(self.items.lock().unwrap().len())
827 }
828 }
829
830
831 #[test]
836 fn test_memory_item_with_metadata() {
837 let item = MemoryItem::new("test")
838 .with_metadata("key1", "value1")
839 .with_metadata("key2", "value2");
840 assert_eq!(item.metadata.get("key1").unwrap(), "value1");
841 assert_eq!(item.metadata.get("key2").unwrap(), "value2");
842 }
843
844 #[test]
845 fn test_memory_item_with_tags_vec() {
846 let item = MemoryItem::new("test").with_tags(vec![
847 "a".to_string(),
848 "b".to_string(),
849 "c".to_string(),
850 ]);
851 assert_eq!(item.tags.len(), 3);
852 }
853
854 #[test]
855 fn test_memory_item_importance_clamped() {
856 let item_high = MemoryItem::new("test").with_importance(1.5);
857 assert_eq!(item_high.importance, 1.0);
858
859 let item_low = MemoryItem::new("test").with_importance(-0.5);
860 assert_eq!(item_low.importance, 0.0);
861 }
862
863 #[test]
864 fn test_memory_item_record_access() {
865 let mut item = MemoryItem::new("test");
866 assert_eq!(item.access_count, 0);
867 assert!(item.last_accessed.is_none());
868
869 item.record_access();
870 assert_eq!(item.access_count, 1);
871 assert!(item.last_accessed.is_some());
872
873 item.record_access();
874 assert_eq!(item.access_count, 2);
875 }
876
877 #[test]
878 fn test_memory_item_all_types() {
879 let episodic = MemoryItem::new("e").with_type(MemoryType::Episodic);
880 assert_eq!(episodic.memory_type, MemoryType::Episodic);
881
882 let semantic = MemoryItem::new("s").with_type(MemoryType::Semantic);
883 assert_eq!(semantic.memory_type, MemoryType::Semantic);
884
885 let procedural = MemoryItem::new("p").with_type(MemoryType::Procedural);
886 assert_eq!(procedural.memory_type, MemoryType::Procedural);
887
888 let working = MemoryItem::new("w").with_type(MemoryType::Working);
889 assert_eq!(working.memory_type, MemoryType::Working);
890 }
891
892 #[test]
893 fn test_memory_item_default_type_is_episodic() {
894 let item = MemoryItem::new("test");
895 assert_eq!(item.memory_type, MemoryType::Episodic);
896 }
897
898 #[tokio::test]
903 async fn test_in_memory_store_retrieve_nonexistent() {
904 let store = TestMemoryStore::new();
905 let result = store.retrieve("nonexistent").await.unwrap();
906 assert!(result.is_none());
907 }
908
909 #[tokio::test]
910 async fn test_in_memory_store_delete() {
911 let store = TestMemoryStore::new();
912 let item = MemoryItem::new("to delete");
913 let id = item.id.clone();
914 store.store(item).await.unwrap();
915 assert_eq!(store.count().await.unwrap(), 1);
916
917 store.delete(&id).await.unwrap();
918 assert_eq!(store.count().await.unwrap(), 0);
919 }
920
921 #[tokio::test]
922 async fn test_in_memory_store_clear() {
923 let store = TestMemoryStore::new();
924 for i in 0..5 {
925 store
926 .store(MemoryItem::new(format!("item {}", i)))
927 .await
928 .unwrap();
929 }
930 assert_eq!(store.count().await.unwrap(), 5);
931
932 store.clear().await.unwrap();
933 assert_eq!(store.count().await.unwrap(), 0);
934 }
935
936 #[tokio::test]
937 async fn test_in_memory_store_get_recent() {
938 let store = TestMemoryStore::new();
939 for i in 0..5 {
940 store
941 .store(MemoryItem::new(format!("item {}", i)))
942 .await
943 .unwrap();
944 }
945 let recent = store.get_recent(3).await.unwrap();
946 assert_eq!(recent.len(), 3);
947 }
948
949 #[tokio::test]
950 async fn test_in_memory_store_get_important() {
951 let store = TestMemoryStore::new();
952 store
953 .store(MemoryItem::new("low").with_importance(0.2))
954 .await
955 .unwrap();
956 store
957 .store(MemoryItem::new("medium").with_importance(0.5))
958 .await
959 .unwrap();
960 store
961 .store(MemoryItem::new("high").with_importance(0.9))
962 .await
963 .unwrap();
964
965 let important = store.get_important(0.7, 10).await.unwrap();
966 assert_eq!(important.len(), 1);
967 assert_eq!(important[0].content, "high");
968 }
969
970 #[tokio::test]
971 async fn test_in_memory_store_search_case_insensitive() {
972 let store = TestMemoryStore::new();
973 store
974 .store(MemoryItem::new("How to CREATE a file"))
975 .await
976 .unwrap();
977 let results = store.search("create", 10).await.unwrap();
978 assert_eq!(results.len(), 1);
979 }
980
981 #[tokio::test]
986 async fn test_agent_memory_short_term() {
987 let memory = AgentMemory::new(Arc::new(TestMemoryStore::new()));
988 memory.remember(MemoryItem::new("item 1")).await.unwrap();
989 memory.remember(MemoryItem::new("item 2")).await.unwrap();
990
991 let short_term = memory.get_short_term().await;
992 assert_eq!(short_term.len(), 2);
993
994 memory.clear_short_term().await;
995 let short_term = memory.get_short_term().await;
996 assert_eq!(short_term.len(), 0);
997 }
998
999 #[tokio::test]
1000 async fn test_agent_memory_short_term_count() {
1001 let memory = AgentMemory::new(Arc::new(TestMemoryStore::new()));
1002 assert_eq!(memory.short_term_count().await, 0);
1003 memory.remember(MemoryItem::new("item")).await.unwrap();
1004 assert_eq!(memory.short_term_count().await, 1);
1005 }
1006
1007 #[tokio::test]
1008 async fn test_agent_memory_working_count() {
1009 let memory = AgentMemory::new(Arc::new(TestMemoryStore::new()));
1010 assert_eq!(memory.working_count().await, 0);
1011 memory
1012 .add_to_working(MemoryItem::new("task"))
1013 .await
1014 .unwrap();
1015 assert_eq!(memory.working_count().await, 1);
1016 }
1017
1018 #[tokio::test]
1019 async fn test_agent_memory_recall_by_tags() {
1020 let memory = AgentMemory::new(Arc::new(TestMemoryStore::new()));
1021 memory
1022 .remember_success("create file", &["write".to_string()], "ok")
1023 .await
1024 .unwrap();
1025 memory
1026 .remember_failure("delete file", "denied", &["bash".to_string()])
1027 .await
1028 .unwrap();
1029
1030 let successes = memory
1031 .recall_by_tags(&["success".to_string()], 10)
1032 .await
1033 .unwrap();
1034 assert_eq!(successes.len(), 1);
1035
1036 let failures = memory
1037 .recall_by_tags(&["failure".to_string()], 10)
1038 .await
1039 .unwrap();
1040 assert_eq!(failures.len(), 1);
1041 }
1042
1043 #[tokio::test]
1044 async fn test_agent_memory_get_recent() {
1045 let memory = AgentMemory::new(Arc::new(TestMemoryStore::new()));
1046 for i in 0..5 {
1047 memory
1048 .remember(MemoryItem::new(format!("item {}", i)))
1049 .await
1050 .unwrap();
1051 }
1052 let recent = memory.get_recent(3).await.unwrap();
1053 assert_eq!(recent.len(), 3);
1054 }
1055
1056 #[tokio::test]
1057 async fn test_agent_memory_store_accessor() {
1058 let memory = AgentMemory::new(Arc::new(TestMemoryStore::new()));
1059 memory.remember(MemoryItem::new("test")).await.unwrap();
1060 let count = memory.store().count().await.unwrap();
1061 assert_eq!(count, 1);
1062 }
1063
1064 #[tokio::test]
1065 async fn test_agent_memory_stats_all_fields() {
1066 let memory = AgentMemory::new(Arc::new(TestMemoryStore::new()));
1067 memory.remember(MemoryItem::new("long term")).await.unwrap();
1068 memory
1069 .add_to_working(MemoryItem::new("working"))
1070 .await
1071 .unwrap();
1072
1073 let stats = memory.stats().await.unwrap();
1074 assert_eq!(stats.long_term_count, 1);
1075 assert_eq!(stats.short_term_count, 1); assert_eq!(stats.working_count, 1);
1077 }
1078
1079 #[tokio::test]
1080 async fn test_agent_memory_working_overflow_trims() {
1081 let store = Arc::new(TestMemoryStore::new());
1082 let memory = AgentMemory {
1083 store,
1084 short_term: Arc::new(RwLock::new(VecDeque::new())),
1085 working: Arc::new(RwLock::new(Vec::new())),
1086 max_short_term: 100,
1087 max_working: 3, relevance_config: RelevanceConfig::default(),
1089 };
1090
1091 for i in 0..5 {
1092 memory
1093 .add_to_working(
1094 MemoryItem::new(format!("task {}", i)).with_importance(i as f32 * 0.2),
1095 )
1096 .await
1097 .unwrap();
1098 }
1099
1100 let working = memory.get_working().await;
1101 assert_eq!(working.len(), 3); }
1103}