1use super::error::{MemoryError, Result};
11use super::models::*;
12use chrono::Utc;
13use pgvector::Vector;
14use serde::{Deserialize, Serialize};
15use std::time::Instant;
16use tracing::{debug, info, warn};
17use uuid::Uuid;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct SimpleConsolidationConfig {
22 pub base_recall_strength: f64,
24
25 pub migration_threshold: f64,
27
28 pub max_consolidation_strength: f64,
30
31 pub similarity_weight: f64,
33
34 pub time_scale_factor: f64,
36}
37
38impl Default for SimpleConsolidationConfig {
39 fn default() -> Self {
40 Self {
41 base_recall_strength: 0.95,
42 migration_threshold: 0.86,
43 max_consolidation_strength: 10.0,
44 similarity_weight: 0.1,
45 time_scale_factor: 0.1, }
47 }
48}
49
50#[derive(Debug, Clone)]
52pub struct SimpleConsolidationResult {
53 pub new_consolidation_strength: f64,
54 pub recall_probability: f64,
55 pub should_migrate: bool,
56 pub calculation_time_ms: u64,
57 pub time_since_access_hours: f64,
58}
59
60pub struct SimpleConsolidationEngine {
62 config: SimpleConsolidationConfig,
63}
64
65impl SimpleConsolidationEngine {
66 pub fn new(config: SimpleConsolidationConfig) -> Self {
67 Self { config }
68 }
69
70 pub fn calculate_recall_probability(
73 &self,
74 memory: &Memory,
75 cos_similarity: Option<f64>,
76 ) -> Result<f64> {
77 let r = self.config.base_recall_strength;
78 let g = memory.consolidation_strength;
79 let n = memory.recall_count() as f64;
80
81 let current_time = Utc::now();
83 let last_access = memory.last_accessed_at.unwrap_or(memory.created_at);
84 let t = current_time
85 .signed_duration_since(last_access)
86 .num_seconds() as f64
87 / 3600.0;
88
89 let t_normalized = t * self.config.time_scale_factor;
91
92 let base_recall = r * (-g * t_normalized / (1.0 + n)).exp();
94
95 let similarity_factor = cos_similarity.unwrap_or(1.0);
97 let weighted_similarity = 1.0 + (similarity_factor - 1.0) * self.config.similarity_weight;
98
99 let recall_probability = base_recall * weighted_similarity;
100
101 Ok(recall_probability.max(0.0).min(1.0))
103 }
104
105 pub fn update_consolidation_strength(
108 &self,
109 current_strength: f64,
110 time_since_access_hours: f64,
111 ) -> Result<f64> {
112 let t = time_since_access_hours * self.config.time_scale_factor;
113
114 let exp_neg_t = (-t).exp();
116 let increment = (1.0 - exp_neg_t) / (1.0 + exp_neg_t);
117
118 let new_strength = current_strength + increment;
119
120 Ok(new_strength
122 .max(0.1)
123 .min(self.config.max_consolidation_strength))
124 }
125
126 pub fn process_memory_consolidation(
128 &self,
129 memory: &Memory,
130 cos_similarity: Option<f64>,
131 ) -> Result<SimpleConsolidationResult> {
132 let start_time = Instant::now();
133
134 let current_time = Utc::now();
136 let last_access = memory.last_accessed_at.unwrap_or(memory.created_at);
137 let time_since_access_hours = current_time
138 .signed_duration_since(last_access)
139 .num_seconds() as f64
140 / 3600.0;
141
142 let new_consolidation_strength = self.update_consolidation_strength(
144 memory.consolidation_strength,
145 time_since_access_hours,
146 )?;
147
148 let mut updated_memory = memory.clone();
150 updated_memory.consolidation_strength = new_consolidation_strength;
151
152 let recall_probability =
154 self.calculate_recall_probability(&updated_memory, cos_similarity)?;
155
156 let should_migrate = recall_probability < self.config.migration_threshold;
158
159 let calculation_time = start_time.elapsed().as_millis() as u64;
160
161 Ok(SimpleConsolidationResult {
162 new_consolidation_strength,
163 recall_probability,
164 should_migrate,
165 calculation_time_ms: calculation_time,
166 time_since_access_hours,
167 })
168 }
169
170 pub fn process_batch_consolidation(
173 &self,
174 memories: &[Memory],
175 similarities: Option<&[f64]>,
176 ) -> Result<Vec<SimpleConsolidationResult>> {
177 let start_time = Instant::now();
178 let mut results = Vec::with_capacity(memories.len());
179
180 for (i, memory) in memories.iter().enumerate() {
181 let cos_similarity = similarities.and_then(|sims| sims.get(i)).copied();
182
183 match self.process_memory_consolidation(memory, cos_similarity) {
184 Ok(result) => results.push(result),
185 Err(e) => {
186 warn!(
187 "Failed to process consolidation for memory {}: {}",
188 memory.id, e
189 );
190 }
192 }
193 }
194
195 let total_time = start_time.elapsed();
196 info!(
197 "Processed {} memories in {:.3}s ({:.1} memories/sec)",
198 results.len(),
199 total_time.as_secs_f64(),
200 results.len() as f64 / total_time.as_secs_f64()
201 );
202
203 Ok(results)
204 }
205
206 pub fn calculate_cosine_similarity(&self, vec1: &Vector, vec2: &Vector) -> Result<f64> {
208 let slice1 = vec1.as_slice();
209 let slice2 = vec2.as_slice();
210
211 if slice1.len() != slice2.len() {
212 return Err(MemoryError::InvalidRequest {
213 message: "Vector dimensions must match for similarity calculation".to_string(),
214 });
215 }
216
217 let dot_product: f64 = slice1
218 .iter()
219 .zip(slice2.iter())
220 .map(|(a, b)| (*a as f64) * (*b as f64))
221 .sum();
222
223 let norm1: f64 = slice1
224 .iter()
225 .map(|x| (*x as f64).powi(2))
226 .sum::<f64>()
227 .sqrt();
228 let norm2: f64 = slice2
229 .iter()
230 .map(|x| (*x as f64).powi(2))
231 .sum::<f64>()
232 .sqrt();
233
234 if norm1 == 0.0 || norm2 == 0.0 {
235 return Ok(0.0);
236 }
237
238 Ok(dot_product / (norm1 * norm2))
239 }
240
241 pub fn get_migration_candidates(
243 &self,
244 memories: &[Memory],
245 ) -> Result<Vec<(usize, MemoryTier)>> {
246 let mut candidates = Vec::new();
247
248 for (i, memory) in memories.iter().enumerate() {
249 let result = self.process_memory_consolidation(memory, None)?;
250
251 if result.should_migrate {
252 if let Some(next_tier) = memory.next_tier() {
253 candidates.push((i, next_tier));
254 }
255 }
256 }
257
258 Ok(candidates)
259 }
260}
261
262pub struct ConsolidationProcessor {
264 engine: SimpleConsolidationEngine,
265 batch_size: usize,
266}
267
268impl ConsolidationProcessor {
269 pub fn new(config: SimpleConsolidationConfig, batch_size: usize) -> Self {
270 Self {
271 engine: SimpleConsolidationEngine::new(config),
272 batch_size,
273 }
274 }
275
276 pub async fn process_consolidation_batch(
278 &self,
279 repository: &crate::memory::repository::MemoryRepository,
280 tier: Option<MemoryTier>,
281 ) -> Result<ConsolidationBatchResult> {
282 let start_time = Instant::now();
283
284 let memories = self.get_memories_for_processing(repository, tier).await?;
286
287 if memories.is_empty() {
288 debug!("No memories found for consolidation processing");
289 return Ok(ConsolidationBatchResult::default());
290 }
291
292 debug!("Processing consolidation for {} memories", memories.len());
293
294 let mut processed_count = 0;
296 let mut migration_candidates = Vec::new();
297 let mut consolidation_updates = Vec::new();
298
299 for chunk in memories.chunks(self.batch_size) {
300 let results = self.engine.process_batch_consolidation(chunk, None)?;
301
302 for (memory, result) in chunk.iter().zip(results.iter()) {
303 processed_count += 1;
304
305 if result.should_migrate {
307 if let Some(next_tier) = memory.next_tier() {
308 migration_candidates.push((memory.id, next_tier));
309 }
310 }
311
312 consolidation_updates.push((
314 memory.id,
315 result.new_consolidation_strength,
316 result.recall_probability,
317 ));
318 }
319 }
320
321 let total_time = start_time.elapsed();
322
323 Ok(ConsolidationBatchResult {
324 processed_count,
325 migration_candidates,
326 consolidation_updates,
327 processing_time_ms: total_time.as_millis() as u64,
328 })
329 }
330
331 async fn get_memories_for_processing(
332 &self,
333 repository: &crate::memory::repository::MemoryRepository,
334 tier: Option<MemoryTier>,
335 ) -> Result<Vec<Memory>> {
336 let tier_filter = if let Some(tier) = tier {
337 format!("AND tier = '{:?}'", tier).to_lowercase()
338 } else {
339 String::new()
340 };
341
342 let query = format!(
343 r#"
344 SELECT * FROM memories
345 WHERE status = 'active'
346 AND (last_accessed_at IS NULL OR last_accessed_at < NOW() - INTERVAL '1 hour')
347 {}
348 ORDER BY last_accessed_at ASC NULLS FIRST
349 LIMIT $1
350 "#,
351 tier_filter
352 );
353
354 let memories = sqlx::query_as::<_, Memory>(&query)
355 .bind(self.batch_size as i64)
356 .fetch_all(repository.pool())
357 .await?;
358
359 Ok(memories)
360 }
361}
362
363#[derive(Debug, Clone, Default)]
365pub struct ConsolidationBatchResult {
366 pub processed_count: usize,
367 pub migration_candidates: Vec<(Uuid, MemoryTier)>,
368 pub consolidation_updates: Vec<(Uuid, f64, f64)>, pub processing_time_ms: u64,
370}
371
372#[cfg(test)]
373mod tests {
374 use super::*;
375 use chrono::Duration;
376
377 fn create_test_memory() -> Memory {
378 let mut memory = Memory::default();
379 memory.consolidation_strength = 2.0;
380 memory.access_count = 5;
381 memory.last_accessed_at = Some(Utc::now() - Duration::hours(2));
382 memory.importance_score = 0.7;
383 memory
384 }
385
386 #[test]
387 fn test_recall_probability_calculation() {
388 let engine = SimpleConsolidationEngine::new(SimpleConsolidationConfig::default());
389 let memory = create_test_memory();
390
391 let recall_prob = engine
392 .calculate_recall_probability(&memory, Some(0.8))
393 .unwrap();
394
395 assert!(recall_prob >= 0.0 && recall_prob <= 1.0);
396 assert!(recall_prob > 0.0); }
398
399 #[test]
400 fn test_consolidation_strength_update() {
401 let engine = SimpleConsolidationEngine::new(SimpleConsolidationConfig::default());
402
403 let new_strength = engine.update_consolidation_strength(2.0, 1.0).unwrap();
404
405 assert!(new_strength > 2.0); assert!(new_strength <= 10.0); }
408
409 #[test]
410 fn test_migration_threshold() {
411 let mut config = SimpleConsolidationConfig::default();
412 config.migration_threshold = 0.5;
413
414 let engine = SimpleConsolidationEngine::new(config);
415 let memory = create_test_memory();
416
417 let result = engine.process_memory_consolidation(&memory, None).unwrap();
418
419 assert_eq!(result.should_migrate, result.recall_probability < 0.5);
421 }
422
423 #[test]
424 fn test_batch_processing_performance() {
425 let engine = SimpleConsolidationEngine::new(SimpleConsolidationConfig::default());
426
427 let memories: Vec<Memory> = (0..100).map(|_| create_test_memory()).collect();
429
430 let start = Instant::now();
431 let results = engine.process_batch_consolidation(&memories, None).unwrap();
432 let duration = start.elapsed();
433
434 assert_eq!(results.len(), 100);
435 assert!(duration.as_millis() < 100); }
437}