1use super::error::{MemoryError, Result};
11use super::models::*;
12use chrono::Utc;
13use pgvector::Vector;
14use serde::{Deserialize, Serialize};
15use std::time::Instant;
16use tracing::{debug, info, warn};
17use uuid::Uuid;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct SimpleConsolidationConfig {
22 pub base_recall_strength: f64,
24
25 pub migration_threshold: f64,
27
28 pub max_consolidation_strength: f64,
30
31 pub time_scale_factor: f64,
33}
34
35impl Default for SimpleConsolidationConfig {
36 fn default() -> Self {
37 use super::math_engine::constants;
38 Self {
39 base_recall_strength: 0.95,
40 migration_threshold: constants::COLD_MIGRATION_THRESHOLD,
41 max_consolidation_strength: 10.0,
42 time_scale_factor: 0.1, }
44 }
45}
46
47#[derive(Debug, Clone)]
49pub struct SimpleConsolidationResult {
50 pub new_consolidation_strength: f64,
51 pub recall_probability: f64,
52 pub should_migrate: bool,
53 pub calculation_time_ms: u64,
54 pub time_since_access_hours: f64,
55}
56
57pub struct SimpleConsolidationEngine {
59 config: SimpleConsolidationConfig,
60}
61
62impl SimpleConsolidationEngine {
63 pub fn new(config: SimpleConsolidationConfig) -> Self {
64 Self { config }
65 }
66
67 pub fn calculate_recall_probability(
70 &self,
71 memory: &Memory,
72 cos_similarity: Option<f64>,
73 ) -> Result<f64> {
74 let r = self.config.base_recall_strength;
75 let g = memory.consolidation_strength;
76 let n = memory.recall_count() as f64;
77
78 let current_time = Utc::now();
80 let last_access = memory.last_accessed_at.unwrap_or(memory.created_at);
81 let t = current_time
82 .signed_duration_since(last_access)
83 .num_seconds() as f64
84 / 3600.0;
85
86 let t_normalized = t * self.config.time_scale_factor;
88
89 let base_recall = r * (-g * t_normalized / (1.0 + n)).exp();
91
92 let similarity_factor = cos_similarity.unwrap_or(1.0);
94
95 let recall_probability = base_recall * similarity_factor;
97
98 Ok(recall_probability.max(0.0).min(1.0))
100 }
101
102 pub fn update_consolidation_strength(
105 &self,
106 current_strength: f64,
107 time_since_access_hours: f64,
108 ) -> Result<f64> {
109 let t = time_since_access_hours * self.config.time_scale_factor;
110
111 let exp_neg_t = (-t).exp();
113 let increment = (1.0 - exp_neg_t) / (1.0 + exp_neg_t);
114
115 let new_strength = current_strength + increment;
116
117 Ok(new_strength
119 .max(0.1)
120 .min(self.config.max_consolidation_strength))
121 }
122
123 pub fn process_memory_consolidation(
125 &self,
126 memory: &Memory,
127 cos_similarity: Option<f64>,
128 ) -> Result<SimpleConsolidationResult> {
129 let start_time = Instant::now();
130
131 let current_time = Utc::now();
133 let last_access = memory.last_accessed_at.unwrap_or(memory.created_at);
134 let time_since_access_hours = current_time
135 .signed_duration_since(last_access)
136 .num_seconds() as f64
137 / 3600.0;
138
139 let new_consolidation_strength = self.update_consolidation_strength(
141 memory.consolidation_strength,
142 time_since_access_hours,
143 )?;
144
145 let mut updated_memory = memory.clone();
147 updated_memory.consolidation_strength = new_consolidation_strength;
148
149 let recall_probability =
151 self.calculate_recall_probability(&updated_memory, cos_similarity)?;
152
153 let should_migrate = recall_probability < self.config.migration_threshold;
155
156 let calculation_time = start_time.elapsed().as_millis() as u64;
157
158 Ok(SimpleConsolidationResult {
159 new_consolidation_strength,
160 recall_probability,
161 should_migrate,
162 calculation_time_ms: calculation_time,
163 time_since_access_hours,
164 })
165 }
166
167 pub fn process_batch_consolidation(
170 &self,
171 memories: &[Memory],
172 similarities: Option<&[f64]>,
173 ) -> Result<Vec<SimpleConsolidationResult>> {
174 let start_time = Instant::now();
175 let mut results = Vec::with_capacity(memories.len());
176
177 for (i, memory) in memories.iter().enumerate() {
178 let cos_similarity = similarities.and_then(|sims| sims.get(i)).copied();
179
180 match self.process_memory_consolidation(memory, cos_similarity) {
181 Ok(result) => results.push(result),
182 Err(e) => {
183 warn!(
184 "Failed to process consolidation for memory {}: {}",
185 memory.id, e
186 );
187 }
189 }
190 }
191
192 let total_time = start_time.elapsed();
193 info!(
194 "Processed {} memories in {:.3}s ({:.1} memories/sec)",
195 results.len(),
196 total_time.as_secs_f64(),
197 results.len() as f64 / total_time.as_secs_f64()
198 );
199
200 Ok(results)
201 }
202
203 pub fn calculate_cosine_similarity(&self, vec1: &Vector, vec2: &Vector) -> Result<f64> {
205 let slice1 = vec1.as_slice();
206 let slice2 = vec2.as_slice();
207
208 if slice1.len() != slice2.len() {
209 return Err(MemoryError::InvalidRequest {
210 message: "Vector dimensions must match for similarity calculation".to_string(),
211 });
212 }
213
214 let dot_product: f64 = slice1
215 .iter()
216 .zip(slice2.iter())
217 .map(|(a, b)| (*a as f64) * (*b as f64))
218 .sum();
219
220 let norm1: f64 = slice1
221 .iter()
222 .map(|x| (*x as f64).powi(2))
223 .sum::<f64>()
224 .sqrt();
225 let norm2: f64 = slice2
226 .iter()
227 .map(|x| (*x as f64).powi(2))
228 .sum::<f64>()
229 .sqrt();
230
231 if norm1 == 0.0 || norm2 == 0.0 {
232 return Ok(0.0);
233 }
234
235 Ok(dot_product / (norm1 * norm2))
236 }
237
238 pub fn get_migration_candidates(
240 &self,
241 memories: &[Memory],
242 ) -> Result<Vec<(usize, MemoryTier)>> {
243 let mut candidates = Vec::new();
244
245 for (i, memory) in memories.iter().enumerate() {
246 let result = self.process_memory_consolidation(memory, None)?;
247
248 if result.should_migrate {
249 if let Some(next_tier) = memory.next_tier() {
250 candidates.push((i, next_tier));
251 }
252 }
253 }
254
255 Ok(candidates)
256 }
257}
258
259pub struct ConsolidationProcessor {
261 engine: SimpleConsolidationEngine,
262 batch_size: usize,
263}
264
265impl ConsolidationProcessor {
266 pub fn new(config: SimpleConsolidationConfig, batch_size: usize) -> Self {
267 Self {
268 engine: SimpleConsolidationEngine::new(config),
269 batch_size,
270 }
271 }
272
273 pub async fn process_consolidation_batch(
275 &self,
276 repository: &crate::memory::repository::MemoryRepository,
277 tier: Option<MemoryTier>,
278 ) -> Result<ConsolidationBatchResult> {
279 let start_time = Instant::now();
280
281 let memories = self.get_memories_for_processing(repository, tier).await?;
283
284 if memories.is_empty() {
285 debug!("No memories found for consolidation processing");
286 return Ok(ConsolidationBatchResult::default());
287 }
288
289 debug!("Processing consolidation for {} memories", memories.len());
290
291 let mut processed_count = 0;
293 let mut migration_candidates = Vec::new();
294 let mut consolidation_updates = Vec::new();
295
296 for chunk in memories.chunks(self.batch_size) {
297 let results = self.engine.process_batch_consolidation(chunk, None)?;
298
299 for (memory, result) in chunk.iter().zip(results.iter()) {
300 processed_count += 1;
301
302 if result.should_migrate {
304 if let Some(next_tier) = memory.next_tier() {
305 migration_candidates.push((memory.id, next_tier));
306 }
307 }
308
309 consolidation_updates.push((
311 memory.id,
312 result.new_consolidation_strength,
313 result.recall_probability,
314 ));
315 }
316 }
317
318 let total_time = start_time.elapsed();
319
320 Ok(ConsolidationBatchResult {
321 processed_count,
322 migration_candidates,
323 consolidation_updates,
324 processing_time_ms: total_time.as_millis() as u64,
325 })
326 }
327
328 async fn get_memories_for_processing(
329 &self,
330 repository: &crate::memory::repository::MemoryRepository,
331 tier: Option<MemoryTier>,
332 ) -> Result<Vec<Memory>> {
333 let base_query = "SELECT * FROM memories WHERE status = 'active' AND (last_accessed_at IS NULL OR last_accessed_at < NOW() - INTERVAL '1 hour')";
335
336 let (query, has_tier) = if let Some(_tier) = tier {
337 let query_with_tier = format!(
338 "{base_query} AND tier = $2 ORDER BY last_accessed_at ASC NULLS FIRST LIMIT $1"
339 );
340 (query_with_tier, true)
341 } else {
342 let query_no_tier =
343 format!("{base_query} ORDER BY last_accessed_at ASC NULLS FIRST LIMIT $1");
344 (query_no_tier, false)
345 };
346
347 let mut sqlx_query = sqlx::query_as::<_, Memory>(&query).bind(self.batch_size as i64);
348
349 if has_tier {
350 if let Some(tier) = tier {
351 sqlx_query = sqlx_query.bind(tier);
352 }
353 }
354
355 let memories = sqlx_query.fetch_all(repository.pool()).await?;
356 Ok(memories)
357 }
358}
359
360#[derive(Debug, Clone, Default)]
362pub struct ConsolidationBatchResult {
363 pub processed_count: usize,
364 pub migration_candidates: Vec<(Uuid, MemoryTier)>,
365 pub consolidation_updates: Vec<(Uuid, f64, f64)>, pub processing_time_ms: u64,
367}
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372 use chrono::Duration;
373
374 fn create_test_memory() -> Memory {
375 let mut memory = Memory::default();
376 memory.consolidation_strength = 2.0;
377 memory.access_count = 5;
378 memory.last_accessed_at = Some(Utc::now() - Duration::hours(2));
379 memory.importance_score = 0.7;
380 memory
381 }
382
383 #[test]
384 fn test_recall_probability_calculation() {
385 let engine = SimpleConsolidationEngine::new(SimpleConsolidationConfig::default());
386 let memory = create_test_memory();
387
388 let recall_prob = engine
389 .calculate_recall_probability(&memory, Some(0.8))
390 .unwrap();
391
392 assert!((0.0..=1.0).contains(&recall_prob));
393 assert!(recall_prob > 0.0); }
395
396 #[test]
397 fn test_consolidation_strength_update() {
398 let engine = SimpleConsolidationEngine::new(SimpleConsolidationConfig::default());
399
400 let new_strength = engine.update_consolidation_strength(2.0, 1.0).unwrap();
401
402 assert!(new_strength > 2.0); assert!(new_strength <= 10.0); }
405
406 #[test]
407 fn test_migration_threshold() {
408 let mut config = SimpleConsolidationConfig::default();
409 config.migration_threshold = 0.5;
410
411 let engine = SimpleConsolidationEngine::new(config);
412 let memory = create_test_memory();
413
414 let result = engine.process_memory_consolidation(&memory, None).unwrap();
415
416 assert_eq!(result.should_migrate, result.recall_probability < 0.5);
418 }
419
420 #[test]
421 fn test_batch_processing_performance() {
422 let engine = SimpleConsolidationEngine::new(SimpleConsolidationConfig::default());
423
424 let memories: Vec<Memory> = (0..100).map(|_| create_test_memory()).collect();
426
427 let start = Instant::now();
428 let results = engine.process_batch_consolidation(&memories, None).unwrap();
429 let duration = start.elapsed();
430
431 assert_eq!(results.len(), 100);
432 assert!(duration.as_millis() < 100); }
434}