codex_memory/memory/
simple_consolidation.rs

1//! Simple Memory Consolidation Engine
2//!
3//! This module implements the simplified consolidation formulas from Story 2:
4//! - P(recall) = r × exp(-g × t / (1 + n)) × cos_similarity
5//! - Consolidation update: gn = gn-1 + (1 - e^-t)/(1 + e^-t)
6//!
7//! This is intentionally simpler than the complex cognitive consolidation
8//! implementation and focuses on fast, efficient batch processing.
9
10use super::error::{MemoryError, Result};
11use super::models::*;
12use chrono::Utc;
13use pgvector::Vector;
14use serde::{Deserialize, Serialize};
15use std::time::Instant;
16use tracing::{debug, info, warn};
17use uuid::Uuid;
18
19/// Configuration for simple consolidation engine
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct SimpleConsolidationConfig {
22    /// Base recall strength (r parameter)
23    pub base_recall_strength: f64,
24
25    /// Migration threshold for recall probability
26    pub migration_threshold: f64,
27
28    /// Maximum consolidation strength
29    pub max_consolidation_strength: f64,
30
31    /// Similarity weight in recall calculation
32    pub similarity_weight: f64,
33
34    /// Time scaling factor (hours to normalized units)
35    pub time_scale_factor: f64,
36}
37
38impl Default for SimpleConsolidationConfig {
39    fn default() -> Self {
40        Self {
41            base_recall_strength: 0.95,
42            migration_threshold: 0.86,
43            max_consolidation_strength: 10.0,
44            similarity_weight: 0.1,
45            time_scale_factor: 0.1, // Much slower time decay
46        }
47    }
48}
49
50/// Result of simple consolidation calculation
51#[derive(Debug, Clone)]
52pub struct SimpleConsolidationResult {
53    pub new_consolidation_strength: f64,
54    pub recall_probability: f64,
55    pub should_migrate: bool,
56    pub calculation_time_ms: u64,
57    pub time_since_access_hours: f64,
58}
59
60/// Simple consolidation engine for fast batch processing
61pub struct SimpleConsolidationEngine {
62    config: SimpleConsolidationConfig,
63}
64
65impl SimpleConsolidationEngine {
66    pub fn new(config: SimpleConsolidationConfig) -> Self {
67        Self { config }
68    }
69
70    /// Calculate recall probability using Story 2 formula:
71    /// P(recall) = r × exp(-g × t / (1 + n)) × cos_similarity
72    pub fn calculate_recall_probability(
73        &self,
74        memory: &Memory,
75        cos_similarity: Option<f64>,
76    ) -> Result<f64> {
77        let r = self.config.base_recall_strength;
78        let g = memory.consolidation_strength;
79        let n = memory.recall_count() as f64;
80
81        // Calculate time since last access in hours
82        let current_time = Utc::now();
83        let last_access = memory.last_accessed_at.unwrap_or(memory.created_at);
84        let t = current_time
85            .signed_duration_since(last_access)
86            .num_seconds() as f64
87            / 3600.0;
88
89        // Normalize time with scaling factor
90        let t_normalized = t * self.config.time_scale_factor;
91
92        // Base recall calculation: r × exp(-g × t / (1 + n))
93        let base_recall = r * (-g * t_normalized / (1.0 + n)).exp();
94
95        // Apply cosine similarity if available
96        let similarity_factor = cos_similarity.unwrap_or(1.0);
97        let weighted_similarity = 1.0 + (similarity_factor - 1.0) * self.config.similarity_weight;
98
99        let recall_probability = base_recall * weighted_similarity;
100
101        // Ensure bounds [0, 1]
102        Ok(recall_probability.max(0.0).min(1.0))
103    }
104
105    /// Update consolidation strength using Story 2 formula:
106    /// gn = gn-1 + (1 - e^-t)/(1 + e^-t)
107    pub fn update_consolidation_strength(
108        &self,
109        current_strength: f64,
110        time_since_access_hours: f64,
111    ) -> Result<f64> {
112        let t = time_since_access_hours * self.config.time_scale_factor;
113
114        // Calculate strength increment: (1 - e^-t)/(1 + e^-t)
115        let exp_neg_t = (-t).exp();
116        let increment = (1.0 - exp_neg_t) / (1.0 + exp_neg_t);
117
118        let new_strength = current_strength + increment;
119
120        // Apply bounds
121        Ok(new_strength
122            .max(0.1)
123            .min(self.config.max_consolidation_strength))
124    }
125
126    /// Process consolidation for a single memory
127    pub fn process_memory_consolidation(
128        &self,
129        memory: &Memory,
130        cos_similarity: Option<f64>,
131    ) -> Result<SimpleConsolidationResult> {
132        let start_time = Instant::now();
133
134        // Calculate time since last access
135        let current_time = Utc::now();
136        let last_access = memory.last_accessed_at.unwrap_or(memory.created_at);
137        let time_since_access_hours = current_time
138            .signed_duration_since(last_access)
139            .num_seconds() as f64
140            / 3600.0;
141
142        // Update consolidation strength
143        let new_consolidation_strength = self.update_consolidation_strength(
144            memory.consolidation_strength,
145            time_since_access_hours,
146        )?;
147
148        // Create temporary memory with updated strength for recall calculation
149        let mut updated_memory = memory.clone();
150        updated_memory.consolidation_strength = new_consolidation_strength;
151
152        // Calculate recall probability
153        let recall_probability =
154            self.calculate_recall_probability(&updated_memory, cos_similarity)?;
155
156        // Check if migration is needed
157        let should_migrate = recall_probability < self.config.migration_threshold;
158
159        let calculation_time = start_time.elapsed().as_millis() as u64;
160
161        Ok(SimpleConsolidationResult {
162            new_consolidation_strength,
163            recall_probability,
164            should_migrate,
165            calculation_time_ms: calculation_time,
166            time_since_access_hours,
167        })
168    }
169
170    /// Process consolidation for a batch of memories
171    /// Target: Process 1000 memories in < 1 second
172    pub fn process_batch_consolidation(
173        &self,
174        memories: &[Memory],
175        similarities: Option<&[f64]>,
176    ) -> Result<Vec<SimpleConsolidationResult>> {
177        let start_time = Instant::now();
178        let mut results = Vec::with_capacity(memories.len());
179
180        for (i, memory) in memories.iter().enumerate() {
181            let cos_similarity = similarities.and_then(|sims| sims.get(i)).copied();
182
183            match self.process_memory_consolidation(memory, cos_similarity) {
184                Ok(result) => results.push(result),
185                Err(e) => {
186                    warn!(
187                        "Failed to process consolidation for memory {}: {}",
188                        memory.id, e
189                    );
190                    // Continue processing other memories
191                }
192            }
193        }
194
195        let total_time = start_time.elapsed();
196        info!(
197            "Processed {} memories in {:.3}s ({:.1} memories/sec)",
198            results.len(),
199            total_time.as_secs_f64(),
200            results.len() as f64 / total_time.as_secs_f64()
201        );
202
203        Ok(results)
204    }
205
206    /// Calculate cosine similarity between two vectors (helper function)
207    pub fn calculate_cosine_similarity(&self, vec1: &Vector, vec2: &Vector) -> Result<f64> {
208        let slice1 = vec1.as_slice();
209        let slice2 = vec2.as_slice();
210
211        if slice1.len() != slice2.len() {
212            return Err(MemoryError::InvalidRequest {
213                message: "Vector dimensions must match for similarity calculation".to_string(),
214            });
215        }
216
217        let dot_product: f64 = slice1
218            .iter()
219            .zip(slice2.iter())
220            .map(|(a, b)| (*a as f64) * (*b as f64))
221            .sum();
222
223        let norm1: f64 = slice1
224            .iter()
225            .map(|x| (*x as f64).powi(2))
226            .sum::<f64>()
227            .sqrt();
228        let norm2: f64 = slice2
229            .iter()
230            .map(|x| (*x as f64).powi(2))
231            .sum::<f64>()
232            .sqrt();
233
234        if norm1 == 0.0 || norm2 == 0.0 {
235            return Ok(0.0);
236        }
237
238        Ok(dot_product / (norm1 * norm2))
239    }
240
241    /// Get migration candidates based on recall probability threshold
242    pub fn get_migration_candidates(
243        &self,
244        memories: &[Memory],
245    ) -> Result<Vec<(usize, MemoryTier)>> {
246        let mut candidates = Vec::new();
247
248        for (i, memory) in memories.iter().enumerate() {
249            let result = self.process_memory_consolidation(memory, None)?;
250
251            if result.should_migrate {
252                if let Some(next_tier) = memory.next_tier() {
253                    candidates.push((i, next_tier));
254                }
255            }
256        }
257
258        Ok(candidates)
259    }
260}
261
262/// Background consolidation processor for efficient batch processing
263pub struct ConsolidationProcessor {
264    engine: SimpleConsolidationEngine,
265    batch_size: usize,
266}
267
268impl ConsolidationProcessor {
269    pub fn new(config: SimpleConsolidationConfig, batch_size: usize) -> Self {
270        Self {
271            engine: SimpleConsolidationEngine::new(config),
272            batch_size,
273        }
274    }
275
276    /// Process memories in batches to maintain performance targets
277    pub async fn process_consolidation_batch(
278        &self,
279        repository: &crate::memory::repository::MemoryRepository,
280        tier: Option<MemoryTier>,
281    ) -> Result<ConsolidationBatchResult> {
282        let start_time = Instant::now();
283
284        // Get memories to process
285        let memories = self.get_memories_for_processing(repository, tier).await?;
286
287        if memories.is_empty() {
288            debug!("No memories found for consolidation processing");
289            return Ok(ConsolidationBatchResult::default());
290        }
291
292        debug!("Processing consolidation for {} memories", memories.len());
293
294        // Process in batches
295        let mut processed_count = 0;
296        let mut migration_candidates = Vec::new();
297        let mut consolidation_updates = Vec::new();
298
299        for chunk in memories.chunks(self.batch_size) {
300            let results = self.engine.process_batch_consolidation(chunk, None)?;
301
302            for (memory, result) in chunk.iter().zip(results.iter()) {
303                processed_count += 1;
304
305                // Collect migration candidates
306                if result.should_migrate {
307                    if let Some(next_tier) = memory.next_tier() {
308                        migration_candidates.push((memory.id, next_tier));
309                    }
310                }
311
312                // Collect consolidation updates
313                consolidation_updates.push((
314                    memory.id,
315                    result.new_consolidation_strength,
316                    result.recall_probability,
317                ));
318            }
319        }
320
321        let total_time = start_time.elapsed();
322
323        Ok(ConsolidationBatchResult {
324            processed_count,
325            migration_candidates,
326            consolidation_updates,
327            processing_time_ms: total_time.as_millis() as u64,
328        })
329    }
330
331    async fn get_memories_for_processing(
332        &self,
333        repository: &crate::memory::repository::MemoryRepository,
334        tier: Option<MemoryTier>,
335    ) -> Result<Vec<Memory>> {
336        let tier_filter = if let Some(tier) = tier {
337            format!("AND tier = '{:?}'", tier).to_lowercase()
338        } else {
339            String::new()
340        };
341
342        let query = format!(
343            r#"
344            SELECT * FROM memories 
345            WHERE status = 'active' 
346            AND (last_accessed_at IS NULL OR last_accessed_at < NOW() - INTERVAL '1 hour')
347            {}
348            ORDER BY last_accessed_at ASC NULLS FIRST
349            LIMIT $1
350            "#,
351            tier_filter
352        );
353
354        let memories = sqlx::query_as::<_, Memory>(&query)
355            .bind(self.batch_size as i64)
356            .fetch_all(repository.pool())
357            .await?;
358
359        Ok(memories)
360    }
361}
362
363/// Result of batch consolidation processing
364#[derive(Debug, Clone, Default)]
365pub struct ConsolidationBatchResult {
366    pub processed_count: usize,
367    pub migration_candidates: Vec<(Uuid, MemoryTier)>,
368    pub consolidation_updates: Vec<(Uuid, f64, f64)>, // (id, new_strength, recall_prob)
369    pub processing_time_ms: u64,
370}
371
372#[cfg(test)]
373mod tests {
374    use super::*;
375    use chrono::Duration;
376
377    fn create_test_memory() -> Memory {
378        let mut memory = Memory::default();
379        memory.consolidation_strength = 2.0;
380        memory.access_count = 5;
381        memory.last_accessed_at = Some(Utc::now() - Duration::hours(2));
382        memory.importance_score = 0.7;
383        memory
384    }
385
386    #[test]
387    fn test_recall_probability_calculation() {
388        let engine = SimpleConsolidationEngine::new(SimpleConsolidationConfig::default());
389        let memory = create_test_memory();
390
391        let recall_prob = engine
392            .calculate_recall_probability(&memory, Some(0.8))
393            .unwrap();
394
395        assert!(recall_prob >= 0.0 && recall_prob <= 1.0);
396        assert!(recall_prob > 0.0); // Should have some recall probability
397    }
398
399    #[test]
400    fn test_consolidation_strength_update() {
401        let engine = SimpleConsolidationEngine::new(SimpleConsolidationConfig::default());
402
403        let new_strength = engine.update_consolidation_strength(2.0, 1.0).unwrap();
404
405        assert!(new_strength > 2.0); // Should increase
406        assert!(new_strength <= 10.0); // Should respect max bound
407    }
408
409    #[test]
410    fn test_migration_threshold() {
411        let mut config = SimpleConsolidationConfig::default();
412        config.migration_threshold = 0.5;
413
414        let engine = SimpleConsolidationEngine::new(config);
415        let memory = create_test_memory();
416
417        let result = engine.process_memory_consolidation(&memory, None).unwrap();
418
419        // Migration decision should be based on recall probability vs threshold
420        assert_eq!(result.should_migrate, result.recall_probability < 0.5);
421    }
422
423    #[test]
424    fn test_batch_processing_performance() {
425        let engine = SimpleConsolidationEngine::new(SimpleConsolidationConfig::default());
426
427        // Create test batch
428        let memories: Vec<Memory> = (0..100).map(|_| create_test_memory()).collect();
429
430        let start = Instant::now();
431        let results = engine.process_batch_consolidation(&memories, None).unwrap();
432        let duration = start.elapsed();
433
434        assert_eq!(results.len(), 100);
435        assert!(duration.as_millis() < 100); // Should be fast for 100 memories
436    }
437}