1use crate::{
13 embedding::{EmbeddingProvider, LocalEmbedding},
14 store::{MemoryEntry, SearchResult, VectorStore},
15};
16use argentor_core::{ArgentorError, ArgentorResult};
17use chrono::Utc;
18use regex::Regex;
19use serde::{Deserialize, Serialize};
20use std::{
21 collections::{HashMap, VecDeque},
22 path::Path,
23 sync::Arc,
24};
25use uuid::Uuid;
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct TieredMemoryConfig {
34 pub short_term_window: usize,
36 pub long_term_threshold: f32,
38 pub entity_extraction: bool,
40 pub summarize_on_evict: bool,
42 pub long_term_top_k: usize,
44}
45
46impl Default for TieredMemoryConfig {
47 fn default() -> Self {
48 Self {
49 short_term_window: 20,
50 long_term_threshold: 0.7,
51 entity_extraction: true,
52 summarize_on_evict: true,
53 long_term_top_k: 5,
54 }
55 }
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct TieredTurn {
65 pub role: String,
67 pub content: String,
69 pub timestamp: chrono::DateTime<Utc>,
71}
72
73#[derive(Debug, Clone)]
75pub struct ScoredMemory {
76 pub entry: MemoryEntry,
78 pub score: f32,
80}
81
82#[derive(Debug, Clone)]
84pub struct MemoryContext {
85 pub short_term: Vec<TieredTurn>,
87 pub relevant_long_term: Vec<ScoredMemory>,
89 pub entity_facts: Vec<String>,
91 pub total_tokens_estimate: usize,
93}
94
95#[derive(Debug, Serialize, Deserialize)]
100struct TieredMemorySnapshot {
101 short_term: Vec<TieredTurn>,
102 entities: HashMap<String, Vec<String>>,
103 config: TieredMemoryConfig,
104}
105
106struct EntityPatterns {
112 capitalized: Regex,
113 at_mention: Regex,
114 quoted: Regex,
115}
116
117impl EntityPatterns {
118 fn new() -> Self {
119 Self {
120 capitalized: compile_entity_regex(r"\b([A-Z][a-z]{2,})\b"),
122 at_mention: compile_entity_regex(r"@([A-Za-z][A-Za-z0-9_]{1,})"),
123 quoted: compile_entity_regex(r#""([^"]{2,32})""#),
124 }
125 }
126
127 fn extract(&self, text: &str) -> Vec<String> {
129 let mut entities: Vec<String> = Vec::new();
130
131 for cap in self.capitalized.captures_iter(text) {
132 entities.push(cap[1].to_string());
133 }
134 for cap in self.at_mention.captures_iter(text) {
135 entities.push(cap[1].to_string());
136 }
137 for cap in self.quoted.captures_iter(text) {
138 entities.push(cap[1].to_string());
139 }
140
141 entities.dedup();
142 entities
143 }
144}
145
146fn compile_entity_regex(pattern: &str) -> Regex {
147 match Regex::new(pattern) {
148 Ok(regex) => regex,
149 Err(err) => panic!("invalid built-in entity regex `{pattern}`: {err}"),
150 }
151}
152
153pub struct TieredMemory {
160 short_term: VecDeque<TieredTurn>,
161 pending_evictions: Vec<TieredTurn>,
164 long_term: Arc<dyn VectorStore>,
165 entities: HashMap<String, Vec<String>>,
166 config: TieredMemoryConfig,
167 embedder: Arc<dyn EmbeddingProvider>,
168 entity_patterns: EntityPatterns,
169}
170
171impl TieredMemory {
172 pub fn new(config: TieredMemoryConfig, store: Arc<dyn VectorStore>) -> Self {
174 Self::with_embedder(config, store, Arc::new(LocalEmbedding::default()))
175 }
176
177 pub fn with_embedder(
179 config: TieredMemoryConfig,
180 store: Arc<dyn VectorStore>,
181 embedder: Arc<dyn EmbeddingProvider>,
182 ) -> Self {
183 Self {
184 short_term: VecDeque::with_capacity(config.short_term_window + 1),
185 pending_evictions: Vec::new(),
186 long_term: store,
187 entities: HashMap::new(),
188 config,
189 embedder,
190 entity_patterns: EntityPatterns::new(),
191 }
192 }
193
194 pub fn add_turn(&mut self, role: &str, content: &str) {
203 if self.config.entity_extraction {
204 self.update_entities(role, content);
205 }
206
207 if self.short_term.len() >= self.config.short_term_window {
208 if let Some(evicted) = self.short_term.pop_front() {
209 if self.config.summarize_on_evict {
210 self.pending_evictions.push(evicted);
211 }
212 }
213 }
214
215 self.short_term.push_back(TieredTurn {
216 role: role.to_string(),
217 content: content.to_string(),
218 timestamp: Utc::now(),
219 });
220 }
221
222 pub async fn flush_evicted(&mut self) -> ArgentorResult<()> {
224 let pending = std::mem::take(&mut self.pending_evictions);
225 for turn in pending {
226 self.store_to_long_term(&turn).await?;
227 }
228 Ok(())
229 }
230
231 pub async fn add_turn_async(&mut self, role: &str, content: &str) -> ArgentorResult<()> {
235 if self.config.entity_extraction {
236 self.update_entities(role, content);
237 }
238
239 if self.short_term.len() >= self.config.short_term_window {
240 if let Some(evicted) = self.short_term.pop_front() {
241 if self.config.summarize_on_evict {
242 self.store_to_long_term(&evicted).await?;
243 }
244 }
245 }
246
247 self.short_term.push_back(TieredTurn {
248 role: role.to_string(),
249 content: content.to_string(),
250 timestamp: Utc::now(),
251 });
252 Ok(())
253 }
254
255 pub async fn get_context(&self, current_query: &str) -> ArgentorResult<MemoryContext> {
261 let short_term: Vec<TieredTurn> = self.short_term.iter().cloned().collect();
262
263 let relevant_long_term = if !current_query.is_empty() {
265 let embedding = self.embedder.embed(current_query).await?;
266 let results = self
267 .long_term
268 .search(&embedding, self.config.long_term_top_k, None)
269 .await?;
270 results
271 .into_iter()
272 .filter(|r| r.score >= self.config.long_term_threshold)
273 .map(|SearchResult { entry, score }| ScoredMemory { entry, score })
274 .collect()
275 } else {
276 Vec::new()
277 };
278
279 let detected = self.entity_patterns.extract(current_query);
281 let mut entity_facts: Vec<String> = Vec::new();
282 for entity in &detected {
283 if let Some(facts) = self.entities.get(entity.as_str()) {
284 for fact in facts {
285 entity_facts.push(format!("[{entity}] {fact}"));
286 }
287 }
288 }
289
290 let char_total: usize = short_term.iter().map(|t| t.content.len()).sum::<usize>()
292 + relevant_long_term
293 .iter()
294 .map(|m| m.entry.content.len())
295 .sum::<usize>()
296 + entity_facts.iter().map(String::len).sum::<usize>();
297 let total_tokens_estimate = char_total / 4;
298
299 Ok(MemoryContext {
300 short_term,
301 relevant_long_term,
302 entity_facts,
303 total_tokens_estimate,
304 })
305 }
306
307 pub fn get_entities(&self) -> &HashMap<String, Vec<String>> {
309 &self.entities
310 }
311
312 pub fn short_term_len(&self) -> usize {
314 self.short_term.len()
315 }
316
317 pub fn entity_count(&self) -> usize {
319 self.entities.len()
320 }
321
322 pub async fn persist(&self, path: &Path) -> ArgentorResult<()> {
324 let snapshot = TieredMemorySnapshot {
325 short_term: self.short_term.iter().cloned().collect(),
326 entities: self.entities.clone(),
327 config: self.config.clone(),
328 };
329 let json = serde_json::to_string_pretty(&snapshot)
330 .map_err(|e| ArgentorError::Session(format!("Failed to serialize snapshot: {e}")))?;
331 if let Some(parent) = path.parent() {
332 tokio::fs::create_dir_all(parent)
333 .await
334 .map_err(|e| ArgentorError::Session(format!("Failed to create dir: {e}")))?;
335 }
336 tokio::fs::write(path, json.as_bytes())
337 .await
338 .map_err(|e| ArgentorError::Session(format!("Failed to write snapshot: {e}")))?;
339 Ok(())
340 }
341
342 pub async fn load(path: &Path, store: Arc<dyn VectorStore>) -> ArgentorResult<Self> {
345 let data = tokio::fs::read_to_string(path)
346 .await
347 .map_err(|e| ArgentorError::Session(format!("Failed to read snapshot: {e}")))?;
348 let snapshot: TieredMemorySnapshot = serde_json::from_str(&data)
349 .map_err(|e| ArgentorError::Session(format!("Failed to parse snapshot: {e}")))?;
350
351 let mut mem = Self::new(snapshot.config, store);
352 for turn in snapshot.short_term {
353 mem.short_term.push_back(turn);
354 }
355 mem.entities = snapshot.entities;
356 Ok(mem)
357 }
358
359 async fn store_to_long_term(&self, turn: &TieredTurn) -> ArgentorResult<()> {
365 let text = format!(
366 "[{}] {}: {}",
367 turn.timestamp.format("%Y-%m-%dT%H:%M"),
368 turn.role,
369 &turn.content[..turn.content.len().min(500)],
370 );
371
372 let embedding = self.embedder.embed(&text).await?;
373 let entry = MemoryEntry {
374 id: Uuid::new_v4(),
375 content: text,
376 embedding,
377 metadata: {
378 let mut m = std::collections::HashMap::new();
379 m.insert(
380 "role".to_string(),
381 serde_json::Value::String(turn.role.clone()),
382 );
383 m.insert(
384 "tier".to_string(),
385 serde_json::Value::String("long_term".to_string()),
386 );
387 m
388 },
389 session_id: None,
390 created_at: turn.timestamp,
391 };
392 self.long_term.insert(entry).await
393 }
394
395 fn update_entities(&mut self, role: &str, content: &str) {
400 if role == "tool" {
401 return;
402 }
403 let entities = self.entity_patterns.extract(content);
404 if entities.is_empty() {
405 return;
406 }
407 let fact = format!("[{}] {}", role, &content[..content.len().min(200)]);
408 for entity in entities {
409 let facts = self.entities.entry(entity).or_default();
410 if facts.len() < 10 {
411 facts.push(fact.clone());
412 }
413 }
414 }
415}
416
417#[cfg(test)]
422#[allow(clippy::unwrap_used, clippy::expect_used)]
423mod tests {
424 use super::*;
425 use crate::store::InMemoryVectorStore;
426
427 fn make_store() -> Arc<dyn VectorStore> {
428 Arc::new(InMemoryVectorStore::new())
429 }
430
431 fn make_mem(window: usize) -> TieredMemory {
432 let config = TieredMemoryConfig {
433 short_term_window: window,
434 long_term_threshold: 0.5,
435 entity_extraction: true,
436 summarize_on_evict: true,
437 long_term_top_k: 5,
438 };
439 TieredMemory::new(config, make_store())
440 }
441
442 #[tokio::test]
447 async fn test_short_term_window_enforced() {
448 let mut mem = make_mem(20);
449 for i in 0..25 {
450 mem.add_turn_async("user", &format!("turn {i}"))
451 .await
452 .unwrap();
453 }
454 assert_eq!(mem.short_term_len(), 20, "window must cap at 20");
455 }
456
457 #[tokio::test]
458 async fn test_short_term_retains_latest() {
459 let mut mem = make_mem(3);
460 mem.add_turn_async("user", "first").await.unwrap();
461 mem.add_turn_async("user", "second").await.unwrap();
462 mem.add_turn_async("user", "third").await.unwrap();
463 mem.add_turn_async("user", "fourth").await.unwrap(); let st: Vec<_> = mem.short_term.iter().map(|t| t.content.as_str()).collect();
466 assert!(!st.contains(&"first"), "oldest must be evicted");
467 assert!(st.contains(&"fourth"), "newest must be present");
468 }
469
470 #[tokio::test]
471 async fn test_short_term_order_preserved() {
472 let mut mem = make_mem(10);
473 for i in 0..5 {
474 mem.add_turn_async("user", &format!("msg{i}"))
475 .await
476 .unwrap();
477 }
478 let ctx = mem.get_context("anything").await.unwrap();
479 assert_eq!(ctx.short_term[0].content, "msg0");
480 assert_eq!(ctx.short_term[4].content, "msg4");
481 }
482
483 #[tokio::test]
488 async fn test_evicted_turns_reach_long_term() {
489 let mut mem = make_mem(3);
490 mem.add_turn_async("user", "alpha rust programming")
491 .await
492 .unwrap();
493 mem.add_turn_async("user", "beta topic").await.unwrap();
494 mem.add_turn_async("user", "gamma topic").await.unwrap();
495 mem.add_turn_async("user", "delta topic").await.unwrap(); let count = mem.long_term.count().await.unwrap();
498 assert_eq!(count, 1, "one evicted turn must land in long-term store");
499 }
500
501 #[tokio::test]
502 async fn test_long_term_retrieved_by_query() {
503 let mut mem = make_mem(2);
504 mem.add_turn_async("user", "rust programming language systems")
505 .await
506 .unwrap();
507 mem.add_turn_async("user", "cooking recipes dinner")
508 .await
509 .unwrap();
510 mem.add_turn_async("user", "another unrelated turn")
511 .await
512 .unwrap(); let ctx = mem.get_context("rust systems programming").await.unwrap();
515 assert!(
516 !ctx.relevant_long_term.is_empty(),
517 "should retrieve relevant long-term episode"
518 );
519 }
520
521 #[tokio::test]
522 async fn test_long_term_threshold_filters_irrelevant() {
523 let store = make_store();
524 let config = TieredMemoryConfig {
525 short_term_window: 2,
526 long_term_threshold: 0.99, entity_extraction: false,
528 summarize_on_evict: true,
529 long_term_top_k: 5,
530 };
531 let mut mem = TieredMemory::new(config, store);
532 mem.add_turn_async("user", "cooking is great")
533 .await
534 .unwrap();
535 mem.add_turn_async("user", "baking bread").await.unwrap();
536 mem.add_turn_async("user", "dessert cake").await.unwrap(); let ctx = mem.get_context("rust programming").await.unwrap();
539 assert!(
540 ctx.relevant_long_term.is_empty(),
541 "threshold 0.99 should filter unrelated episode"
542 );
543 }
544
545 #[tokio::test]
550 async fn test_entity_facts_stored() {
551 let mut mem = make_mem(20);
552 mem.add_turn_async("user", "John is the lead developer")
553 .await
554 .unwrap();
555 mem.add_turn_async("assistant", "John works on the backend")
556 .await
557 .unwrap();
558
559 let entities = mem.get_entities();
560 assert!(entities.contains_key("John"), "John must be tracked");
561 assert!(!entities["John"].is_empty(), "at least one fact for John");
562 }
563
564 #[tokio::test]
565 async fn test_entity_facts_injected_in_context() {
566 let mut mem = make_mem(20);
567 mem.add_turn_async("user", "Alice manages the project")
568 .await
569 .unwrap();
570
571 let ctx = mem.get_context("what does Alice do?").await.unwrap();
572 assert!(
573 ctx.entity_facts.iter().any(|f| f.contains("Alice")),
574 "Alice facts must appear in context"
575 );
576 }
577
578 #[tokio::test]
579 async fn test_entity_at_mention() {
580 let mut mem = make_mem(20);
581 mem.add_turn_async("user", "ping @backend team please")
582 .await
583 .unwrap();
584
585 assert!(
586 mem.get_entities().contains_key("backend"),
587 "@mention must extract entity"
588 );
589 }
590
591 #[tokio::test]
592 async fn test_entity_quoted_term() {
593 let mut mem = make_mem(20);
594 mem.add_turn_async("user", r#"the "auth module" is broken"#)
595 .await
596 .unwrap();
597
598 assert!(
599 mem.get_entities().contains_key("auth module"),
600 "quoted entity must be tracked"
601 );
602 }
603
604 #[tokio::test]
605 async fn test_entity_tool_role_skipped() {
606 let mut mem = make_mem(20);
607 mem.add_turn_async("tool", "Output from John's processing")
608 .await
609 .unwrap();
610
611 assert!(
612 !mem.get_entities().contains_key("John"),
613 "tool turns must not contribute entity facts"
614 );
615 }
616
617 #[tokio::test]
622 async fn test_persist_and_load_round_trip() {
623 let tmp = tempfile::tempdir().unwrap();
624 let snap_path = tmp.path().join("tiered.json");
625
626 let store: Arc<dyn VectorStore> = make_store();
627 let mut mem = TieredMemory::new(TieredMemoryConfig::default(), store.clone());
628 mem.add_turn_async("user", "hello world").await.unwrap();
629 mem.add_turn_async("assistant", "hi there").await.unwrap();
630 mem.persist(&snap_path).await.unwrap();
631
632 let loaded = TieredMemory::load(&snap_path, store).await.unwrap();
633 assert_eq!(loaded.short_term_len(), 2, "turns survive round-trip");
634 }
635
636 #[tokio::test]
637 async fn test_persist_entities_round_trip() {
638 let tmp = tempfile::tempdir().unwrap();
639 let snap_path = tmp.path().join("tiered_ent.json");
640
641 let store: Arc<dyn VectorStore> = make_store();
642 let mut mem = TieredMemory::new(TieredMemoryConfig::default(), store.clone());
643 mem.add_turn_async("user", "Maria leads the team")
644 .await
645 .unwrap();
646 mem.persist(&snap_path).await.unwrap();
647
648 let loaded = TieredMemory::load(&snap_path, store).await.unwrap();
649 assert!(
650 loaded.get_entities().contains_key("Maria"),
651 "entities survive round-trip"
652 );
653 }
654
655 #[tokio::test]
660 async fn test_entity_extraction_disabled() {
661 let store = make_store();
662 let config = TieredMemoryConfig {
663 entity_extraction: false,
664 ..Default::default()
665 };
666 let mut mem = TieredMemory::new(config, store);
667 mem.add_turn_async("user", "Alice and Bob discussed Rust")
668 .await
669 .unwrap();
670 assert!(
671 mem.get_entities().is_empty(),
672 "entities must be empty when extraction is disabled"
673 );
674 }
675
676 #[tokio::test]
677 async fn test_no_summarize_on_evict() {
678 let store = make_store();
679 let config = TieredMemoryConfig {
680 short_term_window: 2,
681 summarize_on_evict: false,
682 entity_extraction: false,
683 long_term_threshold: 0.5,
684 long_term_top_k: 5,
685 };
686 let mut mem = TieredMemory::new(config, store);
687 mem.add_turn_async("user", "first").await.unwrap();
688 mem.add_turn_async("user", "second").await.unwrap();
689 mem.add_turn_async("user", "third").await.unwrap();
690
691 let count = mem.long_term.count().await.unwrap();
692 assert_eq!(
693 count, 0,
694 "no long-term writes when summarize_on_evict=false"
695 );
696 }
697
698 #[tokio::test]
699 async fn test_sync_flush_evicted() {
700 let mut mem = make_mem(2);
701 mem.add_turn("user", "first");
702 mem.add_turn("user", "second");
703 mem.add_turn("user", "third"); let before = mem.long_term.count().await.unwrap();
707 assert_eq!(before, 0);
708
709 mem.flush_evicted().await.unwrap();
710
711 let after = mem.long_term.count().await.unwrap();
712 assert_eq!(after, 1, "flushed eviction must reach long-term store");
713 }
714
715 #[tokio::test]
720 async fn test_token_estimate_non_zero() {
721 let mut mem = make_mem(20);
722 mem.add_turn_async("user", "hello this is a test message for token estimate")
723 .await
724 .unwrap();
725 let ctx = mem.get_context("test").await.unwrap();
726 assert!(ctx.total_tokens_estimate > 0);
727 }
728}