openmemory 0.1.1

OpenMemory - Cognitive memory system for AI applications
Documentation
//! Memory decay and reinforcement system
//!
//! Implements time-based decay of memory salience and reinforcement
//! on access. Uses a three-tier system (hot, warm, cold).

use crate::core::db::Database;
use crate::core::error::Result;
use crate::core::types::{DecayStats, Sector};
use crate::utils::now_ms;
use std::sync::Arc;

/// Memory temperature tiers
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MemoryTier {
    /// Recently accessed, high activity
    Hot,
    /// Medium activity
    Warm,
    /// Old and low salience
    Cold,
}

impl MemoryTier {
    /// Get decay lambda for this tier
    pub fn decay_lambda(&self) -> f64 {
        match self {
            MemoryTier::Hot => 0.005,
            MemoryTier::Warm => 0.02,
            MemoryTier::Cold => 0.05,
        }
    }
}

/// Decay engine configuration
#[derive(Debug, Clone)]
pub struct DecayConfig {
    /// Base decay lambda
    pub decay_lambda: f64,
    /// Decay ratio per interval
    pub decay_ratio: f64,
    /// Cold memory threshold
    pub cold_threshold: f64,
    /// Reinforce on query
    pub reinforce_on_query: bool,
    /// Alpha for reinforcement
    pub alpha_reinforce: f64,
    /// Maximum salience value
    pub max_salience: f64,
}

impl Default for DecayConfig {
    fn default() -> Self {
        Self {
            decay_lambda: 0.02,
            decay_ratio: 0.03,
            cold_threshold: 0.25,
            reinforce_on_query: true,
            alpha_reinforce: 0.08,
            max_salience: 1.0,
        }
    }
}

/// Memory decay engine
pub struct DecayEngine {
    db: Arc<Database>,
    config: DecayConfig,
}

impl DecayEngine {
    /// Create a new decay engine
    pub fn new(db: Arc<Database>, config: DecayConfig) -> Self {
        Self { db, config }
    }

    /// Create with default config
    pub fn with_defaults(db: Arc<Database>) -> Self {
        Self::new(db, DecayConfig::default())
    }

    /// Run decay on all memories
    pub fn run_decay(&self) -> Result<DecayStats> {
        let start = std::time::Instant::now();
        let now = now_ms();

        let mut processed = 0;
        let mut decayed = 0;
        let mut compressed = 0;

        // Process memories in batches
        let batch_size = 100;
        let mut offset = 0;

        loop {
            let memories = self.db.get_all_memories(batch_size, offset)?;
            if memories.is_empty() {
                break;
            }

            for mem in &memories {
                processed += 1;

                // Calculate days since last seen
                let days_since = (now - mem.last_seen_at) as f64 / (1000.0 * 60.0 * 60.0 * 24.0);

                // TODO: coactivations field is not yet in DB schema (JS SDK bug too)
                // Default to 0 for now, which means no activity boost
                let coactivations = 0;

                // Determine tier (matches JS: 6 days threshold, coactivations check)
                let tier = self.classify_tier(mem.salience, days_since, coactivations);

                // Apply decay with activity boost
                let lambda = self.effective_lambda(&mem.primary_sector, tier);
                let new_salience = self.compute_decay(mem.salience, lambda, days_since, coactivations);

                if (new_salience - mem.salience).abs() > 0.001 {
                    self.db.update_memory_seen(
                        &mem.id,
                        mem.last_seen_at,
                        new_salience,
                        now,
                    )?;
                    decayed += 1;

                    // Compress cold memories when decay factor < 0.7
                    let decay_factor = new_salience / mem.salience.max(0.001);
                    if tier == MemoryTier::Cold && decay_factor < 0.7 {
                        compressed += 1;
                        // TODO: Implement vector compression (compress_vector in JS)
                    }
                }
            }

            offset += batch_size;
        }

        // Log maintenance operation
        self.db.log_maintenance("decay", processed as i32)?;

        Ok(DecayStats {
            processed,
            decayed,
            compressed,
            duration_ms: start.elapsed().as_millis() as u64,
        })
    }

    /// Classify memory into temperature tier
    ///
    /// Matches JS implementation:
    /// - Hot: recent (< 6 days) AND (high coactivations > 5 OR salience > 0.7)
    /// - Warm: recent (< 6 days) OR salience > 0.4
    /// - Cold: otherwise
    fn classify_tier(&self, salience: f64, days_since: f64, coactivations: i32) -> MemoryTier {
        let recent = days_since < 6.0;
        let high = coactivations > 5 || salience > 0.7;

        if recent && high {
            MemoryTier::Hot
        } else if recent || salience > 0.4 {
            MemoryTier::Warm
        } else {
            MemoryTier::Cold
        }
    }

    /// Get effective lambda considering sector and tier
    fn effective_lambda(&self, sector: &Sector, tier: MemoryTier) -> f64 {
        let sector_lambda = sector.default_decay_lambda();
        let tier_lambda = tier.decay_lambda();

        // Blend sector and tier lambdas
        (sector_lambda + tier_lambda) / 2.0
    }

    /// Compute decayed salience
    ///
    /// Matches JS implementation:
    /// - Activity boost: salience * (1 + ln(1 + coactivations))
    /// - Decay factor: exp(-lambda * (days / (salience + 0.1)))
    /// - Final: boosted_salience * decay_factor
    fn compute_decay(
        &self,
        initial: f64,
        lambda: f64,
        days: f64,
        coactivations: i32,
    ) -> f64 {
        // Activity boost: higher coactivation count preserves memory longer
        let act = coactivations.max(0) as f64;
        let boosted = (initial * (1.0 + (1.0 + act).ln())).clamp(0.0, 1.0);

        // Decay factor: divide by (salience + 0.1) so higher salience decays slower
        let f = (-lambda * (days / (boosted + 0.1))).exp();

        (boosted * f).clamp(0.0, self.config.max_salience)
    }

    /// Reinforce a memory (called on access)
    pub fn reinforce(&self, id: &str, boost: Option<f64>) -> Result<()> {
        if !self.config.reinforce_on_query {
            return Ok(());
        }

        let mem = match self.db.get_memory(id)? {
            Some(m) => m,
            None => return Ok(()),
        };

        let now = now_ms();
        let days_since = (now - mem.last_seen_at) as f64 / (1000.0 * 60.0 * 60.0 * 24.0);

        // Calculate reinforcement
        let alpha = boost.unwrap_or(self.config.alpha_reinforce);
        let reinforcement = alpha * (1.0 - (-mem.decay_lambda * days_since).exp());

        let new_salience = (mem.salience + reinforcement).min(self.config.max_salience);

        self.db.update_memory_seen(id, now, new_salience, now)?;

        Ok(())
    }

    /// Increment query counter for a memory
    pub fn on_query_hit(&self, id: &str) -> Result<()> {
        self.reinforce(id, Some(self.config.alpha_reinforce))
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_decay_config_default() {
        let config = DecayConfig::default();
        assert!((config.decay_lambda - 0.02).abs() < 1e-6);
        assert!((config.max_salience - 1.0).abs() < 1e-6);
    }

    #[test]
    fn test_memory_tier_lambda() {
        assert!(MemoryTier::Hot.decay_lambda() < MemoryTier::Warm.decay_lambda());
        assert!(MemoryTier::Warm.decay_lambda() < MemoryTier::Cold.decay_lambda());
    }

    #[test]
    fn test_compute_decay() {
        let _config = DecayConfig::default();
        // Create a mock engine (would need mock db in real test)

        // Test decay formula directly
        let initial: f64 = 1.0;
        let lambda: f64 = 0.02;
        let days: f64 = 30.0;

        let decayed = initial * (-lambda * days).exp();
        assert!(decayed < initial);
        assert!(decayed > 0.0);

        // After 30 days with lambda=0.02: e^(-0.6) ≈ 0.55
        assert!((decayed - 0.55).abs() < 0.1);
    }

    #[test]
    fn test_tier_classification() {
        // Test classification logic directly without engine

        // Hot: recent (< 6 days) AND (coactivations > 5 OR salience > 0.7)
        // recent + high coactivations -> Hot
        assert!(matches_tier(0.5, 3.0, 10, "Hot"));
        // recent + high salience -> Hot
        assert!(matches_tier(0.8, 3.0, 0, "Hot"));

        // Warm: recent (< 6 days) OR salience > 0.4
        // recent but low activity/salience -> Warm
        assert!(matches_tier(0.3, 3.0, 0, "Warm"));
        // old but medium salience -> Warm
        assert!(matches_tier(0.5, 10.0, 0, "Warm"));

        // Cold: old AND low salience
        assert!(matches_tier(0.2, 10.0, 0, "Cold"));
    }

    fn matches_tier(salience: f64, days: f64, coactivations: i32, expected: &str) -> bool {
        let recent = days < 6.0;
        let high = coactivations > 5 || salience > 0.7;

        let tier = if recent && high {
            "Hot"
        } else if recent || salience > 0.4 {
            "Warm"
        } else {
            "Cold"
        };

        tier == expected
    }

    #[test]
    fn test_activity_boost() {
        // Test activity boost formula: salience * (1 + ln(1 + coactivations))
        let initial: f64 = 0.5;

        // No coactivations: boost = 1 + ln(1) = 1 + 0 = 1
        let boosted_0 = initial * (1.0 + (1.0_f64).ln());
        assert!((boosted_0 - 0.5).abs() < 0.001);

        // 10 coactivations: boost = 1 + ln(11) ≈ 1 + 2.4 = 3.4
        let boosted_10 = (initial * (1.0 + (11.0_f64).ln())).min(1.0);
        assert!(boosted_10 > boosted_0);
        // boosted would be 0.5 * 3.4 = 1.7, clamped to 1.0
        assert!((boosted_10 - 1.0).abs() < 0.001);
    }
}