mod common;
use common::{create_test_step, setup_test_memory, test_context};
use do_memory_core::{
ComplexityLevel, ExecutionResult, ExecutionStep, Pattern, TaskContext, TaskOutcome, TaskType,
memory::SelfLearningMemory,
};
use std::time::{Duration, Instant};
use uuid::Uuid;
fn load_historical_test_episodes() -> Vec<Episode> {
vec![
create_error_recovery_episode(),
create_tool_sequence_episode(),
create_optimization_episode(),
]
}
fn create_error_recovery_episode() -> Episode {
use do_memory_core::episode::Episode;
let context = TaskContext {
language: Some("rust".to_string()),
domain: "error-handling".to_string(),
tags: vec!["retry".to_string()],
..Default::default()
};
let mut episode = Episode::new(
"Implement retry logic".to_string(),
context,
TaskType::CodeGeneration,
);
let mut error_step = ExecutionStep::new(
1,
"initial_attempt".to_string(),
"Try operation".to_string(),
);
error_step.result = Some(ExecutionResult::Error {
message: "Connection timeout".to_string(),
});
episode.add_step(error_step);
let mut retry_step = ExecutionStep::new(
2,
"retry_handler".to_string(),
"Retry with backoff".to_string(),
);
retry_step.result = Some(ExecutionResult::Success {
output: "Success".to_string(),
});
episode.add_step(retry_step);
episode.complete(TaskOutcome::Success {
verdict: "Retry worked".to_string(),
artifacts: vec![],
});
episode
}
fn create_tool_sequence_episode() -> Episode {
use do_memory_core::episode::Episode;
let context = TaskContext {
language: Some("rust".to_string()),
domain: "code-generation".to_string(),
tags: vec!["sequential".to_string()],
..Default::default()
};
let mut episode = Episode::new(
"Build API endpoint".to_string(),
context,
TaskType::CodeGeneration,
);
let tools = ["analyzer", "designer", "builder", "tester"];
for (i, tool) in tools.iter().enumerate() {
let mut step = ExecutionStep::new(i + 1, (*tool).to_string(), format!("{tool} step"));
step.result = Some(ExecutionResult::Success {
output: "Done".to_string(),
});
episode.add_step(step);
}
episode.complete(TaskOutcome::Success {
verdict: "API built".to_string(),
artifacts: vec!["api.rs".to_string()],
});
episode
}
fn create_optimization_episode() -> Episode {
use do_memory_core::episode::Episode;
let context = TaskContext {
language: Some("rust".to_string()),
domain: "performance".to_string(),
tags: vec!["optimization".to_string()],
..Default::default()
};
let mut episode = Episode::new("Optimize query".to_string(), context, TaskType::Refactoring);
let mut slow_step = ExecutionStep::new(1, "profiler".to_string(), "Profile code".to_string());
slow_step.latency_ms = 1000; slow_step.result = Some(ExecutionResult::Success {
output: "Found bottleneck".to_string(),
});
episode.add_step(slow_step);
let mut fast_step =
ExecutionStep::new(2, "optimizer".to_string(), "Apply optimization".to_string());
fast_step.latency_ms = 100; fast_step.result = Some(ExecutionResult::Success {
output: "Optimized".to_string(),
});
episode.add_step(fast_step);
episode.complete(TaskOutcome::Success {
verdict: "10x faster".to_string(),
artifacts: vec![],
});
episode
}
fn load_reference_patterns() -> Vec<Pattern> {
use chrono::Duration;
use do_memory_core::PatternEffectiveness;
vec![
Pattern::ErrorRecovery {
id: Uuid::new_v4(),
context: TaskContext::default(),
error_type: "timeout".to_string(),
recovery_steps: vec!["retry_with_backoff".to_string()],
success_rate: 0.9,
effectiveness: PatternEffectiveness::default(),
},
Pattern::ToolSequence {
id: Uuid::new_v4(),
context: TaskContext::default(),
tools: vec![
"analyzer".to_string(),
"designer".to_string(),
"builder".to_string(),
"tester".to_string(),
],
success_rate: 0.85,
avg_latency: Duration::milliseconds(60000),
occurrence_count: 10,
effectiveness: PatternEffectiveness::default(),
},
]
}
#[allow(clippy::cast_precision_loss)]
fn calculate_pattern_similarity(extracted: &[Pattern], reference: &[Pattern]) -> f32 {
if reference.is_empty() {
return 0.0;
}
let mut matches = 0;
for ref_pattern in reference {
for ext_pattern in extracted {
if patterns_match(ref_pattern, ext_pattern) {
matches += 1;
break;
}
}
}
matches as f32 / reference.len() as f32
}
fn patterns_match(p1: &Pattern, p2: &Pattern) -> bool {
match (p1, p2) {
(
Pattern::ErrorRecovery { error_type: e1, .. },
Pattern::ErrorRecovery { error_type: e2, .. },
) => e1.contains(e2) || e2.contains(e1),
(Pattern::ToolSequence { tools: t1, .. }, Pattern::ToolSequence { tools: t2, .. }) => {
if t1.len() != t2.len() {
return false;
}
t1.iter().zip(t2.iter()).filter(|(a, b)| a == b).count() >= t1.len() * 2 / 3
}
_ => false,
}
}
fn load_standard_test_queries() -> Vec<TestQuery> {
vec![
TestQuery {
text: "implement authentication".to_string(),
context: TaskContext {
domain: "web-api".to_string(),
..Default::default()
},
},
TestQuery {
text: "handle errors".to_string(),
context: TaskContext {
domain: "error-handling".to_string(),
..Default::default()
},
},
TestQuery {
text: "optimize performance".to_string(),
context: TaskContext {
domain: "performance".to_string(),
..Default::default()
},
},
TestQuery {
text: "write tests".to_string(),
context: TaskContext {
domain: "testing".to_string(),
..Default::default()
},
},
TestQuery {
text: "refactor code".to_string(),
context: TaskContext {
domain: "refactoring".to_string(),
..Default::default()
},
},
]
}
struct TestQuery {
text: String,
context: TaskContext,
}
use do_memory_core::episode::Episode;
async fn setup_memory_with_10k_episodes() -> SelfLearningMemory {
let memory = setup_test_memory();
for i in 0..10000 {
if i % 1000 == 0 {
println!("Loading test episodes: {i}/10000");
}
let context = TaskContext {
domain: format!("domain_{}", i % 10),
complexity: match i % 3 {
0 => ComplexityLevel::Simple,
1 => ComplexityLevel::Moderate,
_ => ComplexityLevel::Complex,
},
..Default::default()
};
let episode_id = memory
.start_episode(format!("Task {i}"), context, TaskType::CodeGeneration)
.await;
memory
.complete_episode(
episode_id,
TaskOutcome::Success {
verdict: "Done".to_string(),
artifacts: vec![],
},
)
.await
.unwrap();
}
memory
}
#[tokio::test]
#[ignore = "Non-deterministic pattern extraction - may fail intermittently"]
async fn should_maintain_pattern_extraction_accuracy_over_time() {
let test_episodes = load_historical_test_episodes();
let memory = setup_test_memory();
for episode in &test_episodes {
let episode_id = memory
.start_episode(
episode.task_description.clone(),
episode.context.clone(),
episode.task_type,
)
.await;
for step in &episode.steps {
memory.log_step(episode_id, step.clone()).await;
}
if let Some(outcome) = &episode.outcome {
memory
.complete_episode(episode_id, outcome.clone())
.await
.unwrap();
}
}
let reference_patterns = load_reference_patterns();
let patterns = memory
.retrieve_relevant_patterns(&TaskContext::default(), 10)
.await;
let accuracy = calculate_pattern_similarity(&patterns, &reference_patterns);
println!(
"Pattern extraction accuracy: {:.1}% ({} extracted, {} expected)",
accuracy * 100.0,
patterns.len(),
reference_patterns.len()
);
assert!(
accuracy > 0.5,
"Pattern extraction accuracy degraded to {:.1}%",
accuracy * 100.0
);
}
#[tokio::test]
async fn should_extract_all_pattern_types_correctly() {
let memory = setup_test_memory();
let episode_id = memory
.start_episode(
"Error handling".to_string(),
TaskContext {
domain: "error-handling".to_string(),
..Default::default()
},
TaskType::CodeGeneration,
)
.await;
let mut error_step = ExecutionStep::new(1, "attempt".to_string(), "Try".to_string());
error_step.result = Some(ExecutionResult::Error {
message: "Failed".to_string(),
});
memory.log_step(episode_id, error_step).await;
let mut success_step = ExecutionStep::new(2, "retry".to_string(), "Retry".to_string());
success_step.result = Some(ExecutionResult::Success {
output: "Success".to_string(),
});
memory.log_step(episode_id, success_step).await;
memory
.complete_episode(
episode_id,
TaskOutcome::Success {
verdict: "Recovered".to_string(),
artifacts: vec![],
},
)
.await
.unwrap();
let completed = memory.get_episode(episode_id).await.unwrap();
assert!(
!completed.patterns.is_empty(),
"No patterns extracted from error recovery episode"
);
}
#[tokio::test]
#[ignore = "Long-running performance test - run with --include-ignored for full validation"]
async fn should_maintain_fast_retrieval_with_large_dataset() {
let memory = setup_memory_with_10k_episodes().await;
let test_queries = load_standard_test_queries();
let mut total_time = Duration::from_millis(0);
for query in &test_queries {
let start = Instant::now();
let _ = memory
.retrieve_relevant_context(query.text.clone(), query.context.clone(), 10)
.await;
total_time += start.elapsed();
}
let avg_time = total_time
/ u32::try_from(test_queries.len()).expect("test_queries length should fit in u32");
println!("Average retrieval time with 10K episodes: {avg_time:?}");
assert!(
avg_time.as_millis() < 100,
"Average retrieval time degraded to {}ms",
avg_time.as_millis()
);
}
#[tokio::test]
async fn should_retrieve_relevant_episodes_by_domain() {
let memory = setup_test_memory();
for i in 0..10 {
let context = TaskContext {
domain: "web-api".to_string(),
tags: vec!["authentication".to_string()],
..Default::default()
};
let episode_id = memory
.start_episode(format!("Auth task {i}"), context, TaskType::CodeGeneration)
.await;
memory
.complete_episode(
episode_id,
TaskOutcome::Success {
verdict: "Done".to_string(),
artifacts: vec![],
},
)
.await
.unwrap();
}
for i in 0..10 {
let context = TaskContext {
domain: "data-processing".to_string(),
tags: vec!["batch".to_string()],
..Default::default()
};
let episode_id = memory
.start_episode(format!("Batch task {i}"), context, TaskType::CodeGeneration)
.await;
memory
.complete_episode(
episode_id,
TaskOutcome::Success {
verdict: "Done".to_string(),
artifacts: vec![],
},
)
.await
.unwrap();
}
let context = TaskContext {
domain: "web-api".to_string(),
..Default::default()
};
let results = memory
.retrieve_relevant_context("authentication".to_string(), context, 10)
.await;
assert!(!results.is_empty());
let web_api_count = results
.iter()
.filter(|e| e.context.domain == "web-api")
.count();
#[allow(clippy::cast_precision_loss)]
{
assert!(
web_api_count as f64 / results.len() as f64 >= 0.5,
"Retrieval quality degraded - only {}/{} results matched domain",
web_api_count,
results.len()
);
}
}
#[tokio::test]
async fn should_maintain_backward_compatible_public_api() {
let memory = setup_test_memory();
let episode_id = memory
.start_episode("test".to_string(), test_context(), TaskType::Testing)
.await;
let step = create_test_step(1);
memory.log_step(episode_id, step).await;
let completed = memory
.complete_episode(
episode_id,
TaskOutcome::Success {
verdict: "test".to_string(),
artifacts: vec![],
},
)
.await;
assert!(completed.is_ok());
let results = memory
.retrieve_relevant_context("test".to_string(), test_context(), 10)
.await;
assert!(results.is_empty() || !results.is_empty());
let episode = memory.get_episode(episode_id).await;
assert!(episode.is_ok());
let (_total, _completed, _patterns) = memory.get_stats().await;
let _patterns = memory.retrieve_relevant_patterns(&test_context(), 10).await;
}
#[tokio::test]
async fn should_maintain_data_structure_compatibility() {
let memory = setup_test_memory();
let episode_id = memory
.start_episode("Test".to_string(), test_context(), TaskType::Testing)
.await;
let episode = memory.get_episode(episode_id).await.unwrap();
let _ = episode.episode_id;
let _ = episode.task_type;
let _ = episode.task_description;
let _ = episode.context;
let _ = episode.start_time;
let _ = episode.end_time;
let _ = episode.steps;
let _ = episode.outcome;
let _ = episode.reward;
let _ = episode.reflection;
let _ = episode.patterns;
let _ = episode.metadata;
assert!(!episode.is_complete());
let _ = episode.duration();
let _ = episode.successful_steps_count();
let _ = episode.failed_steps_count();
let step = create_test_step(1);
let _ = step.step_number;
let _ = step.timestamp;
let _ = step.tool;
let _ = step.action;
let _ = step.parameters;
let _ = step.result;
let _ = step.latency_ms;
let _ = step.tokens_used;
let _ = step.metadata;
let _ = step.is_success();
let context = test_context();
let _ = context.language;
let _ = context.framework;
let _ = context.complexity;
let _ = context.domain;
let _ = context.tags;
}
#[tokio::test]
async fn should_prevent_previously_fixed_bugs_from_recurring() {
let memory1 = setup_test_memory();
let episode_id = memory1
.start_episode("Test".to_string(), test_context(), TaskType::CodeGeneration)
.await;
for i in 1..=5 {
memory1.log_step(episode_id, create_test_step(i)).await;
}
memory1
.complete_episode(
episode_id,
TaskOutcome::Success {
verdict: "Done".to_string(),
artifacts: vec![],
},
)
.await
.unwrap();
let completed = memory1.get_episode(episode_id).await.unwrap();
let mut seen = std::collections::HashSet::new();
for pattern_id in &completed.patterns {
assert!(
seen.insert(*pattern_id),
"Duplicate pattern ID found: {pattern_id}"
);
}
let memory2 = setup_test_memory();
let episode_id2 = memory2
.start_episode("Test".to_string(), test_context(), TaskType::Testing)
.await;
let outcome = TaskOutcome::Success {
verdict: "Done".to_string(),
artifacts: vec![],
};
let result1 = memory2.complete_episode(episode_id2, outcome.clone()).await;
assert!(result1.is_ok());
let result2 = memory2.complete_episode(episode_id2, outcome).await;
assert!(result1.is_ok() || result2.is_err());
let memory3 = setup_test_memory();
let episode_id3 = memory3
.start_episode("Test".to_string(), test_context(), TaskType::Testing)
.await;
let result = memory3
.complete_episode(
episode_id3,
TaskOutcome::Success {
verdict: "Done".to_string(),
artifacts: vec![],
},
)
.await;
assert!(result.is_ok());
let episode = memory3.get_episode(episode_id3).await.unwrap();
assert!(episode.is_complete());
assert_eq!(episode.steps.len(), 0);
}
#[tokio::test]
async fn should_maintain_baseline_episode_creation_performance() {
let memory = setup_test_memory();
let start = Instant::now();
for i in 0..100 {
let episode_id = memory
.start_episode(
format!("Task {i}"),
test_context(),
TaskType::CodeGeneration,
)
.await;
for j in 1..=3 {
memory.log_step(episode_id, create_test_step(j)).await;
}
memory
.complete_episode(
episode_id,
TaskOutcome::Success {
verdict: "Done".to_string(),
artifacts: vec![],
},
)
.await
.unwrap();
}
let elapsed = start.elapsed();
println!(
"100 episodes with 3 steps each: {:?} ({:.2} eps/sec)",
elapsed,
100.0 / elapsed.as_secs_f32()
);
assert!(
elapsed.as_secs() < 5,
"Performance degraded: took {}ms for 100 episodes",
elapsed.as_millis()
);
}