mod common;
use chrono::Utc;
use common::{
ContextBuilder, PatternType, StepBuilder, create_completed_episode_with_pattern,
create_test_episode_with_domain, setup_simple_test_memory, setup_test_memory, test_context,
};
use do_memory_core::{Pattern, TaskOutcome, TaskType};
use serde_json::json;
use uuid::Uuid;
#[tokio::test]
async fn should_create_episodes_with_unique_ids_and_timestamps() {
let memory = setup_test_memory();
let episode_id = memory
.start_episode("Test task".to_string(), test_context(), TaskType::Testing)
.await;
let episode = memory.get_episode(episode_id).await.unwrap();
assert_ne!(episode.episode_id, Uuid::nil());
assert_eq!(episode.episode_id, episode_id);
assert!(episode.start_time <= Utc::now());
assert!(episode.end_time.is_none());
assert!(!episode.is_complete());
assert_eq!(episode.steps.len(), 0);
assert!(episode.outcome.is_none());
let mut episode_ids = Vec::new();
for i in 0..10 {
let id = memory
.start_episode(
format!("Task {i}"),
test_context(),
TaskType::CodeGeneration,
)
.await;
episode_ids.push(id);
}
for i in 0..episode_ids.len() {
for j in i + 1..episode_ids.len() {
assert_ne!(episode_ids[i], episode_ids[j], "Episode IDs must be unique");
}
}
}
#[tokio::test]
async fn should_log_execution_steps_with_ordering_and_metadata() {
let memory = setup_test_memory();
let episode_id = memory
.start_episode("Test task".to_string(), test_context(), TaskType::Testing)
.await;
let step = StepBuilder::new(1, "test_tool", "Test action")
.parameters(json!({"key": "value"}))
.latency_ms(10)
.tokens_used(50)
.success("Success")
.build();
memory.log_step(episode_id, step).await;
memory.flush_steps(episode_id).await.unwrap();
let episode = memory.get_episode(episode_id).await.unwrap();
assert_eq!(episode.steps.len(), 1);
assert_eq!(episode.steps[0].tool, "test_tool");
assert_eq!(episode.steps[0].latency_ms, 10);
assert_eq!(episode.steps[0].tokens_used, Some(50));
assert!(episode.steps[0].is_success());
for i in 2..=5 {
let step = StepBuilder::new(i, format!("tool_{i}"), format!("Action {i}"))
.success("OK")
.build();
memory.log_step(episode_id, step).await;
}
memory.flush_steps(episode_id).await.unwrap();
let episode = memory.get_episode(episode_id).await.unwrap();
assert_eq!(episode.steps.len(), 5);
for (i, step) in episode.steps.iter().enumerate() {
assert_eq!(step.step_number, i + 1);
}
let metadata_step = StepBuilder::new(6, "metadata_tool", "Metadata action")
.latency_ms(150)
.tokens_used(1200)
.parameters(json!({"model": "claude-3", "temperature": 0.7}))
.metadata("custom_key", "custom_value")
.success("OK")
.build();
memory.log_step(episode_id, metadata_step).await;
memory.flush_steps(episode_id).await.unwrap();
let episode = memory.get_episode(episode_id).await.unwrap();
let recorded_step = &episode.steps[5];
assert_eq!(recorded_step.latency_ms, 150);
assert_eq!(recorded_step.tokens_used, Some(1200));
assert_eq!(recorded_step.parameters["model"], "claude-3");
assert_eq!(
recorded_step.metadata.get("custom_key"),
Some(&"custom_value".to_string())
);
}
#[tokio::test]
async fn should_complete_episodes_with_reward_scoring_and_reflection() {
let memory = setup_test_memory();
let episode_id = memory
.start_episode("Test task".to_string(), test_context(), TaskType::Testing)
.await;
let outcome = TaskOutcome::Success {
verdict: "Test passed".to_string(),
artifacts: vec!["test.rs".to_string()],
};
memory.complete_episode(episode_id, outcome).await.unwrap();
let completed = memory.get_episode(episode_id).await.unwrap();
assert!(completed.end_time.is_some());
assert!(completed.reward.is_some());
assert!(completed.reflection.is_some());
let reward = completed.reward.unwrap();
assert!(reward.total >= 0.0);
assert!((reward.base - 1.0).abs() < f32::EPSILON);
let reflection = completed.reflection.unwrap();
assert!(!reflection.successes.is_empty() || !reflection.insights.is_empty());
}
#[tokio::test]
async fn should_handle_failed_episodes_with_improvements() {
let memory = setup_test_memory();
let episode_id = memory
.start_episode("Test task".to_string(), test_context(), TaskType::Testing)
.await;
let outcome = TaskOutcome::Failure {
reason: "Test failed".to_string(),
error_details: Some("Assertion error".to_string()),
};
memory.complete_episode(episode_id, outcome).await.unwrap();
let completed = memory.get_episode(episode_id).await.unwrap();
let reward = completed.reward.unwrap();
assert!((reward.base - 0.0).abs() < f32::EPSILON);
let reflection = completed.reflection.unwrap();
assert!(!reflection.improvements.is_empty());
}
#[tokio::test]
async fn should_score_partial_success_between_failure_and_success() {
let memory = setup_test_memory();
let episode_id = memory
.start_episode("Test task".to_string(), test_context(), TaskType::Testing)
.await;
let outcome = TaskOutcome::PartialSuccess {
verdict: "Some tests passed".to_string(),
completed: vec!["test1".to_string(), "test2".to_string()],
failed: vec!["test3".to_string()],
};
memory.complete_episode(episode_id, outcome).await.unwrap();
let completed = memory.get_episode(episode_id).await.unwrap();
let reward = completed.reward.unwrap();
assert!(reward.base >= 0.5 && reward.base < 1.0);
assert!(reward.total > 0.0);
}
#[tokio::test]
async fn should_extract_patterns_from_completed_episodes() {
let memory = setup_test_memory();
let episode_id =
create_completed_episode_with_pattern(&memory, PatternType::ErrorRecovery).await;
let episode = memory.get_episode(episode_id).await.unwrap();
assert!(
!episode.patterns.is_empty(),
"Expected patterns to be extracted from episode with clear retry pattern"
);
let context = ContextBuilder::new("error-handling").build();
let patterns = memory.retrieve_relevant_patterns(&context, 10).await;
assert!(!patterns.is_empty(), "Expected at least one pattern");
let has_error_recovery = patterns
.iter()
.any(|p| matches!(p, Pattern::ErrorRecovery { .. }));
assert!(
has_error_recovery,
"Expected ErrorRecovery pattern to be extracted"
);
}
#[tokio::test]
async fn should_extract_different_pattern_types_based_on_episode_structure() {
let memory = setup_test_memory();
let episode_id = memory
.start_episode(
"Multi-step task".to_string(),
test_context(),
TaskType::CodeGeneration,
)
.await;
for i in 1..=3 {
let step = StepBuilder::new(i, format!("tool_{i}"), format!("Action {i}"))
.success("Done")
.build();
memory.log_step(episode_id, step).await;
}
memory
.complete_episode(
episode_id,
TaskOutcome::Success {
verdict: "Done".to_string(),
artifacts: vec![],
},
)
.await
.unwrap();
let episode = memory.get_episode(episode_id).await.unwrap();
assert!(
!episode.patterns.is_empty(),
"Expected patterns from multi-step episode"
);
}
#[tokio::test]
async fn should_retrieve_relevant_episodes_with_context_filtering_and_limits() {
let memory = setup_test_memory();
for i in 0..20 {
let domain = if i % 2 == 0 { "web-api" } else { "cli-tool" };
create_test_episode_with_domain(&memory, domain).await;
}
let web_context = ContextBuilder::new("web-api").build();
let results = memory
.retrieve_relevant_context("test query".to_string(), web_context, 10)
.await;
assert!(!results.is_empty());
assert!(results.len() <= 10);
let web_count = results
.iter()
.filter(|e| e.context.domain == "web-api")
.count();
{
#[allow(clippy::cast_precision_loss)]
let ratio = web_count as f32 / results.len() as f32;
assert!(
ratio > 0.5,
"Expected majority of results to match domain filter"
);
}
let memory2 = setup_simple_test_memory();
for _i in 0..50 {
create_test_episode_with_domain(&memory2, "test-domain").await;
}
let context = ContextBuilder::new("test-domain").build();
let results_5 = memory2
.retrieve_relevant_context("query".to_string(), context.clone(), 5)
.await;
let results_20 = memory2
.retrieve_relevant_context("query".to_string(), context, 20)
.await;
assert_eq!(results_5.len(), 5, "Should respect limit of 5");
assert_eq!(results_20.len(), 20, "Should respect limit of 20");
let memory3 = setup_simple_test_memory();
for lang in ["rust", "python", "typescript"] {
let context = ContextBuilder::new("code-gen").language(lang).build();
for _ in 0..5 {
let episode_id = memory3
.start_episode(
"Task".to_string(),
context.clone(),
TaskType::CodeGeneration,
)
.await;
memory3
.complete_episode(
episode_id,
TaskOutcome::Success {
verdict: "Done".to_string(),
artifacts: vec![],
},
)
.await
.unwrap();
}
}
let (total, completed, _) = memory3.get_stats().await;
assert_eq!(total, 15, "Should have 15 episodes total");
assert_eq!(completed, 15, "All 15 episodes should be completed");
let rust_context = ContextBuilder::new("code-gen").language("rust").build();
let results = memory3
.retrieve_relevant_context("task".to_string(), rust_context, 10)
.await;
let rust_count = results
.iter()
.filter(|e| e.context.language.as_deref() == Some("rust"))
.count();
assert!(
!results.is_empty(),
"Should return some results (got {} results)",
results.len()
);
assert!(
rust_count > 0,
"Should return rust episodes (got {} results, {} rust)",
results.len(),
rust_count
);
}
#[tokio::test]
#[ignore = "Requires MCP server implementation"]
async fn should_execute_typescript_code_in_secure_sandbox() {
}
#[tokio::test]
#[ignore = "Requires MCP server implementation"]
async fn should_generate_mcp_tools_from_memory_patterns() {
}
#[tokio::test]
async fn should_maintain_episode_integrity_after_completion() {
let memory = setup_test_memory();
let episode_id = memory
.start_episode("Test".to_string(), test_context(), TaskType::Testing)
.await;
memory
.complete_episode(
episode_id,
TaskOutcome::Success {
verdict: "Done".to_string(),
artifacts: vec![],
},
)
.await
.unwrap();
let step = StepBuilder::new(1, "test_tool", "Test action")
.success("OK")
.build();
memory.log_step(episode_id, step).await;
let episode = memory.get_episode(episode_id).await.unwrap();
assert!(episode.is_complete());
}
#[tokio::test]
async fn should_report_accurate_statistics() {
let memory = setup_test_memory();
for i in 0..5 {
let episode_id = memory
.start_episode(
format!("Task {i}"),
test_context(),
TaskType::CodeGeneration,
)
.await;
if i < 3 {
memory
.complete_episode(
episode_id,
TaskOutcome::Success {
verdict: "Done".to_string(),
artifacts: vec![],
},
)
.await
.unwrap();
}
}
let (total, completed, _patterns) = memory.get_stats().await;
assert_eq!(total, 5);
assert_eq!(completed, 3);
}