1use crate::{
13 embedding::{EmbeddingProvider, LocalEmbedding},
14 store::{MemoryEntry, SearchResult, VectorStore},
15};
16use argentor_core::{ArgentorError, ArgentorResult};
17use chrono::Utc;
18use regex::Regex;
19use serde::{Deserialize, Serialize};
20use std::{
21 collections::{HashMap, VecDeque},
22 path::Path,
23 sync::Arc,
24};
25use uuid::Uuid;
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct TieredMemoryConfig {
34 pub short_term_window: usize,
36 pub long_term_threshold: f32,
38 pub entity_extraction: bool,
40 pub summarize_on_evict: bool,
42 pub long_term_top_k: usize,
44}
45
46impl Default for TieredMemoryConfig {
47 fn default() -> Self {
48 Self {
49 short_term_window: 20,
50 long_term_threshold: 0.7,
51 entity_extraction: true,
52 summarize_on_evict: true,
53 long_term_top_k: 5,
54 }
55 }
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct TieredTurn {
65 pub role: String,
67 pub content: String,
69 pub timestamp: chrono::DateTime<Utc>,
71}
72
73#[derive(Debug, Clone)]
75pub struct ScoredMemory {
76 pub entry: MemoryEntry,
78 pub score: f32,
80}
81
82#[derive(Debug, Clone)]
84pub struct MemoryContext {
85 pub short_term: Vec<TieredTurn>,
87 pub relevant_long_term: Vec<ScoredMemory>,
89 pub entity_facts: Vec<String>,
91 pub total_tokens_estimate: usize,
93}
94
95#[derive(Debug, Serialize, Deserialize)]
100struct TieredMemorySnapshot {
101 short_term: Vec<TieredTurn>,
102 entities: HashMap<String, Vec<String>>,
103 config: TieredMemoryConfig,
104}
105
106struct EntityPatterns {
112 capitalized: Regex,
113 at_mention: Regex,
114 quoted: Regex,
115}
116
117impl EntityPatterns {
118 fn new() -> Self {
119 Self {
120 capitalized: Regex::new(r"\b([A-Z][a-z]{2,})\b").unwrap(),
122 at_mention: Regex::new(r"@([A-Za-z][A-Za-z0-9_]{1,})").unwrap(),
123 quoted: Regex::new(r#""([^"]{2,32})""#).unwrap(),
124 }
125 }
126
127 fn extract(&self, text: &str) -> Vec<String> {
129 let mut entities: Vec<String> = Vec::new();
130
131 for cap in self.capitalized.captures_iter(text) {
132 entities.push(cap[1].to_string());
133 }
134 for cap in self.at_mention.captures_iter(text) {
135 entities.push(cap[1].to_string());
136 }
137 for cap in self.quoted.captures_iter(text) {
138 entities.push(cap[1].to_string());
139 }
140
141 entities.dedup();
142 entities
143 }
144}
145
146pub struct TieredMemory {
153 short_term: VecDeque<TieredTurn>,
154 pending_evictions: Vec<TieredTurn>,
157 long_term: Arc<dyn VectorStore>,
158 entities: HashMap<String, Vec<String>>,
159 config: TieredMemoryConfig,
160 embedder: Arc<dyn EmbeddingProvider>,
161 entity_patterns: EntityPatterns,
162}
163
164impl TieredMemory {
165 pub fn new(config: TieredMemoryConfig, store: Arc<dyn VectorStore>) -> Self {
167 Self::with_embedder(config, store, Arc::new(LocalEmbedding::default()))
168 }
169
170 pub fn with_embedder(
172 config: TieredMemoryConfig,
173 store: Arc<dyn VectorStore>,
174 embedder: Arc<dyn EmbeddingProvider>,
175 ) -> Self {
176 Self {
177 short_term: VecDeque::with_capacity(config.short_term_window + 1),
178 pending_evictions: Vec::new(),
179 long_term: store,
180 entities: HashMap::new(),
181 config,
182 embedder,
183 entity_patterns: EntityPatterns::new(),
184 }
185 }
186
187 pub fn add_turn(&mut self, role: &str, content: &str) {
196 if self.config.entity_extraction {
197 self.update_entities(role, content);
198 }
199
200 if self.short_term.len() >= self.config.short_term_window {
201 if let Some(evicted) = self.short_term.pop_front() {
202 if self.config.summarize_on_evict {
203 self.pending_evictions.push(evicted);
204 }
205 }
206 }
207
208 self.short_term.push_back(TieredTurn {
209 role: role.to_string(),
210 content: content.to_string(),
211 timestamp: Utc::now(),
212 });
213 }
214
215 pub async fn flush_evicted(&mut self) -> ArgentorResult<()> {
217 let pending = std::mem::take(&mut self.pending_evictions);
218 for turn in pending {
219 self.store_to_long_term(&turn).await?;
220 }
221 Ok(())
222 }
223
224 pub async fn add_turn_async(&mut self, role: &str, content: &str) -> ArgentorResult<()> {
228 if self.config.entity_extraction {
229 self.update_entities(role, content);
230 }
231
232 if self.short_term.len() >= self.config.short_term_window {
233 if let Some(evicted) = self.short_term.pop_front() {
234 if self.config.summarize_on_evict {
235 self.store_to_long_term(&evicted).await?;
236 }
237 }
238 }
239
240 self.short_term.push_back(TieredTurn {
241 role: role.to_string(),
242 content: content.to_string(),
243 timestamp: Utc::now(),
244 });
245 Ok(())
246 }
247
248 pub async fn get_context(&self, current_query: &str) -> ArgentorResult<MemoryContext> {
254 let short_term: Vec<TieredTurn> = self.short_term.iter().cloned().collect();
255
256 let relevant_long_term = if !current_query.is_empty() {
258 let embedding = self.embedder.embed(current_query).await?;
259 let results = self
260 .long_term
261 .search(&embedding, self.config.long_term_top_k, None)
262 .await?;
263 results
264 .into_iter()
265 .filter(|r| r.score >= self.config.long_term_threshold)
266 .map(|SearchResult { entry, score }| ScoredMemory { entry, score })
267 .collect()
268 } else {
269 Vec::new()
270 };
271
272 let detected = self.entity_patterns.extract(current_query);
274 let mut entity_facts: Vec<String> = Vec::new();
275 for entity in &detected {
276 if let Some(facts) = self.entities.get(entity.as_str()) {
277 for fact in facts {
278 entity_facts.push(format!("[{entity}] {fact}"));
279 }
280 }
281 }
282
283 let char_total: usize = short_term.iter().map(|t| t.content.len()).sum::<usize>()
285 + relevant_long_term
286 .iter()
287 .map(|m| m.entry.content.len())
288 .sum::<usize>()
289 + entity_facts.iter().map(|f| f.len()).sum::<usize>();
290 let total_tokens_estimate = char_total / 4;
291
292 Ok(MemoryContext {
293 short_term,
294 relevant_long_term,
295 entity_facts,
296 total_tokens_estimate,
297 })
298 }
299
300 pub fn get_entities(&self) -> &HashMap<String, Vec<String>> {
302 &self.entities
303 }
304
305 pub fn short_term_len(&self) -> usize {
307 self.short_term.len()
308 }
309
310 pub fn entity_count(&self) -> usize {
312 self.entities.len()
313 }
314
315 pub async fn persist(&self, path: &Path) -> ArgentorResult<()> {
317 let snapshot = TieredMemorySnapshot {
318 short_term: self.short_term.iter().cloned().collect(),
319 entities: self.entities.clone(),
320 config: self.config.clone(),
321 };
322 let json = serde_json::to_string_pretty(&snapshot)
323 .map_err(|e| ArgentorError::Session(format!("Failed to serialize snapshot: {e}")))?;
324 if let Some(parent) = path.parent() {
325 tokio::fs::create_dir_all(parent)
326 .await
327 .map_err(|e| ArgentorError::Session(format!("Failed to create dir: {e}")))?;
328 }
329 tokio::fs::write(path, json.as_bytes())
330 .await
331 .map_err(|e| ArgentorError::Session(format!("Failed to write snapshot: {e}")))?;
332 Ok(())
333 }
334
335 pub async fn load(path: &Path, store: Arc<dyn VectorStore>) -> ArgentorResult<Self> {
338 let data = tokio::fs::read_to_string(path)
339 .await
340 .map_err(|e| ArgentorError::Session(format!("Failed to read snapshot: {e}")))?;
341 let snapshot: TieredMemorySnapshot = serde_json::from_str(&data)
342 .map_err(|e| ArgentorError::Session(format!("Failed to parse snapshot: {e}")))?;
343
344 let mut mem = Self::new(snapshot.config, store);
345 for turn in snapshot.short_term {
346 mem.short_term.push_back(turn);
347 }
348 mem.entities = snapshot.entities;
349 Ok(mem)
350 }
351
352 async fn store_to_long_term(&self, turn: &TieredTurn) -> ArgentorResult<()> {
358 let text = format!(
359 "[{}] {}: {}",
360 turn.timestamp.format("%Y-%m-%dT%H:%M"),
361 turn.role,
362 &turn.content[..turn.content.len().min(500)],
363 );
364
365 let embedding = self.embedder.embed(&text).await?;
366 let entry = MemoryEntry {
367 id: Uuid::new_v4(),
368 content: text,
369 embedding,
370 metadata: {
371 let mut m = std::collections::HashMap::new();
372 m.insert(
373 "role".to_string(),
374 serde_json::Value::String(turn.role.clone()),
375 );
376 m.insert(
377 "tier".to_string(),
378 serde_json::Value::String("long_term".to_string()),
379 );
380 m
381 },
382 session_id: None,
383 created_at: turn.timestamp,
384 };
385 self.long_term.insert(entry).await
386 }
387
388 fn update_entities(&mut self, role: &str, content: &str) {
393 if role == "tool" {
394 return;
395 }
396 let entities = self.entity_patterns.extract(content);
397 if entities.is_empty() {
398 return;
399 }
400 let fact = format!("[{}] {}", role, &content[..content.len().min(200)]);
401 for entity in entities {
402 let facts = self.entities.entry(entity).or_default();
403 if facts.len() < 10 {
404 facts.push(fact.clone());
405 }
406 }
407 }
408}
409
410#[cfg(test)]
415#[allow(clippy::unwrap_used, clippy::expect_used)]
416mod tests {
417 use super::*;
418 use crate::store::InMemoryVectorStore;
419
420 fn make_store() -> Arc<dyn VectorStore> {
421 Arc::new(InMemoryVectorStore::new())
422 }
423
424 fn make_mem(window: usize) -> TieredMemory {
425 let config = TieredMemoryConfig {
426 short_term_window: window,
427 long_term_threshold: 0.5,
428 entity_extraction: true,
429 summarize_on_evict: true,
430 long_term_top_k: 5,
431 };
432 TieredMemory::new(config, make_store())
433 }
434
435 #[tokio::test]
440 async fn test_short_term_window_enforced() {
441 let mut mem = make_mem(20);
442 for i in 0..25 {
443 mem.add_turn_async("user", &format!("turn {i}"))
444 .await
445 .unwrap();
446 }
447 assert_eq!(mem.short_term_len(), 20, "window must cap at 20");
448 }
449
450 #[tokio::test]
451 async fn test_short_term_retains_latest() {
452 let mut mem = make_mem(3);
453 mem.add_turn_async("user", "first").await.unwrap();
454 mem.add_turn_async("user", "second").await.unwrap();
455 mem.add_turn_async("user", "third").await.unwrap();
456 mem.add_turn_async("user", "fourth").await.unwrap(); let st: Vec<_> = mem.short_term.iter().map(|t| t.content.as_str()).collect();
459 assert!(!st.contains(&"first"), "oldest must be evicted");
460 assert!(st.contains(&"fourth"), "newest must be present");
461 }
462
463 #[tokio::test]
464 async fn test_short_term_order_preserved() {
465 let mut mem = make_mem(10);
466 for i in 0..5 {
467 mem.add_turn_async("user", &format!("msg{i}"))
468 .await
469 .unwrap();
470 }
471 let ctx = mem.get_context("anything").await.unwrap();
472 assert_eq!(ctx.short_term[0].content, "msg0");
473 assert_eq!(ctx.short_term[4].content, "msg4");
474 }
475
476 #[tokio::test]
481 async fn test_evicted_turns_reach_long_term() {
482 let mut mem = make_mem(3);
483 mem.add_turn_async("user", "alpha rust programming")
484 .await
485 .unwrap();
486 mem.add_turn_async("user", "beta topic").await.unwrap();
487 mem.add_turn_async("user", "gamma topic").await.unwrap();
488 mem.add_turn_async("user", "delta topic").await.unwrap(); let count = mem.long_term.count().await.unwrap();
491 assert_eq!(count, 1, "one evicted turn must land in long-term store");
492 }
493
494 #[tokio::test]
495 async fn test_long_term_retrieved_by_query() {
496 let mut mem = make_mem(2);
497 mem.add_turn_async("user", "rust programming language systems")
498 .await
499 .unwrap();
500 mem.add_turn_async("user", "cooking recipes dinner")
501 .await
502 .unwrap();
503 mem.add_turn_async("user", "another unrelated turn")
504 .await
505 .unwrap(); let ctx = mem.get_context("rust systems programming").await.unwrap();
508 assert!(
509 !ctx.relevant_long_term.is_empty(),
510 "should retrieve relevant long-term episode"
511 );
512 }
513
514 #[tokio::test]
515 async fn test_long_term_threshold_filters_irrelevant() {
516 let store = make_store();
517 let config = TieredMemoryConfig {
518 short_term_window: 2,
519 long_term_threshold: 0.99, entity_extraction: false,
521 summarize_on_evict: true,
522 long_term_top_k: 5,
523 };
524 let mut mem = TieredMemory::new(config, store);
525 mem.add_turn_async("user", "cooking is great")
526 .await
527 .unwrap();
528 mem.add_turn_async("user", "baking bread").await.unwrap();
529 mem.add_turn_async("user", "dessert cake").await.unwrap(); let ctx = mem.get_context("rust programming").await.unwrap();
532 assert!(
533 ctx.relevant_long_term.is_empty(),
534 "threshold 0.99 should filter unrelated episode"
535 );
536 }
537
538 #[tokio::test]
543 async fn test_entity_facts_stored() {
544 let mut mem = make_mem(20);
545 mem.add_turn_async("user", "John is the lead developer")
546 .await
547 .unwrap();
548 mem.add_turn_async("assistant", "John works on the backend")
549 .await
550 .unwrap();
551
552 let entities = mem.get_entities();
553 assert!(entities.contains_key("John"), "John must be tracked");
554 assert!(!entities["John"].is_empty(), "at least one fact for John");
555 }
556
557 #[tokio::test]
558 async fn test_entity_facts_injected_in_context() {
559 let mut mem = make_mem(20);
560 mem.add_turn_async("user", "Alice manages the project")
561 .await
562 .unwrap();
563
564 let ctx = mem.get_context("what does Alice do?").await.unwrap();
565 assert!(
566 ctx.entity_facts.iter().any(|f| f.contains("Alice")),
567 "Alice facts must appear in context"
568 );
569 }
570
571 #[tokio::test]
572 async fn test_entity_at_mention() {
573 let mut mem = make_mem(20);
574 mem.add_turn_async("user", "ping @backend team please")
575 .await
576 .unwrap();
577
578 assert!(
579 mem.get_entities().contains_key("backend"),
580 "@mention must extract entity"
581 );
582 }
583
584 #[tokio::test]
585 async fn test_entity_quoted_term() {
586 let mut mem = make_mem(20);
587 mem.add_turn_async("user", r#"the "auth module" is broken"#)
588 .await
589 .unwrap();
590
591 assert!(
592 mem.get_entities().contains_key("auth module"),
593 "quoted entity must be tracked"
594 );
595 }
596
597 #[tokio::test]
598 async fn test_entity_tool_role_skipped() {
599 let mut mem = make_mem(20);
600 mem.add_turn_async("tool", "Output from John's processing")
601 .await
602 .unwrap();
603
604 assert!(
605 !mem.get_entities().contains_key("John"),
606 "tool turns must not contribute entity facts"
607 );
608 }
609
610 #[tokio::test]
615 async fn test_persist_and_load_round_trip() {
616 let tmp = tempfile::tempdir().unwrap();
617 let snap_path = tmp.path().join("tiered.json");
618
619 let store: Arc<dyn VectorStore> = make_store();
620 let mut mem = TieredMemory::new(TieredMemoryConfig::default(), store.clone());
621 mem.add_turn_async("user", "hello world").await.unwrap();
622 mem.add_turn_async("assistant", "hi there").await.unwrap();
623 mem.persist(&snap_path).await.unwrap();
624
625 let loaded = TieredMemory::load(&snap_path, store).await.unwrap();
626 assert_eq!(loaded.short_term_len(), 2, "turns survive round-trip");
627 }
628
629 #[tokio::test]
630 async fn test_persist_entities_round_trip() {
631 let tmp = tempfile::tempdir().unwrap();
632 let snap_path = tmp.path().join("tiered_ent.json");
633
634 let store: Arc<dyn VectorStore> = make_store();
635 let mut mem = TieredMemory::new(TieredMemoryConfig::default(), store.clone());
636 mem.add_turn_async("user", "Maria leads the team")
637 .await
638 .unwrap();
639 mem.persist(&snap_path).await.unwrap();
640
641 let loaded = TieredMemory::load(&snap_path, store).await.unwrap();
642 assert!(
643 loaded.get_entities().contains_key("Maria"),
644 "entities survive round-trip"
645 );
646 }
647
648 #[tokio::test]
653 async fn test_entity_extraction_disabled() {
654 let store = make_store();
655 let config = TieredMemoryConfig {
656 entity_extraction: false,
657 ..Default::default()
658 };
659 let mut mem = TieredMemory::new(config, store);
660 mem.add_turn_async("user", "Alice and Bob discussed Rust")
661 .await
662 .unwrap();
663 assert!(
664 mem.get_entities().is_empty(),
665 "entities must be empty when extraction is disabled"
666 );
667 }
668
669 #[tokio::test]
670 async fn test_no_summarize_on_evict() {
671 let store = make_store();
672 let config = TieredMemoryConfig {
673 short_term_window: 2,
674 summarize_on_evict: false,
675 entity_extraction: false,
676 long_term_threshold: 0.5,
677 long_term_top_k: 5,
678 };
679 let mut mem = TieredMemory::new(config, store);
680 mem.add_turn_async("user", "first").await.unwrap();
681 mem.add_turn_async("user", "second").await.unwrap();
682 mem.add_turn_async("user", "third").await.unwrap();
683
684 let count = mem.long_term.count().await.unwrap();
685 assert_eq!(
686 count, 0,
687 "no long-term writes when summarize_on_evict=false"
688 );
689 }
690
691 #[tokio::test]
692 async fn test_sync_flush_evicted() {
693 let mut mem = make_mem(2);
694 mem.add_turn("user", "first");
695 mem.add_turn("user", "second");
696 mem.add_turn("user", "third"); let before = mem.long_term.count().await.unwrap();
700 assert_eq!(before, 0);
701
702 mem.flush_evicted().await.unwrap();
703
704 let after = mem.long_term.count().await.unwrap();
705 assert_eq!(after, 1, "flushed eviction must reach long-term store");
706 }
707
708 #[tokio::test]
713 async fn test_token_estimate_non_zero() {
714 let mut mem = make_mem(20);
715 mem.add_turn_async("user", "hello this is a test message for token estimate")
716 .await
717 .unwrap();
718 let ctx = mem.get_context("test").await.unwrap();
719 assert!(ctx.total_tokens_estimate > 0);
720 }
721}