1use async_trait::async_trait;
2use chrono::{NaiveDate, Utc};
3
4use crate::{
5 config::MemoryExpiryConfig,
6 error::{AgentError, ToolError},
7};
8
9pub const ENTRY_DELIMITER: &str = "\n§\n";
10
11pub const MAX_ENTRY_CHARS: usize = 500;
13
14pub fn validate_memory_content(content: &str) -> Result<(), ToolError> {
19 if content.chars().count() > MAX_ENTRY_CHARS {
20 return Err(ToolError::InvalidArgs(
21 "content exceeds 500 character limit".into(),
22 ));
23 }
24 let lower = content.to_lowercase();
25 if lower.contains("<untrusted_memory")
26 || lower.contains("</untrusted_memory>")
27 || lower.contains("</untrusted")
28 || lower.contains("<untrusted_external_content")
29 {
30 return Err(ToolError::InvalidArgs(
31 "content contains disallowed XML tags".into(),
32 ));
33 }
34 Ok(())
35}
36
37fn strip_angle_brackets(s: &str) -> String {
42 s.replace(['<', '>'], "")
43}
44
45#[derive(Debug, Clone, PartialEq, Default)]
46pub enum MemoryCategory {
47 Fact,
48 Preference,
49 Skill,
50 Project,
51 #[default]
52 Other,
53}
54
55impl MemoryCategory {
56 pub fn as_str(&self) -> &'static str {
57 match self {
58 Self::Fact => "fact",
59 Self::Preference => "preference",
60 Self::Skill => "skill",
61 Self::Project => "project",
62 Self::Other => "other",
63 }
64 }
65
66 pub fn from_name(s: &str) -> Self {
67 match s {
68 "fact" => Self::Fact,
69 "preference" => Self::Preference,
70 "skill" => Self::Skill,
71 "project" => Self::Project,
72 _ => Self::Other,
73 }
74 }
75
76 pub fn display_name(&self) -> &'static str {
77 match self {
78 Self::Fact => "Facts",
79 Self::Preference => "Preferences",
80 Self::Skill => "Skills",
81 Self::Project => "Project",
82 Self::Other => "Other",
83 }
84 }
85}
86
87#[derive(Debug, Clone)]
88pub struct MemoryEntry {
89 pub category: MemoryCategory,
90 pub content: String,
91 pub created_at: String, }
93
94impl MemoryEntry {
95 pub fn new(category: MemoryCategory, content: String) -> Self {
96 let created_at = Utc::now().format("%Y-%m-%d").to_string();
97 Self {
98 category,
99 content,
100 created_at,
101 }
102 }
103
104 fn parse(raw: &str) -> Self {
106 let raw = raw.trim();
107 if let Some(rest) = raw.strip_prefix('[') {
108 if let Some(bracket_end) = rest.find(']') {
109 let meta = &rest[..bracket_end];
110 let content = rest[bracket_end + 1..].trim().to_string();
111 let mut parts = meta.splitn(2, '|');
112 let cat_str = parts.next().unwrap_or("other");
113 let date = parts.next().unwrap_or("").to_string();
114 return Self {
115 category: MemoryCategory::from_name(cat_str),
116 content,
117 created_at: date,
118 };
119 }
120 }
121 Self {
122 category: MemoryCategory::Other,
123 content: raw.to_string(),
124 created_at: String::new(),
125 }
126 }
127
128 fn serialize(&self) -> String {
129 format!(
130 "[{}|{}] {}",
131 self.category.as_str(),
132 self.created_at,
133 self.content
134 )
135 }
136}
137
138#[derive(Debug, Clone, Default)]
139pub struct MemoryContent {
140 pub entries: Vec<MemoryEntry>,
141}
142
143impl MemoryContent {
144 pub fn parse(raw: &str) -> Self {
145 let entries = raw
146 .split(ENTRY_DELIMITER)
147 .map(str::trim)
148 .filter(|s| !s.is_empty())
149 .map(MemoryEntry::parse)
150 .collect();
151 Self { entries }
152 }
153
154 pub fn serialize(&self) -> String {
155 self.entries
156 .iter()
157 .map(MemoryEntry::serialize)
158 .collect::<Vec<_>>()
159 .join(ENTRY_DELIMITER)
160 }
161
162 pub fn expire(&mut self, config: &MemoryExpiryConfig) -> usize {
166 let today = Utc::now().date_naive();
167 let before = self.entries.len();
168
169 self.entries.retain(|e| {
170 let max_days = match e.category {
171 MemoryCategory::Fact => config.fact_days,
172 MemoryCategory::Project => config.project_days,
173 MemoryCategory::Other => config.other_days,
174 MemoryCategory::Preference => config.preference_days,
175 MemoryCategory::Skill => config.skill_days,
176 };
177
178 let Some(days) = max_days else {
179 return true; };
181
182 let Ok(date) = NaiveDate::parse_from_str(&e.created_at, "%Y-%m-%d") else {
183 return true; };
185
186 let age = (today - date).num_days();
187 age <= i64::from(days)
188 });
189
190 before - self.entries.len()
191 }
192
193 const PREFETCH_LIMIT: usize = 8;
195
196 const STOP_WORDS: &'static [&'static str] = &[
199 "there", "about", "which", "where", "their", "those", "these", "every", "after", "other",
200 "never", "still", "under", "again", "being", "since", "while", "shall", "might", "until",
201 "above", "below", "maybe", "often", "quite", "would", "could", "whose", "whether",
202 "however", "although", "because", "without", "within", "around", "before", "should",
203 "through", "always", "almost", "already",
204 ];
205
206 pub fn prefetch(&self, query: &str) -> Vec<&MemoryEntry> {
211 let words: Vec<String> = query
212 .split_whitespace()
213 .filter_map(|w| {
214 let alpha: String = w.chars().filter(|c| c.is_alphabetic()).collect();
215 if alpha.len() < 5 {
216 return None;
217 }
218 let lower = alpha.to_lowercase();
219 if Self::STOP_WORDS.contains(&lower.as_str()) {
220 return None;
221 }
222 Some(lower)
223 })
224 .collect();
225 if words.is_empty() {
226 return vec![];
227 }
228 let mut hits: Vec<&MemoryEntry> = self
229 .entries
230 .iter()
231 .filter(|e| {
232 let lower = e.content.to_lowercase();
233 words.iter().any(|w| lower.contains(w.as_str()))
234 })
235 .collect();
236 hits.sort_by(|a, b| b.created_at.cmp(&a.created_at));
238 hits.truncate(Self::PREFETCH_LIMIT);
239 hits
240 }
241
242 pub fn prefetch_for_prompt(&self, query: &str) -> String {
245 let hits = self.prefetch(query);
246 if hits.is_empty() {
247 return String::new();
248 }
249 hits.iter()
250 .map(|e| {
251 let safe = strip_angle_brackets(&e.content);
252 if e.created_at.is_empty() {
253 format!("- {safe} [{}]", e.category.display_name())
254 } else {
255 format!(
256 "- {safe} [{}] ({})",
257 e.category.display_name(),
258 e.created_at
259 )
260 }
261 })
262 .collect::<Vec<_>>()
263 .join("\n")
264 }
265
266 pub fn serialize_for_prompt(&self) -> String {
268 if self.entries.is_empty() {
269 return String::new();
270 }
271
272 let order = [
273 MemoryCategory::Fact,
274 MemoryCategory::Preference,
275 MemoryCategory::Skill,
276 MemoryCategory::Project,
277 MemoryCategory::Other,
278 ];
279
280 let mut sections = Vec::new();
281 for cat in &order {
282 let items: Vec<&MemoryEntry> =
283 self.entries.iter().filter(|e| &e.category == cat).collect();
284 if items.is_empty() {
285 continue;
286 }
287 let lines: Vec<String> = items
288 .iter()
289 .map(|e| {
290 let safe = strip_angle_brackets(&e.content);
291 if e.created_at.is_empty() {
292 format!("- {safe}")
293 } else {
294 format!("- {safe} ({})", e.created_at)
295 }
296 })
297 .collect();
298 sections.push(format!("## {}\n{}", cat.display_name(), lines.join("\n")));
299 }
300
301 let inner = sections.join("\n\n");
302 format!("<untrusted_memory>\n{inner}\n</untrusted_memory>")
303 }
304}
305
306#[async_trait]
307pub trait MemoryStore: Send + Sync + 'static {
308 async fn read_memory(&self) -> Result<MemoryContent, AgentError>;
309 async fn write_memory(&self, content: &MemoryContent) -> Result<(), AgentError>;
310
311 async fn read_user_profile(&self) -> Result<String, AgentError>;
312 async fn write_user_profile(&self, content: &str) -> Result<(), AgentError>;
313}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318
319 #[test]
322 fn category_roundtrip() {
323 for (s, cat) in [
324 ("fact", MemoryCategory::Fact),
325 ("preference", MemoryCategory::Preference),
326 ("skill", MemoryCategory::Skill),
327 ("project", MemoryCategory::Project),
328 ("other", MemoryCategory::Other),
329 ] {
330 assert_eq!(MemoryCategory::from_name(s), cat);
331 assert_eq!(cat.as_str(), s);
332 }
333 }
334
335 #[test]
336 fn unknown_category_becomes_other() {
337 assert_eq!(MemoryCategory::from_name("unknown"), MemoryCategory::Other);
338 assert_eq!(MemoryCategory::from_name(""), MemoryCategory::Other);
339 }
340
341 #[test]
344 fn parse_structured_entry() {
345 let e = MemoryEntry::parse("[fact|2026-04-27] user likes Rust");
346 assert_eq!(e.category, MemoryCategory::Fact);
347 assert_eq!(e.content, "user likes Rust");
348 assert_eq!(e.created_at, "2026-04-27");
349 }
350
351 #[test]
352 fn parse_all_category_prefixes() {
353 for cat in ["fact", "preference", "skill", "project", "other"] {
354 let raw = format!("[{cat}|2026-01-01] some content");
355 let e = MemoryEntry::parse(&raw);
356 assert_eq!(e.category, MemoryCategory::from_name(cat));
357 assert_eq!(e.content, "some content");
358 }
359 }
360
361 #[test]
362 fn parse_plain_entry_backward_compat() {
363 let e = MemoryEntry::parse("old plain memory entry");
364 assert_eq!(e.category, MemoryCategory::Other);
365 assert_eq!(e.content, "old plain memory entry");
366 assert!(e.created_at.is_empty());
367 }
368
369 #[test]
370 fn parse_trims_whitespace() {
371 let e = MemoryEntry::parse(" [fact|2026-04-27] trimmed content ");
372 assert_eq!(e.content, "trimmed content");
373 }
374
375 #[test]
378 fn parse_empty_string() {
379 let mc = MemoryContent::parse("");
380 assert!(mc.entries.is_empty());
381 }
382
383 #[test]
384 fn parse_whitespace_only() {
385 let mc = MemoryContent::parse(" \n ");
386 assert!(mc.entries.is_empty());
387 }
388
389 #[test]
390 fn serialize_roundtrip() {
391 let raw = "[fact|2026-04-27] entry one\n§\n[preference|2026-04-27] entry two";
392 let mc = MemoryContent::parse(raw);
393 assert_eq!(mc.entries.len(), 2);
394 let serialized = mc.serialize();
395 let mc2 = MemoryContent::parse(&serialized);
396 assert_eq!(mc2.entries.len(), 2);
397 assert_eq!(mc2.entries[0].content, "entry one");
398 assert_eq!(mc2.entries[1].content, "entry two");
399 assert_eq!(mc2.entries[0].category, MemoryCategory::Fact);
400 assert_eq!(mc2.entries[1].category, MemoryCategory::Preference);
401 }
402
403 #[test]
404 fn serialize_roundtrip_backward_compat() {
405 let raw = "plain old entry\n§\n[fact|2026-01-01] new entry";
407 let mc = MemoryContent::parse(raw);
408 assert_eq!(mc.entries.len(), 2);
409 assert_eq!(mc.entries[0].category, MemoryCategory::Other);
410 assert_eq!(mc.entries[0].content, "plain old entry");
411 }
412
413 #[test]
416 fn prompt_empty_when_no_entries() {
417 let mc = MemoryContent::default();
418 assert!(mc.serialize_for_prompt().is_empty());
419 }
420
421 #[test]
422 fn prompt_groups_by_category() {
423 let raw = "[preference|2026-04-27] short answers\n§\n[fact|2026-04-27] Rust is fast\n§\n[preference|2026-04-27] no emojis";
424 let mc = MemoryContent::parse(raw);
425 let prompt = mc.serialize_for_prompt();
426
427 let fact_pos = prompt.find("## Facts").unwrap();
429 let pref_pos = prompt.find("## Preferences").unwrap();
430 assert!(fact_pos < pref_pos);
431
432 assert!(prompt.contains("- Rust is fast"));
433 assert!(prompt.contains("- short answers"));
434 assert!(prompt.contains("- no emojis"));
435 }
436
437 #[test]
438 fn prompt_skips_empty_categories() {
439 let raw = "[fact|2026-04-27] only facts here";
440 let mc = MemoryContent::parse(raw);
441 let prompt = mc.serialize_for_prompt();
442 assert!(prompt.contains("## Facts"));
443 assert!(!prompt.contains("## Preferences"));
444 assert!(!prompt.contains("## Skills"));
445 assert!(!prompt.contains("## Other"));
446 }
447
448 #[test]
449 fn prompt_shows_date_when_present() {
450 let raw = "[fact|2026-04-27] dated entry";
451 let mc = MemoryContent::parse(raw);
452 let prompt = mc.serialize_for_prompt();
453 assert!(prompt.contains("(2026-04-27)"));
454 }
455
456 #[test]
457 fn prompt_omits_date_for_plain_entries() {
458 let raw = "plain entry no date";
459 let mc = MemoryContent::parse(raw);
460 let prompt = mc.serialize_for_prompt();
461 assert!(prompt.contains("- plain entry no date"));
462 assert!(!prompt.contains("()"));
463 }
464
465 use crate::config::MemoryExpiryConfig;
468 use chrono::{Duration, Utc};
469
470 fn days_ago(n: i64) -> String {
471 (Utc::now().date_naive() - Duration::days(n))
472 .format("%Y-%m-%d")
473 .to_string()
474 }
475
476 fn default_expiry() -> MemoryExpiryConfig {
477 MemoryExpiryConfig::default() }
479
480 #[test]
481 fn expire_removes_old_fact() {
482 let raw = format!("[fact|{}] old fact", days_ago(91));
483 let mut mc = MemoryContent::parse(&raw);
484 let removed = mc.expire(&default_expiry());
485 assert_eq!(removed, 1);
486 assert!(mc.entries.is_empty());
487 }
488
489 #[test]
490 fn expire_keeps_recent_fact() {
491 let raw = format!("[fact|{}] recent fact", days_ago(10));
492 let mut mc = MemoryContent::parse(&raw);
493 let removed = mc.expire(&default_expiry());
494 assert_eq!(removed, 0);
495 assert_eq!(mc.entries.len(), 1);
496 }
497
498 #[test]
499 fn expire_removes_old_project() {
500 let raw = format!("[project|{}] stale project", days_ago(31));
501 let mut mc = MemoryContent::parse(&raw);
502 let removed = mc.expire(&default_expiry());
503 assert_eq!(removed, 1);
504 }
505
506 #[test]
507 fn expire_never_removes_preference() {
508 let raw = format!("[preference|{}] old preference", days_ago(999));
509 let mut mc = MemoryContent::parse(&raw);
510 let removed = mc.expire(&default_expiry());
511 assert_eq!(removed, 0);
512 assert_eq!(mc.entries.len(), 1);
513 }
514
515 #[test]
516 fn expire_never_removes_skill() {
517 let raw = format!("[skill|{}] old skill", days_ago(999));
518 let mut mc = MemoryContent::parse(&raw);
519 let removed = mc.expire(&default_expiry());
520 assert_eq!(removed, 0);
521 assert_eq!(mc.entries.len(), 1);
522 }
523
524 #[test]
525 fn expire_skips_entries_without_date() {
526 let mut mc = MemoryContent::parse("plain old entry with no date");
527 let removed = mc.expire(&default_expiry());
528 assert_eq!(removed, 0);
529 assert_eq!(mc.entries.len(), 1);
530 }
531
532 #[test]
533 fn expire_returns_correct_count() {
534 let raw = format!(
535 "[fact|{}] keep\n§\n[fact|{}] drop one\n§\n[project|{}] drop two",
536 days_ago(10),
537 days_ago(91),
538 days_ago(31),
539 );
540 let mut mc = MemoryContent::parse(&raw);
541 let removed = mc.expire(&default_expiry());
542 assert_eq!(removed, 2);
543 assert_eq!(mc.entries.len(), 1);
544 assert_eq!(mc.entries[0].content, "keep");
545 }
546
547 #[test]
550 fn prefetch_returns_matching_entries() {
551 let raw = "[preference|2026-04-29] user drinks black coffee\n§\n[fact|2026-04-29] user lives in Bangkok";
552 let mc = MemoryContent::parse(raw);
553 let hits = mc.prefetch("about coffee");
555 assert_eq!(hits.len(), 1);
556 assert!(hits[0].content.contains("coffee"));
557 }
558
559 #[test]
560 fn prefetch_returns_empty_when_no_match() {
561 let raw = "[preference|2026-04-29] user likes black coffee";
562 let mc = MemoryContent::parse(raw);
563 let hits = mc.prefetch("current weather");
565 assert!(hits.is_empty());
566 }
567
568 #[test]
569 fn prefetch_is_case_insensitive() {
570 let raw = "[preference|2026-04-29] user likes Black Coffee";
571 let mc = MemoryContent::parse(raw);
572 let hits = mc.prefetch("COFFEE");
573 assert_eq!(hits.len(), 1);
574 }
575
576 #[test]
577 fn prefetch_strips_punctuation_from_query_words() {
578 let raw = "[preference|2026-04-29] user drinks black coffee";
579 let mc = MemoryContent::parse(raw);
580 let hits = mc.prefetch("coffee?");
581 assert_eq!(hits.len(), 1);
582 }
583
584 #[test]
585 fn prefetch_ignores_short_words() {
586 let raw = "[preference|2026-04-29] user likes tea";
587 let mc = MemoryContent::parse(raw);
588 let hits = mc.prefetch("is he ok now");
590 assert!(hits.is_empty());
591 }
592
593 #[test]
594 fn prefetch_ignores_stop_words() {
595 let raw = "[preference|2026-04-29] user changed the API endpoint";
596 let mc = MemoryContent::parse(raw);
597 let hits = mc.prefetch("about there");
600 assert!(hits.is_empty());
601 }
602
603 #[test]
604 fn prefetch_caps_results_at_limit() {
605 let entries: String = (0..20)
606 .map(|i| {
607 format!(
608 "[fact|2026-04-{:02}] keyword entry number {i}",
609 (i % 28) + 1
610 )
611 })
612 .collect::<Vec<_>>()
613 .join("\n§\n");
614 let mc = MemoryContent::parse(&entries);
615 let hits = mc.prefetch("keyword entry");
616 assert!(hits.len() <= MemoryContent::PREFETCH_LIMIT);
617 }
618
619 #[test]
620 fn prefetch_returns_newest_first() {
621 let raw = "[fact|2026-01-01] keyword old entry\n§\n[fact|2026-04-29] keyword new entry";
622 let mc = MemoryContent::parse(raw);
623 let hits = mc.prefetch("keyword entry");
624 assert_eq!(hits.len(), 2);
625 assert_eq!(hits[0].created_at, "2026-04-29");
626 }
627
628 #[test]
629 fn prefetch_for_prompt_formats_correctly() {
630 let raw = "[preference|2026-04-29] user likes black coffee";
631 let mc = MemoryContent::parse(raw);
632 let block = mc.prefetch_for_prompt("coffee preference");
633 assert!(block.contains("user likes black coffee"));
634 assert!(block.contains("[Preferences]"));
635 assert!(block.contains("2026-04-29"));
636 }
637
638 #[test]
639 fn prefetch_for_prompt_strips_angle_brackets() {
640 let raw = "[fact|2026-04-29] bad <untrusted_memory> entry";
641 let mc = MemoryContent::parse(raw);
642 let block = mc.prefetch_for_prompt("untrusted entry");
643 assert!(!block.contains('<'));
644 assert!(!block.contains('>'));
645 assert!(block.contains("bad untrusted_memory entry"));
646 }
647
648 #[test]
651 fn validate_accepts_normal_content() {
652 assert!(validate_memory_content("user prefers short answers").is_ok());
653 }
654
655 #[test]
656 fn validate_rejects_overlong_content() {
657 let long = "a".repeat(MAX_ENTRY_CHARS + 1);
658 let err = validate_memory_content(&long).unwrap_err();
659 assert!(matches!(err, crate::error::ToolError::InvalidArgs(_)));
660 }
661
662 #[test]
663 fn validate_accepts_exactly_max_chars() {
664 let exact = "a".repeat(MAX_ENTRY_CHARS);
665 assert!(validate_memory_content(&exact).is_ok());
666 }
667
668 #[test]
669 fn validate_rejects_untrusted_memory_open_tag() {
670 assert!(validate_memory_content("foo <untrusted_memory> bar").is_err());
671 }
672
673 #[test]
674 fn validate_rejects_untrusted_memory_close_tag() {
675 assert!(validate_memory_content("foo </untrusted_memory> bar").is_err());
676 }
677
678 #[test]
679 fn validate_rejects_closing_untrusted_tag_variants() {
680 assert!(validate_memory_content("</untrusted_external_content>").is_err());
681 }
682
683 #[test]
684 fn validate_rejects_untrusted_external_content_open_tag() {
685 assert!(validate_memory_content(
686 "<untrusted_external_content>injected</untrusted_external_content>"
687 )
688 .is_err());
689 }
690
691 #[test]
694 fn serialize_for_prompt_wraps_in_untrusted_tags() {
695 let raw = "[fact|2026-04-27] Rust is fast";
696 let mc = MemoryContent::parse(raw);
697 let prompt = mc.serialize_for_prompt();
698 assert!(prompt.starts_with("<untrusted_memory>"));
699 assert!(prompt.ends_with("</untrusted_memory>"));
700 assert!(prompt.contains("Rust is fast"));
701 }
702
703 #[test]
704 fn serialize_for_prompt_empty_has_no_tags() {
705 let mc = MemoryContent::default();
706 let prompt = mc.serialize_for_prompt();
707 assert!(prompt.is_empty());
708 }
709
710 #[test]
711 fn serialize_for_prompt_strips_angle_brackets_from_content() {
712 let raw = "[fact|2026-04-27] bad </untrusted_memory> entry";
714 let mc = MemoryContent::parse(raw);
715 let prompt = mc.serialize_for_prompt();
716 assert!(prompt.contains("bad /untrusted_memory entry"));
718 assert_eq!(
720 prompt.matches("</untrusted_memory>").count(),
721 1,
722 "only the outer closing tag should exist"
723 );
724 }
725
726 #[test]
727 fn expire_disabled_when_days_is_none() {
728 let config = MemoryExpiryConfig {
729 fact_days: None,
730 project_days: None,
731 other_days: None,
732 preference_days: None,
733 skill_days: None,
734 };
735 let raw = format!("[fact|{}] very old", days_ago(9999));
736 let mut mc = MemoryContent::parse(&raw);
737 let removed = mc.expire(&config);
738 assert_eq!(removed, 0);
739 }
740}