1use chrono::Utc;
16use serde::{Deserialize, Serialize};
17use tracing::{debug, info};
18
19use punch_types::{FighterId, PunchError, PunchResult};
20
21use crate::MemorySubstrate;
22use crate::memories::MemoryEntry;
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct ConsolidationConfig {
28 pub max_memories_per_fighter: usize,
30 pub consolidation_threshold: usize,
32 pub min_confidence: f64,
34 pub decay_rate: f64,
36 pub merge_similarity_threshold: f64,
38 pub max_age_days: u64,
40}
41
42impl Default for ConsolidationConfig {
43 fn default() -> Self {
44 Self {
45 max_memories_per_fighter: 1000,
46 consolidation_threshold: 800,
47 min_confidence: 0.3,
48 decay_rate: 0.01,
49 merge_similarity_threshold: 0.8,
50 max_age_days: 90,
51 }
52 }
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct ConsolidationResult {
58 pub memories_before: usize,
60 pub memories_after: usize,
62 pub merged: usize,
64 pub pruned: usize,
66 pub decayed: usize,
68 pub duration_ms: u64,
70}
71
72#[derive(Debug, Clone)]
79pub struct MemoryConsolidator {
80 pub config: ConsolidationConfig,
82}
83
84impl MemoryConsolidator {
85 pub fn new(config: ConsolidationConfig) -> Self {
87 Self { config }
88 }
89
90 pub fn with_defaults() -> Self {
93 Self {
94 config: ConsolidationConfig::default(),
95 }
96 }
97
98 pub async fn consolidate(
107 &self,
108 memory: &MemorySubstrate,
109 fighter_id: &FighterId,
110 ) -> PunchResult<ConsolidationResult> {
111 let start = std::time::Instant::now();
112 let mut merged_count = 0usize;
113 let mut pruned_count = 0usize;
114 let mut decayed_count = 0usize;
115
116 let all_memories = self.fetch_all_memories(memory, fighter_id).await?;
118 let memories_before = all_memories.len();
119
120 info!(
121 fighter_id = %fighter_id,
122 memory_count = memories_before,
123 "beginning memory consolidation — entering recovery phase"
124 );
125
126 let now = Utc::now();
128 for entry in &all_memories {
129 let age_days = (now - entry.created_at).num_seconds() as f64 / 86400.0;
130 if age_days > 0.0 {
131 let new_confidence = self.apply_decay(entry.confidence, age_days);
132 if (new_confidence - entry.confidence).abs() > f64::EPSILON {
133 memory
134 .store_memory(fighter_id, &entry.key, &entry.value, new_confidence)
135 .await?;
136 decayed_count += 1;
137 }
138 }
139 }
140
141 let after_decay = self.fetch_all_memories(memory, fighter_id).await?;
143 for entry in &after_decay {
144 if entry.confidence < self.config.min_confidence {
145 memory.delete_memory(fighter_id, &entry.key).await?;
146 pruned_count += 1;
147 }
148 }
149
150 let after_prune = self.fetch_all_memories(memory, fighter_id).await?;
152 let mut consumed: Vec<bool> = vec![false; after_prune.len()];
153
154 for i in 0..after_prune.len() {
155 if consumed[i] {
156 continue;
157 }
158 let mut group: Vec<(&str, f64)> =
159 vec![(after_prune[i].value.as_str(), after_prune[i].confidence)];
160 let mut merge_keys: Vec<usize> = Vec::new();
161
162 for j in (i + 1)..after_prune.len() {
163 if consumed[j] {
164 continue;
165 }
166 if Self::keys_are_similar(&after_prune[i].key, &after_prune[j].key) {
167 group.push((after_prune[j].value.as_str(), after_prune[j].confidence));
168 merge_keys.push(j);
169 }
170 }
171
172 if !merge_keys.is_empty() {
173 let (merged_value, merged_confidence) = Self::merge_values(&group);
175
176 for &idx in &merge_keys {
178 memory
179 .delete_memory(fighter_id, &after_prune[idx].key)
180 .await?;
181 consumed[idx] = true;
182 merged_count += 1;
183 }
184
185 memory
187 .store_memory(
188 fighter_id,
189 &after_prune[i].key,
190 &merged_value,
191 merged_confidence,
192 )
193 .await?;
194 }
195 }
196
197 let mut current = self.fetch_all_memories(memory, fighter_id).await?;
199 if current.len() > self.config.max_memories_per_fighter {
200 current.sort_by(|a, b| {
202 a.confidence
203 .partial_cmp(&b.confidence)
204 .unwrap_or(std::cmp::Ordering::Equal)
205 .then_with(|| a.created_at.cmp(&b.created_at))
206 });
207
208 let excess = current.len() - self.config.max_memories_per_fighter;
209 for entry in current.iter().take(excess) {
210 memory.delete_memory(fighter_id, &entry.key).await?;
211 pruned_count += 1;
212 }
213 }
214
215 let memories_after = self.fetch_all_memories(memory, fighter_id).await?.len();
216
217 let duration_ms = start.elapsed().as_millis() as u64;
218
219 let result = ConsolidationResult {
220 memories_before,
221 memories_after,
222 merged: merged_count,
223 pruned: pruned_count,
224 decayed: decayed_count,
225 duration_ms,
226 };
227
228 info!(
229 fighter_id = %fighter_id,
230 before = memories_before,
231 after = memories_after,
232 merged = merged_count,
233 pruned = pruned_count,
234 decayed = decayed_count,
235 duration_ms = duration_ms,
236 "memory consolidation complete — fighter is battle-ready"
237 );
238
239 Ok(result)
240 }
241
242 pub fn apply_decay(&self, confidence: f64, age_days: f64) -> f64 {
247 let decayed = confidence * (1.0 - self.config.decay_rate).powf(age_days);
248 decayed.max(0.0)
250 }
251
252 pub fn should_consolidate(&self, memory_count: usize) -> bool {
255 memory_count > self.config.consolidation_threshold
256 }
257
258 pub fn keys_are_similar(a: &str, b: &str) -> bool {
265 if a == b {
266 return true;
267 }
268 let similarity = normalized_similarity(a, b);
269 similarity >= 0.8
272 }
273
274 pub fn merge_values(values: &[(&str, f64)]) -> (String, f64) {
280 if values.is_empty() {
281 return (String::new(), 0.0);
282 }
283
284 let best = values
286 .iter()
287 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
288 .map(|(v, _)| v.to_string())
289 .unwrap_or_default();
290
291 let avg_confidence = values.iter().map(|(_, c)| c).sum::<f64>() / values.len() as f64;
293
294 (best, avg_confidence)
295 }
296
297 async fn fetch_all_memories(
302 &self,
303 memory: &MemorySubstrate,
304 fighter_id: &FighterId,
305 ) -> PunchResult<Vec<MemoryEntry>> {
306 let fighter_str = fighter_id.to_string();
307 let conn = memory.conn().await;
308
309 let mut stmt = conn
310 .prepare(
311 "SELECT key, value, confidence, created_at, accessed_at
312 FROM memories
313 WHERE fighter_id = ?1
314 ORDER BY confidence DESC",
315 )
316 .map_err(|e| PunchError::Memory(format!("failed to fetch all memories: {e}")))?;
317
318 let rows = stmt
319 .query_map(rusqlite::params![fighter_str], |row| {
320 let key: String = row.get(0)?;
321 let value: String = row.get(1)?;
322 let confidence: f64 = row.get(2)?;
323 let created_at: String = row.get(3)?;
324 let accessed_at: String = row.get(4)?;
325 Ok((key, value, confidence, created_at, accessed_at))
326 })
327 .map_err(|e| PunchError::Memory(format!("failed to fetch all memories: {e}")))?;
328
329 let mut entries = Vec::new();
330 for row in rows {
331 let (key, value, confidence, created_at, accessed_at) =
332 row.map_err(|e| PunchError::Memory(format!("failed to read memory row: {e}")))?;
333
334 let created_at = parse_ts(&created_at)?;
335 let accessed_at = parse_ts(&accessed_at)?;
336
337 entries.push(MemoryEntry {
338 key,
339 value,
340 confidence,
341 created_at,
342 accessed_at,
343 });
344 }
345
346 debug!(
347 fighter_id = %fighter_id,
348 count = entries.len(),
349 "fetched all memories for consolidation"
350 );
351
352 Ok(entries)
353 }
354}
355
356fn levenshtein_distance(a: &str, b: &str) -> usize {
358 let a_chars: Vec<char> = a.chars().collect();
359 let b_chars: Vec<char> = b.chars().collect();
360 let a_len = a_chars.len();
361 let b_len = b_chars.len();
362
363 if a_len == 0 {
364 return b_len;
365 }
366 if b_len == 0 {
367 return a_len;
368 }
369
370 let mut prev_row: Vec<usize> = (0..=b_len).collect();
372 let mut curr_row: Vec<usize> = vec![0; b_len + 1];
373
374 for i in 1..=a_len {
375 curr_row[0] = i;
376 for j in 1..=b_len {
377 let cost = if a_chars[i - 1] == b_chars[j - 1] {
378 0
379 } else {
380 1
381 };
382 curr_row[j] = (prev_row[j] + 1)
383 .min(curr_row[j - 1] + 1)
384 .min(prev_row[j - 1] + cost);
385 }
386 std::mem::swap(&mut prev_row, &mut curr_row);
387 }
388
389 prev_row[b_len]
390}
391
392fn normalized_similarity(a: &str, b: &str) -> f64 {
395 let max_len = a.len().max(b.len());
396 if max_len == 0 {
397 return 1.0;
398 }
399 let dist = levenshtein_distance(a, b);
400 1.0 - (dist as f64 / max_len as f64)
401}
402
403fn parse_ts(s: &str) -> PunchResult<chrono::DateTime<Utc>> {
405 chrono::DateTime::parse_from_rfc3339(s)
406 .map(|dt| dt.with_timezone(&Utc))
407 .or_else(|_| {
408 chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%SZ").map(|ndt| ndt.and_utc())
409 })
410 .map_err(|e| PunchError::Memory(format!("invalid timestamp '{s}': {e}")))
411}
412
413#[cfg(test)]
414mod tests {
415 use super::*;
416 use punch_types::{FighterManifest, FighterStatus, ModelConfig, Provider, WeightClass};
417
418 fn test_manifest() -> FighterManifest {
419 FighterManifest {
420 name: "Consolidation Fighter".into(),
421 description: "memory consolidation test".into(),
422 model: ModelConfig {
423 provider: Provider::Anthropic,
424 model: "claude-sonnet-4-20250514".into(),
425 api_key_env: None,
426 base_url: None,
427 max_tokens: Some(4096),
428 temperature: Some(0.7),
429 },
430 system_prompt: "test".into(),
431 capabilities: Vec::new(),
432 weight_class: WeightClass::Featherweight,
433 tenant_id: None,
434 }
435 }
436
437 fn default_consolidator() -> MemoryConsolidator {
438 MemoryConsolidator::with_defaults()
439 }
440
441 #[test]
444 fn test_decay_zero_days_no_change() {
445 let c = default_consolidator();
446 let result = c.apply_decay(0.9, 0.0);
447 assert!(
448 (result - 0.9).abs() < f64::EPSILON,
449 "0 days should produce no decay"
450 );
451 }
452
453 #[test]
454 fn test_decay_30_days_significant() {
455 let c = default_consolidator();
456 let result = c.apply_decay(1.0, 30.0);
457 assert!(
459 result < 0.75,
460 "30 days should produce significant decay, got {result}"
461 );
462 assert!(
463 result > 0.70,
464 "30 days decay should not be too aggressive, got {result}"
465 );
466 }
467
468 #[test]
469 fn test_decay_does_not_go_below_zero() {
470 let c = default_consolidator();
471 let result = c.apply_decay(0.01, 100_000.0);
473 assert!(result >= 0.0, "decayed confidence must never be negative");
474 }
475
476 #[test]
479 fn test_should_consolidate_triggers_at_threshold() {
480 let c = default_consolidator();
481 assert!(c.should_consolidate(801), "should trigger above threshold");
483 assert!(
484 c.should_consolidate(1000),
485 "should trigger well above threshold"
486 );
487 }
488
489 #[test]
490 fn test_should_consolidate_false_below_threshold() {
491 let c = default_consolidator();
492 assert!(
493 !c.should_consolidate(800),
494 "should not trigger at exactly threshold"
495 );
496 assert!(
497 !c.should_consolidate(500),
498 "should not trigger below threshold"
499 );
500 assert!(
501 !c.should_consolidate(0),
502 "should not trigger with no memories"
503 );
504 }
505
506 #[test]
509 fn test_keys_identical_match() {
510 assert!(
511 MemoryConsolidator::keys_are_similar("user_preference", "user_preference"),
512 "identical keys must match"
513 );
514 }
515
516 #[test]
517 fn test_keys_very_different_no_match() {
518 assert!(
519 !MemoryConsolidator::keys_are_similar("user_preference", "system_config_debug_level"),
520 "very different keys must not match"
521 );
522 }
523
524 #[test]
525 fn test_keys_similar_match() {
526 assert!(
528 MemoryConsolidator::keys_are_similar("user_preference", "user_preferences"),
529 "similar keys (singular vs plural) should match"
530 );
531 }
532
533 #[test]
536 fn test_merge_values_picks_highest_confidence() {
537 let values = vec![("low_val", 0.3), ("high_val", 0.9), ("mid_val", 0.6)];
538 let (value, _) = MemoryConsolidator::merge_values(&values);
539 assert_eq!(
540 value, "high_val",
541 "should pick the value with highest confidence"
542 );
543 }
544
545 #[test]
546 fn test_merge_values_averages_confidences() {
547 let values = vec![("a", 0.3), ("b", 0.9), ("c", 0.6)];
548 let (_, avg) = MemoryConsolidator::merge_values(&values);
549 let expected = (0.3 + 0.9 + 0.6) / 3.0;
550 assert!(
551 (avg - expected).abs() < f64::EPSILON,
552 "should average confidences: expected {expected}, got {avg}"
553 );
554 }
555
556 #[tokio::test]
559 async fn test_full_consolidation_reduces_count() {
560 let substrate = MemorySubstrate::in_memory().expect("in-memory substrate");
561 let fid = FighterId::new();
562 substrate
563 .save_fighter(&fid, &test_manifest(), FighterStatus::Idle)
564 .await
565 .expect("save fighter");
566
567 for i in 0..20 {
569 let confidence = if i < 5 { 0.1 } else { 0.8 };
570 substrate
571 .store_memory(
572 &fid,
573 &format!("memory_{i}"),
574 &format!("value_{i}"),
575 confidence,
576 )
577 .await
578 .expect("store memory");
579 }
580
581 let consolidator = MemoryConsolidator::new(ConsolidationConfig {
582 max_memories_per_fighter: 100,
583 consolidation_threshold: 10,
584 min_confidence: 0.3,
585 decay_rate: 0.0, merge_similarity_threshold: 0.8,
587 max_age_days: 90,
588 });
589
590 let result = consolidator
591 .consolidate(&substrate, &fid)
592 .await
593 .expect("consolidation");
594
595 assert_eq!(result.memories_before, 20);
596 assert!(
597 result.memories_after < result.memories_before,
598 "consolidation should reduce memory count"
599 );
600 assert!(
601 result.pruned > 0,
602 "should have pruned some low-confidence memories"
603 );
604 }
605
606 #[tokio::test]
607 async fn test_pruning_removes_low_confidence() {
608 let substrate = MemorySubstrate::in_memory().expect("in-memory substrate");
609 let fid = FighterId::new();
610 substrate
611 .save_fighter(&fid, &test_manifest(), FighterStatus::Idle)
612 .await
613 .expect("save fighter");
614
615 substrate
617 .store_memory(&fid, "strong_memory", "important", 0.9)
618 .await
619 .expect("store");
620 substrate
621 .store_memory(&fid, "weak_memory", "forgettable", 0.1)
622 .await
623 .expect("store");
624 substrate
625 .store_memory(&fid, "medium_memory", "moderate", 0.5)
626 .await
627 .expect("store");
628
629 let consolidator = MemoryConsolidator::new(ConsolidationConfig {
630 min_confidence: 0.3,
631 decay_rate: 0.0,
632 ..ConsolidationConfig::default()
633 });
634
635 let result = consolidator
636 .consolidate(&substrate, &fid)
637 .await
638 .expect("consolidation");
639
640 assert!(result.pruned >= 1, "should prune at least the weak memory");
642
643 let remaining = substrate
645 .recall_memories(&fid, "weak_memory", 10)
646 .await
647 .expect("recall");
648 assert!(remaining.is_empty(), "weak memory should be pruned");
649
650 let strong = substrate
652 .recall_memories(&fid, "strong_memory", 10)
653 .await
654 .expect("recall");
655 assert_eq!(
656 strong.len(),
657 1,
658 "strong memory should survive consolidation"
659 );
660 }
661
662 #[test]
663 fn test_config_defaults_are_sensible() {
664 let config = ConsolidationConfig::default();
665 assert_eq!(config.max_memories_per_fighter, 1000);
666 assert_eq!(config.consolidation_threshold, 800);
667 assert!((config.min_confidence - 0.3).abs() < f64::EPSILON);
668 assert!((config.decay_rate - 0.01).abs() < f64::EPSILON);
669 assert!((config.merge_similarity_threshold - 0.8).abs() < f64::EPSILON);
670 assert_eq!(config.max_age_days, 90);
671 assert!(
673 config.consolidation_threshold < config.max_memories_per_fighter,
674 "threshold should be below max"
675 );
676 }
677
678 #[test]
681 fn test_levenshtein_identical() {
682 assert_eq!(levenshtein_distance("hello", "hello"), 0);
683 }
684
685 #[test]
686 fn test_levenshtein_empty() {
687 assert_eq!(levenshtein_distance("", "abc"), 3);
688 assert_eq!(levenshtein_distance("abc", ""), 3);
689 }
690
691 #[test]
692 fn test_levenshtein_one_edit() {
693 assert_eq!(levenshtein_distance("cat", "cats"), 1);
694 assert_eq!(levenshtein_distance("cat", "car"), 1);
695 }
696
697 #[test]
698 fn test_normalized_similarity_identical() {
699 let sim = normalized_similarity("test", "test");
700 assert!((sim - 1.0).abs() < f64::EPSILON);
701 }
702
703 #[test]
704 fn test_normalized_similarity_completely_different() {
705 let sim = normalized_similarity("abc", "xyz");
706 assert!(
707 sim < 0.5,
708 "completely different strings should have low similarity"
709 );
710 }
711
712 #[test]
713 fn test_levenshtein_both_empty() {
714 assert_eq!(levenshtein_distance("", ""), 0);
715 }
716
717 #[test]
718 fn test_levenshtein_substitutions() {
719 assert_eq!(levenshtein_distance("abc", "xyz"), 3);
720 }
721
722 #[test]
723 fn test_normalized_similarity_both_empty() {
724 let sim = normalized_similarity("", "");
725 assert!((sim - 1.0).abs() < f64::EPSILON);
726 }
727
728 #[test]
729 fn test_normalized_similarity_one_empty() {
730 let sim = normalized_similarity("hello", "");
731 assert!((sim - 0.0).abs() < f64::EPSILON);
732 }
733
734 #[test]
735 fn test_merge_values_empty() {
736 let (value, confidence) = MemoryConsolidator::merge_values(&[]);
737 assert!(value.is_empty());
738 assert!((confidence - 0.0).abs() < f64::EPSILON);
739 }
740
741 #[test]
742 fn test_merge_values_single() {
743 let (value, confidence) = MemoryConsolidator::merge_values(&[("only", 0.5)]);
744 assert_eq!(value, "only");
745 assert!((confidence - 0.5).abs() < f64::EPSILON);
746 }
747
748 #[test]
749 fn test_decay_custom_rate() {
750 let config = ConsolidationConfig {
751 decay_rate: 0.1,
752 ..ConsolidationConfig::default()
753 };
754 let c = MemoryConsolidator::new(config);
755 let result = c.apply_decay(1.0, 10.0);
756 assert!(result < 0.4);
758 assert!(result > 0.3);
759 }
760
761 #[test]
762 fn test_config_serde_roundtrip() {
763 let config = ConsolidationConfig::default();
764 let json = serde_json::to_string(&config).unwrap();
765 let restored: ConsolidationConfig = serde_json::from_str(&json).unwrap();
766 assert_eq!(
767 restored.max_memories_per_fighter,
768 config.max_memories_per_fighter
769 );
770 assert_eq!(restored.max_age_days, config.max_age_days);
771 }
772
773 #[test]
774 fn test_result_serde_roundtrip() {
775 let result = ConsolidationResult {
776 memories_before: 100,
777 memories_after: 80,
778 merged: 5,
779 pruned: 15,
780 decayed: 90,
781 duration_ms: 42,
782 };
783 let json = serde_json::to_string(&result).unwrap();
784 let restored: ConsolidationResult = serde_json::from_str(&json).unwrap();
785 assert_eq!(restored.memories_before, 100);
786 assert_eq!(restored.pruned, 15);
787 }
788
789 #[test]
790 fn test_keys_similar_empty_strings() {
791 assert!(MemoryConsolidator::keys_are_similar("", ""));
792 }
793}