1use super::error::{MemoryError, Result};
2use super::math_engine::{MathEngine, MemoryParameters};
3use super::models::{Memory, MemoryTier};
4use super::repository::MemoryRepository;
5use crate::config::TierManagerConfig;
6use chrono::{DateTime, Duration, Utc};
7use prometheus::{register_counter, register_gauge, register_histogram, Counter, Gauge, Histogram};
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::Instant;
12use tokio::sync::RwLock;
13use tokio::time::{interval, sleep, Duration as TokioDuration};
14use tracing::{debug, error, info, warn};
15use uuid::Uuid;
16
17pub struct TierManager {
23 repository: Arc<MemoryRepository>,
24 config: TierManagerConfig,
25 math_engine: MathEngine,
26
27 running: Arc<AtomicBool>,
29 last_scan_time: Arc<RwLock<Option<DateTime<Utc>>>>,
30
31 migrations_completed: Arc<AtomicU64>,
33 migrations_failed: Arc<AtomicU64>,
34 total_scan_time_ms: Arc<AtomicU64>,
35
36 scan_duration_histogram: Histogram,
38 migration_counter: Counter,
39 migration_failure_counter: Counter,
40 memories_per_tier_gauge: Gauge,
41 recall_probability_histogram: Histogram,
42}
43
44#[derive(Debug, Clone)]
45pub struct TierMigrationCandidate {
46 pub memory_id: Uuid,
47 pub current_tier: MemoryTier,
48 pub target_tier: MemoryTier,
49 pub recall_probability: f64,
50 pub migration_reason: String,
51 pub priority_score: f64, }
53
54#[derive(Debug, Clone)]
55pub struct TierMigrationBatch {
56 pub candidates: Vec<TierMigrationCandidate>,
57 pub batch_id: Uuid,
58 pub estimated_duration_ms: u64,
59}
60
61#[derive(Debug, Clone)]
62pub struct TierMigrationResult {
63 pub batch_id: Uuid,
64 pub successful_migrations: Vec<Uuid>,
65 pub failed_migrations: Vec<(Uuid, String)>,
66 pub duration_ms: u64,
67 pub memories_per_second: f64,
68}
69
70#[derive(Debug, Clone)]
71pub struct TierManagerMetrics {
72 pub total_migrations_completed: u64,
73 pub total_migrations_failed: u64,
74 pub last_scan_duration_ms: u64,
75 pub memories_by_tier: HashMap<MemoryTier, u64>,
76 pub average_recall_probability_by_tier: HashMap<MemoryTier, f64>,
77 pub migrations_per_second_recent: f64,
78 pub is_running: bool,
79 pub last_scan_time: Option<DateTime<Utc>>,
80}
81
82impl TierManager {
83 pub fn new(repository: Arc<MemoryRepository>, config: TierManagerConfig) -> Result<Self> {
85 let scan_duration_histogram = register_histogram!(
87 "tier_manager_scan_duration_seconds",
88 "Time taken for tier management scans",
89 vec![0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 30.0, 60.0]
90 )?;
91
92 let migration_counter = register_counter!(
93 "tier_manager_migrations_total",
94 "Total number of tier migrations completed"
95 )?;
96
97 let migration_failure_counter = register_counter!(
98 "tier_manager_migration_failures_total",
99 "Total number of tier migration failures"
100 )?;
101
102 let memories_per_tier_gauge = register_gauge!(
103 "tier_manager_memories_per_tier",
104 "Number of memories in each tier"
105 )?;
106
107 let recall_probability_histogram = register_histogram!(
108 "tier_manager_recall_probability",
109 "Distribution of recall probabilities",
110 vec![0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
111 )?;
112
113 Ok(Self {
114 repository,
115 config,
116 math_engine: MathEngine::new(),
117 running: Arc::new(AtomicBool::new(false)),
118 last_scan_time: Arc::new(RwLock::new(None)),
119 migrations_completed: Arc::new(AtomicU64::new(0)),
120 migrations_failed: Arc::new(AtomicU64::new(0)),
121 total_scan_time_ms: Arc::new(AtomicU64::new(0)),
122 scan_duration_histogram,
123 migration_counter,
124 migration_failure_counter,
125 memories_per_tier_gauge,
126 recall_probability_histogram,
127 })
128 }
129
130 pub async fn start(&self) -> Result<()> {
132 if self.running.load(Ordering::Relaxed) {
133 return Err(MemoryError::ServiceError(
134 "TierManager is already running".to_string(),
135 ));
136 }
137
138 if !self.config.enabled {
139 info!("TierManager is disabled in configuration");
140 return Ok(());
141 }
142
143 info!(
144 "Starting TierManager service with {} second scan interval",
145 self.config.scan_interval_seconds
146 );
147
148 self.running.store(true, Ordering::Relaxed);
149
150 let manager = self.clone();
152 tokio::spawn(async move {
153 manager.management_loop().await;
154 });
155
156 Ok(())
157 }
158
159 pub async fn stop(&self) {
161 info!("Stopping TierManager service");
162 self.running.store(false, Ordering::Relaxed);
163
164 sleep(TokioDuration::from_secs(2)).await;
166 }
167
168 pub async fn get_metrics(&self) -> Result<TierManagerMetrics> {
170 let memories_by_tier = self.get_memory_counts_by_tier().await?;
171 let recall_probabilities = self.get_average_recall_probabilities_by_tier().await?;
172
173 Ok(TierManagerMetrics {
174 total_migrations_completed: self.migrations_completed.load(Ordering::Relaxed),
175 total_migrations_failed: self.migrations_failed.load(Ordering::Relaxed),
176 last_scan_duration_ms: self.total_scan_time_ms.load(Ordering::Relaxed),
177 memories_by_tier,
178 average_recall_probability_by_tier: recall_probabilities,
179 migrations_per_second_recent: self.calculate_recent_migration_rate().await,
180 is_running: self.running.load(Ordering::Relaxed),
181 last_scan_time: *self.last_scan_time.read().await,
182 })
183 }
184
185 pub async fn force_scan(&self) -> Result<TierMigrationResult> {
187 if !self.running.load(Ordering::Relaxed) {
188 return Err(MemoryError::ServiceError(
189 "TierManager is not running".to_string(),
190 ));
191 }
192
193 info!("Forcing immediate tier management scan");
194 self.perform_tier_management_scan().await
195 }
196}
197
198impl TierManager {
200 async fn management_loop(&self) {
202 let mut scan_interval =
203 interval(TokioDuration::from_secs(self.config.scan_interval_seconds));
204
205 while self.running.load(Ordering::Relaxed) {
206 scan_interval.tick().await;
207
208 if let Err(e) = self.perform_tier_management_scan().await {
209 error!("Tier management scan failed: {}", e);
210 }
212 }
213
214 info!("TierManager management loop stopped");
215 }
216
217 async fn perform_tier_management_scan(&self) -> Result<TierMigrationResult> {
219 let scan_start = Instant::now();
220 let scan_time = Utc::now();
221
222 debug!("Starting tier management scan");
223
224 let candidates = self.find_migration_candidates().await?;
226
227 if candidates.is_empty() {
228 debug!("No migration candidates found");
229 *self.last_scan_time.write().await = Some(scan_time);
230 return Ok(TierMigrationResult {
231 batch_id: Uuid::new_v4(),
232 successful_migrations: Vec::new(),
233 failed_migrations: Vec::new(),
234 duration_ms: scan_start.elapsed().as_millis() as u64,
235 memories_per_second: 0.0,
236 });
237 }
238
239 info!("Found {} migration candidates", candidates.len());
240
241 let batches = self.create_migration_batches(candidates);
243
244 let result = self.process_migration_batches(batches).await?;
246
247 let scan_duration = scan_start.elapsed();
249 self.scan_duration_histogram
250 .observe(scan_duration.as_secs_f64());
251 self.total_scan_time_ms
252 .store(scan_duration.as_millis() as u64, Ordering::Relaxed);
253 *self.last_scan_time.write().await = Some(scan_time);
254
255 self.update_tier_metrics().await?;
257
258 info!(
259 "Tier management scan completed: {} successful, {} failed, {:.2} migrations/sec",
260 result.successful_migrations.len(),
261 result.failed_migrations.len(),
262 result.memories_per_second
263 );
264
265 Ok(result)
266 }
267
268 async fn find_migration_candidates(&self) -> Result<Vec<TierMigrationCandidate>> {
270 let mut candidates = Vec::new();
271
272 for tier in [MemoryTier::Working, MemoryTier::Warm, MemoryTier::Cold] {
274 let tier_candidates = self.find_candidates_for_tier(tier).await?;
275 candidates.extend(tier_candidates);
276 }
277
278 candidates.sort_by(|a, b| {
280 b.priority_score
281 .partial_cmp(&a.priority_score)
282 .unwrap_or(std::cmp::Ordering::Equal)
283 });
284
285 Ok(candidates)
286 }
287
288 async fn find_candidates_for_tier(
290 &self,
291 source_tier: MemoryTier,
292 ) -> Result<Vec<TierMigrationCandidate>> {
293 let min_age_hours = match source_tier {
295 MemoryTier::Working => self.config.min_working_age_hours,
296 MemoryTier::Warm => self.config.min_warm_age_hours,
297 MemoryTier::Cold => self.config.min_cold_age_hours,
298 MemoryTier::Frozen => return Ok(Vec::new()), };
300
301 let min_age_time = Utc::now() - Duration::hours(min_age_hours as i64);
302
303 let query_ids = sqlx::query_scalar!(
306 "SELECT id FROM memories WHERE tier = $1 AND status = 'active' AND updated_at <= $2 ORDER BY updated_at ASC LIMIT 1000",
307 source_tier as MemoryTier,
308 min_age_time
309 )
310 .fetch_all(self.repository.pool())
311 .await?;
312
313 let mut candidates = Vec::new();
314
315 for memory_id in query_ids {
317 if let Ok(memory) = self.repository.get_memory(memory_id).await {
319 if let Some(candidate) = self.evaluate_migration_candidate(&memory).await? {
320 candidates.push(candidate);
321 }
322 }
323 }
324
325 Ok(candidates)
326 }
327
328 async fn evaluate_migration_candidate(
330 &self,
331 memory: &Memory,
332 ) -> Result<Option<TierMigrationCandidate>> {
333 let recall_probability = self.calculate_recall_probability(memory)?;
335
336 if self.config.enable_metrics {
338 self.recall_probability_histogram
339 .observe(recall_probability);
340 }
341
342 let (should_migrate, target_tier, reason) = match memory.tier {
344 MemoryTier::Working => {
345 if recall_probability < self.config.working_to_warm_threshold {
346 (
347 true,
348 MemoryTier::Warm,
349 format!(
350 "Recall probability {:.3} below threshold {:.3}",
351 recall_probability, self.config.working_to_warm_threshold
352 ),
353 )
354 } else {
355 (false, memory.tier, String::new())
356 }
357 }
358 MemoryTier::Warm => {
359 if recall_probability < self.config.warm_to_cold_threshold {
360 (
361 true,
362 MemoryTier::Cold,
363 format!(
364 "Recall probability {:.3} below threshold {:.3}",
365 recall_probability, self.config.warm_to_cold_threshold
366 ),
367 )
368 } else {
369 (false, memory.tier, String::new())
370 }
371 }
372 MemoryTier::Cold => {
373 if recall_probability < self.config.cold_to_frozen_threshold {
374 (
375 true,
376 MemoryTier::Frozen,
377 format!(
378 "Recall probability {:.3} below threshold {:.3}",
379 recall_probability, self.config.cold_to_frozen_threshold
380 ),
381 )
382 } else {
383 (false, memory.tier, String::new())
384 }
385 }
386 MemoryTier::Frozen => (false, memory.tier, String::new()), };
388
389 if !should_migrate {
390 return Ok(None);
391 }
392
393 let age_factor = Utc::now()
395 .signed_duration_since(memory.updated_at)
396 .num_hours() as f64
397 / 24.0;
398 let priority_score = (1.0 - recall_probability) * (1.0 + age_factor.ln().max(0.0));
399
400 Ok(Some(TierMigrationCandidate {
401 memory_id: memory.id,
402 current_tier: memory.tier,
403 target_tier,
404 recall_probability,
405 migration_reason: reason,
406 priority_score,
407 }))
408 }
409
410 fn calculate_recall_probability(&self, memory: &Memory) -> Result<f64> {
412 let params = MemoryParameters {
413 consolidation_strength: memory.consolidation_strength,
414 decay_rate: memory.decay_rate,
415 last_accessed_at: memory.last_accessed_at,
416 created_at: memory.created_at,
417 access_count: memory.access_count,
418 importance_score: memory.importance_score,
419 };
420
421 match self.math_engine.calculate_recall_probability(¶ms) {
422 Ok(result) => Ok(result.recall_probability),
423 Err(e) => {
424 warn!(
425 "Math engine calculation failed for memory {}: {}",
426 memory.id, e
427 );
428 let fallback = (memory.importance_score * memory.consolidation_strength / 10.0)
430 .min(1.0)
431 .max(0.0);
432 Ok(fallback)
433 }
434 }
435 }
436
437 fn create_migration_batches(
439 &self,
440 candidates: Vec<TierMigrationCandidate>,
441 ) -> Vec<TierMigrationBatch> {
442 let mut batches = Vec::new();
443 let batch_size = self.config.migration_batch_size;
444
445 for chunk in candidates.chunks(batch_size) {
446 let batch = TierMigrationBatch {
447 candidates: chunk.to_vec(),
448 batch_id: Uuid::new_v4(),
449 estimated_duration_ms: self.estimate_batch_duration(chunk.len()),
450 };
451 batches.push(batch);
452 }
453
454 batches
455 }
456
457 fn estimate_batch_duration(&self, batch_size: usize) -> u64 {
459 (batch_size as f64 * 1.2) as u64
462 }
463
464 async fn process_migration_batches(
466 &self,
467 batches: Vec<TierMigrationBatch>,
468 ) -> Result<TierMigrationResult> {
469 let start_time = Instant::now();
470 let mut all_successful = Vec::new();
471 let mut all_failed = Vec::new();
472
473 let semaphore = Arc::new(tokio::sync::Semaphore::new(
475 self.config.max_concurrent_migrations,
476 ));
477 let mut handles = Vec::new();
478
479 for batch in batches {
480 let semaphore = semaphore.clone();
481 let repository = self.repository.clone();
482 let config = self.config.clone();
483
484 let handle = tokio::spawn(async move {
485 let _permit = semaphore
486 .acquire()
487 .await
488 .expect("Semaphore acquisition failed");
489 Self::process_single_batch(repository, batch, config).await
490 });
491
492 handles.push(handle);
493 }
494
495 for handle in handles {
497 match handle.await {
498 Ok(Ok(result)) => {
499 all_successful.extend(result.successful_migrations);
500 all_failed.extend(result.failed_migrations);
501 }
502 Ok(Err(e)) => {
503 error!("Batch processing failed: {}", e);
504 }
505 Err(e) => {
506 error!("Batch task panicked: {}", e);
507 }
508 }
509 }
510
511 let duration = start_time.elapsed();
512 let total_migrations = all_successful.len() + all_failed.len();
513 let memories_per_second = if duration.as_secs_f64() > 0.0 {
514 total_migrations as f64 / duration.as_secs_f64()
515 } else {
516 0.0
517 };
518
519 self.migrations_completed
521 .fetch_add(all_successful.len() as u64, Ordering::Relaxed);
522 self.migrations_failed
523 .fetch_add(all_failed.len() as u64, Ordering::Relaxed);
524
525 if self.config.enable_metrics {
526 self.migration_counter.inc_by(all_successful.len() as f64);
527 self.migration_failure_counter
528 .inc_by(all_failed.len() as f64);
529 }
530
531 Ok(TierMigrationResult {
532 batch_id: Uuid::new_v4(),
533 successful_migrations: all_successful,
534 failed_migrations: all_failed,
535 duration_ms: duration.as_millis() as u64,
536 memories_per_second,
537 })
538 }
539
540 async fn process_single_batch(
542 repository: Arc<MemoryRepository>,
543 batch: TierMigrationBatch,
544 config: TierManagerConfig,
545 ) -> Result<TierMigrationResult> {
546 let start_time = Instant::now();
547 let mut successful = Vec::new();
548 let mut failed = Vec::new();
549
550 for candidate in &batch.candidates {
551 match Self::migrate_single_memory(&repository, candidate, &config).await {
552 Ok(_) => {
553 successful.push(candidate.memory_id);
554 debug!(
555 "Successfully migrated memory {} from {:?} to {:?}",
556 candidate.memory_id, candidate.current_tier, candidate.target_tier
557 );
558 }
559 Err(e) => {
560 failed.push((candidate.memory_id, e.to_string()));
561 warn!("Failed to migrate memory {}: {}", candidate.memory_id, e);
562 }
563 }
564 }
565
566 let duration = start_time.elapsed();
567 let memories_per_second = if duration.as_secs_f64() > 0.0 {
568 batch.candidates.len() as f64 / duration.as_secs_f64()
569 } else {
570 0.0
571 };
572
573 Ok(TierMigrationResult {
574 batch_id: batch.batch_id,
575 successful_migrations: successful,
576 failed_migrations: failed,
577 duration_ms: duration.as_millis() as u64,
578 memories_per_second,
579 })
580 }
581
582 async fn migrate_single_memory(
584 repository: &MemoryRepository,
585 candidate: &TierMigrationCandidate,
586 config: &TierManagerConfig,
587 ) -> Result<()> {
588 let mut tx = repository.pool().begin().await?;
589
590 sqlx::query!(
592 "UPDATE memories SET tier = $1, updated_at = NOW() WHERE id = $2",
593 candidate.target_tier as MemoryTier,
594 candidate.memory_id
595 )
596 .execute(&mut *tx)
597 .await?;
598
599 if config.log_migrations {
601 sqlx::query!(
602 r#"
603 INSERT INTO memory_consolidation_log (
604 id, memory_id, old_consolidation_strength, new_consolidation_strength,
605 old_recall_probability, new_recall_probability, consolidation_event,
606 trigger_reason, created_at
607 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, NOW())
608 "#,
609 Uuid::new_v4(),
610 candidate.memory_id,
611 0.0, 0.0,
613 Some(candidate.recall_probability),
614 Some(candidate.recall_probability),
615 format!(
616 "tier_migration_{}_{}",
617 format!("{:?}", candidate.current_tier).to_lowercase(),
618 format!("{:?}", candidate.target_tier).to_lowercase()
619 ),
620 Some(format!(
621 "{}. Priority score: {:.3}",
622 candidate.migration_reason, candidate.priority_score
623 ))
624 )
625 .execute(&mut *tx)
626 .await?;
627 }
628
629 tx.commit().await?;
630 Ok(())
631 }
632
633 async fn get_memory_counts_by_tier(&self) -> Result<HashMap<MemoryTier, u64>> {
635 let rows = sqlx::query!(
636 r#"
637 SELECT tier as "tier: MemoryTier", COUNT(*) as count
638 FROM memories
639 WHERE status = 'active'
640 GROUP BY tier
641 "#
642 )
643 .fetch_all(self.repository.pool())
644 .await?;
645
646 let mut counts = HashMap::new();
647 for row in rows {
648 counts.insert(row.tier, row.count.unwrap_or(0) as u64);
649 }
650
651 Ok(counts)
652 }
653
654 async fn get_average_recall_probabilities_by_tier(&self) -> Result<HashMap<MemoryTier, f64>> {
656 let rows = sqlx::query!(
657 r#"
658 SELECT tier as "tier: MemoryTier", AVG(recall_probability) as avg_recall_prob
659 FROM memories
660 WHERE status = 'active' AND recall_probability IS NOT NULL
661 GROUP BY tier
662 "#
663 )
664 .fetch_all(self.repository.pool())
665 .await?;
666
667 let mut averages = HashMap::new();
668 for row in rows {
669 if let Some(avg) = row.avg_recall_prob {
670 averages.insert(row.tier, avg);
671 }
672 }
673
674 Ok(averages)
675 }
676
677 async fn calculate_recent_migration_rate(&self) -> f64 {
679 let completed = self.migrations_completed.load(Ordering::Relaxed);
682 let scan_time_ms = self.total_scan_time_ms.load(Ordering::Relaxed);
683
684 if scan_time_ms > 0 {
685 (completed as f64 * 1000.0) / scan_time_ms as f64
686 } else {
687 0.0
688 }
689 }
690
691 async fn update_tier_metrics(&self) -> Result<()> {
693 if !self.config.enable_metrics {
694 return Ok(());
695 }
696
697 let counts = self.get_memory_counts_by_tier().await?;
698
699 for (_tier, count) in counts {
700 self.memories_per_tier_gauge.set(count as f64);
702 }
703
704 Ok(())
705 }
706}
707
708impl Clone for TierManager {
710 fn clone(&self) -> Self {
711 Self {
712 repository: self.repository.clone(),
713 config: self.config.clone(),
714 math_engine: MathEngine::new(), running: self.running.clone(),
716 last_scan_time: self.last_scan_time.clone(),
717 migrations_completed: self.migrations_completed.clone(),
718 migrations_failed: self.migrations_failed.clone(),
719 total_scan_time_ms: self.total_scan_time_ms.clone(),
720 scan_duration_histogram: self.scan_duration_histogram.clone(),
721 migration_counter: self.migration_counter.clone(),
722 migration_failure_counter: self.migration_failure_counter.clone(),
723 memories_per_tier_gauge: self.memories_per_tier_gauge.clone(),
724 recall_probability_histogram: self.recall_probability_histogram.clone(),
725 }
726 }
727}
728
729#[cfg(test)]
730mod tests {
731 use super::*;
732 use crate::memory::connection::create_pool;
733 use anyhow::Context;
734 use dotenv::dotenv;
735 use sqlx::PgPool;
736
737 async fn create_test_pool() -> anyhow::Result<PgPool> {
738 let _ = dotenv();
740
741 let database_url = std::env::var("DATABASE_URL")
742 .context("DATABASE_URL environment variable not set. Ensure .env file is present.")?;
743
744 create_pool(&database_url, 5).await
745 }
746
747 #[tokio::test]
748 async fn test_tier_manager_creation() {
749 let pool = create_test_pool().await.unwrap();
750 let repository = Arc::new(MemoryRepository::new(pool));
751 let config = TierManagerConfig::default();
752
753 let manager = TierManager::new(repository, config);
754 assert!(manager.is_ok());
755 }
756
757 #[tokio::test]
758 async fn test_migration_candidate_evaluation() {
759 let memory = Memory {
762 consolidation_strength: 1.0,
763 decay_rate: 1.0,
764 recall_probability: Some(0.3), tier: MemoryTier::Working,
766 ..Memory::default()
767 };
768
769 assert_eq!(memory.tier, MemoryTier::Working);
771 }
772}