1use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, VecDeque};
9use std::sync::Arc;
10use tokio::sync::RwLock;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
18#[serde(rename_all = "camelCase")]
19pub struct RelevanceConfig {
20 #[serde(default = "RelevanceConfig::default_decay_days")]
22 pub decay_days: f32,
23 #[serde(default = "RelevanceConfig::default_importance_weight")]
25 pub importance_weight: f32,
26 #[serde(default = "RelevanceConfig::default_recency_weight")]
28 pub recency_weight: f32,
29}
30
31impl RelevanceConfig {
32 fn default_decay_days() -> f32 {
33 30.0
34 }
35 fn default_importance_weight() -> f32 {
36 0.7
37 }
38 fn default_recency_weight() -> f32 {
39 0.3
40 }
41}
42
43impl Default for RelevanceConfig {
44 fn default() -> Self {
45 Self {
46 decay_days: 30.0,
47 importance_weight: 0.7,
48 recency_weight: 0.3,
49 }
50 }
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55#[serde(rename_all = "camelCase")]
56pub struct MemoryConfig {
57 #[serde(default)]
59 pub relevance: RelevanceConfig,
60 #[serde(default = "MemoryConfig::default_max_short_term")]
62 pub max_short_term: usize,
63 #[serde(default = "MemoryConfig::default_max_working")]
65 pub max_working: usize,
66}
67
68impl MemoryConfig {
69 fn default_max_short_term() -> usize {
70 100
71 }
72 fn default_max_working() -> usize {
73 10
74 }
75}
76
77impl Default for MemoryConfig {
78 fn default() -> Self {
79 Self {
80 relevance: RelevanceConfig::default(),
81 max_short_term: 100,
82 max_working: 10,
83 }
84 }
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct MemoryItem {
94 pub id: String,
96 pub content: String,
98 pub timestamp: DateTime<Utc>,
100 pub importance: f32,
102 pub tags: Vec<String>,
104 pub memory_type: MemoryType,
106 pub metadata: HashMap<String, String>,
108 pub access_count: u32,
110 pub last_accessed: Option<DateTime<Utc>>,
112 #[serde(skip)]
114 pub content_lower: String,
115}
116
117impl MemoryItem {
118 pub fn new(content: impl Into<String>) -> Self {
120 let content = content.into();
121 let content_lower = content.to_lowercase();
122 Self {
123 id: uuid::Uuid::new_v4().to_string(),
124 content,
125 timestamp: Utc::now(),
126 importance: 0.5,
127 tags: Vec::new(),
128 memory_type: MemoryType::Episodic,
129 metadata: HashMap::new(),
130 access_count: 0,
131 last_accessed: None,
132 content_lower,
133 }
134 }
135
136 pub fn with_importance(mut self, importance: f32) -> Self {
138 self.importance = importance.clamp(0.0, 1.0);
139 self
140 }
141
142 pub fn with_tags(mut self, tags: Vec<String>) -> Self {
144 self.tags = tags;
145 self
146 }
147
148 pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
150 self.tags.push(tag.into());
151 self
152 }
153
154 pub fn with_type(mut self, memory_type: MemoryType) -> Self {
156 self.memory_type = memory_type;
157 self
158 }
159
160 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
162 self.metadata.insert(key.into(), value.into());
163 self
164 }
165
166 pub fn record_access(&mut self) {
168 self.access_count += 1;
169 self.last_accessed = Some(Utc::now());
170 }
171
172 pub fn relevance_score_at(&self, now: DateTime<Utc>) -> f32 {
176 let age_seconds = (now - self.timestamp).num_seconds() as f32;
177 let age_days = age_seconds / 86400.0;
178
179 let decay = (-age_days / 30.0).exp(); self.importance * 0.7 + decay * 0.3
184 }
185
186 pub fn relevance_score(&self) -> f32 {
188 self.relevance_score_at(Utc::now())
189 }
190}
191
192#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
194#[serde(rename_all = "snake_case")]
195pub enum MemoryType {
196 Episodic,
198 Semantic,
200 Procedural,
202 Working,
204}
205
206#[async_trait::async_trait]
212pub trait MemoryStore: Send + Sync {
213 async fn store(&self, item: MemoryItem) -> anyhow::Result<()>;
215
216 async fn retrieve(&self, id: &str) -> anyhow::Result<Option<MemoryItem>>;
218
219 async fn search(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryItem>>;
221
222 async fn search_by_tags(
224 &self,
225 tags: &[String],
226 limit: usize,
227 ) -> anyhow::Result<Vec<MemoryItem>>;
228
229 async fn get_recent(&self, limit: usize) -> anyhow::Result<Vec<MemoryItem>>;
231
232 async fn get_important(&self, threshold: f32, limit: usize) -> anyhow::Result<Vec<MemoryItem>>;
234
235 async fn delete(&self, id: &str) -> anyhow::Result<()>;
237
238 async fn clear(&self) -> anyhow::Result<()>;
240
241 async fn count(&self) -> anyhow::Result<usize>;
243}
244
245fn search_memories(memories: &[MemoryItem], query: &str, limit: usize) -> Vec<MemoryItem> {
251 let query_lower = query.to_lowercase();
252 let mut results: Vec<_> = memories
253 .iter()
254 .filter(|m| m.content_lower.contains(&query_lower))
255 .cloned()
256 .collect();
257 sort_by_relevance(&mut results);
258 results.truncate(limit);
259 results
260}
261
262fn search_memories_by_tags(
264 memories: &[MemoryItem],
265 tags: &[String],
266 limit: usize,
267) -> Vec<MemoryItem> {
268 let mut results: Vec<_> = memories
269 .iter()
270 .filter(|m| tags.iter().any(|tag| m.tags.contains(tag)))
271 .cloned()
272 .collect();
273 sort_by_relevance(&mut results);
274 results.truncate(limit);
275 results
276}
277
278fn recent_memories(memories: &[MemoryItem], limit: usize) -> Vec<MemoryItem> {
280 let mut results: Vec<_> = memories.to_vec();
281 results.sort_by(|a, b| b.timestamp.cmp(&a.timestamp));
282 results.truncate(limit);
283 results
284}
285
286fn important_memories(memories: &[MemoryItem], threshold: f32, limit: usize) -> Vec<MemoryItem> {
288 let mut results: Vec<_> = memories
289 .iter()
290 .filter(|m| m.importance >= threshold)
291 .cloned()
292 .collect();
293 results.sort_by(|a, b| {
294 b.importance
295 .partial_cmp(&a.importance)
296 .unwrap_or(std::cmp::Ordering::Equal)
297 });
298 results.truncate(limit);
299 results
300}
301
302fn sort_by_relevance(items: &mut [MemoryItem]) {
304 let now = Utc::now();
305 items.sort_by(|a, b| {
306 b.relevance_score_at(now)
307 .partial_cmp(&a.relevance_score_at(now))
308 .unwrap_or(std::cmp::Ordering::Equal)
309 });
310}
311
312#[derive(Debug, Clone)]
318pub struct InMemoryStore {
319 memories: Arc<RwLock<Vec<MemoryItem>>>,
320}
321
322impl InMemoryStore {
323 pub fn new() -> Self {
325 Self {
326 memories: Arc::new(RwLock::new(Vec::new())),
327 }
328 }
329}
330
331impl Default for InMemoryStore {
332 fn default() -> Self {
333 Self::new()
334 }
335}
336
337#[async_trait::async_trait]
338impl MemoryStore for InMemoryStore {
339 async fn store(&self, item: MemoryItem) -> anyhow::Result<()> {
340 let mut memories = self.memories.write().await;
341 memories.push(item);
342 Ok(())
343 }
344
345 async fn retrieve(&self, id: &str) -> anyhow::Result<Option<MemoryItem>> {
346 let memories = self.memories.read().await;
347 Ok(memories.iter().find(|m| m.id == id).cloned())
348 }
349
350 async fn search(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
351 let memories = self.memories.read().await;
352 Ok(search_memories(&memories, query, limit))
353 }
354
355 async fn search_by_tags(
356 &self,
357 tags: &[String],
358 limit: usize,
359 ) -> anyhow::Result<Vec<MemoryItem>> {
360 let memories = self.memories.read().await;
361 Ok(search_memories_by_tags(&memories, tags, limit))
362 }
363
364 async fn get_recent(&self, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
365 let memories = self.memories.read().await;
366 Ok(recent_memories(&memories, limit))
367 }
368
369 async fn get_important(&self, threshold: f32, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
370 let memories = self.memories.read().await;
371 Ok(important_memories(&memories, threshold, limit))
372 }
373
374 async fn delete(&self, id: &str) -> anyhow::Result<()> {
375 let mut memories = self.memories.write().await;
376 memories.retain(|m| m.id != id);
377 Ok(())
378 }
379
380 async fn clear(&self) -> anyhow::Result<()> {
381 let mut memories = self.memories.write().await;
382 memories.clear();
383 Ok(())
384 }
385
386 async fn count(&self) -> anyhow::Result<usize> {
387 let memories = self.memories.read().await;
388 Ok(memories.len())
389 }
390}
391
392#[derive(Debug, Clone)]
398pub struct FileStore {
399 file_path: std::path::PathBuf,
400 memories: Arc<RwLock<Vec<MemoryItem>>>,
401}
402
403impl FileStore {
404 pub fn new(file_path: impl Into<std::path::PathBuf>) -> anyhow::Result<Self> {
409 let file_path = file_path.into();
410
411 if let Some(parent) = file_path.parent() {
413 std::fs::create_dir_all(parent)?;
414 }
415
416 let memories = if file_path.exists() {
418 Self::load_from_file(&file_path)?
419 } else {
420 Vec::new()
421 };
422
423 Ok(Self {
424 file_path,
425 memories: Arc::new(RwLock::new(memories)),
426 })
427 }
428
429 pub async fn open(file_path: impl Into<std::path::PathBuf>) -> anyhow::Result<Self> {
431 let file_path = file_path.into();
432
433 if let Some(parent) = file_path.parent() {
435 tokio::fs::create_dir_all(parent).await?;
436 }
437
438 let memories = if file_path.exists() {
440 let content = tokio::fs::read_to_string(&file_path).await?;
441 Self::parse_jsonl(&content)?
442 } else {
443 Vec::new()
444 };
445
446 Ok(Self {
447 file_path,
448 memories: Arc::new(RwLock::new(memories)),
449 })
450 }
451
452 fn load_from_file(path: &std::path::Path) -> anyhow::Result<Vec<MemoryItem>> {
454 let content = std::fs::read_to_string(path)?;
455 Self::parse_jsonl(&content)
456 }
457
458 fn parse_jsonl(content: &str) -> anyhow::Result<Vec<MemoryItem>> {
460 let mut memories = Vec::new();
461
462 for line in content.lines() {
463 if line.trim().is_empty() {
464 continue;
465 }
466 let mut item: MemoryItem = serde_json::from_str(line)?;
467 item.content_lower = item.content.to_lowercase();
468 memories.push(item);
469 }
470
471 Ok(memories)
472 }
473
474 async fn save_to_file(&self) -> anyhow::Result<()> {
476 let memories = self.memories.read().await;
477 let mut content = String::new();
478
479 for memory in memories.iter() {
480 let json = serde_json::to_string(memory)?;
481 content.push_str(&json);
482 content.push('\n');
483 }
484
485 let temp_path = self.file_path.with_extension("tmp");
487 tokio::fs::write(&temp_path, content).await?;
488 tokio::fs::rename(&temp_path, &self.file_path).await?;
489
490 Ok(())
491 }
492}
493
494#[async_trait::async_trait]
495impl MemoryStore for FileStore {
496 async fn store(&self, item: MemoryItem) -> anyhow::Result<()> {
497 {
498 let mut memories = self.memories.write().await;
499 memories.push(item);
500 }
501 self.save_to_file().await
502 }
503
504 async fn retrieve(&self, id: &str) -> anyhow::Result<Option<MemoryItem>> {
505 let memories = self.memories.read().await;
506 Ok(memories.iter().find(|m| m.id == id).cloned())
507 }
508
509 async fn search(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
510 let memories = self.memories.read().await;
511 Ok(search_memories(&memories, query, limit))
512 }
513
514 async fn search_by_tags(
515 &self,
516 tags: &[String],
517 limit: usize,
518 ) -> anyhow::Result<Vec<MemoryItem>> {
519 let memories = self.memories.read().await;
520 Ok(search_memories_by_tags(&memories, tags, limit))
521 }
522
523 async fn get_recent(&self, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
524 let memories = self.memories.read().await;
525 Ok(recent_memories(&memories, limit))
526 }
527
528 async fn get_important(&self, threshold: f32, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
529 let memories = self.memories.read().await;
530 Ok(important_memories(&memories, threshold, limit))
531 }
532
533 async fn delete(&self, id: &str) -> anyhow::Result<()> {
534 {
535 let mut memories = self.memories.write().await;
536 memories.retain(|m| m.id != id);
537 }
538 self.save_to_file().await
539 }
540
541 async fn clear(&self) -> anyhow::Result<()> {
542 {
543 let mut memories = self.memories.write().await;
544 memories.clear();
545 }
546 self.save_to_file().await
547 }
548
549 async fn count(&self) -> anyhow::Result<usize> {
550 let memories = self.memories.read().await;
551 Ok(memories.len())
552 }
553}
554
555#[derive(Clone)]
561pub struct AgentMemory {
562 store: Arc<dyn MemoryStore>,
564 short_term: Arc<RwLock<VecDeque<MemoryItem>>>,
566 working: Arc<RwLock<Vec<MemoryItem>>>,
568 max_short_term: usize,
570 max_working: usize,
572 relevance_config: RelevanceConfig,
574}
575
576impl AgentMemory {
577 pub fn new(store: Arc<dyn MemoryStore>) -> Self {
579 Self::with_config(store, MemoryConfig::default())
580 }
581
582 pub fn with_config(store: Arc<dyn MemoryStore>, config: MemoryConfig) -> Self {
584 Self {
585 store,
586 short_term: Arc::new(RwLock::new(VecDeque::new())),
587 working: Arc::new(RwLock::new(Vec::new())),
588 max_short_term: config.max_short_term,
589 max_working: config.max_working,
590 relevance_config: config.relevance,
591 }
592 }
593
594 pub fn in_memory() -> Self {
596 Self::new(Arc::new(InMemoryStore::new()))
597 }
598
599 fn score(&self, item: &MemoryItem, now: DateTime<Utc>) -> f32 {
601 let age_seconds = (now - item.timestamp).num_seconds() as f32;
602 let age_days = age_seconds / 86400.0;
603 let decay = (-age_days / self.relevance_config.decay_days).exp();
604 item.importance * self.relevance_config.importance_weight
605 + decay * self.relevance_config.recency_weight
606 }
607
608 pub async fn remember(&self, item: MemoryItem) -> anyhow::Result<()> {
610 self.store.store(item.clone()).await?;
612
613 let mut short_term = self.short_term.write().await;
615 short_term.push_back(item);
616
617 if short_term.len() > self.max_short_term {
619 short_term.pop_front();
620 }
621
622 Ok(())
623 }
624
625 pub async fn remember_success(
627 &self,
628 prompt: &str,
629 tools_used: &[String],
630 result: &str,
631 ) -> anyhow::Result<()> {
632 let content = format!(
633 "Success: {}\nTools: {}\nResult: {}",
634 prompt,
635 tools_used.join(", "),
636 result
637 );
638
639 let item = MemoryItem::new(content)
640 .with_importance(0.8)
641 .with_tag("success")
642 .with_tag("pattern")
643 .with_type(MemoryType::Procedural)
644 .with_metadata("prompt", prompt)
645 .with_metadata("tools", tools_used.join(","));
646
647 self.remember(item).await
648 }
649
650 pub async fn remember_failure(
652 &self,
653 prompt: &str,
654 error: &str,
655 attempted_tools: &[String],
656 ) -> anyhow::Result<()> {
657 let content = format!(
658 "Failure: {}\nError: {}\nAttempted tools: {}",
659 prompt,
660 error,
661 attempted_tools.join(", ")
662 );
663
664 let item = MemoryItem::new(content)
665 .with_importance(0.9) .with_tag("failure")
667 .with_tag("avoid")
668 .with_type(MemoryType::Episodic)
669 .with_metadata("prompt", prompt)
670 .with_metadata("error", error);
671
672 self.remember(item).await
673 }
674
675 pub async fn recall_similar(
677 &self,
678 prompt: &str,
679 limit: usize,
680 ) -> anyhow::Result<Vec<MemoryItem>> {
681 self.store.search(prompt, limit).await
682 }
683
684 pub async fn recall_by_tags(
686 &self,
687 tags: &[String],
688 limit: usize,
689 ) -> anyhow::Result<Vec<MemoryItem>> {
690 self.store.search_by_tags(tags, limit).await
691 }
692
693 pub async fn get_recent(&self, limit: usize) -> anyhow::Result<Vec<MemoryItem>> {
695 self.store.get_recent(limit).await
696 }
697
698 pub async fn add_to_working(&self, item: MemoryItem) -> anyhow::Result<()> {
700 let mut working = self.working.write().await;
701 working.push(item);
702
703 if working.len() > self.max_working {
705 let now = Utc::now();
706 working.sort_by(|a, b| {
707 self.score(b, now)
708 .partial_cmp(&self.score(a, now))
709 .unwrap_or(std::cmp::Ordering::Equal)
710 });
711 working.truncate(self.max_working);
712 }
713
714 Ok(())
715 }
716
717 pub async fn get_working(&self) -> Vec<MemoryItem> {
719 self.working.read().await.clone()
720 }
721
722 pub async fn clear_working(&self) {
724 self.working.write().await.clear();
725 }
726
727 pub async fn get_short_term(&self) -> Vec<MemoryItem> {
729 self.short_term.read().await.iter().cloned().collect()
730 }
731
732 pub async fn clear_short_term(&self) {
734 self.short_term.write().await.clear();
735 }
736
737 pub async fn stats(&self) -> anyhow::Result<MemoryStats> {
739 let long_term_count = self.store.count().await?;
740 let short_term_count = self.short_term.read().await.len();
741 let working_count = self.working.read().await.len();
742
743 Ok(MemoryStats {
744 long_term_count,
745 short_term_count,
746 working_count,
747 })
748 }
749
750 pub fn store(&self) -> &Arc<dyn MemoryStore> {
752 &self.store
753 }
754
755 pub async fn working_count(&self) -> usize {
757 self.working.read().await.len()
758 }
759
760 pub async fn short_term_count(&self) -> usize {
762 self.short_term.read().await.len()
763 }
764}
765
766#[derive(Debug, Clone, Serialize, Deserialize)]
768pub struct MemoryStats {
769 pub long_term_count: usize,
771 pub short_term_count: usize,
773 pub working_count: usize,
775}
776
777pub struct MemoryContextProvider {
786 memory: AgentMemory,
787}
788
789impl MemoryContextProvider {
790 pub fn new(memory: AgentMemory) -> Self {
792 Self { memory }
793 }
794}
795
796#[async_trait::async_trait]
797impl crate::context::ContextProvider for MemoryContextProvider {
798 fn name(&self) -> &str {
799 "memory"
800 }
801
802 async fn query(
803 &self,
804 query: &crate::context::ContextQuery,
805 ) -> anyhow::Result<crate::context::ContextResult> {
806 let limit = query.max_results.min(5);
807 let items = self.memory.recall_similar(&query.query, limit).await?;
808
809 let mut result = crate::context::ContextResult::new("memory");
810 for item in items {
811 let relevance = item.relevance_score();
812 let token_count = item.content.len() / 4; let context_item = crate::context::ContextItem::new(
814 &item.id,
815 crate::context::ContextType::Memory,
816 &item.content,
817 )
818 .with_relevance(relevance)
819 .with_token_count(token_count)
820 .with_source("memory");
821 result.add_item(context_item);
822 }
823
824 Ok(result)
825 }
826
827 async fn on_turn_complete(
828 &self,
829 _session_id: &str,
830 prompt: &str,
831 response: &str,
832 ) -> anyhow::Result<()> {
833 self.memory.remember_success(prompt, &[], response).await
835 }
836}
837
838#[cfg(test)]
843mod tests {
844 use super::*;
845
846 #[test]
847 fn test_memory_item_creation() {
848 let item = MemoryItem::new("Test memory")
849 .with_importance(0.8)
850 .with_tag("test")
851 .with_type(MemoryType::Semantic);
852
853 assert_eq!(item.content, "Test memory");
854 assert_eq!(item.importance, 0.8);
855 assert_eq!(item.tags, vec!["test"]);
856 assert_eq!(item.memory_type, MemoryType::Semantic);
857 }
858
859 #[test]
860 fn test_memory_item_relevance() {
861 let item = MemoryItem::new("Test").with_importance(0.9);
862 let score = item.relevance_score();
863
864 assert!(score > 0.6);
866 }
867
868 #[test]
869 fn test_relevance_config_defaults() {
870 let config = RelevanceConfig::default();
871 assert_eq!(config.decay_days, 30.0);
872 assert_eq!(config.importance_weight, 0.7);
873 assert_eq!(config.recency_weight, 0.3);
874 }
875
876 #[test]
877 fn test_memory_config_defaults() {
878 let config = MemoryConfig::default();
879 assert_eq!(config.max_short_term, 100);
880 assert_eq!(config.max_working, 10);
881 assert_eq!(config.relevance.decay_days, 30.0);
882 }
883
884 #[test]
885 fn test_memory_config_serde_roundtrip() {
886 let config = MemoryConfig::default();
887 let json = serde_json::to_string(&config).unwrap();
888 let parsed: MemoryConfig = serde_json::from_str(&json).unwrap();
889 assert_eq!(parsed.max_short_term, config.max_short_term);
890 assert_eq!(parsed.max_working, config.max_working);
891 assert_eq!(parsed.relevance.decay_days, config.relevance.decay_days);
892 }
893
894 #[test]
895 fn test_agent_memory_with_config() {
896 let config = MemoryConfig {
897 relevance: RelevanceConfig {
898 decay_days: 7.0,
899 importance_weight: 0.5,
900 recency_weight: 0.5,
901 },
902 max_short_term: 50,
903 max_working: 5,
904 };
905 let memory = AgentMemory::with_config(
906 Arc::new(InMemoryStore::new()),
907 config,
908 );
909 assert_eq!(memory.max_short_term, 50);
910 assert_eq!(memory.max_working, 5);
911 assert_eq!(memory.relevance_config.decay_days, 7.0);
912 }
913
914 #[test]
915 fn test_agent_memory_score_uses_config() {
916 let config = MemoryConfig {
917 relevance: RelevanceConfig {
918 decay_days: 7.0,
919 importance_weight: 0.9,
920 recency_weight: 0.1,
921 },
922 ..Default::default()
923 };
924 let memory = AgentMemory::with_config(
925 Arc::new(InMemoryStore::new()),
926 config,
927 );
928
929 let item = MemoryItem::new("Test").with_importance(1.0);
930 let now = Utc::now();
931 let score = memory.score(&item, now);
932
933 assert!(score > 0.95, "Score was {}", score);
936 }
937
938 #[tokio::test]
939 async fn test_in_memory_store() {
940 let store = InMemoryStore::new();
941
942 let item = MemoryItem::new("Test memory").with_tag("test");
943 store.store(item.clone()).await.unwrap();
944
945 let retrieved = store.retrieve(&item.id).await.unwrap();
946 assert!(retrieved.is_some());
947 assert_eq!(retrieved.unwrap().content, "Test memory");
948 }
949
950 #[tokio::test]
951 async fn test_memory_search() {
952 let store = InMemoryStore::new();
953
954 store
955 .store(MemoryItem::new("How to create a file").with_tag("file"))
956 .await
957 .unwrap();
958 store
959 .store(MemoryItem::new("How to delete a file").with_tag("file"))
960 .await
961 .unwrap();
962 store
963 .store(MemoryItem::new("How to create a directory").with_tag("dir"))
964 .await
965 .unwrap();
966
967 let results = store.search("create", 10).await.unwrap();
968 assert_eq!(results.len(), 2);
969
970 let results = store
971 .search_by_tags(&["file".to_string()], 10)
972 .await
973 .unwrap();
974 assert_eq!(results.len(), 2);
975 }
976
977 #[tokio::test]
978 async fn test_agent_memory() {
979 let memory = AgentMemory::in_memory();
980
981 memory
983 .remember_success("Create a file", &["write".to_string()], "File created")
984 .await
985 .unwrap();
986
987 memory
989 .remember_failure("Delete file", "Permission denied", &["bash".to_string()])
990 .await
991 .unwrap();
992
993 let results = memory.recall_similar("create", 10).await.unwrap();
995 assert!(!results.is_empty());
996
997 let stats = memory.stats().await.unwrap();
998 assert_eq!(stats.long_term_count, 2);
999 }
1000
1001 #[tokio::test]
1002 async fn test_working_memory() {
1003 let memory = AgentMemory::in_memory();
1004
1005 let item = MemoryItem::new("Active task").with_type(MemoryType::Working);
1006 memory.add_to_working(item).await.unwrap();
1007
1008 let working = memory.get_working().await;
1009 assert_eq!(working.len(), 1);
1010
1011 memory.clear_working().await;
1012 let working = memory.get_working().await;
1013 assert_eq!(working.len(), 0);
1014 }
1015
1016 #[tokio::test]
1017 async fn test_file_store_basic() {
1018 let temp_dir = std::env::temp_dir();
1019 let test_file = temp_dir.join(format!("test_memory_{}.jsonl", uuid::Uuid::new_v4()));
1020
1021 let store = FileStore::new(&test_file).unwrap();
1023
1024 let item1 = MemoryItem::new("Test memory 1").with_tag("test");
1026 let item2 = MemoryItem::new("Test memory 2").with_tag("test");
1027
1028 store.store(item1.clone()).await.unwrap();
1029 store.store(item2.clone()).await.unwrap();
1030
1031 assert_eq!(store.count().await.unwrap(), 2);
1033
1034 let retrieved = store.retrieve(&item1.id).await.unwrap();
1036 assert!(retrieved.is_some());
1037 assert_eq!(retrieved.unwrap().content, "Test memory 1");
1038
1039 let _ = std::fs::remove_file(&test_file);
1041 }
1042
1043 #[tokio::test]
1044 async fn test_file_store_persistence() {
1045 let temp_dir = std::env::temp_dir();
1046 let test_file = temp_dir.join(format!(
1047 "test_memory_persist_{}.jsonl",
1048 uuid::Uuid::new_v4()
1049 ));
1050
1051 let item_id = {
1052 let store = FileStore::new(&test_file).unwrap();
1054 let item = MemoryItem::new("Persistent memory").with_importance(0.9);
1055 let id = item.id.clone();
1056 store.store(item).await.unwrap();
1057 id
1058 };
1059
1060 let store2 = FileStore::new(&test_file).unwrap();
1062
1063 assert_eq!(store2.count().await.unwrap(), 1);
1065 let retrieved = store2.retrieve(&item_id).await.unwrap();
1066 assert!(retrieved.is_some());
1067 assert_eq!(retrieved.unwrap().content, "Persistent memory");
1068
1069 let _ = std::fs::remove_file(&test_file);
1071 }
1072
1073 #[tokio::test]
1074 async fn test_file_store_search() {
1075 let temp_dir = std::env::temp_dir();
1076 let test_file = temp_dir.join(format!("test_memory_search_{}.jsonl", uuid::Uuid::new_v4()));
1077
1078 let store = FileStore::new(&test_file).unwrap();
1079
1080 store
1082 .store(MemoryItem::new("How to create a file").with_tag("file"))
1083 .await
1084 .unwrap();
1085 store
1086 .store(MemoryItem::new("How to delete a file").with_tag("file"))
1087 .await
1088 .unwrap();
1089 store
1090 .store(MemoryItem::new("How to create a directory").with_tag("dir"))
1091 .await
1092 .unwrap();
1093
1094 let results = store.search("create", 10).await.unwrap();
1096 assert_eq!(results.len(), 2);
1097
1098 let results = store
1100 .search_by_tags(&["file".to_string()], 10)
1101 .await
1102 .unwrap();
1103 assert_eq!(results.len(), 2);
1104
1105 let _ = std::fs::remove_file(&test_file);
1107 }
1108
1109 #[tokio::test]
1110 async fn test_file_store_delete() {
1111 let temp_dir = std::env::temp_dir();
1112 let test_file = temp_dir.join(format!("test_memory_delete_{}.jsonl", uuid::Uuid::new_v4()));
1113
1114 let store = FileStore::new(&test_file).unwrap();
1115
1116 let item = MemoryItem::new("To be deleted");
1117 let item_id = item.id.clone();
1118 store.store(item).await.unwrap();
1119
1120 assert_eq!(store.count().await.unwrap(), 1);
1121
1122 store.delete(&item_id).await.unwrap();
1124 assert_eq!(store.count().await.unwrap(), 0);
1125
1126 let store2 = FileStore::new(&test_file).unwrap();
1128 assert_eq!(store2.count().await.unwrap(), 0);
1129
1130 let _ = std::fs::remove_file(&test_file);
1132 }
1133
1134 #[tokio::test]
1135 async fn test_file_store_clear() {
1136 let temp_dir = std::env::temp_dir();
1137 let test_file = temp_dir.join(format!("test_memory_clear_{}.jsonl", uuid::Uuid::new_v4()));
1138
1139 let store = FileStore::new(&test_file).unwrap();
1140
1141 for i in 0..5 {
1143 store
1144 .store(MemoryItem::new(format!("Memory {}", i)))
1145 .await
1146 .unwrap();
1147 }
1148
1149 assert_eq!(store.count().await.unwrap(), 5);
1150
1151 store.clear().await.unwrap();
1153 assert_eq!(store.count().await.unwrap(), 0);
1154
1155 let store2 = FileStore::new(&test_file).unwrap();
1157 assert_eq!(store2.count().await.unwrap(), 0);
1158
1159 let _ = std::fs::remove_file(&test_file);
1161 }
1162}
1163
1164#[cfg(test)]
1165mod extra_memory_tests {
1166 use super::*;
1167
1168 #[test]
1173 fn test_memory_item_with_metadata() {
1174 let item = MemoryItem::new("test")
1175 .with_metadata("key1", "value1")
1176 .with_metadata("key2", "value2");
1177 assert_eq!(item.metadata.get("key1").unwrap(), "value1");
1178 assert_eq!(item.metadata.get("key2").unwrap(), "value2");
1179 }
1180
1181 #[test]
1182 fn test_memory_item_with_tags_vec() {
1183 let item = MemoryItem::new("test").with_tags(vec![
1184 "a".to_string(),
1185 "b".to_string(),
1186 "c".to_string(),
1187 ]);
1188 assert_eq!(item.tags.len(), 3);
1189 }
1190
1191 #[test]
1192 fn test_memory_item_importance_clamped() {
1193 let item_high = MemoryItem::new("test").with_importance(1.5);
1194 assert_eq!(item_high.importance, 1.0);
1195
1196 let item_low = MemoryItem::new("test").with_importance(-0.5);
1197 assert_eq!(item_low.importance, 0.0);
1198 }
1199
1200 #[test]
1201 fn test_memory_item_record_access() {
1202 let mut item = MemoryItem::new("test");
1203 assert_eq!(item.access_count, 0);
1204 assert!(item.last_accessed.is_none());
1205
1206 item.record_access();
1207 assert_eq!(item.access_count, 1);
1208 assert!(item.last_accessed.is_some());
1209
1210 item.record_access();
1211 assert_eq!(item.access_count, 2);
1212 }
1213
1214 #[test]
1215 fn test_memory_item_all_types() {
1216 let episodic = MemoryItem::new("e").with_type(MemoryType::Episodic);
1217 assert_eq!(episodic.memory_type, MemoryType::Episodic);
1218
1219 let semantic = MemoryItem::new("s").with_type(MemoryType::Semantic);
1220 assert_eq!(semantic.memory_type, MemoryType::Semantic);
1221
1222 let procedural = MemoryItem::new("p").with_type(MemoryType::Procedural);
1223 assert_eq!(procedural.memory_type, MemoryType::Procedural);
1224
1225 let working = MemoryItem::new("w").with_type(MemoryType::Working);
1226 assert_eq!(working.memory_type, MemoryType::Working);
1227 }
1228
1229 #[test]
1230 fn test_memory_item_default_type_is_episodic() {
1231 let item = MemoryItem::new("test");
1232 assert_eq!(item.memory_type, MemoryType::Episodic);
1233 }
1234
1235 #[tokio::test]
1240 async fn test_in_memory_store_retrieve_nonexistent() {
1241 let store = InMemoryStore::new();
1242 let result = store.retrieve("nonexistent").await.unwrap();
1243 assert!(result.is_none());
1244 }
1245
1246 #[tokio::test]
1247 async fn test_in_memory_store_delete() {
1248 let store = InMemoryStore::new();
1249 let item = MemoryItem::new("to delete");
1250 let id = item.id.clone();
1251 store.store(item).await.unwrap();
1252 assert_eq!(store.count().await.unwrap(), 1);
1253
1254 store.delete(&id).await.unwrap();
1255 assert_eq!(store.count().await.unwrap(), 0);
1256 }
1257
1258 #[tokio::test]
1259 async fn test_in_memory_store_clear() {
1260 let store = InMemoryStore::new();
1261 for i in 0..5 {
1262 store
1263 .store(MemoryItem::new(format!("item {}", i)))
1264 .await
1265 .unwrap();
1266 }
1267 assert_eq!(store.count().await.unwrap(), 5);
1268
1269 store.clear().await.unwrap();
1270 assert_eq!(store.count().await.unwrap(), 0);
1271 }
1272
1273 #[tokio::test]
1274 async fn test_in_memory_store_get_recent() {
1275 let store = InMemoryStore::new();
1276 for i in 0..5 {
1277 store
1278 .store(MemoryItem::new(format!("item {}", i)))
1279 .await
1280 .unwrap();
1281 }
1282 let recent = store.get_recent(3).await.unwrap();
1283 assert_eq!(recent.len(), 3);
1284 }
1285
1286 #[tokio::test]
1287 async fn test_in_memory_store_get_important() {
1288 let store = InMemoryStore::new();
1289 store
1290 .store(MemoryItem::new("low").with_importance(0.2))
1291 .await
1292 .unwrap();
1293 store
1294 .store(MemoryItem::new("medium").with_importance(0.5))
1295 .await
1296 .unwrap();
1297 store
1298 .store(MemoryItem::new("high").with_importance(0.9))
1299 .await
1300 .unwrap();
1301
1302 let important = store.get_important(0.7, 10).await.unwrap();
1303 assert_eq!(important.len(), 1);
1304 assert_eq!(important[0].content, "high");
1305 }
1306
1307 #[tokio::test]
1308 async fn test_in_memory_store_search_case_insensitive() {
1309 let store = InMemoryStore::new();
1310 store
1311 .store(MemoryItem::new("How to CREATE a file"))
1312 .await
1313 .unwrap();
1314 let results = store.search("create", 10).await.unwrap();
1315 assert_eq!(results.len(), 1);
1316 }
1317
1318 #[tokio::test]
1323 async fn test_agent_memory_short_term() {
1324 let memory = AgentMemory::in_memory();
1325 memory.remember(MemoryItem::new("item 1")).await.unwrap();
1326 memory.remember(MemoryItem::new("item 2")).await.unwrap();
1327
1328 let short_term = memory.get_short_term().await;
1329 assert_eq!(short_term.len(), 2);
1330
1331 memory.clear_short_term().await;
1332 let short_term = memory.get_short_term().await;
1333 assert_eq!(short_term.len(), 0);
1334 }
1335
1336 #[tokio::test]
1337 async fn test_agent_memory_short_term_count() {
1338 let memory = AgentMemory::in_memory();
1339 assert_eq!(memory.short_term_count().await, 0);
1340 memory.remember(MemoryItem::new("item")).await.unwrap();
1341 assert_eq!(memory.short_term_count().await, 1);
1342 }
1343
1344 #[tokio::test]
1345 async fn test_agent_memory_working_count() {
1346 let memory = AgentMemory::in_memory();
1347 assert_eq!(memory.working_count().await, 0);
1348 memory
1349 .add_to_working(MemoryItem::new("task"))
1350 .await
1351 .unwrap();
1352 assert_eq!(memory.working_count().await, 1);
1353 }
1354
1355 #[tokio::test]
1356 async fn test_agent_memory_recall_by_tags() {
1357 let memory = AgentMemory::in_memory();
1358 memory
1359 .remember_success("create file", &["write".to_string()], "ok")
1360 .await
1361 .unwrap();
1362 memory
1363 .remember_failure("delete file", "denied", &["bash".to_string()])
1364 .await
1365 .unwrap();
1366
1367 let successes = memory
1368 .recall_by_tags(&["success".to_string()], 10)
1369 .await
1370 .unwrap();
1371 assert_eq!(successes.len(), 1);
1372
1373 let failures = memory
1374 .recall_by_tags(&["failure".to_string()], 10)
1375 .await
1376 .unwrap();
1377 assert_eq!(failures.len(), 1);
1378 }
1379
1380 #[tokio::test]
1381 async fn test_agent_memory_get_recent() {
1382 let memory = AgentMemory::in_memory();
1383 for i in 0..5 {
1384 memory
1385 .remember(MemoryItem::new(format!("item {}", i)))
1386 .await
1387 .unwrap();
1388 }
1389 let recent = memory.get_recent(3).await.unwrap();
1390 assert_eq!(recent.len(), 3);
1391 }
1392
1393 #[tokio::test]
1394 async fn test_agent_memory_store_accessor() {
1395 let memory = AgentMemory::in_memory();
1396 memory.remember(MemoryItem::new("test")).await.unwrap();
1397 let count = memory.store().count().await.unwrap();
1398 assert_eq!(count, 1);
1399 }
1400
1401 #[tokio::test]
1402 async fn test_agent_memory_stats_all_fields() {
1403 let memory = AgentMemory::in_memory();
1404 memory.remember(MemoryItem::new("long term")).await.unwrap();
1405 memory
1406 .add_to_working(MemoryItem::new("working"))
1407 .await
1408 .unwrap();
1409
1410 let stats = memory.stats().await.unwrap();
1411 assert_eq!(stats.long_term_count, 1);
1412 assert_eq!(stats.short_term_count, 1); assert_eq!(stats.working_count, 1);
1414 }
1415
1416 #[tokio::test]
1417 async fn test_agent_memory_working_overflow_trims() {
1418 let store = Arc::new(InMemoryStore::new());
1419 let memory = AgentMemory {
1420 store,
1421 short_term: Arc::new(RwLock::new(VecDeque::new())),
1422 working: Arc::new(RwLock::new(Vec::new())),
1423 max_short_term: 100,
1424 max_working: 3, relevance_config: RelevanceConfig::default(),
1426 };
1427
1428 for i in 0..5 {
1429 memory
1430 .add_to_working(
1431 MemoryItem::new(format!("task {}", i)).with_importance(i as f32 * 0.2),
1432 )
1433 .await
1434 .unwrap();
1435 }
1436
1437 let working = memory.get_working().await;
1438 assert_eq!(working.len(), 3); }
1440}
1441
1442#[cfg(test)]
1443mod extra_memory_tests2 {
1444 use super::*;
1445
1446 #[tokio::test]
1447 async fn test_file_store_open_creates_parent_dirs() {
1448 let dir = tempfile::tempdir().unwrap();
1450 let path = dir
1451 .path()
1452 .join("nested")
1453 .join("deep")
1454 .join("memories.jsonl");
1455 let store = FileStore::open(&path).await.unwrap();
1456 let all = store.search("", 100).await.unwrap();
1458 assert!(all.is_empty());
1459 }
1460
1461 #[tokio::test]
1462 async fn test_file_store_open_loads_existing() {
1463 let dir = tempfile::tempdir().unwrap();
1464 let path = dir.path().join("memories.jsonl");
1465 {
1467 let store = FileStore::open(&path).await.unwrap();
1468 let item = MemoryItem::new("test memory".to_string());
1469 store.store(item).await.unwrap();
1470 }
1471 let store = FileStore::open(&path).await.unwrap();
1473 let results = store.search("test", 10).await.unwrap();
1474 assert_eq!(results.len(), 1);
1475 assert!(results[0].content.contains("test memory"));
1476 }
1477
1478 #[tokio::test]
1479 async fn test_file_store_open_nonexistent_file() {
1480 let dir = tempfile::tempdir().unwrap();
1481 let path = dir.path().join("nonexistent.jsonl");
1482 let store = FileStore::open(&path).await.unwrap();
1483 let all = store.search("", 100).await.unwrap();
1484 assert!(all.is_empty());
1485 }
1486
1487 #[test]
1488 fn test_parse_jsonl_empty_string() {
1489 let result = FileStore::parse_jsonl("").unwrap();
1490 assert!(result.is_empty());
1491 }
1492
1493 #[test]
1494 fn test_parse_jsonl_empty_lines_skipped() {
1495 let item = MemoryItem::new("hello".to_string());
1497 let json = serde_json::to_string(&item).unwrap();
1498 let content = format!("\n{}\n\n{}\n\n", json, json);
1499 let result = FileStore::parse_jsonl(&content).unwrap();
1500 assert_eq!(result.len(), 2);
1501 }
1502
1503 #[test]
1504 fn test_parse_jsonl_invalid_json_returns_error() {
1505 let result = FileStore::parse_jsonl("not valid json");
1506 assert!(result.is_err());
1507 }
1508
1509 #[test]
1510 fn test_parse_jsonl_valid_single_line() {
1511 let item = MemoryItem::new("single".to_string());
1512 let json = serde_json::to_string(&item).unwrap();
1513 let result = FileStore::parse_jsonl(&json).unwrap();
1514 assert_eq!(result.len(), 1);
1515 assert_eq!(result[0].content, "single");
1516 }
1517}