codex_memory/memory/
simple_consolidation.rs

1//! Simple Memory Consolidation Engine
2//!
3//! This module implements the simplified consolidation formulas from Story 2:
4//! - P(recall) = r × exp(-g × t / (1 + n)) × cos_similarity
5//! - Consolidation update: gn = gn-1 + (1 - e^-t)/(1 + e^-t)
6//!
7//! This is intentionally simpler than the complex cognitive consolidation
8//! implementation and focuses on fast, efficient batch processing.
9
10use super::error::{MemoryError, Result};
11use super::models::*;
12use chrono::Utc;
13use pgvector::Vector;
14use serde::{Deserialize, Serialize};
15use std::time::Instant;
16use tracing::{debug, info, warn};
17use uuid::Uuid;
18
19/// Configuration for simple consolidation engine
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct SimpleConsolidationConfig {
22    /// Base recall strength (r parameter)
23    pub base_recall_strength: f64,
24
25    /// Migration threshold for recall probability
26    pub migration_threshold: f64,
27
28    /// Maximum consolidation strength
29    pub max_consolidation_strength: f64,
30
31    /// Time scaling factor (hours to normalized units)
32    pub time_scale_factor: f64,
33}
34
35impl Default for SimpleConsolidationConfig {
36    fn default() -> Self {
37        use super::math_engine::constants;
38        Self {
39            base_recall_strength: 0.95,
40            migration_threshold: constants::COLD_MIGRATION_THRESHOLD,
41            max_consolidation_strength: 10.0,
42            time_scale_factor: 0.1, // Much slower time decay
43        }
44    }
45}
46
47/// Result of simple consolidation calculation
48#[derive(Debug, Clone)]
49pub struct SimpleConsolidationResult {
50    pub new_consolidation_strength: f64,
51    pub recall_probability: f64,
52    pub should_migrate: bool,
53    pub calculation_time_ms: u64,
54    pub time_since_access_hours: f64,
55}
56
57/// Simple consolidation engine for fast batch processing
58pub struct SimpleConsolidationEngine {
59    config: SimpleConsolidationConfig,
60}
61
62impl SimpleConsolidationEngine {
63    pub fn new(config: SimpleConsolidationConfig) -> Self {
64        Self { config }
65    }
66
67    /// Calculate recall probability using Story 2 formula:
68    /// P(recall) = r × exp(-g × t / (1 + n)) × cos_similarity
69    pub fn calculate_recall_probability(
70        &self,
71        memory: &Memory,
72        cos_similarity: Option<f64>,
73    ) -> Result<f64> {
74        let r = self.config.base_recall_strength;
75        let g = memory.consolidation_strength;
76        let n = memory.recall_count() as f64;
77
78        // Calculate time since last access in hours
79        let current_time = Utc::now();
80        let last_access = memory.last_accessed_at.unwrap_or(memory.created_at);
81        let t = current_time
82            .signed_duration_since(last_access)
83            .num_seconds() as f64
84            / 3600.0;
85
86        // Normalize time with scaling factor
87        let t_normalized = t * self.config.time_scale_factor;
88
89        // Base recall calculation: r × exp(-g × t / (1 + n))
90        let base_recall = r * (-g * t_normalized / (1.0 + n)).exp();
91
92        // Apply cosine similarity if available (Story 2 formula: direct multiplication)
93        let similarity_factor = cos_similarity.unwrap_or(1.0);
94
95        // Story 2 formula: P(recall) = r × exp(-g × t / (1 + n)) × cos_similarity
96        let recall_probability = base_recall * similarity_factor;
97
98        // Ensure bounds [0, 1]
99        Ok(recall_probability.max(0.0).min(1.0))
100    }
101
102    /// Update consolidation strength using Story 2 formula:
103    /// gn = gn-1 + (1 - e^-t)/(1 + e^-t)
104    pub fn update_consolidation_strength(
105        &self,
106        current_strength: f64,
107        time_since_access_hours: f64,
108    ) -> Result<f64> {
109        let t = time_since_access_hours * self.config.time_scale_factor;
110
111        // Calculate strength increment: (1 - e^-t)/(1 + e^-t)
112        let exp_neg_t = (-t).exp();
113        let increment = (1.0 - exp_neg_t) / (1.0 + exp_neg_t);
114
115        let new_strength = current_strength + increment;
116
117        // Apply bounds
118        Ok(new_strength
119            .max(0.1)
120            .min(self.config.max_consolidation_strength))
121    }
122
123    /// Process consolidation for a single memory
124    pub fn process_memory_consolidation(
125        &self,
126        memory: &Memory,
127        cos_similarity: Option<f64>,
128    ) -> Result<SimpleConsolidationResult> {
129        let start_time = Instant::now();
130
131        // Calculate time since last access
132        let current_time = Utc::now();
133        let last_access = memory.last_accessed_at.unwrap_or(memory.created_at);
134        let time_since_access_hours = current_time
135            .signed_duration_since(last_access)
136            .num_seconds() as f64
137            / 3600.0;
138
139        // Update consolidation strength
140        let new_consolidation_strength = self.update_consolidation_strength(
141            memory.consolidation_strength,
142            time_since_access_hours,
143        )?;
144
145        // Create temporary memory with updated strength for recall calculation
146        let mut updated_memory = memory.clone();
147        updated_memory.consolidation_strength = new_consolidation_strength;
148
149        // Calculate recall probability
150        let recall_probability =
151            self.calculate_recall_probability(&updated_memory, cos_similarity)?;
152
153        // Check if migration is needed
154        let should_migrate = recall_probability < self.config.migration_threshold;
155
156        let calculation_time = start_time.elapsed().as_millis() as u64;
157
158        Ok(SimpleConsolidationResult {
159            new_consolidation_strength,
160            recall_probability,
161            should_migrate,
162            calculation_time_ms: calculation_time,
163            time_since_access_hours,
164        })
165    }
166
167    /// Process consolidation for a batch of memories
168    /// Target: Process 1000 memories in < 1 second
169    pub fn process_batch_consolidation(
170        &self,
171        memories: &[Memory],
172        similarities: Option<&[f64]>,
173    ) -> Result<Vec<SimpleConsolidationResult>> {
174        let start_time = Instant::now();
175        let mut results = Vec::with_capacity(memories.len());
176
177        for (i, memory) in memories.iter().enumerate() {
178            let cos_similarity = similarities.and_then(|sims| sims.get(i)).copied();
179
180            match self.process_memory_consolidation(memory, cos_similarity) {
181                Ok(result) => results.push(result),
182                Err(e) => {
183                    warn!(
184                        "Failed to process consolidation for memory {}: {}",
185                        memory.id, e
186                    );
187                    // Continue processing other memories
188                }
189            }
190        }
191
192        let total_time = start_time.elapsed();
193        info!(
194            "Processed {} memories in {:.3}s ({:.1} memories/sec)",
195            results.len(),
196            total_time.as_secs_f64(),
197            results.len() as f64 / total_time.as_secs_f64()
198        );
199
200        Ok(results)
201    }
202
203    /// Calculate cosine similarity between two vectors (helper function)
204    pub fn calculate_cosine_similarity(&self, vec1: &Vector, vec2: &Vector) -> Result<f64> {
205        let slice1 = vec1.as_slice();
206        let slice2 = vec2.as_slice();
207
208        if slice1.len() != slice2.len() {
209            return Err(MemoryError::InvalidRequest {
210                message: "Vector dimensions must match for similarity calculation".to_string(),
211            });
212        }
213
214        let dot_product: f64 = slice1
215            .iter()
216            .zip(slice2.iter())
217            .map(|(a, b)| (*a as f64) * (*b as f64))
218            .sum();
219
220        let norm1: f64 = slice1
221            .iter()
222            .map(|x| (*x as f64).powi(2))
223            .sum::<f64>()
224            .sqrt();
225        let norm2: f64 = slice2
226            .iter()
227            .map(|x| (*x as f64).powi(2))
228            .sum::<f64>()
229            .sqrt();
230
231        if norm1 == 0.0 || norm2 == 0.0 {
232            return Ok(0.0);
233        }
234
235        Ok(dot_product / (norm1 * norm2))
236    }
237
238    /// Get migration candidates based on recall probability threshold
239    pub fn get_migration_candidates(
240        &self,
241        memories: &[Memory],
242    ) -> Result<Vec<(usize, MemoryTier)>> {
243        let mut candidates = Vec::new();
244
245        for (i, memory) in memories.iter().enumerate() {
246            let result = self.process_memory_consolidation(memory, None)?;
247
248            if result.should_migrate {
249                if let Some(next_tier) = memory.next_tier() {
250                    candidates.push((i, next_tier));
251                }
252            }
253        }
254
255        Ok(candidates)
256    }
257}
258
259/// Background consolidation processor for efficient batch processing
260pub struct ConsolidationProcessor {
261    engine: SimpleConsolidationEngine,
262    batch_size: usize,
263}
264
265impl ConsolidationProcessor {
266    pub fn new(config: SimpleConsolidationConfig, batch_size: usize) -> Self {
267        Self {
268            engine: SimpleConsolidationEngine::new(config),
269            batch_size,
270        }
271    }
272
273    /// Process memories in batches to maintain performance targets
274    pub async fn process_consolidation_batch(
275        &self,
276        repository: &crate::memory::repository::MemoryRepository,
277        tier: Option<MemoryTier>,
278    ) -> Result<ConsolidationBatchResult> {
279        let start_time = Instant::now();
280
281        // Get memories to process
282        let memories = self.get_memories_for_processing(repository, tier).await?;
283
284        if memories.is_empty() {
285            debug!("No memories found for consolidation processing");
286            return Ok(ConsolidationBatchResult::default());
287        }
288
289        debug!("Processing consolidation for {} memories", memories.len());
290
291        // Process in batches
292        let mut processed_count = 0;
293        let mut migration_candidates = Vec::new();
294        let mut consolidation_updates = Vec::new();
295
296        for chunk in memories.chunks(self.batch_size) {
297            let results = self.engine.process_batch_consolidation(chunk, None)?;
298
299            for (memory, result) in chunk.iter().zip(results.iter()) {
300                processed_count += 1;
301
302                // Collect migration candidates
303                if result.should_migrate {
304                    if let Some(next_tier) = memory.next_tier() {
305                        migration_candidates.push((memory.id, next_tier));
306                    }
307                }
308
309                // Collect consolidation updates
310                consolidation_updates.push((
311                    memory.id,
312                    result.new_consolidation_strength,
313                    result.recall_probability,
314                ));
315            }
316        }
317
318        let total_time = start_time.elapsed();
319
320        Ok(ConsolidationBatchResult {
321            processed_count,
322            migration_candidates,
323            consolidation_updates,
324            processing_time_ms: total_time.as_millis() as u64,
325        })
326    }
327
328    async fn get_memories_for_processing(
329        &self,
330        repository: &crate::memory::repository::MemoryRepository,
331        tier: Option<MemoryTier>,
332    ) -> Result<Vec<Memory>> {
333        // Use parameterized query to prevent SQL injection
334        let base_query = "SELECT * FROM memories WHERE status = 'active' AND (last_accessed_at IS NULL OR last_accessed_at < NOW() - INTERVAL '1 hour')";
335
336        let (query, has_tier) = if let Some(_tier) = tier {
337            let query_with_tier = format!(
338                "{base_query} AND tier = $2 ORDER BY last_accessed_at ASC NULLS FIRST LIMIT $1"
339            );
340            (query_with_tier, true)
341        } else {
342            let query_no_tier =
343                format!("{base_query} ORDER BY last_accessed_at ASC NULLS FIRST LIMIT $1");
344            (query_no_tier, false)
345        };
346
347        let mut sqlx_query = sqlx::query_as::<_, Memory>(&query).bind(self.batch_size as i64);
348
349        if has_tier {
350            if let Some(tier) = tier {
351                sqlx_query = sqlx_query.bind(tier);
352            }
353        }
354
355        let memories = sqlx_query.fetch_all(repository.pool()).await?;
356        Ok(memories)
357    }
358}
359
360/// Result of batch consolidation processing
361#[derive(Debug, Clone, Default)]
362pub struct ConsolidationBatchResult {
363    pub processed_count: usize,
364    pub migration_candidates: Vec<(Uuid, MemoryTier)>,
365    pub consolidation_updates: Vec<(Uuid, f64, f64)>, // (id, new_strength, recall_prob)
366    pub processing_time_ms: u64,
367}
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372    use chrono::Duration;
373
374    fn create_test_memory() -> Memory {
375        let mut memory = Memory::default();
376        memory.consolidation_strength = 2.0;
377        memory.access_count = 5;
378        memory.last_accessed_at = Some(Utc::now() - Duration::hours(2));
379        memory.importance_score = 0.7;
380        memory
381    }
382
383    #[test]
384    fn test_recall_probability_calculation() {
385        let engine = SimpleConsolidationEngine::new(SimpleConsolidationConfig::default());
386        let memory = create_test_memory();
387
388        let recall_prob = engine
389            .calculate_recall_probability(&memory, Some(0.8))
390            .unwrap();
391
392        assert!((0.0..=1.0).contains(&recall_prob));
393        assert!(recall_prob > 0.0); // Should have some recall probability
394    }
395
396    #[test]
397    fn test_consolidation_strength_update() {
398        let engine = SimpleConsolidationEngine::new(SimpleConsolidationConfig::default());
399
400        let new_strength = engine.update_consolidation_strength(2.0, 1.0).unwrap();
401
402        assert!(new_strength > 2.0); // Should increase
403        assert!(new_strength <= 10.0); // Should respect max bound
404    }
405
406    #[test]
407    fn test_migration_threshold() {
408        let mut config = SimpleConsolidationConfig::default();
409        config.migration_threshold = 0.5;
410
411        let engine = SimpleConsolidationEngine::new(config);
412        let memory = create_test_memory();
413
414        let result = engine.process_memory_consolidation(&memory, None).unwrap();
415
416        // Migration decision should be based on recall probability vs threshold
417        assert_eq!(result.should_migrate, result.recall_probability < 0.5);
418    }
419
420    #[test]
421    fn test_batch_processing_performance() {
422        let engine = SimpleConsolidationEngine::new(SimpleConsolidationConfig::default());
423
424        // Create test batch
425        let memories: Vec<Memory> = (0..100).map(|_| create_test_memory()).collect();
426
427        let start = Instant::now();
428        let results = engine.process_batch_consolidation(&memories, None).unwrap();
429        let duration = start.elapsed();
430
431        assert_eq!(results.len(), 100);
432        assert!(duration.as_millis() < 100); // Should be fast for 100 memories
433    }
434}