use crate::core::db::Database;
use crate::core::error::Result;
use crate::core::types::{DecayStats, Sector};
use crate::utils::now_ms;
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MemoryTier {
Hot,
Warm,
Cold,
}
impl MemoryTier {
pub fn decay_lambda(&self) -> f64 {
match self {
MemoryTier::Hot => 0.005,
MemoryTier::Warm => 0.02,
MemoryTier::Cold => 0.05,
}
}
}
#[derive(Debug, Clone)]
pub struct DecayConfig {
pub decay_lambda: f64,
pub decay_ratio: f64,
pub cold_threshold: f64,
pub reinforce_on_query: bool,
pub alpha_reinforce: f64,
pub max_salience: f64,
}
impl Default for DecayConfig {
fn default() -> Self {
Self {
decay_lambda: 0.02,
decay_ratio: 0.03,
cold_threshold: 0.25,
reinforce_on_query: true,
alpha_reinforce: 0.08,
max_salience: 1.0,
}
}
}
pub struct DecayEngine {
db: Arc<Database>,
config: DecayConfig,
}
impl DecayEngine {
pub fn new(db: Arc<Database>, config: DecayConfig) -> Self {
Self { db, config }
}
pub fn with_defaults(db: Arc<Database>) -> Self {
Self::new(db, DecayConfig::default())
}
pub fn run_decay(&self) -> Result<DecayStats> {
let start = std::time::Instant::now();
let now = now_ms();
let mut processed = 0;
let mut decayed = 0;
let mut compressed = 0;
let batch_size = 100;
let mut offset = 0;
loop {
let memories = self.db.get_all_memories(batch_size, offset)?;
if memories.is_empty() {
break;
}
for mem in &memories {
processed += 1;
let days_since = (now - mem.last_seen_at) as f64 / (1000.0 * 60.0 * 60.0 * 24.0);
let coactivations = 0;
let tier = self.classify_tier(mem.salience, days_since, coactivations);
let lambda = self.effective_lambda(&mem.primary_sector, tier);
let new_salience = self.compute_decay(mem.salience, lambda, days_since, coactivations);
if (new_salience - mem.salience).abs() > 0.001 {
self.db.update_memory_seen(
&mem.id,
mem.last_seen_at,
new_salience,
now,
)?;
decayed += 1;
let decay_factor = new_salience / mem.salience.max(0.001);
if tier == MemoryTier::Cold && decay_factor < 0.7 {
compressed += 1;
}
}
}
offset += batch_size;
}
self.db.log_maintenance("decay", processed as i32)?;
Ok(DecayStats {
processed,
decayed,
compressed,
duration_ms: start.elapsed().as_millis() as u64,
})
}
fn classify_tier(&self, salience: f64, days_since: f64, coactivations: i32) -> MemoryTier {
let recent = days_since < 6.0;
let high = coactivations > 5 || salience > 0.7;
if recent && high {
MemoryTier::Hot
} else if recent || salience > 0.4 {
MemoryTier::Warm
} else {
MemoryTier::Cold
}
}
fn effective_lambda(&self, sector: &Sector, tier: MemoryTier) -> f64 {
let sector_lambda = sector.default_decay_lambda();
let tier_lambda = tier.decay_lambda();
(sector_lambda + tier_lambda) / 2.0
}
fn compute_decay(
&self,
initial: f64,
lambda: f64,
days: f64,
coactivations: i32,
) -> f64 {
let act = coactivations.max(0) as f64;
let boosted = (initial * (1.0 + (1.0 + act).ln())).clamp(0.0, 1.0);
let f = (-lambda * (days / (boosted + 0.1))).exp();
(boosted * f).clamp(0.0, self.config.max_salience)
}
pub fn reinforce(&self, id: &str, boost: Option<f64>) -> Result<()> {
if !self.config.reinforce_on_query {
return Ok(());
}
let mem = match self.db.get_memory(id)? {
Some(m) => m,
None => return Ok(()),
};
let now = now_ms();
let days_since = (now - mem.last_seen_at) as f64 / (1000.0 * 60.0 * 60.0 * 24.0);
let alpha = boost.unwrap_or(self.config.alpha_reinforce);
let reinforcement = alpha * (1.0 - (-mem.decay_lambda * days_since).exp());
let new_salience = (mem.salience + reinforcement).min(self.config.max_salience);
self.db.update_memory_seen(id, now, new_salience, now)?;
Ok(())
}
pub fn on_query_hit(&self, id: &str) -> Result<()> {
self.reinforce(id, Some(self.config.alpha_reinforce))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_decay_config_default() {
let config = DecayConfig::default();
assert!((config.decay_lambda - 0.02).abs() < 1e-6);
assert!((config.max_salience - 1.0).abs() < 1e-6);
}
#[test]
fn test_memory_tier_lambda() {
assert!(MemoryTier::Hot.decay_lambda() < MemoryTier::Warm.decay_lambda());
assert!(MemoryTier::Warm.decay_lambda() < MemoryTier::Cold.decay_lambda());
}
#[test]
fn test_compute_decay() {
let _config = DecayConfig::default();
let initial: f64 = 1.0;
let lambda: f64 = 0.02;
let days: f64 = 30.0;
let decayed = initial * (-lambda * days).exp();
assert!(decayed < initial);
assert!(decayed > 0.0);
assert!((decayed - 0.55).abs() < 0.1);
}
#[test]
fn test_tier_classification() {
assert!(matches_tier(0.5, 3.0, 10, "Hot"));
assert!(matches_tier(0.8, 3.0, 0, "Hot"));
assert!(matches_tier(0.3, 3.0, 0, "Warm"));
assert!(matches_tier(0.5, 10.0, 0, "Warm"));
assert!(matches_tier(0.2, 10.0, 0, "Cold"));
}
fn matches_tier(salience: f64, days: f64, coactivations: i32, expected: &str) -> bool {
let recent = days < 6.0;
let high = coactivations > 5 || salience > 0.7;
let tier = if recent && high {
"Hot"
} else if recent || salience > 0.4 {
"Warm"
} else {
"Cold"
};
tier == expected
}
#[test]
fn test_activity_boost() {
let initial: f64 = 0.5;
let boosted_0 = initial * (1.0 + (1.0_f64).ln());
assert!((boosted_0 - 0.5).abs() < 0.001);
let boosted_10 = (initial * (1.0 + (11.0_f64).ln())).min(1.0);
assert!(boosted_10 > boosted_0);
assert!((boosted_10 - 1.0).abs() < 0.001);
}
}