codex_memory/memory/
tier_manager.rs

1use super::error::{MemoryError, Result};
2use super::math_engine::{MathEngine, MemoryParameters};
3use super::models::{Memory, MemoryTier};
4use super::repository::MemoryRepository;
5use crate::config::TierManagerConfig;
6use chrono::{DateTime, Duration, Utc};
7use prometheus::{register_counter, register_gauge, register_histogram, Counter, Gauge, Histogram};
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::Instant;
12use tokio::sync::RwLock;
13use tokio::time::{interval, sleep, Duration as TokioDuration};
14use tracing::{debug, error, info, warn};
15use uuid::Uuid;
16
17/// Centralized tier management service implementing cognitive memory research principles
18///
19/// This service continuously monitors memory recall probabilities and automatically
20/// migrates memories between tiers based on forgetting curves and consolidation strength.
21/// It follows Ebbinghaus's forgetting curve and modern spaced repetition research.
22pub struct TierManager {
23    repository: Arc<MemoryRepository>,
24    config: TierManagerConfig,
25    math_engine: MathEngine,
26
27    // Service state
28    running: Arc<AtomicBool>,
29    last_scan_time: Arc<RwLock<Option<DateTime<Utc>>>>,
30
31    // Performance tracking
32    migrations_completed: Arc<AtomicU64>,
33    migrations_failed: Arc<AtomicU64>,
34    total_scan_time_ms: Arc<AtomicU64>,
35
36    // Prometheus metrics
37    scan_duration_histogram: Histogram,
38    migration_counter: Counter,
39    migration_failure_counter: Counter,
40    memories_per_tier_gauge: Gauge,
41    recall_probability_histogram: Histogram,
42}
43
44#[derive(Debug, Clone)]
45pub struct TierMigrationCandidate {
46    pub memory_id: Uuid,
47    pub current_tier: MemoryTier,
48    pub target_tier: MemoryTier,
49    pub recall_probability: f64,
50    pub migration_reason: String,
51    pub priority_score: f64, // Higher means more urgent migration
52}
53
54#[derive(Debug, Clone)]
55pub struct TierMigrationBatch {
56    pub candidates: Vec<TierMigrationCandidate>,
57    pub batch_id: Uuid,
58    pub estimated_duration_ms: u64,
59}
60
61#[derive(Debug, Clone)]
62pub struct TierMigrationResult {
63    pub batch_id: Uuid,
64    pub successful_migrations: Vec<Uuid>,
65    pub failed_migrations: Vec<(Uuid, String)>,
66    pub duration_ms: u64,
67    pub memories_per_second: f64,
68}
69
70#[derive(Debug, Clone)]
71pub struct TierManagerMetrics {
72    pub total_migrations_completed: u64,
73    pub total_migrations_failed: u64,
74    pub last_scan_duration_ms: u64,
75    pub memories_by_tier: HashMap<MemoryTier, u64>,
76    pub average_recall_probability_by_tier: HashMap<MemoryTier, f64>,
77    pub migrations_per_second_recent: f64,
78    pub is_running: bool,
79    pub last_scan_time: Option<DateTime<Utc>>,
80}
81
82impl TierManager {
83    /// Create a new TierManager instance
84    pub fn new(repository: Arc<MemoryRepository>, config: TierManagerConfig) -> Result<Self> {
85        // Initialize Prometheus metrics
86        let scan_duration_histogram = register_histogram!(
87            "tier_manager_scan_duration_seconds",
88            "Time taken for tier management scans",
89            vec![0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 30.0, 60.0]
90        )?;
91
92        let migration_counter = register_counter!(
93            "tier_manager_migrations_total",
94            "Total number of tier migrations completed"
95        )?;
96
97        let migration_failure_counter = register_counter!(
98            "tier_manager_migration_failures_total",
99            "Total number of tier migration failures"
100        )?;
101
102        let memories_per_tier_gauge = register_gauge!(
103            "tier_manager_memories_per_tier",
104            "Number of memories in each tier"
105        )?;
106
107        let recall_probability_histogram = register_histogram!(
108            "tier_manager_recall_probability",
109            "Distribution of recall probabilities",
110            vec![0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
111        )?;
112
113        Ok(Self {
114            repository,
115            config,
116            math_engine: MathEngine::new(),
117            running: Arc::new(AtomicBool::new(false)),
118            last_scan_time: Arc::new(RwLock::new(None)),
119            migrations_completed: Arc::new(AtomicU64::new(0)),
120            migrations_failed: Arc::new(AtomicU64::new(0)),
121            total_scan_time_ms: Arc::new(AtomicU64::new(0)),
122            scan_duration_histogram,
123            migration_counter,
124            migration_failure_counter,
125            memories_per_tier_gauge,
126            recall_probability_histogram,
127        })
128    }
129
130    /// Start the tier management service as a background task
131    pub async fn start(&self) -> Result<()> {
132        if self.running.load(Ordering::Relaxed) {
133            return Err(MemoryError::ServiceError(
134                "TierManager is already running".to_string(),
135            ));
136        }
137
138        if !self.config.enabled {
139            info!("TierManager is disabled in configuration");
140            return Ok(());
141        }
142
143        info!(
144            "Starting TierManager service with {} second scan interval",
145            self.config.scan_interval_seconds
146        );
147
148        self.running.store(true, Ordering::Relaxed);
149
150        // Start the main management loop
151        let manager = self.clone();
152        tokio::spawn(async move {
153            manager.management_loop().await;
154        });
155
156        Ok(())
157    }
158
159    /// Stop the tier management service
160    pub async fn stop(&self) {
161        info!("Stopping TierManager service");
162        self.running.store(false, Ordering::Relaxed);
163
164        // Give time for any running operations to complete
165        sleep(TokioDuration::from_secs(2)).await;
166    }
167
168    /// Get current metrics for monitoring
169    pub async fn get_metrics(&self) -> Result<TierManagerMetrics> {
170        let memories_by_tier = self.get_memory_counts_by_tier().await?;
171        let recall_probabilities = self.get_average_recall_probabilities_by_tier().await?;
172
173        Ok(TierManagerMetrics {
174            total_migrations_completed: self.migrations_completed.load(Ordering::Relaxed),
175            total_migrations_failed: self.migrations_failed.load(Ordering::Relaxed),
176            last_scan_duration_ms: self.total_scan_time_ms.load(Ordering::Relaxed),
177            memories_by_tier,
178            average_recall_probability_by_tier: recall_probabilities,
179            migrations_per_second_recent: self.calculate_recent_migration_rate().await,
180            is_running: self.running.load(Ordering::Relaxed),
181            last_scan_time: *self.last_scan_time.read().await,
182        })
183    }
184
185    /// Force an immediate tier management scan (for testing/manual triggering)
186    pub async fn force_scan(&self) -> Result<TierMigrationResult> {
187        if !self.running.load(Ordering::Relaxed) {
188            return Err(MemoryError::ServiceError(
189                "TierManager is not running".to_string(),
190            ));
191        }
192
193        info!("Forcing immediate tier management scan");
194        self.perform_tier_management_scan().await
195    }
196}
197
198// Private implementation methods
199impl TierManager {
200    /// Main management loop that runs continuously
201    async fn management_loop(&self) {
202        let mut scan_interval =
203            interval(TokioDuration::from_secs(self.config.scan_interval_seconds));
204
205        while self.running.load(Ordering::Relaxed) {
206            scan_interval.tick().await;
207
208            if let Err(e) = self.perform_tier_management_scan().await {
209                error!("Tier management scan failed: {}", e);
210                // Continue running despite errors
211            }
212        }
213
214        info!("TierManager management loop stopped");
215    }
216
217    /// Perform a complete tier management scan
218    async fn perform_tier_management_scan(&self) -> Result<TierMigrationResult> {
219        let scan_start = Instant::now();
220        let scan_time = Utc::now();
221
222        debug!("Starting tier management scan");
223
224        // Find migration candidates for each tier transition
225        let candidates = self.find_migration_candidates().await?;
226
227        if candidates.is_empty() {
228            debug!("No migration candidates found");
229            *self.last_scan_time.write().await = Some(scan_time);
230            return Ok(TierMigrationResult {
231                batch_id: Uuid::new_v4(),
232                successful_migrations: Vec::new(),
233                failed_migrations: Vec::new(),
234                duration_ms: scan_start.elapsed().as_millis() as u64,
235                memories_per_second: 0.0,
236            });
237        }
238
239        info!("Found {} migration candidates", candidates.len());
240
241        // Create migration batches
242        let batches = self.create_migration_batches(candidates);
243
244        // Process batches with concurrency control
245        let result = self.process_migration_batches(batches).await?;
246
247        // Update metrics
248        let scan_duration = scan_start.elapsed();
249        self.scan_duration_histogram
250            .observe(scan_duration.as_secs_f64());
251        self.total_scan_time_ms
252            .store(scan_duration.as_millis() as u64, Ordering::Relaxed);
253        *self.last_scan_time.write().await = Some(scan_time);
254
255        // Update tier count metrics
256        self.update_tier_metrics().await?;
257
258        info!(
259            "Tier management scan completed: {} successful, {} failed, {:.2} migrations/sec",
260            result.successful_migrations.len(),
261            result.failed_migrations.len(),
262            result.memories_per_second
263        );
264
265        Ok(result)
266    }
267
268    /// Find all memories that should be migrated to different tiers
269    async fn find_migration_candidates(&self) -> Result<Vec<TierMigrationCandidate>> {
270        let mut candidates = Vec::new();
271
272        // Check each tier for migration candidates
273        for tier in [MemoryTier::Working, MemoryTier::Warm, MemoryTier::Cold] {
274            let tier_candidates = self.find_candidates_for_tier(tier).await?;
275            candidates.extend(tier_candidates);
276        }
277
278        // Sort by priority (higher priority first)
279        candidates.sort_by(|a, b| {
280            b.priority_score
281                .partial_cmp(&a.priority_score)
282                .unwrap_or(std::cmp::Ordering::Equal)
283        });
284
285        Ok(candidates)
286    }
287
288    /// Find migration candidates for a specific source tier
289    async fn find_candidates_for_tier(
290        &self,
291        source_tier: MemoryTier,
292    ) -> Result<Vec<TierMigrationCandidate>> {
293        // Get minimum age threshold for this tier
294        let min_age_hours = match source_tier {
295            MemoryTier::Working => self.config.min_working_age_hours,
296            MemoryTier::Warm => self.config.min_warm_age_hours,
297            MemoryTier::Cold => self.config.min_cold_age_hours,
298            MemoryTier::Frozen => return Ok(Vec::new()), // Frozen memories don't migrate
299        };
300
301        let min_age_time = Utc::now() - Duration::hours(min_age_hours as i64);
302
303        // Get all memories in this tier using a more robust approach
304        // We'll get a limited set and filter by age in Rust to avoid schema issues
305        let query_ids = sqlx::query_scalar!(
306            "SELECT id FROM memories WHERE tier = $1 AND status = 'active' AND updated_at <= $2 ORDER BY updated_at ASC LIMIT 1000",
307            source_tier as MemoryTier,
308            min_age_time
309        )
310        .fetch_all(self.repository.pool())
311        .await?;
312
313        let mut candidates = Vec::new();
314
315        // Process memories in batches to avoid overwhelming the system
316        for memory_id in query_ids {
317            // Use the repository's get_memory method to handle schema variations properly
318            if let Ok(memory) = self.repository.get_memory(memory_id).await {
319                if let Some(candidate) = self.evaluate_migration_candidate(&memory).await? {
320                    candidates.push(candidate);
321                }
322            }
323        }
324
325        Ok(candidates)
326    }
327
328    /// Evaluate if a memory should be migrated and determine target tier
329    async fn evaluate_migration_candidate(
330        &self,
331        memory: &Memory,
332    ) -> Result<Option<TierMigrationCandidate>> {
333        // Calculate current recall probability using the math engine
334        let recall_probability = self.calculate_recall_probability(memory)?;
335
336        // Record this measurement for metrics
337        if self.config.enable_metrics {
338            self.recall_probability_histogram
339                .observe(recall_probability);
340        }
341
342        // Determine if migration is needed based on thresholds
343        let (should_migrate, target_tier, reason) = match memory.tier {
344            MemoryTier::Working => {
345                if recall_probability < self.config.working_to_warm_threshold {
346                    (
347                        true,
348                        MemoryTier::Warm,
349                        format!(
350                            "Recall probability {:.3} below threshold {:.3}",
351                            recall_probability, self.config.working_to_warm_threshold
352                        ),
353                    )
354                } else {
355                    (false, memory.tier, String::new())
356                }
357            }
358            MemoryTier::Warm => {
359                if recall_probability < self.config.warm_to_cold_threshold {
360                    (
361                        true,
362                        MemoryTier::Cold,
363                        format!(
364                            "Recall probability {:.3} below threshold {:.3}",
365                            recall_probability, self.config.warm_to_cold_threshold
366                        ),
367                    )
368                } else {
369                    (false, memory.tier, String::new())
370                }
371            }
372            MemoryTier::Cold => {
373                if recall_probability < self.config.cold_to_frozen_threshold {
374                    (
375                        true,
376                        MemoryTier::Frozen,
377                        format!(
378                            "Recall probability {:.3} below threshold {:.3}",
379                            recall_probability, self.config.cold_to_frozen_threshold
380                        ),
381                    )
382                } else {
383                    (false, memory.tier, String::new())
384                }
385            }
386            MemoryTier::Frozen => (false, memory.tier, String::new()), // Frozen never migrates
387        };
388
389        if !should_migrate {
390            return Ok(None);
391        }
392
393        // Calculate priority score (lower recall probability = higher priority)
394        let age_factor = Utc::now()
395            .signed_duration_since(memory.updated_at)
396            .num_hours() as f64
397            / 24.0;
398        let priority_score = (1.0 - recall_probability) * (1.0 + age_factor.ln().max(0.0));
399
400        Ok(Some(TierMigrationCandidate {
401            memory_id: memory.id,
402            current_tier: memory.tier,
403            target_tier,
404            recall_probability,
405            migration_reason: reason,
406            priority_score,
407        }))
408    }
409
410    /// Calculate recall probability for a memory using the math engine
411    fn calculate_recall_probability(&self, memory: &Memory) -> Result<f64> {
412        let params = MemoryParameters {
413            consolidation_strength: memory.consolidation_strength,
414            decay_rate: memory.decay_rate,
415            last_accessed_at: memory.last_accessed_at,
416            created_at: memory.created_at,
417            access_count: memory.access_count,
418            importance_score: memory.importance_score,
419        };
420
421        match self.math_engine.calculate_recall_probability(&params) {
422            Ok(result) => Ok(result.recall_probability),
423            Err(e) => {
424                warn!(
425                    "Math engine calculation failed for memory {}: {}",
426                    memory.id, e
427                );
428                // Use fallback calculation based on consolidation strength and importance
429                let fallback = (memory.importance_score * memory.consolidation_strength / 10.0)
430                    .min(1.0)
431                    .max(0.0);
432                Ok(fallback)
433            }
434        }
435    }
436
437    /// Create migration batches from candidates
438    fn create_migration_batches(
439        &self,
440        candidates: Vec<TierMigrationCandidate>,
441    ) -> Vec<TierMigrationBatch> {
442        let mut batches = Vec::new();
443        let batch_size = self.config.migration_batch_size;
444
445        for chunk in candidates.chunks(batch_size) {
446            let batch = TierMigrationBatch {
447                candidates: chunk.to_vec(),
448                batch_id: Uuid::new_v4(),
449                estimated_duration_ms: self.estimate_batch_duration(chunk.len()),
450            };
451            batches.push(batch);
452        }
453
454        batches
455    }
456
457    /// Estimate how long a batch will take to process
458    fn estimate_batch_duration(&self, batch_size: usize) -> u64 {
459        // Based on target performance: 1000 migrations/second means 1ms per migration
460        // Add 20% overhead for safety
461        (batch_size as f64 * 1.2) as u64
462    }
463
464    /// Process migration batches with concurrency control
465    async fn process_migration_batches(
466        &self,
467        batches: Vec<TierMigrationBatch>,
468    ) -> Result<TierMigrationResult> {
469        let start_time = Instant::now();
470        let mut all_successful = Vec::new();
471        let mut all_failed = Vec::new();
472
473        // Process batches with concurrency limit
474        let semaphore = Arc::new(tokio::sync::Semaphore::new(
475            self.config.max_concurrent_migrations,
476        ));
477        let mut handles = Vec::new();
478
479        for batch in batches {
480            let semaphore = semaphore.clone();
481            let repository = self.repository.clone();
482            let config = self.config.clone();
483
484            let handle = tokio::spawn(async move {
485                let _permit = semaphore
486                    .acquire()
487                    .await
488                    .expect("Semaphore acquisition failed");
489                Self::process_single_batch(repository, batch, config).await
490            });
491
492            handles.push(handle);
493        }
494
495        // Wait for all batches to complete
496        for handle in handles {
497            match handle.await {
498                Ok(Ok(result)) => {
499                    all_successful.extend(result.successful_migrations);
500                    all_failed.extend(result.failed_migrations);
501                }
502                Ok(Err(e)) => {
503                    error!("Batch processing failed: {}", e);
504                }
505                Err(e) => {
506                    error!("Batch task panicked: {}", e);
507                }
508            }
509        }
510
511        let duration = start_time.elapsed();
512        let total_migrations = all_successful.len() + all_failed.len();
513        let memories_per_second = if duration.as_secs_f64() > 0.0 {
514            total_migrations as f64 / duration.as_secs_f64()
515        } else {
516            0.0
517        };
518
519        // Update counters
520        self.migrations_completed
521            .fetch_add(all_successful.len() as u64, Ordering::Relaxed);
522        self.migrations_failed
523            .fetch_add(all_failed.len() as u64, Ordering::Relaxed);
524
525        if self.config.enable_metrics {
526            self.migration_counter.inc_by(all_successful.len() as f64);
527            self.migration_failure_counter
528                .inc_by(all_failed.len() as f64);
529        }
530
531        Ok(TierMigrationResult {
532            batch_id: Uuid::new_v4(),
533            successful_migrations: all_successful,
534            failed_migrations: all_failed,
535            duration_ms: duration.as_millis() as u64,
536            memories_per_second,
537        })
538    }
539
540    /// Process a single migration batch
541    async fn process_single_batch(
542        repository: Arc<MemoryRepository>,
543        batch: TierMigrationBatch,
544        config: TierManagerConfig,
545    ) -> Result<TierMigrationResult> {
546        let start_time = Instant::now();
547        let mut successful = Vec::new();
548        let mut failed = Vec::new();
549
550        for candidate in &batch.candidates {
551            match Self::migrate_single_memory(&repository, candidate, &config).await {
552                Ok(_) => {
553                    successful.push(candidate.memory_id);
554                    debug!(
555                        "Successfully migrated memory {} from {:?} to {:?}",
556                        candidate.memory_id, candidate.current_tier, candidate.target_tier
557                    );
558                }
559                Err(e) => {
560                    failed.push((candidate.memory_id, e.to_string()));
561                    warn!("Failed to migrate memory {}: {}", candidate.memory_id, e);
562                }
563            }
564        }
565
566        let duration = start_time.elapsed();
567        let memories_per_second = if duration.as_secs_f64() > 0.0 {
568            batch.candidates.len() as f64 / duration.as_secs_f64()
569        } else {
570            0.0
571        };
572
573        Ok(TierMigrationResult {
574            batch_id: batch.batch_id,
575            successful_migrations: successful,
576            failed_migrations: failed,
577            duration_ms: duration.as_millis() as u64,
578            memories_per_second,
579        })
580    }
581
582    /// Migrate a single memory to its target tier
583    async fn migrate_single_memory(
584        repository: &MemoryRepository,
585        candidate: &TierMigrationCandidate,
586        config: &TierManagerConfig,
587    ) -> Result<()> {
588        let mut tx = repository.pool().begin().await?;
589
590        // Update memory tier
591        sqlx::query!(
592            "UPDATE memories SET tier = $1, updated_at = NOW() WHERE id = $2",
593            candidate.target_tier as MemoryTier,
594            candidate.memory_id
595        )
596        .execute(&mut *tx)
597        .await?;
598
599        // Log migration if enabled
600        if config.log_migrations {
601            sqlx::query!(
602                r#"
603                INSERT INTO memory_consolidation_log (
604                    id, memory_id, old_consolidation_strength, new_consolidation_strength, 
605                    old_recall_probability, new_recall_probability, consolidation_event, 
606                    trigger_reason, created_at
607                ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, NOW())
608                "#,
609                Uuid::new_v4(),
610                candidate.memory_id,
611                0.0, // We don't change consolidation strength during tier migration
612                0.0,
613                Some(candidate.recall_probability),
614                Some(candidate.recall_probability),
615                format!(
616                    "tier_migration_{}_{}",
617                    format!("{:?}", candidate.current_tier).to_lowercase(),
618                    format!("{:?}", candidate.target_tier).to_lowercase()
619                ),
620                Some(format!(
621                    "{}. Priority score: {:.3}",
622                    candidate.migration_reason, candidate.priority_score
623                ))
624            )
625            .execute(&mut *tx)
626            .await?;
627        }
628
629        tx.commit().await?;
630        Ok(())
631    }
632
633    /// Get memory counts by tier for metrics
634    async fn get_memory_counts_by_tier(&self) -> Result<HashMap<MemoryTier, u64>> {
635        let rows = sqlx::query!(
636            r#"
637            SELECT tier as "tier: MemoryTier", COUNT(*) as count
638            FROM memories
639            WHERE status = 'active'
640            GROUP BY tier
641            "#
642        )
643        .fetch_all(self.repository.pool())
644        .await?;
645
646        let mut counts = HashMap::new();
647        for row in rows {
648            counts.insert(row.tier, row.count.unwrap_or(0) as u64);
649        }
650
651        Ok(counts)
652    }
653
654    /// Get average recall probabilities by tier
655    async fn get_average_recall_probabilities_by_tier(&self) -> Result<HashMap<MemoryTier, f64>> {
656        let rows = sqlx::query!(
657            r#"
658            SELECT tier as "tier: MemoryTier", AVG(recall_probability) as avg_recall_prob
659            FROM memories
660            WHERE status = 'active' AND recall_probability IS NOT NULL
661            GROUP BY tier
662            "#
663        )
664        .fetch_all(self.repository.pool())
665        .await?;
666
667        let mut averages = HashMap::new();
668        for row in rows {
669            if let Some(avg) = row.avg_recall_prob {
670                averages.insert(row.tier, avg);
671            }
672        }
673
674        Ok(averages)
675    }
676
677    /// Calculate recent migration rate for performance monitoring
678    async fn calculate_recent_migration_rate(&self) -> f64 {
679        // This is a simplified calculation - in production you might want to track
680        // migrations over a sliding time window
681        let completed = self.migrations_completed.load(Ordering::Relaxed);
682        let scan_time_ms = self.total_scan_time_ms.load(Ordering::Relaxed);
683
684        if scan_time_ms > 0 {
685            (completed as f64 * 1000.0) / scan_time_ms as f64
686        } else {
687            0.0
688        }
689    }
690
691    /// Update Prometheus metrics for tier counts
692    async fn update_tier_metrics(&self) -> Result<()> {
693        if !self.config.enable_metrics {
694            return Ok(());
695        }
696
697        let counts = self.get_memory_counts_by_tier().await?;
698
699        for (_tier, count) in counts {
700            // Set gauge with tier label - this is simplified; in production you'd use labeled metrics
701            self.memories_per_tier_gauge.set(count as f64);
702        }
703
704        Ok(())
705    }
706}
707
708// Clone implementation for moving the manager into async tasks
709impl Clone for TierManager {
710    fn clone(&self) -> Self {
711        Self {
712            repository: self.repository.clone(),
713            config: self.config.clone(),
714            math_engine: MathEngine::new(), // Math engine is stateless
715            running: self.running.clone(),
716            last_scan_time: self.last_scan_time.clone(),
717            migrations_completed: self.migrations_completed.clone(),
718            migrations_failed: self.migrations_failed.clone(),
719            total_scan_time_ms: self.total_scan_time_ms.clone(),
720            scan_duration_histogram: self.scan_duration_histogram.clone(),
721            migration_counter: self.migration_counter.clone(),
722            migration_failure_counter: self.migration_failure_counter.clone(),
723            memories_per_tier_gauge: self.memories_per_tier_gauge.clone(),
724            recall_probability_histogram: self.recall_probability_histogram.clone(),
725        }
726    }
727}
728
729#[cfg(test)]
730mod tests {
731    use super::*;
732    use crate::memory::connection::create_pool;
733    use anyhow::Context;
734    use dotenv::dotenv;
735    use sqlx::PgPool;
736
737    async fn create_test_pool() -> anyhow::Result<PgPool> {
738        // Load environment variables from .env file
739        let _ = dotenv();
740
741        let database_url = std::env::var("DATABASE_URL")
742            .context("DATABASE_URL environment variable not set. Ensure .env file is present.")?;
743
744        create_pool(&database_url, 5).await
745    }
746
747    #[tokio::test]
748    async fn test_tier_manager_creation() {
749        let pool = create_test_pool().await.unwrap();
750        let repository = Arc::new(MemoryRepository::new(pool));
751        let config = TierManagerConfig::default();
752
753        let manager = TierManager::new(repository, config);
754        assert!(manager.is_ok());
755    }
756
757    #[tokio::test]
758    async fn test_migration_candidate_evaluation() {
759        // This test would need a proper test database setup
760        // For now, we'll just test the structure
761        let memory = Memory {
762            consolidation_strength: 1.0,
763            decay_rate: 1.0,
764            recall_probability: Some(0.3), // Below working threshold
765            tier: MemoryTier::Working,
766            ..Memory::default()
767        };
768
769        // The actual test would check that this memory gets flagged for migration
770        assert_eq!(memory.tier, MemoryTier::Working);
771    }
772}