1use std::sync::Arc;
20
21use anyhow::Result;
22use chrono::Utc;
23use uuid::Uuid;
24
25use crate::{
26 EmbeddingProvider, FactStore, LanceDatabase, MessageMetadata, MessageStore, SummaryStore,
27 TierMetadataStore,
28};
29
30const SECS_PER_HOUR: f32 = 3600.0;
31const SIMILARITY_WEIGHT: f32 = 0.50;
32const RECENCY_WEIGHT: f32 = 0.30;
33const IMPORTANCE_WEIGHT: f32 = 0.20;
34const DEFAULT_HOT_RETENTION_HOURS: u64 = 24;
35const DEFAULT_WARM_RETENTION_HOURS: u64 = 168;
36const DEFAULT_HOT_IMPORTANCE_THRESHOLD: f32 = 0.3;
37const DEFAULT_WARM_IMPORTANCE_THRESHOLD: f32 = 0.1;
38const DEFAULT_MAX_HOT_MESSAGES: usize = 1000;
39const DEFAULT_MAX_WARM_SUMMARIES: usize = 5000;
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
49#[serde(rename_all = "snake_case")]
50#[derive(Default)]
51pub enum MemoryAuthority {
52 Ephemeral,
54 #[default]
56 Session,
57 Canonical,
62}
63
64impl MemoryAuthority {
65 pub fn as_str(&self) -> &'static str {
67 match self {
68 Self::Ephemeral => "ephemeral",
69 Self::Session => "session",
70 Self::Canonical => "canonical",
71 }
72 }
73
74 pub fn parse(s: &str) -> Self {
76 match s {
77 "ephemeral" => Self::Ephemeral,
78 "canonical" => Self::Canonical,
79 _ => Self::Session,
80 }
81 }
82}
83
84#[derive(Debug)]
98pub struct CanonicalWriteToken(());
99
100impl CanonicalWriteToken {
101 #[allow(dead_code)]
103 pub(crate) fn new() -> Self {
104 Self(())
105 }
106}
107
108#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
112pub enum MemoryTier {
113 Hot,
115 Warm,
117 Cold,
119}
120
121impl MemoryTier {
122 pub fn demote(&self) -> Option<MemoryTier> {
124 match self {
125 MemoryTier::Hot => Some(MemoryTier::Warm),
126 MemoryTier::Warm => Some(MemoryTier::Cold),
127 MemoryTier::Cold => None,
128 }
129 }
130
131 pub fn promote(&self) -> Option<MemoryTier> {
133 match self {
134 MemoryTier::Hot => None,
135 MemoryTier::Warm => Some(MemoryTier::Hot),
136 MemoryTier::Cold => Some(MemoryTier::Warm),
137 }
138 }
139}
140
141#[derive(Debug, Clone)]
143pub struct TierMetadata {
144 pub message_id: String,
146 pub tier: MemoryTier,
148 pub importance: f32,
150 pub last_accessed: i64,
152 pub access_count: u32,
154 pub created_at: i64,
156 pub authority: MemoryAuthority,
162}
163
164impl TierMetadata {
165 pub fn new(message_id: String, importance: f32) -> Self {
167 let now = Utc::now().timestamp();
168 Self {
169 message_id,
170 tier: MemoryTier::Hot,
171 importance,
172 last_accessed: now,
173 access_count: 0,
174 created_at: now,
175 authority: MemoryAuthority::Session,
176 }
177 }
178
179 pub fn with_authority(message_id: String, importance: f32, authority: MemoryAuthority) -> Self {
181 Self {
182 authority,
183 ..Self::new(message_id, importance)
184 }
185 }
186
187 pub fn record_access(&mut self) {
189 self.last_accessed = Utc::now().timestamp();
190 self.access_count += 1;
191 }
192
193 pub fn retention_score(&self) -> f32 {
195 let age_hours = (Utc::now().timestamp() - self.last_accessed) as f32 / SECS_PER_HOUR;
196 let recency_factor = (-0.01 * age_hours).exp(); let access_factor = (self.access_count as f32).ln_1p() * 0.1; self.importance * SIMILARITY_WEIGHT
200 + recency_factor * RECENCY_WEIGHT
201 + access_factor * IMPORTANCE_WEIGHT
202 }
203}
204
205#[derive(Debug, Clone)]
207pub struct MessageSummary {
208 pub summary_id: String,
210 pub original_message_id: String,
212 pub conversation_id: String,
214 pub role: String,
216 pub summary: String,
218 pub key_entities: Vec<String>,
220 pub created_at: i64,
222}
223
224#[derive(Debug, Clone)]
226pub struct KeyFact {
227 pub fact_id: String,
229 pub original_message_ids: Vec<String>,
231 pub conversation_id: String,
233 pub fact: String,
235 pub fact_type: FactType,
237 pub created_at: i64,
239}
240
241#[derive(Debug, Clone, Copy, PartialEq, Eq)]
243pub enum FactType {
244 Decision,
246 Definition,
248 Requirement,
250 CodeChange,
252 Configuration,
254 Other,
256}
257
258#[derive(Debug, Clone)]
262pub struct MultiFactorScore {
263 pub similarity: f32,
265 pub recency: f32,
268 pub importance: f32,
270 pub combined: f32,
272}
273
274impl MultiFactorScore {
275 pub fn compute(similarity: f32, recency: f32, importance: f32) -> Self {
277 let combined = similarity * SIMILARITY_WEIGHT
278 + recency * RECENCY_WEIGHT
279 + importance * IMPORTANCE_WEIGHT;
280 Self {
281 similarity,
282 recency,
283 importance,
284 combined,
285 }
286 }
287
288 const DECAY_RATE: f32 = 0.01;
290
291 pub fn recency_from_hours(hours_since_access: f32) -> f32 {
293 (-Self::DECAY_RATE * hours_since_access).exp()
294 }
295}
296
297#[derive(Debug, Clone)]
299pub struct TieredSearchResult {
300 pub content: String,
302 pub score: f32,
304 pub tier: MemoryTier,
306 pub original_message_id: Option<String>,
308 pub metadata: Option<MessageMetadata>,
310 pub multi_factor_score: Option<MultiFactorScore>,
314}
315
316#[derive(Debug, Clone)]
318pub struct TieredMemoryConfig {
319 pub hot_retention_hours: u64,
321 pub warm_retention_hours: u64,
323 pub hot_importance_threshold: f32,
325 pub warm_importance_threshold: f32,
327 pub max_hot_messages: usize,
329 pub max_warm_summaries: usize,
331 pub session_ttl_secs: Option<u64>,
341}
342
343impl Default for TieredMemoryConfig {
344 fn default() -> Self {
345 Self {
346 hot_retention_hours: DEFAULT_HOT_RETENTION_HOURS,
347 warm_retention_hours: DEFAULT_WARM_RETENTION_HOURS,
348 hot_importance_threshold: DEFAULT_HOT_IMPORTANCE_THRESHOLD,
349 warm_importance_threshold: DEFAULT_WARM_IMPORTANCE_THRESHOLD,
350 max_hot_messages: DEFAULT_MAX_HOT_MESSAGES,
351 max_warm_summaries: DEFAULT_MAX_WARM_SUMMARIES,
352 session_ttl_secs: None,
353 }
354 }
355}
356
357pub struct TieredMemory {
359 pub hot: Arc<MessageStore>,
361
362 warm: SummaryStore,
364
365 cold: FactStore,
367
368 tier_metadata: TierMetadataStore,
370
371 config: TieredMemoryConfig,
373
374 #[allow(dead_code)]
376 embeddings: Arc<EmbeddingProvider>,
377}
378
379impl TieredMemory {
380 pub async fn new(
382 hot_store: Arc<MessageStore>,
383 db: Arc<LanceDatabase>,
384 embeddings: Arc<EmbeddingProvider>,
385 config: TieredMemoryConfig,
386 ) -> Self {
387 Self {
388 hot: hot_store,
389 warm: SummaryStore::new(Arc::clone(&db), Arc::clone(&embeddings)),
390 cold: FactStore::new(Arc::clone(&db), Arc::clone(&embeddings)),
391 tier_metadata: TierMetadataStore::new(db),
392 config,
393 embeddings,
394 }
395 }
396
397 pub async fn with_defaults(
399 hot_store: Arc<MessageStore>,
400 db: Arc<LanceDatabase>,
401 embeddings: Arc<EmbeddingProvider>,
402 ) -> Self {
403 Self::new(hot_store, db, embeddings, TieredMemoryConfig::default()).await
404 }
405
406 pub async fn add_message(
412 &mut self,
413 mut message: MessageMetadata,
414 importance: f32,
415 ) -> Result<()> {
416 if let Some(ttl_secs) = self.config.session_ttl_secs {
418 message.expires_at = Some(Utc::now().timestamp() + ttl_secs as i64);
419 }
420 let metadata = TierMetadata::new(message.message_id.clone(), importance);
421 self.tier_metadata.add(metadata).await?;
422 self.hot.add(message).await
423 }
424
425 pub async fn add_canonical_message(
431 &mut self,
432 message: MessageMetadata,
433 importance: f32,
434 _token: CanonicalWriteToken,
435 ) -> Result<()> {
436 let metadata = TierMetadata::with_authority(
438 message.message_id.clone(),
439 importance,
440 MemoryAuthority::Canonical,
441 );
442 self.tier_metadata.add(metadata).await?;
443 self.hot.add(message).await
444 }
445
446 pub async fn evict_expired(&self) -> Result<usize> {
454 let evicted = self.hot.delete_expired().await?;
455 if evicted > 0 {
456 tracing::info!(
457 evicted,
458 "TieredMemory: evicted {} expired message(s)",
459 evicted
460 );
461 }
462 Ok(evicted)
463 }
464
465 pub async fn record_access(&mut self, message_id: &str) -> Result<()> {
467 if let Some(mut meta) = self.tier_metadata.get(message_id).await? {
468 meta.record_access();
469 self.tier_metadata.update(meta).await?;
470 }
471 Ok(())
472 }
473
474 pub async fn search_adaptive(
476 &mut self,
477 query: &str,
478 conversation_id: Option<&str>,
479 ) -> Result<Vec<TieredSearchResult>> {
480 let mut results = Vec::new();
481
482 let hot_results = if let Some(conv_id) = conversation_id {
484 self.hot.search_conversation(conv_id, query, 5, 0.6).await?
485 } else {
486 self.hot.search(query, 5, 0.6).await?
487 };
488
489 for (msg, score) in hot_results {
490 if let Some(exp) = msg.expires_at
492 && exp <= Utc::now().timestamp()
493 {
494 continue;
495 }
496
497 let _ = self.record_access(&msg.message_id).await;
499
500 results.push(TieredSearchResult {
501 content: msg.content.clone(),
502 score,
503 tier: MemoryTier::Hot,
504 original_message_id: Some(msg.message_id.clone()),
505 metadata: Some(msg),
506 multi_factor_score: None,
507 });
508 }
509
510 if results.iter().any(|r| r.score > 0.85) {
512 return Ok(results);
513 }
514
515 let warm_results = if let Some(conv_id) = conversation_id {
517 self.warm
518 .search_conversation(conv_id, query, 3, 0.5)
519 .await?
520 } else {
521 self.warm.search(query, 3, 0.5).await?
522 };
523
524 for (summary, score) in warm_results {
525 results.push(TieredSearchResult {
526 content: summary.summary.clone(),
527 score,
528 tier: MemoryTier::Warm,
529 original_message_id: Some(summary.original_message_id.clone()),
530 metadata: None,
531 multi_factor_score: None,
532 });
533 }
534
535 if results.iter().all(|r| r.score < 0.7) {
537 let cold_results = if let Some(conv_id) = conversation_id {
538 self.cold
539 .search_conversation(conv_id, query, 3, 0.4)
540 .await?
541 } else {
542 self.cold.search(query, 3, 0.4).await?
543 };
544
545 for (fact, score) in cold_results {
546 results.push(TieredSearchResult {
547 content: fact.fact.clone(),
548 score,
549 tier: MemoryTier::Cold,
550 original_message_id: fact.original_message_ids.first().cloned(),
551 metadata: None,
552 multi_factor_score: None,
553 });
554 }
555 }
556
557 results.sort_by(|a, b| {
559 b.score
560 .partial_cmp(&a.score)
561 .unwrap_or(std::cmp::Ordering::Equal)
562 });
563
564 Ok(results)
565 }
566
567 pub async fn search_adaptive_multi_factor(
576 &mut self,
577 query: &str,
578 conversation_id: Option<&str>,
579 ) -> Result<Vec<TieredSearchResult>> {
580 let mut results = self.search_adaptive(query, conversation_id).await?;
582
583 let ids: Vec<&str> = results
585 .iter()
586 .filter_map(|r| r.original_message_id.as_deref())
587 .collect();
588
589 let meta_map = self.tier_metadata.get_many(&ids).await.unwrap_or_default();
590
591 let now_secs = chrono::Utc::now().timestamp();
592
593 for result in &mut results {
594 let similarity = result.score;
595
596 let (recency, importance) = if let Some(id) = &result.original_message_id {
597 if let Some(meta) = meta_map.get(id.as_str()) {
598 let hours_since = (now_secs - meta.last_accessed).max(0) as f32 / 3600.0;
599 (
600 MultiFactorScore::recency_from_hours(hours_since),
601 meta.importance,
602 )
603 } else {
604 (1.0_f32, 0.5_f32) }
606 } else {
607 (1.0_f32, 0.5_f32)
608 };
609
610 result.multi_factor_score =
611 Some(MultiFactorScore::compute(similarity, recency, importance));
612 }
613
614 results.sort_by(|a, b| {
616 let sa = a
617 .multi_factor_score
618 .as_ref()
619 .map_or(a.score, |s| s.combined);
620 let sb = b
621 .multi_factor_score
622 .as_ref()
623 .map_or(b.score, |s| s.combined);
624 sb.partial_cmp(&sa).unwrap_or(std::cmp::Ordering::Equal)
625 });
626
627 Ok(results)
628 }
629
630 pub async fn demote_to_warm(
632 &mut self,
633 message_id: &str,
634 summary: MessageSummary,
635 ) -> Result<()> {
636 if let Some(mut meta) = self.tier_metadata.get(message_id).await? {
638 meta.tier = MemoryTier::Warm;
639 self.tier_metadata.update(meta).await?;
640 }
641
642 self.warm.add(summary).await
644 }
645
646 pub async fn demote_to_cold(&mut self, summary_id: &str, fact: KeyFact) -> Result<()> {
648 self.warm.delete(summary_id).await?;
650
651 self.cold.add(fact).await
653 }
654
655 pub async fn promote_to_hot(&mut self, message_id: &str) -> Result<Option<MessageMetadata>> {
657 if let Some(mut meta) = self.tier_metadata.get(message_id).await? {
659 meta.tier = MemoryTier::Hot;
660 meta.record_access();
661 self.tier_metadata.update(meta).await?;
662 }
663
664 Ok(None)
667 }
668
669 pub async fn get_demotion_candidates(
671 &self,
672 tier: MemoryTier,
673 count: usize,
674 ) -> Result<Vec<String>> {
675 let all_metadata = self.tier_metadata.get_by_tier(tier).await?;
676
677 let mut candidates: Vec<_> = all_metadata
678 .into_iter()
679 .map(|m| (m.message_id.clone(), m.retention_score()))
680 .collect();
681
682 candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
684
685 Ok(candidates
686 .into_iter()
687 .take(count)
688 .map(|(id, _)| id)
689 .collect())
690 }
691
692 pub async fn get_stats(&self) -> Result<TieredMemoryStats> {
694 let hot_count = self.tier_metadata.count_by_tier(MemoryTier::Hot).await?;
695 let warm_count = self.warm.count().await?;
696 let cold_count = self.cold.count().await?;
697 let total_tracked = self.tier_metadata.count().await?;
698
699 Ok(TieredMemoryStats {
700 hot_count,
701 warm_count,
702 cold_count,
703 total_tracked,
704 })
705 }
706
707 pub fn fallback_summarize(&self, content: &str) -> String {
709 let words: Vec<&str> = content.split_whitespace().collect();
710 if words.len() <= 75 {
711 content.to_string()
712 } else {
713 format!("{}...", words[..75].join(" "))
714 }
715 }
716
717 pub fn fallback_fact(&self, summary: &MessageSummary) -> KeyFact {
719 KeyFact {
720 fact_id: Uuid::new_v4().to_string(),
721 original_message_ids: vec![summary.original_message_id.clone()],
722 conversation_id: summary.conversation_id.clone(),
723 fact: summary.summary.clone(),
724 fact_type: FactType::Other,
725 created_at: Utc::now().timestamp(),
726 }
727 }
728}
729
730#[derive(Debug, Clone)]
732pub struct TieredMemoryStats {
733 pub hot_count: usize,
735 pub warm_count: usize,
737 pub cold_count: usize,
739 pub total_tracked: usize,
741}
742
743#[cfg(test)]
744mod tests {
745 use super::*;
746
747 #[test]
750 fn test_multi_factor_score_weights_sum_to_one() {
751 let score = MultiFactorScore::compute(1.0, 1.0, 1.0);
753 assert!(
754 (score.combined - 1.0).abs() < 1e-6,
755 "all-one inputs should yield combined=1"
756 );
757 }
758
759 #[test]
760 fn test_multi_factor_score_zero_inputs() {
761 let score = MultiFactorScore::compute(0.0, 0.0, 0.0);
762 assert_eq!(score.combined, 0.0);
763 }
764
765 #[test]
766 fn test_recency_factor_fresh_entry() {
767 let r = MultiFactorScore::recency_from_hours(0.0);
769 assert!((r - 1.0).abs() < 1e-6);
770 }
771
772 #[test]
773 fn test_recency_factor_decays_over_time() {
774 let r_now = MultiFactorScore::recency_from_hours(0.0);
775 let r_day = MultiFactorScore::recency_from_hours(24.0);
776 let r_week = MultiFactorScore::recency_from_hours(168.0);
777 assert!(
778 r_now > r_day,
779 "fresh entry must score higher than 1-day-old"
780 );
781 assert!(
782 r_day > r_week,
783 "1-day-old must score higher than 1-week-old"
784 );
785 assert!(r_week > 0.0, "recency factor must remain positive");
786 }
787
788 #[test]
789 fn test_high_similarity_low_recency_can_be_beaten_by_balanced_entry() {
790 let stale =
792 MultiFactorScore::compute(0.95, MultiFactorScore::recency_from_hours(168.0), 0.0);
793 let fresh = MultiFactorScore::compute(0.70, MultiFactorScore::recency_from_hours(1.0), 0.9);
795 assert!(
797 fresh.combined > stale.combined,
798 "fresh important entry ({:.3}) should beat stale high-similarity entry ({:.3})",
799 fresh.combined,
800 stale.combined
801 );
802 }
803
804 #[test]
807 fn test_tier_demotion() {
808 assert_eq!(MemoryTier::Hot.demote(), Some(MemoryTier::Warm));
809 assert_eq!(MemoryTier::Warm.demote(), Some(MemoryTier::Cold));
810 assert_eq!(MemoryTier::Cold.demote(), None);
811 }
812
813 #[test]
814 fn test_tier_promotion() {
815 assert_eq!(MemoryTier::Hot.promote(), None);
816 assert_eq!(MemoryTier::Warm.promote(), Some(MemoryTier::Hot));
817 assert_eq!(MemoryTier::Cold.promote(), Some(MemoryTier::Warm));
818 }
819
820 #[test]
821 fn test_tier_metadata_retention_score() {
822 let mut meta = TierMetadata::new("test-1".to_string(), 0.8);
823
824 let score1 = meta.retention_score();
826 assert!(score1 > 0.0);
827
828 meta.record_access();
830 let score2 = meta.retention_score();
831 assert!(score2 >= score1 * 0.9); }
833
834 #[test]
835 fn test_default_config() {
836 let config = TieredMemoryConfig::default();
837 assert_eq!(config.hot_retention_hours, 24);
838 assert_eq!(config.warm_retention_hours, 168);
839 assert!(config.hot_importance_threshold > 0.0);
840 assert!(config.session_ttl_secs.is_none());
841 }
842
843 #[test]
844 fn test_config_with_session_ttl() {
845 let config = TieredMemoryConfig {
846 session_ttl_secs: Some(3600),
847 ..TieredMemoryConfig::default()
848 };
849 assert_eq!(config.session_ttl_secs, Some(3600));
850 }
851
852 #[test]
855 fn test_memory_authority_default() {
856 assert_eq!(MemoryAuthority::default(), MemoryAuthority::Session);
857 }
858
859 #[test]
860 fn test_memory_authority_round_trip() {
861 for auth in [
862 MemoryAuthority::Ephemeral,
863 MemoryAuthority::Session,
864 MemoryAuthority::Canonical,
865 ] {
866 assert_eq!(MemoryAuthority::parse(auth.as_str()), auth);
867 }
868 }
869
870 #[test]
871 fn test_memory_authority_unknown_defaults_to_session() {
872 assert_eq!(MemoryAuthority::parse("bogus"), MemoryAuthority::Session);
873 }
874
875 #[test]
876 fn test_tier_metadata_default_authority() {
877 let meta = TierMetadata::new("m-1".to_string(), 0.5);
878 assert_eq!(meta.authority, MemoryAuthority::Session);
879 }
880
881 #[test]
882 fn test_tier_metadata_with_authority() {
883 let meta = TierMetadata::with_authority("m-2".to_string(), 0.9, MemoryAuthority::Canonical);
884 assert_eq!(meta.authority, MemoryAuthority::Canonical);
885 assert_eq!(meta.importance, 0.9);
886 }
887
888 #[test]
889 fn test_canonical_write_token_is_crate_private() {
890 let _token = CanonicalWriteToken::new();
893 }
894}