use std::sync::Arc;
use anyhow::Result;
use chrono::Utc;
use uuid::Uuid;
use crate::databases::StorageBackend;
use crate::stores::mental_model_store::{MentalModel, MentalModelStore, ModelType};
use crate::{
EmbeddingProvider, FactStore, LanceDatabase, MessageMetadata, MessageStore, SummaryStore,
TierMetadataStore,
};
const SECS_PER_HOUR: f32 = 3600.0;
const SIMILARITY_WEIGHT: f32 = 0.50;
const RECENCY_WEIGHT: f32 = 0.30;
const IMPORTANCE_WEIGHT: f32 = 0.20;
const DEFAULT_HOT_RETENTION_HOURS: u64 = 24;
const DEFAULT_WARM_RETENTION_HOURS: u64 = 168;
const DEFAULT_HOT_IMPORTANCE_THRESHOLD: f32 = 0.3;
const DEFAULT_WARM_IMPORTANCE_THRESHOLD: f32 = 0.1;
const DEFAULT_MAX_HOT_MESSAGES: usize = 1000;
const DEFAULT_MAX_WARM_SUMMARIES: usize = 5000;
const FAST_DECAY_RATE: f32 = 0.05;
const TEMPORAL_KEYWORDS: &[&str] = &[
"recent",
"recently",
"latest",
"last",
"current",
"currently",
"today",
"yesterday",
"this week",
"now",
"just",
"new",
"newest",
];
fn detect_temporal_query(query: &str) -> f32 {
let lower = query.to_lowercase();
let hits = TEMPORAL_KEYWORDS
.iter()
.filter(|kw| lower.contains(*kw))
.count();
(hits as f32 / 3.0).min(1.0)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
#[derive(Default)]
pub enum MemoryAuthority {
Ephemeral,
#[default]
Session,
Canonical,
}
impl MemoryAuthority {
pub fn as_str(&self) -> &'static str {
match self {
Self::Ephemeral => "ephemeral",
Self::Session => "session",
Self::Canonical => "canonical",
}
}
pub fn parse(s: &str) -> Self {
match s {
"ephemeral" => Self::Ephemeral,
"canonical" => Self::Canonical,
_ => Self::Session,
}
}
}
#[derive(Debug)]
pub struct CanonicalWriteToken(());
impl CanonicalWriteToken {
#[allow(dead_code)]
pub(crate) fn new() -> Self {
Self(())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum MemoryTier {
Hot,
Warm,
Cold,
MentalModel,
}
impl MemoryTier {
pub fn demote(&self) -> Option<MemoryTier> {
match self {
MemoryTier::Hot => Some(MemoryTier::Warm),
MemoryTier::Warm => Some(MemoryTier::Cold),
MemoryTier::Cold => Some(MemoryTier::MentalModel),
MemoryTier::MentalModel => None,
}
}
pub fn promote(&self) -> Option<MemoryTier> {
match self {
MemoryTier::Hot => None,
MemoryTier::Warm => Some(MemoryTier::Hot),
MemoryTier::Cold => Some(MemoryTier::Warm),
MemoryTier::MentalModel => Some(MemoryTier::Cold),
}
}
}
#[derive(Debug, Clone)]
pub struct TierMetadata {
pub message_id: String,
pub tier: MemoryTier,
pub importance: f32,
pub last_accessed: i64,
pub access_count: u32,
pub created_at: i64,
pub authority: MemoryAuthority,
}
impl TierMetadata {
pub fn new(message_id: String, importance: f32) -> Self {
let now = Utc::now().timestamp();
Self {
message_id,
tier: MemoryTier::Hot,
importance,
last_accessed: now,
access_count: 0,
created_at: now,
authority: MemoryAuthority::Session,
}
}
pub fn with_authority(message_id: String, importance: f32, authority: MemoryAuthority) -> Self {
Self {
authority,
..Self::new(message_id, importance)
}
}
pub fn record_access(&mut self) {
self.last_accessed = Utc::now().timestamp();
self.access_count += 1;
}
pub fn retention_score(&self) -> f32 {
let age_hours = (Utc::now().timestamp() - self.last_accessed) as f32 / SECS_PER_HOUR;
let recency_factor = (-0.01 * age_hours).exp(); let access_factor = (self.access_count as f32).ln_1p() * 0.1;
self.importance * SIMILARITY_WEIGHT
+ recency_factor * RECENCY_WEIGHT
+ access_factor * IMPORTANCE_WEIGHT
}
}
#[derive(Debug, Clone)]
pub struct MessageSummary {
pub summary_id: String,
pub original_message_id: String,
pub conversation_id: String,
pub role: String,
pub summary: String,
pub key_entities: Vec<String>,
pub created_at: i64,
}
#[derive(Debug, Clone)]
pub struct KeyFact {
pub fact_id: String,
pub original_message_ids: Vec<String>,
pub conversation_id: String,
pub fact: String,
pub fact_type: FactType,
pub created_at: i64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FactType {
Decision,
Definition,
Requirement,
CodeChange,
Configuration,
Other,
}
#[derive(Debug, Clone)]
pub struct MultiFactorScore {
pub similarity: f32,
pub recency: f32,
pub importance: f32,
pub combined: f32,
}
impl MultiFactorScore {
pub fn compute(similarity: f32, recency: f32, importance: f32) -> Self {
Self::compute_with_weights(
similarity,
recency,
importance,
SIMILARITY_WEIGHT,
RECENCY_WEIGHT,
IMPORTANCE_WEIGHT,
)
}
pub fn compute_with_weights(
similarity: f32,
recency: f32,
importance: f32,
sim_w: f32,
rec_w: f32,
imp_w: f32,
) -> Self {
let combined = (similarity * sim_w + recency * rec_w + importance * imp_w).clamp(0.0, 1.0);
Self {
similarity,
recency,
importance,
combined,
}
}
const DECAY_RATE: f32 = 0.01;
pub fn recency_from_hours(hours_since_access: f32) -> f32 {
(-Self::DECAY_RATE * hours_since_access).exp()
}
pub fn recency_from_hours_fast(hours_since_access: f32) -> f32 {
(-FAST_DECAY_RATE * hours_since_access).exp()
}
}
#[derive(Debug, Clone)]
pub struct TieredSearchResult {
pub content: String,
pub score: f32,
pub tier: MemoryTier,
pub original_message_id: Option<String>,
pub metadata: Option<MessageMetadata>,
pub multi_factor_score: Option<MultiFactorScore>,
}
#[derive(Debug, Clone)]
pub struct TieredMemoryConfig {
pub hot_retention_hours: u64,
pub warm_retention_hours: u64,
pub hot_importance_threshold: f32,
pub warm_importance_threshold: f32,
pub max_hot_messages: usize,
pub max_warm_summaries: usize,
pub session_ttl_secs: Option<u64>,
pub temporal_boost: f32,
pub fast_decay: bool,
pub max_mental_models: usize,
}
impl Default for TieredMemoryConfig {
fn default() -> Self {
Self {
hot_retention_hours: DEFAULT_HOT_RETENTION_HOURS,
warm_retention_hours: DEFAULT_WARM_RETENTION_HOURS,
hot_importance_threshold: DEFAULT_HOT_IMPORTANCE_THRESHOLD,
warm_importance_threshold: DEFAULT_WARM_IMPORTANCE_THRESHOLD,
max_hot_messages: DEFAULT_MAX_HOT_MESSAGES,
max_warm_summaries: DEFAULT_MAX_WARM_SUMMARIES,
session_ttl_secs: None,
temporal_boost: 0.3,
fast_decay: false,
max_mental_models: 500,
}
}
}
pub struct TieredMemory {
pub hot: Arc<MessageStore>,
warm: SummaryStore,
cold: FactStore,
tier_metadata: TierMetadataStore,
mental_model: MentalModelStore,
config: TieredMemoryConfig,
#[allow(dead_code)]
embeddings: Arc<EmbeddingProvider>,
}
impl TieredMemory {
pub async fn new(
hot_store: Arc<MessageStore>,
db: Arc<LanceDatabase>,
embeddings: Arc<EmbeddingProvider>,
config: TieredMemoryConfig,
) -> Self {
let mental_model = MentalModelStore::new(
Arc::clone(&db) as Arc<dyn StorageBackend>,
Arc::clone(&embeddings),
);
Self {
hot: hot_store,
warm: SummaryStore::new(Arc::clone(&db), Arc::clone(&embeddings)),
cold: FactStore::new(Arc::clone(&db), Arc::clone(&embeddings)),
tier_metadata: TierMetadataStore::new(db),
mental_model,
config,
embeddings,
}
}
pub async fn with_defaults(
hot_store: Arc<MessageStore>,
db: Arc<LanceDatabase>,
embeddings: Arc<EmbeddingProvider>,
) -> Self {
Self::new(hot_store, db, embeddings, TieredMemoryConfig::default()).await
}
pub async fn add_message(
&mut self,
mut message: MessageMetadata,
importance: f32,
) -> Result<()> {
if let Some(ttl_secs) = self.config.session_ttl_secs {
message.expires_at = Some(Utc::now().timestamp() + ttl_secs as i64);
}
let metadata = TierMetadata::new(message.message_id.clone(), importance);
self.tier_metadata.add(metadata).await?;
self.hot.add(message).await
}
pub async fn add_canonical_message(
&mut self,
message: MessageMetadata,
importance: f32,
_token: CanonicalWriteToken,
) -> Result<()> {
let metadata = TierMetadata::with_authority(
message.message_id.clone(),
importance,
MemoryAuthority::Canonical,
);
self.tier_metadata.add(metadata).await?;
self.hot.add(message).await
}
pub async fn evict_expired(&self) -> Result<usize> {
let evicted = self.hot.delete_expired().await?;
if evicted > 0 {
tracing::info!(
evicted,
"TieredMemory: evicted {} expired message(s)",
evicted
);
}
Ok(evicted)
}
pub async fn record_access(&mut self, message_id: &str) -> Result<()> {
if let Some(mut meta) = self.tier_metadata.get(message_id).await? {
meta.record_access();
self.tier_metadata.update(meta).await?;
}
Ok(())
}
pub async fn search_adaptive(
&mut self,
query: &str,
conversation_id: Option<&str>,
) -> Result<Vec<TieredSearchResult>> {
let mut results = Vec::new();
let hot_results = if let Some(conv_id) = conversation_id {
self.hot.search_conversation(conv_id, query, 5, 0.6).await?
} else {
self.hot.search(query, 5, 0.6).await?
};
for (msg, score) in hot_results {
if let Some(exp) = msg.expires_at
&& exp <= Utc::now().timestamp()
{
continue;
}
let _ = self.record_access(&msg.message_id).await;
results.push(TieredSearchResult {
content: msg.content.clone(),
score,
tier: MemoryTier::Hot,
original_message_id: Some(msg.message_id.clone()),
metadata: Some(msg),
multi_factor_score: None,
});
}
if results.iter().any(|r| r.score > 0.85) {
return Ok(results);
}
let warm_results = if let Some(conv_id) = conversation_id {
self.warm
.search_conversation(conv_id, query, 3, 0.5)
.await?
} else {
self.warm.search(query, 3, 0.5).await?
};
for (summary, score) in warm_results {
results.push(TieredSearchResult {
content: summary.summary.clone(),
score,
tier: MemoryTier::Warm,
original_message_id: Some(summary.original_message_id.clone()),
metadata: None,
multi_factor_score: None,
});
}
if results.iter().all(|r| r.score < 0.7) {
let cold_results = if let Some(conv_id) = conversation_id {
self.cold
.search_conversation(conv_id, query, 3, 0.4)
.await?
} else {
self.cold.search(query, 3, 0.4).await?
};
for (fact, score) in cold_results {
results.push(TieredSearchResult {
content: fact.fact.clone(),
score,
tier: MemoryTier::Cold,
original_message_id: fact.original_message_ids.first().cloned(),
metadata: None,
multi_factor_score: None,
});
}
}
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(results)
}
pub async fn search_adaptive_multi_factor(
&mut self,
query: &str,
conversation_id: Option<&str>,
) -> Result<Vec<TieredSearchResult>> {
let mut results = self.search_adaptive(query, conversation_id).await?;
let ids: Vec<&str> = results
.iter()
.filter_map(|r| r.original_message_id.as_deref())
.collect();
let meta_map = self.tier_metadata.get_many(&ids).await.unwrap_or_default();
let now_secs = chrono::Utc::now().timestamp();
let temporal_factor = detect_temporal_query(query);
let use_fast_decay = self.config.fast_decay && temporal_factor > 0.0;
let extra_recency = self.config.temporal_boost * temporal_factor;
let rec_w = (RECENCY_WEIGHT + extra_recency).min(1.0);
let remaining = 1.0 - rec_w;
let sim_share = SIMILARITY_WEIGHT / (SIMILARITY_WEIGHT + IMPORTANCE_WEIGHT);
let sim_w = sim_share * remaining;
let imp_w = remaining - sim_w;
for result in &mut results {
let similarity = result.score;
let (recency, importance) = if let Some(id) = &result.original_message_id {
if let Some(meta) = meta_map.get(id.as_str()) {
let hours_since = (now_secs - meta.last_accessed).max(0) as f32 / 3600.0;
let rec = if use_fast_decay {
MultiFactorScore::recency_from_hours_fast(hours_since)
} else {
MultiFactorScore::recency_from_hours(hours_since)
};
(rec, meta.importance)
} else {
(1.0_f32, 0.5_f32) }
} else {
(1.0_f32, 0.5_f32)
};
result.multi_factor_score = Some(MultiFactorScore::compute_with_weights(
similarity, recency, importance, sim_w, rec_w, imp_w,
));
}
if let Ok(mm_results) = self.search_mental_models(query, 5).await {
for mut mm in mm_results {
mm.multi_factor_score = Some(MultiFactorScore::compute_with_weights(
mm.score, 1.0, 0.5, sim_w, rec_w, imp_w,
));
results.push(mm);
}
}
results.sort_by(|a, b| {
let sa = a
.multi_factor_score
.as_ref()
.map_or(a.score, |s| s.combined);
let sb = b
.multi_factor_score
.as_ref()
.map_or(b.score, |s| s.combined);
sb.partial_cmp(&sa).unwrap_or(std::cmp::Ordering::Equal)
});
Ok(results)
}
pub async fn demote_to_warm(
&mut self,
message_id: &str,
summary: MessageSummary,
) -> Result<()> {
if let Some(mut meta) = self.tier_metadata.get(message_id).await? {
meta.tier = MemoryTier::Warm;
self.tier_metadata.update(meta).await?;
}
self.warm.add(summary).await
}
pub async fn demote_to_cold(&mut self, summary_id: &str, fact: KeyFact) -> Result<()> {
self.warm.delete(summary_id).await?;
self.cold.add(fact).await
}
pub async fn promote_to_hot(&mut self, message_id: &str) -> Result<Option<MessageMetadata>> {
if let Some(mut meta) = self.tier_metadata.get(message_id).await? {
meta.tier = MemoryTier::Hot;
meta.record_access();
self.tier_metadata.update(meta).await?;
}
Ok(None)
}
pub async fn get_demotion_candidates(
&self,
tier: MemoryTier,
count: usize,
) -> Result<Vec<String>> {
let all_metadata = self.tier_metadata.get_by_tier(tier).await?;
let mut candidates: Vec<_> = all_metadata
.into_iter()
.map(|m| (m.message_id.clone(), m.retention_score()))
.collect();
candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(candidates
.into_iter()
.take(count)
.map(|(id, _)| id)
.collect())
}
pub async fn get_stats(&self) -> Result<TieredMemoryStats> {
let hot_count = self.tier_metadata.count_by_tier(MemoryTier::Hot).await?;
let warm_count = self.warm.count().await?;
let cold_count = self.cold.count().await?;
let mental_model_count = self.mental_model.count().await.unwrap_or(0);
let total_tracked = self.tier_metadata.count().await?;
Ok(TieredMemoryStats {
hot_count,
warm_count,
cold_count,
mental_model_count,
total_tracked,
})
}
pub fn fallback_summarize(&self, content: &str) -> String {
let words: Vec<&str> = content.split_whitespace().collect();
if words.len() <= 75 {
content.to_string()
} else {
format!("{}...", words[..75].join(" "))
}
}
pub fn fallback_fact(&self, summary: &MessageSummary) -> KeyFact {
KeyFact {
fact_id: Uuid::new_v4().to_string(),
original_message_ids: vec![summary.original_message_id.clone()],
conversation_id: summary.conversation_id.clone(),
fact: summary.summary.clone(),
fact_type: FactType::Other,
created_at: Utc::now().timestamp(),
}
}
pub async fn synthesize_mental_model(
&mut self,
fact_ids: &[String],
model_text: String,
model_type: ModelType,
conversation_id: String,
) -> Result<String> {
self.mental_model.ensure_table().await?;
let mut model =
MentalModel::new(model_text, model_type, conversation_id, fact_ids.to_vec());
model.evidence_count = fact_ids.len() as u32;
let id = model.model_id.clone();
self.mental_model.add(model).await?;
Ok(id)
}
pub async fn search_mental_models(
&self,
query: &str,
limit: usize,
) -> Result<Vec<TieredSearchResult>> {
let raw = self.mental_model.search(query, limit).await?;
Ok(raw
.into_iter()
.map(|(model, score)| TieredSearchResult {
content: model.model_text.clone(),
score,
tier: MemoryTier::MentalModel,
original_message_id: model.source_fact_ids.first().cloned(),
metadata: None,
multi_factor_score: None,
})
.collect())
}
}
#[derive(Debug, Clone)]
pub struct TieredMemoryStats {
pub hot_count: usize,
pub warm_count: usize,
pub cold_count: usize,
pub mental_model_count: usize,
pub total_tracked: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_multi_factor_score_weights_sum_to_one() {
let score = MultiFactorScore::compute(1.0, 1.0, 1.0);
assert!(
(score.combined - 1.0).abs() < 1e-6,
"all-one inputs should yield combined=1"
);
}
#[test]
fn test_multi_factor_score_zero_inputs() {
let score = MultiFactorScore::compute(0.0, 0.0, 0.0);
assert_eq!(score.combined, 0.0);
}
#[test]
fn test_recency_factor_fresh_entry() {
let r = MultiFactorScore::recency_from_hours(0.0);
assert!((r - 1.0).abs() < 1e-6);
}
#[test]
fn test_recency_factor_decays_over_time() {
let r_now = MultiFactorScore::recency_from_hours(0.0);
let r_day = MultiFactorScore::recency_from_hours(24.0);
let r_week = MultiFactorScore::recency_from_hours(168.0);
assert!(
r_now > r_day,
"fresh entry must score higher than 1-day-old"
);
assert!(
r_day > r_week,
"1-day-old must score higher than 1-week-old"
);
assert!(r_week > 0.0, "recency factor must remain positive");
}
#[test]
fn test_high_similarity_low_recency_can_be_beaten_by_balanced_entry() {
let stale =
MultiFactorScore::compute(0.95, MultiFactorScore::recency_from_hours(168.0), 0.0);
let fresh = MultiFactorScore::compute(0.70, MultiFactorScore::recency_from_hours(1.0), 0.9);
assert!(
fresh.combined > stale.combined,
"fresh important entry ({:.3}) should beat stale high-similarity entry ({:.3})",
fresh.combined,
stale.combined
);
}
#[test]
fn test_tier_demotion() {
assert_eq!(MemoryTier::Hot.demote(), Some(MemoryTier::Warm));
assert_eq!(MemoryTier::Warm.demote(), Some(MemoryTier::Cold));
assert_eq!(MemoryTier::Cold.demote(), Some(MemoryTier::MentalModel));
assert_eq!(MemoryTier::MentalModel.demote(), None);
}
#[test]
fn test_tier_promotion() {
assert_eq!(MemoryTier::Hot.promote(), None);
assert_eq!(MemoryTier::Warm.promote(), Some(MemoryTier::Hot));
assert_eq!(MemoryTier::Cold.promote(), Some(MemoryTier::Warm));
assert_eq!(MemoryTier::MentalModel.promote(), Some(MemoryTier::Cold));
}
#[test]
fn test_tier_metadata_retention_score() {
let mut meta = TierMetadata::new("test-1".to_string(), 0.8);
let score1 = meta.retention_score();
assert!(score1 > 0.0);
meta.record_access();
let score2 = meta.retention_score();
assert!(score2 >= score1 * 0.9); }
#[test]
fn test_default_config() {
let config = TieredMemoryConfig::default();
assert_eq!(config.hot_retention_hours, 24);
assert_eq!(config.warm_retention_hours, 168);
assert!(config.hot_importance_threshold > 0.0);
assert!(config.session_ttl_secs.is_none());
}
#[test]
fn test_config_with_session_ttl() {
let config = TieredMemoryConfig {
session_ttl_secs: Some(3600),
..TieredMemoryConfig::default()
};
assert_eq!(config.session_ttl_secs, Some(3600));
}
#[test]
fn test_memory_authority_default() {
assert_eq!(MemoryAuthority::default(), MemoryAuthority::Session);
}
#[test]
fn test_memory_authority_round_trip() {
for auth in [
MemoryAuthority::Ephemeral,
MemoryAuthority::Session,
MemoryAuthority::Canonical,
] {
assert_eq!(MemoryAuthority::parse(auth.as_str()), auth);
}
}
#[test]
fn test_memory_authority_unknown_defaults_to_session() {
assert_eq!(MemoryAuthority::parse("bogus"), MemoryAuthority::Session);
}
#[test]
fn test_tier_metadata_default_authority() {
let meta = TierMetadata::new("m-1".to_string(), 0.5);
assert_eq!(meta.authority, MemoryAuthority::Session);
}
#[test]
fn test_tier_metadata_with_authority() {
let meta = TierMetadata::with_authority("m-2".to_string(), 0.9, MemoryAuthority::Canonical);
assert_eq!(meta.authority, MemoryAuthority::Canonical);
assert_eq!(meta.importance, 0.9);
}
#[test]
fn test_canonical_write_token_is_crate_private() {
let _token = CanonicalWriteToken::new();
}
#[test]
fn test_detect_temporal_query_empty() {
assert_eq!(detect_temporal_query(""), 0.0);
}
#[test]
fn test_detect_temporal_query_no_keywords() {
assert_eq!(detect_temporal_query("how does authentication work?"), 0.0);
}
#[test]
fn test_detect_temporal_query_single_keyword() {
let score = detect_temporal_query("what is the latest approach?");
assert!(score > 0.0, "expected score > 0 for 'latest'");
}
#[test]
fn test_detect_temporal_query_dense() {
let score = detect_temporal_query("what was the latest change today?");
assert!(score > 0.0);
}
#[test]
fn test_detect_temporal_query_max_clamp() {
let score = detect_temporal_query("recent latest current today now new");
assert!(score <= 1.0, "score must not exceed 1.0");
assert!(score > 0.0);
}
#[test]
fn test_compute_with_weights_sum_normalised() {
let sim_w = 0.4_f32;
let rec_w = 0.4_f32;
let imp_w = 0.2_f32;
let score = MultiFactorScore::compute_with_weights(0.8, 0.9, 0.6, sim_w, rec_w, imp_w);
let expected = (0.8 * sim_w + 0.9 * rec_w + 0.6 * imp_w).clamp(0.0, 1.0);
assert!((score.combined - expected).abs() < 1e-5);
}
#[test]
fn test_compute_with_weights_matches_compute_for_default_weights() {
let a = MultiFactorScore::compute(0.7, 0.8, 0.5);
let b = MultiFactorScore::compute_with_weights(
0.7,
0.8,
0.5,
SIMILARITY_WEIGHT,
RECENCY_WEIGHT,
IMPORTANCE_WEIGHT,
);
assert!((a.combined - b.combined).abs() < 1e-5);
}
#[test]
fn test_temporal_config_defaults() {
let cfg = TieredMemoryConfig::default();
assert_eq!(cfg.temporal_boost, 0.3);
assert!(!cfg.fast_decay);
}
#[test]
fn test_fast_decay_rate_higher_than_normal() {
let hours = 48.0_f32;
let normal = MultiFactorScore::recency_from_hours(hours);
let fast = MultiFactorScore::recency_from_hours_fast(hours);
assert!(
fast < normal,
"fast decay should produce lower recency for old items"
);
}
}