#[cfg(test)]
mod state_tests {
use super::super::*;
use crate::cook::execution::mapreduce::state::persistence::InMemoryStateStore;
use crate::cook::execution::mapreduce::{AgentResult, AgentStatus, MapReduceConfig};
use std::sync::Arc;
use std::time::Duration;
async fn create_test_state_manager() -> Arc<StateManager> {
let store = Arc::new(InMemoryStateStore::new());
Arc::new(StateManager::new(store))
}
async fn create_test_job(manager: &StateManager, job_id: &str) -> JobState {
let config = MapReduceConfig::default();
let _state = manager
.create_job(&config, job_id.to_string())
.await
.unwrap();
manager
.update_state(job_id, |state| {
state.total_items = 10;
for i in 0..5 {
state.processed_items.insert(format!("item_{}", i));
state.agent_results.insert(
format!("item_{}", i),
AgentResult {
item_id: format!("item_{}", i),
status: AgentStatus::Success,
output: Some(format!("output_{}", i)),
commits: vec![],
duration: Duration::from_secs(1),
error: None,
worktree_path: None,
branch_name: None,
worktree_session_id: None,
files_modified: vec![],
json_log_location: None,
cleanup_status: None,
},
);
}
Ok(())
})
.await
.unwrap();
manager.get_state(job_id).await.unwrap().unwrap()
}
#[tokio::test]
async fn test_state_lifecycle() {
let manager = create_test_state_manager().await;
let job_id = "test-lifecycle";
let state = manager
.create_job(&MapReduceConfig::default(), job_id.to_string())
.await
.unwrap();
assert_eq!(state.phase, PhaseType::Setup);
assert!(!state.is_complete);
manager.mark_job_started(job_id).await.unwrap();
let state = manager.get_state(job_id).await.unwrap().unwrap();
assert_eq!(state.phase, PhaseType::Map);
manager.mark_reduce_started(job_id).await.unwrap();
let state = manager.get_state(job_id).await.unwrap().unwrap();
assert_eq!(state.phase, PhaseType::Reduce);
manager.mark_job_completed(job_id).await.unwrap();
let state = manager.get_state(job_id).await.unwrap().unwrap();
assert_eq!(state.phase, PhaseType::Completed);
assert!(state.is_complete);
}
#[tokio::test]
async fn test_checkpoint_creation_and_recovery() {
let manager = create_test_state_manager().await;
let job_id = "test-checkpoint";
let _ = create_test_job(&manager, job_id).await;
let checkpoint = manager.create_checkpoint(job_id).await.unwrap();
assert_eq!(checkpoint.version, 1);
assert_eq!(checkpoint.items_processed.len(), 5);
let recovery_plan = manager.recover_from_checkpoint(job_id, None).await.unwrap();
assert_eq!(recovery_plan.skip_items.len(), 5);
assert!(!recovery_plan.pending_items.is_empty());
}
#[tokio::test]
async fn test_state_transitions_validation() {
let manager = create_test_state_manager().await;
let job_id = "test-transitions";
manager
.create_job(&MapReduceConfig::default(), job_id.to_string())
.await
.unwrap();
assert!(manager
.can_transition(job_id, PhaseType::Map)
.await
.unwrap());
assert!(manager
.can_transition(job_id, PhaseType::Failed)
.await
.unwrap());
assert!(!manager
.can_transition(job_id, PhaseType::Reduce)
.await
.unwrap());
assert!(!manager
.can_transition(job_id, PhaseType::Completed)
.await
.unwrap());
let transitions = manager.get_valid_transitions(job_id).await.unwrap();
assert!(transitions.contains(&PhaseType::Map));
assert!(transitions.contains(&PhaseType::Failed));
}
#[tokio::test]
async fn test_concurrent_state_updates() {
let manager = create_test_state_manager().await;
let job_id = "test-concurrent";
manager
.create_job(&MapReduceConfig::default(), job_id.to_string())
.await
.unwrap();
let manager1 = manager.clone();
let job_id1 = job_id.to_string();
let handle1 = tokio::spawn(async move {
for i in 0..5 {
manager1
.mark_items_processed(&job_id1, vec![format!("item_{}", i)])
.await
.unwrap();
}
});
let manager2 = manager.clone();
let job_id2 = job_id.to_string();
let handle2 = tokio::spawn(async move {
for i in 5..10 {
manager2
.mark_items_processed(&job_id2, vec![format!("item_{}", i)])
.await
.unwrap();
}
});
handle1.await.unwrap();
handle2.await.unwrap();
let state = manager.get_state(job_id).await.unwrap().unwrap();
assert_eq!(state.processed_items.len(), 10);
}
#[tokio::test]
async fn test_failed_items_tracking() {
let manager = create_test_state_manager().await;
let job_id = "test-failed";
manager
.create_job(&MapReduceConfig::default(), job_id.to_string())
.await
.unwrap();
manager
.mark_items_failed(job_id, vec!["item_0".to_string(), "item_1".to_string()])
.await
.unwrap();
let state = manager.get_state(job_id).await.unwrap().unwrap();
assert_eq!(state.failed_items.len(), 2);
assert!(state.failed_items.contains(&"item_0".to_string()));
assert!(state.failed_items.contains(&"item_1".to_string()));
}
#[tokio::test]
async fn test_state_history() {
let manager = create_test_state_manager().await;
let job_id = "test-history";
manager
.create_job(&MapReduceConfig::default(), job_id.to_string())
.await
.unwrap();
manager.mark_job_started(job_id).await.unwrap();
manager
.mark_items_processed(job_id, vec!["item_0".to_string()])
.await
.unwrap();
manager
.mark_items_failed(job_id, vec!["item_1".to_string()])
.await
.unwrap();
manager.mark_job_completed(job_id).await.unwrap();
let history = manager.get_state_history(job_id).await;
assert!(history.len() >= 5);
let event_types: Vec<_> = history.iter().map(|e| &e.event_type).collect();
assert!(event_types
.iter()
.any(|e| matches!(e, StateEventType::JobCreated)));
assert!(event_types
.iter()
.any(|e| matches!(e, StateEventType::JobCompleted)));
}
#[tokio::test]
async fn test_recovery_with_partial_completion() {
let manager = create_test_state_manager().await;
let job_id = "test-partial-recovery";
let config = MapReduceConfig::default();
manager
.create_job(&config, job_id.to_string())
.await
.unwrap();
manager
.update_state(job_id, |state| {
state.total_items = 10;
for i in 0..5 {
state.processed_items.insert(format!("item_{}", i));
state.agent_results.insert(
format!("item_{}", i),
AgentResult {
item_id: format!("item_{}", i),
status: AgentStatus::Success,
output: Some(format!("output_{}", i)),
commits: vec![],
duration: Duration::from_secs(1),
error: None,
worktree_path: None,
branch_name: None,
worktree_session_id: None,
files_modified: vec![],
json_log_location: None,
cleanup_status: None,
},
);
}
state.failed_items.push("item_5".to_string());
state.failed_items.push("item_6".to_string());
Ok(())
})
.await
.unwrap();
manager.create_checkpoint(job_id).await.unwrap();
let plan = manager.recover_from_checkpoint(job_id, None).await.unwrap();
assert_eq!(plan.skip_items.len(), 5);
assert!(plan.pending_items.len() >= 5); }
#[tokio::test]
async fn test_terminal_state_restrictions() {
let manager = create_test_state_manager().await;
let job_id = "test-terminal";
manager
.create_job(&MapReduceConfig::default(), job_id.to_string())
.await
.unwrap();
manager.mark_job_started(job_id).await.unwrap();
manager.mark_job_completed(job_id).await.unwrap();
let result = manager.transition_to_phase(job_id, PhaseType::Map).await;
assert!(result.is_err());
let result = manager.mark_job_failed(job_id, "test".to_string()).await;
assert!(result.is_err());
let state = manager.get_state(job_id).await.unwrap().unwrap();
assert_eq!(state.phase, PhaseType::Completed);
assert!(state.is_complete);
}
#[tokio::test]
async fn test_checkpoint_version_management() {
let manager = create_test_state_manager().await;
let job_id = "test-versions";
create_test_job(&manager, job_id).await;
let checkpoint1 = manager.create_checkpoint(job_id).await.unwrap();
assert_eq!(checkpoint1.version, 1);
manager
.mark_items_processed(job_id, vec!["item_5".to_string()])
.await
.unwrap();
let checkpoint2 = manager.create_checkpoint(job_id).await.unwrap();
assert_eq!(checkpoint2.version, 2);
let latest = manager.get_checkpoint(job_id, None).await.unwrap().unwrap();
assert_eq!(latest.version, 2);
}
}