#[cfg(test)]
use std::sync::Arc;
#[cfg(test)]
use crate::state_store::StateStore;
use anyhow::Result;
use serde::{Deserialize, Serialize};
use super::{MemoryEntry, MemoryManager, MemoryType};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MigrationProgress {
pub total: usize,
pub migrated: usize,
pub failed: usize,
pub errors: Vec<(String, String)>,
}
impl MigrationProgress {
pub fn new(total: usize) -> Self {
Self {
total,
migrated: 0,
failed: 0,
errors: Vec::new(),
}
}
pub fn is_complete(&self) -> bool {
self.migrated + self.failed >= self.total
}
pub fn success_rate(&self) -> f32 {
if self.total == 0 {
1.0
} else {
self.migrated as f32 / self.total as f32
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MigrationReport {
pub progress: MigrationProgress,
pub duration_ms: u64,
}
const BATCH_SIZE: usize = 32;
impl MemoryManager {
pub async fn migrate_from_tfidf(&self) -> Result<MigrationReport> {
let start = std::time::Instant::now();
let all_types = [
MemoryType::Conversation,
MemoryType::Session,
MemoryType::Fact,
MemoryType::Episode,
MemoryType::Knowledge,
];
let mut all_entries: Vec<MemoryEntry> = Vec::new();
for mt in &all_types {
if let Ok(entries) = self.list(*mt, 100_000).await {
all_entries.extend(entries);
}
}
let total = all_entries.len();
let mut progress = MigrationProgress::new(total);
tracing::info!(total, "Starting TF-IDF migration");
for chunk in all_entries.chunks(BATCH_SIZE) {
for entry in chunk {
match self.re_index_entry(entry).await {
Ok(()) => {
progress.migrated += 1;
}
Err(e) => {
progress.failed += 1;
progress.errors.push((entry.id.clone(), e.to_string()));
tracing::warn!(id = %entry.id, error = %e, "Migration failed for entry");
}
}
}
if progress.migrated % BATCH_SIZE == 0 {
tracing::info!(
migrated = progress.migrated,
failed = progress.failed,
total = progress.total,
"Migration progress"
);
}
}
let duration_ms = start.elapsed().as_millis() as u64;
tracing::info!(
migrated = progress.migrated,
failed = progress.failed,
total = progress.total,
duration_ms,
"Migration complete"
);
Ok(MigrationReport {
progress,
duration_ms,
})
}
async fn re_index_entry(&self, entry: &MemoryEntry) -> Result<()> {
let vector = self.embedding.embed(&entry.content).await?;
{
let mut index = self.vector_index.write();
index.insert(entry.id.clone(), vector);
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_migration_progress_new() {
let p = MigrationProgress::new(100);
assert_eq!(p.total, 100);
assert_eq!(p.migrated, 0);
assert_eq!(p.failed, 0);
assert!(!p.is_complete());
}
#[test]
fn test_migration_progress_complete() {
let mut p = MigrationProgress::new(10);
p.migrated = 8;
p.failed = 2;
assert!(p.is_complete());
assert!((p.success_rate() - 0.8).abs() < 0.01);
}
#[test]
fn test_migration_progress_empty() {
let p = MigrationProgress::new(0);
assert!(p.is_complete());
assert_eq!(p.success_rate(), 1.0);
}
#[tokio::test]
async fn test_migrate_empty_store() {
let temp_dir = tempfile::tempdir().unwrap();
let store = Arc::new(StateStore::new(temp_dir.path().to_path_buf()).unwrap());
let mgr = MemoryManager::new(store);
let report = mgr.migrate_from_tfidf().await.unwrap();
assert_eq!(report.progress.total, 0);
assert_eq!(report.progress.migrated, 0);
}
#[tokio::test]
async fn test_migrate_with_entries() {
let temp_dir = tempfile::tempdir().unwrap();
let store = Arc::new(StateStore::new(temp_dir.path().to_path_buf()).unwrap());
let mgr = MemoryManager::new(store.clone());
for i in 0..5 {
let entry = MemoryEntry {
id: format!("migrate-test-{}", i),
memory_type: MemoryType::Fact,
content: format!("Test content for migration entry {}", i),
source: "test".to_string(),
session_id: None,
tags: vec![],
importance: 0.5,
created_at: chrono::Utc::now(),
accessed_at: chrono::Utc::now(),
access_count: 0,
};
mgr.remember(entry).await.unwrap();
}
{
let mut index = mgr.vector_index.write();
index.clear();
}
assert_eq!(mgr.vector_index_size(), 0);
let report = mgr.migrate_from_tfidf().await.unwrap();
assert_eq!(report.progress.total, 5);
assert_eq!(report.progress.migrated, 5);
assert_eq!(report.progress.failed, 0);
assert_eq!(mgr.vector_index_size(), 5);
}
}