use do_memory_core::ExecutionStep;
use do_memory_core::memory::SelfLearningMemory;
use do_memory_core::types::{
ComplexityLevel, ExecutionResult, MemoryConfig, TaskContext, TaskOutcome, TaskType,
};
fn create_high_quality_episode_data() -> (String, TaskContext, TaskType, Vec<ExecutionStep>) {
let context = TaskContext {
language: Some("rust".to_string()),
framework: Some("tokio".to_string()),
complexity: ComplexityLevel::Complex,
domain: "async-web-api".to_string(),
tags: vec!["http".to_string(), "rest".to_string()],
};
let task_description =
"Implement comprehensive async REST API with authentication and validation".to_string();
let task_type = TaskType::CodeGeneration;
let mut steps = Vec::new();
let mut execution_step1 = ExecutionStep::new(
1,
"planner".to_string(),
"Choose async implementation strategy".to_string(),
);
execution_step1.parameters = serde_json::json!({
"strategy": "tokio-async",
"approach": "layered-architecture"
});
execution_step1.result = Some(ExecutionResult::Success {
output: "Strategy selected: async with layered architecture".to_string(),
});
execution_step1.latency_ms = 120;
steps.push(execution_step1);
for i in 2..=4 {
let mut step = ExecutionStep::new(i, format!("builder_{i}"), format!("Build layer {i}"));
step.result = Some(ExecutionResult::Success {
output: format!("Layer {i} complete"),
});
step.latency_ms = 100 + (i as u64 * 10);
steps.push(step);
}
let mut error_step = ExecutionStep::new(
5,
"validator".to_string(),
"Validate API endpoints".to_string(),
);
error_step.result = Some(ExecutionResult::Error {
message: "Validation failed: missing authentication header".to_string(),
});
error_step.latency_ms = 80;
steps.push(error_step);
let mut recovery_step = ExecutionStep::new(
6,
"validator".to_string(),
"Add authentication header validation".to_string(),
);
recovery_step.result = Some(ExecutionResult::Success {
output: "Authentication validation added".to_string(),
});
recovery_step.latency_ms = 150;
steps.push(recovery_step);
for i in 7..=10 {
let mut step = ExecutionStep::new(
i,
format!("integrator_{}", i % 3),
format!("Integration step {i}"),
);
step.result = Some(ExecutionResult::Success {
output: format!("Integration {i} complete"),
});
step.latency_ms = 90 + (i as u64 * 5);
steps.push(step);
}
(task_description, context, task_type, steps)
}
fn create_low_quality_episode_data() -> (String, TaskContext, TaskType, Vec<ExecutionStep>) {
let context = TaskContext {
language: Some("rust".to_string()),
framework: None,
complexity: ComplexityLevel::Simple,
domain: "testing".to_string(),
tags: vec![],
};
let task_description = "Test".to_string(); let task_type = TaskType::Testing;
let mut steps = Vec::new();
for i in 1..=2 {
let mut step =
ExecutionStep::new(i, "simple_tool".to_string(), "Simple action".to_string());
step.result = Some(ExecutionResult::Success {
output: "OK".to_string(),
});
step.latency_ms = 10;
steps.push(step);
}
(task_description, context, task_type, steps)
}
#[tokio::test]
async fn test_high_quality_episode_accepted() {
let config = MemoryConfig {
quality_threshold: 0.5,
..Default::default()
};
let memory = SelfLearningMemory::with_config(config);
let (task_description, context, task_type, steps) = create_high_quality_episode_data();
let episode_id = memory
.start_episode(task_description, context, task_type)
.await;
for step in steps {
memory.log_step(episode_id, step).await;
}
memory.flush_steps(episode_id).await.unwrap();
let result = memory
.complete_episode(
episode_id,
TaskOutcome::Success {
verdict: "API implemented successfully with comprehensive authentication, validation, and error recovery patterns"
.to_string(),
artifacts: vec![
"api.rs".to_string(),
"auth.rs".to_string(),
"validation.rs".to_string(),
"tests.rs".to_string(),
],
},
)
.await;
assert!(
result.is_ok(),
"High-quality episode should be accepted, got error: {:?}",
result.err()
);
let episode = memory.get_episode(episode_id).await.unwrap();
assert!(episode.is_complete());
assert!(episode.reward.is_some());
assert!(episode.reflection.is_some());
assert!(
episode.salient_features.is_some(),
"Salient features should be extracted for high-quality episode"
);
let features = episode.salient_features.as_ref().unwrap();
assert!(!features.is_empty(), "Salient features should not be empty");
assert!(features.count() > 0, "Should have extracted some features");
assert!(
!features.critical_decisions.is_empty(),
"Should have extracted critical decisions"
);
assert!(
!features.tool_combinations.is_empty(),
"Should have extracted tool combinations"
);
assert!(
!features.error_recovery_patterns.is_empty(),
"Should have extracted error recovery patterns"
);
assert!(
!features.key_insights.is_empty(),
"Should have extracted key insights"
);
}
#[tokio::test]
async fn test_low_quality_episode_rejected() {
let memory = SelfLearningMemory::new();
let (task_description, context, task_type, steps) = create_low_quality_episode_data();
let episode_id = memory
.start_episode(task_description, context, task_type)
.await;
for step in steps {
memory.log_step(episode_id, step).await;
}
let result = memory
.complete_episode(
episode_id,
TaskOutcome::Success {
verdict: "Done".to_string(),
artifacts: vec![],
},
)
.await;
assert!(result.is_err(), "Low-quality episode should be rejected");
match result.unwrap_err() {
do_memory_core::Error::ValidationFailed(msg) => {
assert!(msg.contains("quality score"));
assert!(msg.contains("below threshold"));
}
other => panic!("Expected ValidationFailed error, got: {other:?}"),
}
}
#[tokio::test]
async fn test_custom_quality_threshold() {
let config = MemoryConfig {
quality_threshold: 0.4,
..Default::default()
};
let memory = SelfLearningMemory::with_config(config);
let (task_description, context, task_type, steps) = create_low_quality_episode_data();
let episode_id = memory
.start_episode(task_description, context, task_type)
.await;
for step in steps {
memory.log_step(episode_id, step).await;
}
let result = memory
.complete_episode(
episode_id,
TaskOutcome::Success {
verdict: "Done".to_string(),
artifacts: vec![],
},
)
.await;
match result {
Ok(()) => {
let episode = memory.get_episode(episode_id).await.unwrap();
assert!(episode.salient_features.is_some());
}
Err(err) => {
match err {
do_memory_core::Error::ValidationFailed(_) => {
}
other => panic!("Unexpected error: {other:?}"),
}
}
}
}
#[tokio::test]
async fn test_salient_features_storage_in_cache() {
let config = MemoryConfig {
quality_threshold: 0.5, ..Default::default()
};
let memory = SelfLearningMemory::with_config(config);
let (task_description, context, task_type, steps) = create_high_quality_episode_data();
let episode_id = memory
.start_episode(task_description, context, task_type)
.await;
for step in steps {
memory.log_step(episode_id, step).await;
}
memory
.complete_episode(
episode_id,
TaskOutcome::Success {
verdict: "Success".to_string(),
artifacts: vec!["artifact.rs".to_string()],
},
)
.await
.unwrap();
let episode = memory.get_episode(episode_id).await.unwrap();
assert!(episode.salient_features.is_some());
let features = episode.salient_features.unwrap();
assert!(!features.critical_decisions.is_empty());
for decision in &features.critical_decisions {
assert!(!decision.is_empty(), "Decision should have content");
}
assert!(!features.tool_combinations.is_empty());
for combo in &features.tool_combinations {
assert!(
combo.len() >= 2,
"Tool combination should have at least 2 tools"
);
}
assert!(!features.error_recovery_patterns.is_empty());
for pattern in &features.error_recovery_patterns {
assert!(
pattern.contains("->"),
"Recovery pattern should show error->recovery"
);
}
}
#[tokio::test]
async fn test_performance_overhead() {
use std::time::Instant;
let config = MemoryConfig {
quality_threshold: 0.5, ..Default::default()
};
let memory = SelfLearningMemory::with_config(config);
let (task_description, context, task_type, steps) = create_high_quality_episode_data();
let episode_id = memory
.start_episode(task_description, context, task_type)
.await;
for step in steps {
memory.log_step(episode_id, step).await;
}
let start = Instant::now();
let result = memory
.complete_episode(
episode_id,
TaskOutcome::Success {
verdict: "Success".to_string(),
artifacts: vec!["artifact.rs".to_string()],
},
)
.await;
let elapsed = start.elapsed();
assert!(result.is_ok());
assert!(
elapsed.as_millis() < 100,
"Complete episode with PREMem should take < 100ms, took: {}ms",
elapsed.as_millis()
);
println!(
"Complete episode with PREMem overhead: {}ms",
elapsed.as_millis()
);
}
#[tokio::test]
async fn test_rejection_logging() {
let memory = SelfLearningMemory::new();
let (task_description, context, task_type, steps) = create_low_quality_episode_data();
let episode_id = memory
.start_episode(task_description, context, task_type)
.await;
for step in steps {
memory.log_step(episode_id, step).await;
}
let result = memory
.complete_episode(
episode_id,
TaskOutcome::Success {
verdict: "Done".to_string(),
artifacts: vec![],
},
)
.await;
assert!(result.is_err());
let err = result.unwrap_err();
let err_msg = format!("{err}");
assert!(err_msg.contains("quality score"));
assert!(err_msg.contains("threshold"));
assert!(err_msg.contains('(') && err_msg.contains(')'));
}