1use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, VecDeque};
9use std::sync::Arc;
10use tokio::sync::RwLock;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
18#[serde(rename_all = "camelCase")]
19pub struct RelevanceConfig {
20 #[serde(default = "RelevanceConfig::default_decay_days")]
22 pub decay_days: f32,
23 #[serde(default = "RelevanceConfig::default_importance_weight")]
25 pub importance_weight: f32,
26 #[serde(default = "RelevanceConfig::default_recency_weight")]
28 pub recency_weight: f32,
29}
30
31impl RelevanceConfig {
32 fn default_decay_days() -> f32 {
33 30.0
34 }
35 fn default_importance_weight() -> f32 {
36 0.7
37 }
38 fn default_recency_weight() -> f32 {
39 0.3
40 }
41}
42
43impl Default for RelevanceConfig {
44 fn default() -> Self {
45 Self {
46 decay_days: 30.0,
47 importance_weight: 0.7,
48 recency_weight: 0.3,
49 }
50 }
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55#[serde(rename_all = "camelCase")]
56pub struct MemoryConfig {
57 #[serde(default)]
59 pub relevance: RelevanceConfig,
60 #[serde(default = "MemoryConfig::default_max_short_term")]
62 pub max_short_term: usize,
63 #[serde(default = "MemoryConfig::default_max_working")]
65 pub max_working: usize,
66}
67
68impl MemoryConfig {
69 fn default_max_short_term() -> usize {
70 100
71 }
72 fn default_max_working() -> usize {
73 10
74 }
75}
76
77impl Default for MemoryConfig {
78 fn default() -> Self {
79 Self {
80 relevance: RelevanceConfig::default(),
81 max_short_term: 100,
82 max_working: 10,
83 }
84 }
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct MemoryItem {
94 pub id: String,
96 pub content: String,
98 pub timestamp: DateTime<Utc>,
100 pub importance: f32,
102 pub tags: Vec<String>,
104 pub memory_type: MemoryType,
106 pub metadata: HashMap<String, String>,
108 pub access_count: u32,
110 pub last_accessed: Option<DateTime<Utc>>,
112 #[serde(skip)]
114 pub content_lower: String,
115}
116
117impl MemoryItem {
118 pub fn new(content: impl Into<String>) -> Self {
120 let content = content.into();
121 let content_lower = content.to_lowercase();
122 Self {
123 id: uuid::Uuid::new_v4().to_string(),
124 content,
125 timestamp: Utc::now(),
126 importance: 0.5,
127 tags: Vec::new(),
128 memory_type: MemoryType::Episodic,
129 metadata: HashMap::new(),
130 access_count: 0,
131 last_accessed: None,
132 content_lower,
133 }
134 }
135
136 pub fn with_importance(mut self, importance: f32) -> Self {
138 self.importance = importance.clamp(0.0, 1.0);
139 self
140 }
141
142 pub fn with_tags(mut self, tags: Vec<String>) -> Self {
144 self.tags = tags;
145 self
146 }
147
148 pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
150 self.tags.push(tag.into());
151 self
152 }
153
154 pub fn with_type(mut self, memory_type: MemoryType) -> Self {
156 self.memory_type = memory_type;
157 self
158 }
159
160 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
162 self.metadata.insert(key.into(), value.into());
163 self
164 }
165
166 pub fn record_access(&mut self) {
168 self.access_count += 1;
169 self.last_accessed = Some(Utc::now());
170 }
171
172 pub fn relevance_score_at(&self, now: DateTime<Utc>) -> f32 {
176 let age_seconds = (now - self.timestamp).num_seconds() as f32;
177 let age_days = age_seconds / 86400.0;
178
179 let decay = (-age_days / 30.0).exp(); self.importance * 0.7 + decay * 0.3
184 }
185
186 pub fn relevance_score(&self) -> f32 {
188 self.relevance_score_at(Utc::now())
189 }
190}
191
192#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
194#[serde(rename_all = "snake_case")]
195pub enum MemoryType {
196 Episodic,
198 Semantic,
200 Procedural,
202 Working,
204}
205
206#[async_trait::async_trait]
212pub trait MemoryStore: Send + Sync {
213 async fn store(&self, item: MemoryItem) -> anyhow::Result<()>;
215
216 async fn retrieve(&self, id: &str) -> anyhow::Result<Option<MemoryItem>>;
218
219 async fn search(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryItem>>;
221
222 async fn search_by_tags(
224 &self,
225 tags: &[String],
226 limit: usize,
227 ) -> anyhow::Result<Vec<MemoryItem>>;
228
229 async fn get_recent(&self, limit: usize) -> anyhow::Result<Vec<MemoryItem>>;
231
232 async fn get_important(&self, threshold: f32, limit: usize) -> anyhow::Result<Vec<MemoryItem>>;
234
235 async fn delete(&self, id: &str) -> anyhow::Result<()>;
237
238 async fn clear(&self) -> anyhow::Result<()>;
240
241 async fn count(&self) -> anyhow::Result<usize>;
243}
244
245fn search_memories(memories: &[MemoryItem], query: &str, limit: usize) -> Vec<MemoryItem> {
251 let query_lower = query.to_lowercase();
252 let mut results: Vec<_> = memories
253 .iter()
254 .filter(|m| m.content_lower.contains(&query_lower))
255 .cloned()
256 .collect();
257 sort_by_relevance(&mut results);
258 results.truncate(limit);
259 results
260}
261
262fn search_memories_by_tags(
264 memories: &[MemoryItem],
265 tags: &[String],
266 limit: usize,
267) -> Vec<MemoryItem> {
268 let mut results: Vec<_> = memories
269 .iter()
270 .filter(|m| tags.iter().any(|tag| m.tags.contains(tag)))
271 .cloned()
272 .collect();
273 sort_by_relevance(&mut results);
274 results.truncate(limit);
275 results
276}
277
278fn recent_memories(memories: &[MemoryItem], limit: usize) -> Vec<MemoryItem> {
280 let mut results: Vec<_> = memories.to_vec();
281 results.sort_by(|a, b| b.timestamp.cmp(&a.timestamp));
282 results.truncate(limit);
283 results
284}
285
286fn important_memories(memories: &[MemoryItem], threshold: f32, limit: usize) -> Vec<MemoryItem> {
288 let mut results: Vec<_> = memories
289 .iter()
290 .filter(|m| m.importance >= threshold)
291 .cloned()
292 .collect();
293 results.sort_by(|a, b| {
294 b.importance
295 .partial_cmp(&a.importance)
296 .unwrap_or(std::cmp::Ordering::Equal)
297 });
298 results.truncate(limit);
299 results
300}
301
302fn sort_by_relevance(items: &mut [MemoryItem]) {
304 let now = Utc::now();
305 items.sort_by(|a, b| {
306 b.relevance_score_at(now)
307 .partial_cmp(&a.relevance_score_at(now))
308 .unwrap_or(std::cmp::Ordering::Equal)
309 });
310}
311
312#[derive(Debug, Clone)]
318pub struct InMemoryStore {
319 memories: Arc<RwLock<Vec<MemoryItem>>>,
320}
321
322impl InMemoryStore {
323 pub fn new() -> Self {
325 Self {
326 memories: Arc::new(RwLock::new(Vec::new())),
327 }
328 }
329}
330
331impl Default for InMemoryStore {
332 fn default() -> Self {
333 Self::new()
334 }
335}
336
337#[async_trait::async_trait]
338impl MemoryStore for InMemoryStore {
339 async fn store(&self, item: MemoryItem) -> anyhow::Result<()> {
340 let mut memories = self.memories.write().await;
341 memories.push(item);
342 Ok(())
343 }
344
345 async fn retrieve(&self, id: &str) -> anyhow::Result<Option<MemoryItem>> {
346 let memories = self.memories.read().await;
347 Ok(memories.iter().find(|m| m.id == id).cloned())
348 }
349
350 async fn search(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
351 let memories = self.memories.read().await;
352 Ok(search_memories(&memories, query, limit))
353 }
354
355 async fn search_by_tags(
356 &self,
357 tags: &[String],
358 limit: usize,
359 ) -> anyhow::Result<Vec<MemoryItem>> {
360 let memories = self.memories.read().await;
361 Ok(search_memories_by_tags(&memories, tags, limit))
362 }
363
364 async fn get_recent(&self, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
365 let memories = self.memories.read().await;
366 Ok(recent_memories(&memories, limit))
367 }
368
369 async fn get_important(&self, threshold: f32, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
370 let memories = self.memories.read().await;
371 Ok(important_memories(&memories, threshold, limit))
372 }
373
374 async fn delete(&self, id: &str) -> anyhow::Result<()> {
375 let mut memories = self.memories.write().await;
376 memories.retain(|m| m.id != id);
377 Ok(())
378 }
379
380 async fn clear(&self) -> anyhow::Result<()> {
381 let mut memories = self.memories.write().await;
382 memories.clear();
383 Ok(())
384 }
385
386 async fn count(&self) -> anyhow::Result<usize> {
387 let memories = self.memories.read().await;
388 Ok(memories.len())
389 }
390}
391
392#[derive(Debug, Clone)]
398pub struct FileStore {
399 file_path: std::path::PathBuf,
400 memories: Arc<RwLock<Vec<MemoryItem>>>,
401}
402
403impl FileStore {
404 pub fn new(file_path: impl Into<std::path::PathBuf>) -> anyhow::Result<Self> {
409 let file_path = file_path.into();
410
411 if let Some(parent) = file_path.parent() {
413 std::fs::create_dir_all(parent)?;
414 }
415
416 let memories = if file_path.exists() {
418 Self::load_from_file(&file_path)?
419 } else {
420 Vec::new()
421 };
422
423 Ok(Self {
424 file_path,
425 memories: Arc::new(RwLock::new(memories)),
426 })
427 }
428
429 pub async fn open(file_path: impl Into<std::path::PathBuf>) -> anyhow::Result<Self> {
431 let file_path = file_path.into();
432
433 if let Some(parent) = file_path.parent() {
435 tokio::fs::create_dir_all(parent).await?;
436 }
437
438 let memories = if file_path.exists() {
440 let content = tokio::fs::read_to_string(&file_path).await?;
441 Self::parse_jsonl(&content)?
442 } else {
443 Vec::new()
444 };
445
446 Ok(Self {
447 file_path,
448 memories: Arc::new(RwLock::new(memories)),
449 })
450 }
451
452 fn load_from_file(path: &std::path::Path) -> anyhow::Result<Vec<MemoryItem>> {
454 let content = std::fs::read_to_string(path)?;
455 Self::parse_jsonl(&content)
456 }
457
458 fn parse_jsonl(content: &str) -> anyhow::Result<Vec<MemoryItem>> {
460 let mut memories = Vec::new();
461
462 for line in content.lines() {
463 if line.trim().is_empty() {
464 continue;
465 }
466 let mut item: MemoryItem = serde_json::from_str(line)?;
467 item.content_lower = item.content.to_lowercase();
468 memories.push(item);
469 }
470
471 Ok(memories)
472 }
473
474 async fn save_to_file(&self) -> anyhow::Result<()> {
476 let memories = self.memories.read().await;
477 let mut content = String::new();
478
479 for memory in memories.iter() {
480 let json = serde_json::to_string(memory)?;
481 content.push_str(&json);
482 content.push('\n');
483 }
484
485 let temp_path = self.file_path.with_extension("tmp");
487 tokio::fs::write(&temp_path, content).await?;
488 tokio::fs::rename(&temp_path, &self.file_path).await?;
489
490 Ok(())
491 }
492}
493
494#[async_trait::async_trait]
495impl MemoryStore for FileStore {
496 async fn store(&self, item: MemoryItem) -> anyhow::Result<()> {
497 {
498 let mut memories = self.memories.write().await;
499 memories.push(item);
500 }
501 self.save_to_file().await
502 }
503
504 async fn retrieve(&self, id: &str) -> anyhow::Result<Option<MemoryItem>> {
505 let memories = self.memories.read().await;
506 Ok(memories.iter().find(|m| m.id == id).cloned())
507 }
508
509 async fn search(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
510 let memories = self.memories.read().await;
511 Ok(search_memories(&memories, query, limit))
512 }
513
514 async fn search_by_tags(
515 &self,
516 tags: &[String],
517 limit: usize,
518 ) -> anyhow::Result<Vec<MemoryItem>> {
519 let memories = self.memories.read().await;
520 Ok(search_memories_by_tags(&memories, tags, limit))
521 }
522
523 async fn get_recent(&self, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
524 let memories = self.memories.read().await;
525 Ok(recent_memories(&memories, limit))
526 }
527
528 async fn get_important(&self, threshold: f32, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
529 let memories = self.memories.read().await;
530 Ok(important_memories(&memories, threshold, limit))
531 }
532
533 async fn delete(&self, id: &str) -> anyhow::Result<()> {
534 {
535 let mut memories = self.memories.write().await;
536 memories.retain(|m| m.id != id);
537 }
538 self.save_to_file().await
539 }
540
541 async fn clear(&self) -> anyhow::Result<()> {
542 {
543 let mut memories = self.memories.write().await;
544 memories.clear();
545 }
546 self.save_to_file().await
547 }
548
549 async fn count(&self) -> anyhow::Result<usize> {
550 let memories = self.memories.read().await;
551 Ok(memories.len())
552 }
553}
554
555#[derive(Clone)]
561pub struct AgentMemory {
562 store: Arc<dyn MemoryStore>,
564 short_term: Arc<RwLock<VecDeque<MemoryItem>>>,
566 working: Arc<RwLock<Vec<MemoryItem>>>,
568 max_short_term: usize,
570 max_working: usize,
572 relevance_config: RelevanceConfig,
574}
575
576impl AgentMemory {
577 pub fn new(store: Arc<dyn MemoryStore>) -> Self {
579 Self::with_config(store, MemoryConfig::default())
580 }
581
582 pub fn with_config(store: Arc<dyn MemoryStore>, config: MemoryConfig) -> Self {
584 Self {
585 store,
586 short_term: Arc::new(RwLock::new(VecDeque::new())),
587 working: Arc::new(RwLock::new(Vec::new())),
588 max_short_term: config.max_short_term,
589 max_working: config.max_working,
590 relevance_config: config.relevance,
591 }
592 }
593
594 pub fn in_memory() -> Self {
596 Self::new(Arc::new(InMemoryStore::new()))
597 }
598
599 fn score(&self, item: &MemoryItem, now: DateTime<Utc>) -> f32 {
601 let age_seconds = (now - item.timestamp).num_seconds() as f32;
602 let age_days = age_seconds / 86400.0;
603 let decay = (-age_days / self.relevance_config.decay_days).exp();
604 item.importance * self.relevance_config.importance_weight
605 + decay * self.relevance_config.recency_weight
606 }
607
608 pub async fn remember(&self, item: MemoryItem) -> anyhow::Result<()> {
610 self.store.store(item.clone()).await?;
612
613 let mut short_term = self.short_term.write().await;
615 short_term.push_back(item);
616
617 if short_term.len() > self.max_short_term {
619 short_term.pop_front();
620 }
621
622 Ok(())
623 }
624
625 pub async fn remember_success(
627 &self,
628 prompt: &str,
629 tools_used: &[String],
630 result: &str,
631 ) -> anyhow::Result<()> {
632 let content = format!(
633 "Success: {}\nTools: {}\nResult: {}",
634 prompt,
635 tools_used.join(", "),
636 result
637 );
638
639 let item = MemoryItem::new(content)
640 .with_importance(0.8)
641 .with_tag("success")
642 .with_tag("pattern")
643 .with_type(MemoryType::Procedural)
644 .with_metadata("prompt", prompt)
645 .with_metadata("tools", tools_used.join(","));
646
647 self.remember(item).await
648 }
649
650 pub async fn remember_failure(
652 &self,
653 prompt: &str,
654 error: &str,
655 attempted_tools: &[String],
656 ) -> anyhow::Result<()> {
657 let content = format!(
658 "Failure: {}\nError: {}\nAttempted tools: {}",
659 prompt,
660 error,
661 attempted_tools.join(", ")
662 );
663
664 let item = MemoryItem::new(content)
665 .with_importance(0.9) .with_tag("failure")
667 .with_tag("avoid")
668 .with_type(MemoryType::Episodic)
669 .with_metadata("prompt", prompt)
670 .with_metadata("error", error);
671
672 self.remember(item).await
673 }
674
675 pub async fn recall_similar(
677 &self,
678 prompt: &str,
679 limit: usize,
680 ) -> anyhow::Result<Vec<MemoryItem>> {
681 self.store.search(prompt, limit).await
682 }
683
684 pub async fn recall_by_tags(
686 &self,
687 tags: &[String],
688 limit: usize,
689 ) -> anyhow::Result<Vec<MemoryItem>> {
690 self.store.search_by_tags(tags, limit).await
691 }
692
693 pub async fn get_recent(&self, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
695 self.store.get_recent(limit).await
696 }
697
698 pub async fn add_to_working(&self, item: MemoryItem) -> anyhow::Result<()> {
700 let mut working = self.working.write().await;
701 working.push(item);
702
703 if working.len() > self.max_working {
705 let now = Utc::now();
706 working.sort_by(|a, b| {
707 self.score(b, now)
708 .partial_cmp(&self.score(a, now))
709 .unwrap_or(std::cmp::Ordering::Equal)
710 });
711 working.truncate(self.max_working);
712 }
713
714 Ok(())
715 }
716
717 pub async fn get_working(&self) -> Vec<MemoryItem> {
719 self.working.read().await.clone()
720 }
721
722 pub async fn clear_working(&self) {
724 self.working.write().await.clear();
725 }
726
727 pub async fn get_short_term(&self) -> Vec<MemoryItem> {
729 self.short_term.read().await.iter().cloned().collect()
730 }
731
732 pub async fn clear_short_term(&self) {
734 self.short_term.write().await.clear();
735 }
736
737 pub async fn stats(&self) -> anyhow::Result<MemoryStats> {
739 let long_term_count = self.store.count().await?;
740 let short_term_count = self.short_term.read().await.len();
741 let working_count = self.working.read().await.len();
742
743 Ok(MemoryStats {
744 long_term_count,
745 short_term_count,
746 working_count,
747 })
748 }
749
750 pub fn store(&self) -> &Arc<dyn MemoryStore> {
752 &self.store
753 }
754
755 pub async fn working_count(&self) -> usize {
757 self.working.read().await.len()
758 }
759
760 pub async fn short_term_count(&self) -> usize {
762 self.short_term.read().await.len()
763 }
764}
765
766#[derive(Debug, Clone, Serialize, Deserialize)]
768pub struct MemoryStats {
769 pub long_term_count: usize,
771 pub short_term_count: usize,
773 pub working_count: usize,
775}
776
777pub struct MemoryContextProvider {
786 memory: AgentMemory,
787}
788
789impl MemoryContextProvider {
790 pub fn new(memory: AgentMemory) -> Self {
792 Self { memory }
793 }
794}
795
796#[async_trait::async_trait]
797impl crate::context::ContextProvider for MemoryContextProvider {
798 fn name(&self) -> &str {
799 "memory"
800 }
801
802 async fn query(
803 &self,
804 query: &crate::context::ContextQuery,
805 ) -> anyhow::Result<crate::context::ContextResult> {
806 let limit = query.max_results.min(5);
807 let items = self.memory.recall_similar(&query.query, limit).await?;
808
809 let mut result = crate::context::ContextResult::new("memory");
810 for item in items {
811 let relevance = item.relevance_score();
812 let token_count = item.content.len() / 4; let context_item = crate::context::ContextItem::new(
814 &item.id,
815 crate::context::ContextType::Memory,
816 &item.content,
817 )
818 .with_relevance(relevance)
819 .with_token_count(token_count)
820 .with_source("memory");
821 result.add_item(context_item);
822 }
823
824 Ok(result)
825 }
826
827 async fn on_turn_complete(
828 &self,
829 _session_id: &str,
830 prompt: &str,
831 response: &str,
832 ) -> anyhow::Result<()> {
833 self.memory.remember_success(prompt, &[], response).await
835 }
836}
837
838#[cfg(test)]
843mod tests {
844 use super::*;
845
846 #[test]
847 fn test_memory_item_creation() {
848 let item = MemoryItem::new("Test memory")
849 .with_importance(0.8)
850 .with_tag("test")
851 .with_type(MemoryType::Semantic);
852
853 assert_eq!(item.content, "Test memory");
854 assert_eq!(item.importance, 0.8);
855 assert_eq!(item.tags, vec!["test"]);
856 assert_eq!(item.memory_type, MemoryType::Semantic);
857 }
858
859 #[test]
860 fn test_memory_item_relevance() {
861 let item = MemoryItem::new("Test").with_importance(0.9);
862 let score = item.relevance_score();
863
864 assert!(score > 0.6);
866 }
867
868 #[test]
869 fn test_relevance_config_defaults() {
870 let config = RelevanceConfig::default();
871 assert_eq!(config.decay_days, 30.0);
872 assert_eq!(config.importance_weight, 0.7);
873 assert_eq!(config.recency_weight, 0.3);
874 }
875
876 #[test]
877 fn test_memory_config_defaults() {
878 let config = MemoryConfig::default();
879 assert_eq!(config.max_short_term, 100);
880 assert_eq!(config.max_working, 10);
881 assert_eq!(config.relevance.decay_days, 30.0);
882 }
883
884 #[test]
885 fn test_memory_config_serde_roundtrip() {
886 let config = MemoryConfig::default();
887 let json = serde_json::to_string(&config).unwrap();
888 let parsed: MemoryConfig = serde_json::from_str(&json).unwrap();
889 assert_eq!(parsed.max_short_term, config.max_short_term);
890 assert_eq!(parsed.max_working, config.max_working);
891 assert_eq!(parsed.relevance.decay_days, config.relevance.decay_days);
892 }
893
894 #[test]
895 fn test_agent_memory_with_config() {
896 let config = MemoryConfig {
897 relevance: RelevanceConfig {
898 decay_days: 7.0,
899 importance_weight: 0.5,
900 recency_weight: 0.5,
901 },
902 max_short_term: 50,
903 max_working: 5,
904 };
905 let memory = AgentMemory::with_config(Arc::new(InMemoryStore::new()), config);
906 assert_eq!(memory.max_short_term, 50);
907 assert_eq!(memory.max_working, 5);
908 assert_eq!(memory.relevance_config.decay_days, 7.0);
909 }
910
911 #[test]
912 fn test_agent_memory_score_uses_config() {
913 let config = MemoryConfig {
914 relevance: RelevanceConfig {
915 decay_days: 7.0,
916 importance_weight: 0.9,
917 recency_weight: 0.1,
918 },
919 ..Default::default()
920 };
921 let memory = AgentMemory::with_config(Arc::new(InMemoryStore::new()), config);
922
923 let item = MemoryItem::new("Test").with_importance(1.0);
924 let now = Utc::now();
925 let score = memory.score(&item, now);
926
927 assert!(score > 0.95, "Score was {}", score);
930 }
931
932 #[tokio::test]
933 async fn test_in_memory_store() {
934 let store = InMemoryStore::new();
935
936 let item = MemoryItem::new("Test memory").with_tag("test");
937 store.store(item.clone()).await.unwrap();
938
939 let retrieved = store.retrieve(&item.id).await.unwrap();
940 assert!(retrieved.is_some());
941 assert_eq!(retrieved.unwrap().content, "Test memory");
942 }
943
944 #[tokio::test]
945 async fn test_memory_search() {
946 let store = InMemoryStore::new();
947
948 store
949 .store(MemoryItem::new("How to create a file").with_tag("file"))
950 .await
951 .unwrap();
952 store
953 .store(MemoryItem::new("How to delete a file").with_tag("file"))
954 .await
955 .unwrap();
956 store
957 .store(MemoryItem::new("How to create a directory").with_tag("dir"))
958 .await
959 .unwrap();
960
961 let results = store.search("create", 10).await.unwrap();
962 assert_eq!(results.len(), 2);
963
964 let results = store
965 .search_by_tags(&["file".to_string()], 10)
966 .await
967 .unwrap();
968 assert_eq!(results.len(), 2);
969 }
970
971 #[tokio::test]
972 async fn test_agent_memory() {
973 let memory = AgentMemory::in_memory();
974
975 memory
977 .remember_success("Create a file", &["write".to_string()], "File created")
978 .await
979 .unwrap();
980
981 memory
983 .remember_failure("Delete file", "Permission denied", &["bash".to_string()])
984 .await
985 .unwrap();
986
987 let results = memory.recall_similar("create", 10).await.unwrap();
989 assert!(!results.is_empty());
990
991 let stats = memory.stats().await.unwrap();
992 assert_eq!(stats.long_term_count, 2);
993 }
994
995 #[tokio::test]
996 async fn test_working_memory() {
997 let memory = AgentMemory::in_memory();
998
999 let item = MemoryItem::new("Active task").with_type(MemoryType::Working);
1000 memory.add_to_working(item).await.unwrap();
1001
1002 let working = memory.get_working().await;
1003 assert_eq!(working.len(), 1);
1004
1005 memory.clear_working().await;
1006 let working = memory.get_working().await;
1007 assert_eq!(working.len(), 0);
1008 }
1009
1010 #[tokio::test]
1011 async fn test_file_store_basic() {
1012 let temp_dir = std::env::temp_dir();
1013 let test_file = temp_dir.join(format!("test_memory_{}.jsonl", uuid::Uuid::new_v4()));
1014
1015 let store = FileStore::new(&test_file).unwrap();
1017
1018 let item1 = MemoryItem::new("Test memory 1").with_tag("test");
1020 let item2 = MemoryItem::new("Test memory 2").with_tag("test");
1021
1022 store.store(item1.clone()).await.unwrap();
1023 store.store(item2.clone()).await.unwrap();
1024
1025 assert_eq!(store.count().await.unwrap(), 2);
1027
1028 let retrieved = store.retrieve(&item1.id).await.unwrap();
1030 assert!(retrieved.is_some());
1031 assert_eq!(retrieved.unwrap().content, "Test memory 1");
1032
1033 let _ = std::fs::remove_file(&test_file);
1035 }
1036
1037 #[tokio::test]
1038 async fn test_file_store_persistence() {
1039 let temp_dir = std::env::temp_dir();
1040 let test_file = temp_dir.join(format!(
1041 "test_memory_persist_{}.jsonl",
1042 uuid::Uuid::new_v4()
1043 ));
1044
1045 let item_id = {
1046 let store = FileStore::new(&test_file).unwrap();
1048 let item = MemoryItem::new("Persistent memory").with_importance(0.9);
1049 let id = item.id.clone();
1050 store.store(item).await.unwrap();
1051 id
1052 };
1053
1054 let store2 = FileStore::new(&test_file).unwrap();
1056
1057 assert_eq!(store2.count().await.unwrap(), 1);
1059 let retrieved = store2.retrieve(&item_id).await.unwrap();
1060 assert!(retrieved.is_some());
1061 assert_eq!(retrieved.unwrap().content, "Persistent memory");
1062
1063 let _ = std::fs::remove_file(&test_file);
1065 }
1066
1067 #[tokio::test]
1068 async fn test_file_store_search() {
1069 let temp_dir = std::env::temp_dir();
1070 let test_file = temp_dir.join(format!("test_memory_search_{}.jsonl", uuid::Uuid::new_v4()));
1071
1072 let store = FileStore::new(&test_file).unwrap();
1073
1074 store
1076 .store(MemoryItem::new("How to create a file").with_tag("file"))
1077 .await
1078 .unwrap();
1079 store
1080 .store(MemoryItem::new("How to delete a file").with_tag("file"))
1081 .await
1082 .unwrap();
1083 store
1084 .store(MemoryItem::new("How to create a directory").with_tag("dir"))
1085 .await
1086 .unwrap();
1087
1088 let results = store.search("create", 10).await.unwrap();
1090 assert_eq!(results.len(), 2);
1091
1092 let results = store
1094 .search_by_tags(&["file".to_string()], 10)
1095 .await
1096 .unwrap();
1097 assert_eq!(results.len(), 2);
1098
1099 let _ = std::fs::remove_file(&test_file);
1101 }
1102
1103 #[tokio::test]
1104 async fn test_file_store_delete() {
1105 let temp_dir = std::env::temp_dir();
1106 let test_file = temp_dir.join(format!("test_memory_delete_{}.jsonl", uuid::Uuid::new_v4()));
1107
1108 let store = FileStore::new(&test_file).unwrap();
1109
1110 let item = MemoryItem::new("To be deleted");
1111 let item_id = item.id.clone();
1112 store.store(item).await.unwrap();
1113
1114 assert_eq!(store.count().await.unwrap(), 1);
1115
1116 store.delete(&item_id).await.unwrap();
1118 assert_eq!(store.count().await.unwrap(), 0);
1119
1120 let store2 = FileStore::new(&test_file).unwrap();
1122 assert_eq!(store2.count().await.unwrap(), 0);
1123
1124 let _ = std::fs::remove_file(&test_file);
1126 }
1127
1128 #[tokio::test]
1129 async fn test_file_store_clear() {
1130 let temp_dir = std::env::temp_dir();
1131 let test_file = temp_dir.join(format!("test_memory_clear_{}.jsonl", uuid::Uuid::new_v4()));
1132
1133 let store = FileStore::new(&test_file).unwrap();
1134
1135 for i in 0..5 {
1137 store
1138 .store(MemoryItem::new(format!("Memory {}", i)))
1139 .await
1140 .unwrap();
1141 }
1142
1143 assert_eq!(store.count().await.unwrap(), 5);
1144
1145 store.clear().await.unwrap();
1147 assert_eq!(store.count().await.unwrap(), 0);
1148
1149 let store2 = FileStore::new(&test_file).unwrap();
1151 assert_eq!(store2.count().await.unwrap(), 0);
1152
1153 let _ = std::fs::remove_file(&test_file);
1155 }
1156}
1157
1158#[cfg(test)]
1159mod extra_memory_tests {
1160 use super::*;
1161
1162 #[test]
1167 fn test_memory_item_with_metadata() {
1168 let item = MemoryItem::new("test")
1169 .with_metadata("key1", "value1")
1170 .with_metadata("key2", "value2");
1171 assert_eq!(item.metadata.get("key1").unwrap(), "value1");
1172 assert_eq!(item.metadata.get("key2").unwrap(), "value2");
1173 }
1174
1175 #[test]
1176 fn test_memory_item_with_tags_vec() {
1177 let item = MemoryItem::new("test").with_tags(vec![
1178 "a".to_string(),
1179 "b".to_string(),
1180 "c".to_string(),
1181 ]);
1182 assert_eq!(item.tags.len(), 3);
1183 }
1184
1185 #[test]
1186 fn test_memory_item_importance_clamped() {
1187 let item_high = MemoryItem::new("test").with_importance(1.5);
1188 assert_eq!(item_high.importance, 1.0);
1189
1190 let item_low = MemoryItem::new("test").with_importance(-0.5);
1191 assert_eq!(item_low.importance, 0.0);
1192 }
1193
1194 #[test]
1195 fn test_memory_item_record_access() {
1196 let mut item = MemoryItem::new("test");
1197 assert_eq!(item.access_count, 0);
1198 assert!(item.last_accessed.is_none());
1199
1200 item.record_access();
1201 assert_eq!(item.access_count, 1);
1202 assert!(item.last_accessed.is_some());
1203
1204 item.record_access();
1205 assert_eq!(item.access_count, 2);
1206 }
1207
1208 #[test]
1209 fn test_memory_item_all_types() {
1210 let episodic = MemoryItem::new("e").with_type(MemoryType::Episodic);
1211 assert_eq!(episodic.memory_type, MemoryType::Episodic);
1212
1213 let semantic = MemoryItem::new("s").with_type(MemoryType::Semantic);
1214 assert_eq!(semantic.memory_type, MemoryType::Semantic);
1215
1216 let procedural = MemoryItem::new("p").with_type(MemoryType::Procedural);
1217 assert_eq!(procedural.memory_type, MemoryType::Procedural);
1218
1219 let working = MemoryItem::new("w").with_type(MemoryType::Working);
1220 assert_eq!(working.memory_type, MemoryType::Working);
1221 }
1222
1223 #[test]
1224 fn test_memory_item_default_type_is_episodic() {
1225 let item = MemoryItem::new("test");
1226 assert_eq!(item.memory_type, MemoryType::Episodic);
1227 }
1228
1229 #[tokio::test]
1234 async fn test_in_memory_store_retrieve_nonexistent() {
1235 let store = InMemoryStore::new();
1236 let result = store.retrieve("nonexistent").await.unwrap();
1237 assert!(result.is_none());
1238 }
1239
1240 #[tokio::test]
1241 async fn test_in_memory_store_delete() {
1242 let store = InMemoryStore::new();
1243 let item = MemoryItem::new("to delete");
1244 let id = item.id.clone();
1245 store.store(item).await.unwrap();
1246 assert_eq!(store.count().await.unwrap(), 1);
1247
1248 store.delete(&id).await.unwrap();
1249 assert_eq!(store.count().await.unwrap(), 0);
1250 }
1251
1252 #[tokio::test]
1253 async fn test_in_memory_store_clear() {
1254 let store = InMemoryStore::new();
1255 for i in 0..5 {
1256 store
1257 .store(MemoryItem::new(format!("item {}", i)))
1258 .await
1259 .unwrap();
1260 }
1261 assert_eq!(store.count().await.unwrap(), 5);
1262
1263 store.clear().await.unwrap();
1264 assert_eq!(store.count().await.unwrap(), 0);
1265 }
1266
1267 #[tokio::test]
1268 async fn test_in_memory_store_get_recent() {
1269 let store = InMemoryStore::new();
1270 for i in 0..5 {
1271 store
1272 .store(MemoryItem::new(format!("item {}", i)))
1273 .await
1274 .unwrap();
1275 }
1276 let recent = store.get_recent(3).await.unwrap();
1277 assert_eq!(recent.len(), 3);
1278 }
1279
1280 #[tokio::test]
1281 async fn test_in_memory_store_get_important() {
1282 let store = InMemoryStore::new();
1283 store
1284 .store(MemoryItem::new("low").with_importance(0.2))
1285 .await
1286 .unwrap();
1287 store
1288 .store(MemoryItem::new("medium").with_importance(0.5))
1289 .await
1290 .unwrap();
1291 store
1292 .store(MemoryItem::new("high").with_importance(0.9))
1293 .await
1294 .unwrap();
1295
1296 let important = store.get_important(0.7, 10).await.unwrap();
1297 assert_eq!(important.len(), 1);
1298 assert_eq!(important[0].content, "high");
1299 }
1300
1301 #[tokio::test]
1302 async fn test_in_memory_store_search_case_insensitive() {
1303 let store = InMemoryStore::new();
1304 store
1305 .store(MemoryItem::new("How to CREATE a file"))
1306 .await
1307 .unwrap();
1308 let results = store.search("create", 10).await.unwrap();
1309 assert_eq!(results.len(), 1);
1310 }
1311
1312 #[tokio::test]
1317 async fn test_agent_memory_short_term() {
1318 let memory = AgentMemory::in_memory();
1319 memory.remember(MemoryItem::new("item 1")).await.unwrap();
1320 memory.remember(MemoryItem::new("item 2")).await.unwrap();
1321
1322 let short_term = memory.get_short_term().await;
1323 assert_eq!(short_term.len(), 2);
1324
1325 memory.clear_short_term().await;
1326 let short_term = memory.get_short_term().await;
1327 assert_eq!(short_term.len(), 0);
1328 }
1329
1330 #[tokio::test]
1331 async fn test_agent_memory_short_term_count() {
1332 let memory = AgentMemory::in_memory();
1333 assert_eq!(memory.short_term_count().await, 0);
1334 memory.remember(MemoryItem::new("item")).await.unwrap();
1335 assert_eq!(memory.short_term_count().await, 1);
1336 }
1337
1338 #[tokio::test]
1339 async fn test_agent_memory_working_count() {
1340 let memory = AgentMemory::in_memory();
1341 assert_eq!(memory.working_count().await, 0);
1342 memory
1343 .add_to_working(MemoryItem::new("task"))
1344 .await
1345 .unwrap();
1346 assert_eq!(memory.working_count().await, 1);
1347 }
1348
1349 #[tokio::test]
1350 async fn test_agent_memory_recall_by_tags() {
1351 let memory = AgentMemory::in_memory();
1352 memory
1353 .remember_success("create file", &["write".to_string()], "ok")
1354 .await
1355 .unwrap();
1356 memory
1357 .remember_failure("delete file", "denied", &["bash".to_string()])
1358 .await
1359 .unwrap();
1360
1361 let successes = memory
1362 .recall_by_tags(&["success".to_string()], 10)
1363 .await
1364 .unwrap();
1365 assert_eq!(successes.len(), 1);
1366
1367 let failures = memory
1368 .recall_by_tags(&["failure".to_string()], 10)
1369 .await
1370 .unwrap();
1371 assert_eq!(failures.len(), 1);
1372 }
1373
1374 #[tokio::test]
1375 async fn test_agent_memory_get_recent() {
1376 let memory = AgentMemory::in_memory();
1377 for i in 0..5 {
1378 memory
1379 .remember(MemoryItem::new(format!("item {}", i)))
1380 .await
1381 .unwrap();
1382 }
1383 let recent = memory.get_recent(3).await.unwrap();
1384 assert_eq!(recent.len(), 3);
1385 }
1386
1387 #[tokio::test]
1388 async fn test_agent_memory_store_accessor() {
1389 let memory = AgentMemory::in_memory();
1390 memory.remember(MemoryItem::new("test")).await.unwrap();
1391 let count = memory.store().count().await.unwrap();
1392 assert_eq!(count, 1);
1393 }
1394
1395 #[tokio::test]
1396 async fn test_agent_memory_stats_all_fields() {
1397 let memory = AgentMemory::in_memory();
1398 memory.remember(MemoryItem::new("long term")).await.unwrap();
1399 memory
1400 .add_to_working(MemoryItem::new("working"))
1401 .await
1402 .unwrap();
1403
1404 let stats = memory.stats().await.unwrap();
1405 assert_eq!(stats.long_term_count, 1);
1406 assert_eq!(stats.short_term_count, 1); assert_eq!(stats.working_count, 1);
1408 }
1409
1410 #[tokio::test]
1411 async fn test_agent_memory_working_overflow_trims() {
1412 let store = Arc::new(InMemoryStore::new());
1413 let memory = AgentMemory {
1414 store,
1415 short_term: Arc::new(RwLock::new(VecDeque::new())),
1416 working: Arc::new(RwLock::new(Vec::new())),
1417 max_short_term: 100,
1418 max_working: 3, relevance_config: RelevanceConfig::default(),
1420 };
1421
1422 for i in 0..5 {
1423 memory
1424 .add_to_working(
1425 MemoryItem::new(format!("task {}", i)).with_importance(i as f32 * 0.2),
1426 )
1427 .await
1428 .unwrap();
1429 }
1430
1431 let working = memory.get_working().await;
1432 assert_eq!(working.len(), 3); }
1434}
1435
1436#[cfg(test)]
1437mod extra_memory_tests2 {
1438 use super::*;
1439
1440 #[tokio::test]
1441 async fn test_file_store_open_creates_parent_dirs() {
1442 let dir = tempfile::tempdir().unwrap();
1444 let path = dir
1445 .path()
1446 .join("nested")
1447 .join("deep")
1448 .join("memories.jsonl");
1449 let store = FileStore::open(&path).await.unwrap();
1450 let all = store.search("", 100).await.unwrap();
1452 assert!(all.is_empty());
1453 }
1454
1455 #[tokio::test]
1456 async fn test_file_store_open_loads_existing() {
1457 let dir = tempfile::tempdir().unwrap();
1458 let path = dir.path().join("memories.jsonl");
1459 {
1461 let store = FileStore::open(&path).await.unwrap();
1462 let item = MemoryItem::new("test memory".to_string());
1463 store.store(item).await.unwrap();
1464 }
1465 let store = FileStore::open(&path).await.unwrap();
1467 let results = store.search("test", 10).await.unwrap();
1468 assert_eq!(results.len(), 1);
1469 assert!(results[0].content.contains("test memory"));
1470 }
1471
1472 #[tokio::test]
1473 async fn test_file_store_open_nonexistent_file() {
1474 let dir = tempfile::tempdir().unwrap();
1475 let path = dir.path().join("nonexistent.jsonl");
1476 let store = FileStore::open(&path).await.unwrap();
1477 let all = store.search("", 100).await.unwrap();
1478 assert!(all.is_empty());
1479 }
1480
1481 #[test]
1482 fn test_parse_jsonl_empty_string() {
1483 let result = FileStore::parse_jsonl("").unwrap();
1484 assert!(result.is_empty());
1485 }
1486
1487 #[test]
1488 fn test_parse_jsonl_empty_lines_skipped() {
1489 let item = MemoryItem::new("hello".to_string());
1491 let json = serde_json::to_string(&item).unwrap();
1492 let content = format!("\n{}\n\n{}\n\n", json, json);
1493 let result = FileStore::parse_jsonl(&content).unwrap();
1494 assert_eq!(result.len(), 2);
1495 }
1496
1497 #[test]
1498 fn test_parse_jsonl_invalid_json_returns_error() {
1499 let result = FileStore::parse_jsonl("not valid json");
1500 assert!(result.is_err());
1501 }
1502
1503 #[test]
1504 fn test_parse_jsonl_valid_single_line() {
1505 let item = MemoryItem::new("single".to_string());
1506 let json = serde_json::to_string(&item).unwrap();
1507 let result = FileStore::parse_jsonl(&json).unwrap();
1508 assert_eq!(result.len(), 1);
1509 assert_eq!(result[0].content, "single");
1510 }
1511}