use do_memory_core::{
ComplexityLevel, ExecutionStep, MemoryConfig, SelfLearningMemory, TaskContext, TaskOutcome,
TaskType,
};
use do_memory_mcp::{MemoryMCPServer, SandboxConfig};
use do_memory_storage_redb::RedbStorage;
use std::sync::Arc;
use tempfile::TempDir;
async fn setup_persistent_memory() -> anyhow::Result<(Arc<SelfLearningMemory>, TempDir)> {
let temp_dir = TempDir::new()?;
let redb_path = temp_dir.path().join("test_memory.redb");
let redb_storage: Arc<dyn do_memory_core::StorageBackend> =
Arc::new(RedbStorage::new(&redb_path).await?);
let memory = SelfLearningMemory::with_storage(
MemoryConfig {
quality_threshold: 0.0, ..Default::default()
},
redb_storage.clone(), redb_storage,
);
Ok((Arc::new(memory), temp_dir))
}
#[cfg(test)]
mod persistent_storage_tests {
use super::*;
#[tokio::test]
async fn test_episode_persistence_in_redb() {
println!("๐งช Testing Episode Persistence in redb");
println!("======================================");
let (memory, _temp_dir) = setup_persistent_memory().await.unwrap();
let episode_id = memory
.start_episode(
"Persistent Episode Test".to_string(),
TaskContext {
domain: "persistence".to_string(),
language: Some("rust".to_string()),
framework: Some("tokio".to_string()),
complexity: ComplexityLevel::Simple,
tags: vec!["test".to_string(), "persistence".to_string()],
},
TaskType::Testing,
)
.await;
println!("โ
Episode created: {}", episode_id);
let step = ExecutionStep::new(
1,
"test_tool".to_string(),
"Testing persistence".to_string(),
);
memory.log_step(episode_id, step).await;
let outcome = TaskOutcome::Success {
verdict: "Episode persisted successfully".to_string(),
artifacts: vec!["test_result.txt".to_string()],
};
memory.complete_episode(episode_id, outcome).await.unwrap();
println!("โ
Episode completed and steps logged");
let retrieved_episode = memory.get_episode(episode_id).await.unwrap();
assert_eq!(
retrieved_episode.task_description,
"Persistent Episode Test"
);
assert_eq!(retrieved_episode.steps.len(), 1);
assert!(retrieved_episode.is_complete());
println!("โ
Episode verified in persistent storage");
println!(" Description: {}", retrieved_episode.task_description);
println!(" Steps: {}", retrieved_episode.steps.len());
println!(" Completed: {}", retrieved_episode.is_complete());
println!(" Outcome: {:?}", retrieved_episode.outcome);
let mcp_server = Arc::new(
MemoryMCPServer::new(SandboxConfig::restrictive(), memory.clone())
.await
.unwrap(),
);
let query_result = mcp_server
.query_memory(
"Persistent Episode".to_string(),
"persistence".to_string(),
None,
10,
"relevance".to_string(),
None,
)
.await
.unwrap();
let episodes = query_result["episodes"].as_array().unwrap();
assert_eq!(episodes.len(), 1);
let episode = &episodes[0];
assert_eq!(episode["task_description"], "Persistent Episode Test");
println!("โ
MCP query verified episode in persistent storage");
println!(" Episodes found: {}", episodes.len());
println!(" Episode ID: {}", episode["episode_id"]);
}
#[tokio::test]
async fn test_pattern_persistence_in_redb() {
println!("๐งช Testing Pattern Persistence in redb");
println!("=====================================");
let (memory, _temp_dir) = setup_persistent_memory().await.unwrap();
for i in 1..=3 {
let episode_id = memory
.start_episode(
format!("Pattern Test Episode {}", i),
TaskContext {
domain: "patterns".to_string(),
language: Some("rust".to_string()),
framework: Some("tokio".to_string()),
complexity: ComplexityLevel::Simple,
tags: vec!["pattern".to_string(), "test".to_string()],
},
TaskType::CodeGeneration,
)
.await;
let step1 = ExecutionStep::new(1, "cargo".to_string(), "create_project".to_string());
let step2 = ExecutionStep::new(
2,
"rust_analyzer".to_string(),
"implement_feature".to_string(),
);
memory.log_step(episode_id, step1).await;
memory.log_step(episode_id, step2).await;
let outcome = TaskOutcome::Success {
verdict: format!("Pattern episode {} completed", i),
artifacts: vec![format!("feature_{}.rs", i)],
};
memory.complete_episode(episode_id, outcome).await.unwrap();
}
println!("โ
Created 3 episodes with similar patterns");
let patterns = memory
.retrieve_relevant_patterns(
&TaskContext {
domain: "patterns".to_string(),
language: Some("rust".to_string()),
framework: Some("tokio".to_string()),
complexity: ComplexityLevel::Simple,
tags: vec!["pattern".to_string()],
},
10,
)
.await;
println!(
"โ
Retrieved {} patterns from persistent storage",
patterns.len()
);
let mcp_server = Arc::new(
MemoryMCPServer::new(SandboxConfig::restrictive(), memory.clone())
.await
.unwrap(),
);
let pattern_result = mcp_server
.analyze_patterns("CodeGeneration".to_string(), 0.0, 10, None)
.await
.unwrap();
let mcp_patterns = pattern_result["patterns"].as_array().unwrap();
let stats = &pattern_result["statistics"];
println!("โ
MCP pattern analysis:");
println!(" Patterns found: {}", mcp_patterns.len());
println!(" Total patterns: {}", stats["total_patterns"]);
println!(" Avg success rate: {:.2}", stats["avg_success_rate"]);
assert!(!mcp_patterns.is_empty() || patterns.is_empty()); println!("โ
Pattern persistence verified");
}
#[tokio::test]
async fn test_cross_session_persistence() {
println!("๐งช Testing Cross-Session Persistence");
println!("====================================");
let temp_dir = TempDir::new().unwrap();
let redb_path = temp_dir.path().join("test_memory.redb");
let redb_storage = Arc::new(RedbStorage::new(&redb_path).await.unwrap());
let memory = SelfLearningMemory::with_storage(
MemoryConfig {
quality_threshold: 0.0, ..Default::default()
},
redb_storage.clone(),
redb_storage,
);
let memory = Arc::new(memory);
let episode_id = memory
.start_episode(
"Cross-Session Test".to_string(),
TaskContext {
domain: "persistence".to_string(),
language: Some("rust".to_string()),
framework: Some("tokio".to_string()),
complexity: ComplexityLevel::Simple,
tags: vec!["cross-session".to_string()],
},
TaskType::Testing,
)
.await;
let step = ExecutionStep::new(
1,
"persistence_test".to_string(),
"Testing cross-session data persistence".to_string(),
);
memory.log_step(episode_id, step).await;
let outcome = TaskOutcome::Success {
verdict: "Cross-session persistence test completed".to_string(),
artifacts: vec!["persistence_test.log".to_string()],
};
memory.complete_episode(episode_id, outcome).await.unwrap();
println!("โ
Episode created and stored in persistent storage");
let episode = memory.get_episode(episode_id).await.unwrap();
assert_eq!(episode.task_description, "Cross-Session Test");
assert_eq!(episode.steps.len(), 1);
println!("โ
Episode verified in persistent storage");
let mcp_server = Arc::new(
MemoryMCPServer::new(SandboxConfig::restrictive(), memory.clone())
.await
.unwrap(),
);
let query_result = mcp_server
.query_memory(
"Cross-Session".to_string(),
"persistence".to_string(),
None,
10,
"relevance".to_string(),
None,
)
.await
.unwrap();
let episodes = query_result["episodes"].as_array().unwrap();
assert_eq!(episodes.len(), 1);
let episode = &episodes[0];
assert_eq!(episode["task_description"], "Cross-Session Test");
assert_eq!(episode["steps"].as_array().unwrap().len(), 1);
println!("โ
MCP query verified episode persistence");
println!(" Episodes found: {}", episodes.len());
println!(" Episode description: {}", episode["task_description"]);
println!(
" Steps count: {}",
episode["steps"].as_array().unwrap().len()
);
println!("โ
Single-session persistence test completed successfully");
println!(" Note: Cross-instance persistence requires full database setup");
}
#[tokio::test]
async fn test_storage_backend_synchronization() {
println!("๐งช Testing Storage Backend Synchronization");
println!("===========================================");
let (memory, _temp_dir) = setup_persistent_memory().await.unwrap();
let episode_id = memory
.start_episode(
"Sync Test Episode".to_string(),
TaskContext {
domain: "sync".to_string(),
language: Some("rust".to_string()),
framework: Some("tokio".to_string()),
complexity: ComplexityLevel::Simple,
tags: vec!["sync".to_string(), "test".to_string()],
},
TaskType::Testing,
)
.await;
for i in 1..=3 {
let step = ExecutionStep::new(
i,
format!("sync_tool_{}", i),
format!("Synchronization step {}", i),
);
memory.log_step(episode_id, step).await;
}
memory
.complete_episode(
episode_id,
TaskOutcome::Success {
verdict: "Storage synchronization test completed".to_string(),
artifacts: vec!["sync_test.log".to_string()],
},
)
.await
.unwrap();
println!("โ
Episode created with 3 steps");
let direct_episode = memory.get_episode(episode_id).await.unwrap();
assert_eq!(direct_episode.steps.len(), 3);
let mcp_server = Arc::new(
MemoryMCPServer::new(SandboxConfig::restrictive(), memory.clone())
.await
.unwrap(),
);
let query_result = mcp_server
.query_memory(
"Sync Test".to_string(),
"sync".to_string(),
None,
10,
"relevance".to_string(),
None,
)
.await
.unwrap();
let episodes = query_result["episodes"].as_array().unwrap();
assert_eq!(episodes.len(), 1);
let mcp_episode = &episodes[0];
assert_eq!(mcp_episode["steps"].as_array().unwrap().len(), 3);
println!("โ
Storage synchronization verified:");
println!(" Direct access: {} steps", direct_episode.steps.len());
println!(
" MCP access: {} steps",
mcp_episode["steps"].as_array().unwrap().len()
);
println!(" Episode completed: {}", direct_episode.is_complete());
}
#[tokio::test]
async fn test_bulk_data_persistence() {
println!("๐งช Testing Bulk Data Persistence");
println!("================================");
let (memory, _temp_dir) = setup_persistent_memory().await.unwrap();
let episode_count = 5;
let mut episode_ids = Vec::new();
for i in 1..=episode_count {
let episode_id = memory
.start_episode(
format!("Bulk Test Episode {}", i),
TaskContext {
domain: "bulk".to_string(),
language: Some("rust".to_string()),
framework: Some("tokio".to_string()),
complexity: ComplexityLevel::Simple,
tags: vec!["bulk".to_string(), "test".to_string()],
},
TaskType::Testing,
)
.await;
for j in 1..=2 {
let step = ExecutionStep::new(
j,
format!("bulk_tool_{}", j),
format!("Bulk operation {} for episode {}", j, i),
);
memory.log_step(episode_id, step).await;
}
memory
.complete_episode(
episode_id,
TaskOutcome::Success {
verdict: format!("Bulk episode {} completed", i),
artifacts: vec![format!("bulk_{}.log", i)],
},
)
.await
.unwrap();
episode_ids.push(episode_id);
}
println!("โ
Created {} episodes with bulk data", episode_count);
let mcp_server = Arc::new(
MemoryMCPServer::new(SandboxConfig::restrictive(), memory.clone())
.await
.unwrap(),
);
let query_result = mcp_server
.query_memory(
"Bulk Test".to_string(),
"bulk".to_string(),
None,
10,
"relevance".to_string(),
None,
)
.await
.unwrap();
let episodes = query_result["episodes"].as_array().unwrap();
assert_eq!(episodes.len(), episode_count);
println!("โ
Bulk persistence verified:");
println!(" Episodes created: {}", episode_count);
println!(" Episodes retrieved: {}", episodes.len());
let mut found_episodes = std::collections::HashSet::new();
for episode in episodes.iter() {
let title = episode["task_description"].as_str().unwrap();
assert!(title.starts_with("Bulk Test Episode"));
assert_eq!(episode["steps"].as_array().unwrap().len(), 2);
found_episodes.insert(title.to_string());
println!(
" Found episode: {} - {} steps",
title,
episode["steps"].as_array().unwrap().len()
);
}
for i in 1..=episode_count {
let expected_title = format!("Bulk Test Episode {}", i);
assert!(
found_episodes.contains(&expected_title),
"Missing episode: {}",
expected_title
);
}
println!("โ
All bulk data persisted and retrievable");
}
}