use super::{JobState, RecoveryPlan, StateError, StateEvent, StateEventType, StateManager};
use chrono::Utc;
use serde_json::Value;
use tracing::{debug, info};
impl StateManager {
pub async fn recover_from_checkpoint(
&self,
job_id: &str,
checkpoint_version: Option<u32>,
) -> Result<RecoveryPlan, StateError> {
let state = self
.get_state(job_id)
.await?
.ok_or_else(|| StateError::NotFound(job_id.to_string()))?;
if let Some(version) = checkpoint_version {
if let Some(ref checkpoint) = state.checkpoint {
if checkpoint.version != version {
return Err(StateError::ValidationError(format!(
"Requested checkpoint version {} but found {}",
version, checkpoint.version
)));
}
} else {
return Err(StateError::ValidationError(format!(
"No checkpoint found for job {}",
job_id
)));
}
}
self.validate_checkpoint(&state)?;
let pending_items = self.calculate_pending_items(&state, 0)?;
self.log_event(StateEvent {
timestamp: Utc::now(),
event_type: StateEventType::RecoveryStarted {
checkpoint_version: state.checkpoint.as_ref().map(|c| c.version).unwrap_or(0),
},
job_id: job_id.to_string(),
details: Some(format!("Recovering {} pending items", pending_items.len())),
})
.await;
let plan = RecoveryPlan {
resume_phase: state.phase,
pending_items,
skip_items: state.processed_items.clone(),
variables: state.variables.clone(),
agent_results: state.agent_results.clone(),
};
info!(
"Created recovery plan for job {}: {} items to process, {} to skip",
job_id,
plan.pending_items.len(),
plan.skip_items.len()
);
Ok(plan)
}
pub fn calculate_pending_items(
&self,
state: &JobState,
max_additional_retries: u32,
) -> Result<Vec<Value>, StateError> {
let mut pending_items = Vec::new();
let work_items = self.get_work_items_from_state(state)?;
for (i, item) in work_items.iter().enumerate() {
let item_id = format!("item_{}", i);
if !state.processed_items.contains(&item_id) && !state.failed_items.contains(&item_id) {
pending_items.push(item.clone());
debug!("Adding never-attempted item: {}", item_id);
}
}
let _max_retries = max_additional_retries;
for failed_item_id in &state.failed_items {
if let Some(idx) = failed_item_id
.strip_prefix("item_")
.and_then(|s| s.parse::<usize>().ok())
{
if idx < work_items.len() {
pending_items.push(work_items[idx].clone());
debug!("Adding failed item for retry: {}", failed_item_id);
}
}
}
info!(
"Calculated {} pending items for recovery",
pending_items.len()
);
Ok(pending_items)
}
fn get_work_items_from_state(&self, state: &JobState) -> Result<Vec<Value>, StateError> {
let mut items = Vec::new();
for i in 0..state.total_items {
items.push(Value::String(format!("item_{}", i)));
}
Ok(items)
}
pub async fn can_resume_job(&self, job_id: &str) -> bool {
match self.get_state(job_id).await {
Ok(Some(state)) => !state.is_complete && state.checkpoint.is_some(),
_ => false,
}
}
pub async fn apply_recovery_plan(
&self,
job_id: &str,
plan: &RecoveryPlan,
) -> Result<(), StateError> {
self.update_state(job_id, |state| {
state.variables = plan.variables.clone();
state.agent_results = plan.agent_results.clone();
state.processed_items = plan.skip_items.clone();
if state.phase != plan.resume_phase {
debug!(
"Updating phase from {:?} to {:?}",
state.phase, plan.resume_phase
);
state.phase = plan.resume_phase;
}
Ok(())
})
.await?;
info!(
"Applied recovery plan to job {}: resuming from phase {:?}",
job_id, plan.resume_phase
);
Ok(())
}
pub async fn mark_items_processed(
&self,
job_id: &str,
item_ids: Vec<String>,
) -> Result<(), StateError> {
let count = item_ids.len();
self.update_state(job_id, |state| {
for item_id in item_ids {
state.processed_items.insert(item_id);
}
Ok(())
})
.await?;
self.log_event(StateEvent {
timestamp: Utc::now(),
event_type: StateEventType::ItemsProcessed { count },
job_id: job_id.to_string(),
details: None,
})
.await;
Ok(())
}
pub async fn mark_items_failed(
&self,
job_id: &str,
item_ids: Vec<String>,
) -> Result<(), StateError> {
let count = item_ids.len();
self.update_state(job_id, |state| {
for item_id in item_ids {
if !state.failed_items.contains(&item_id) {
state.failed_items.push(item_id);
}
}
Ok(())
})
.await?;
self.log_event(StateEvent {
timestamp: Utc::now(),
event_type: StateEventType::ItemsFailed { count },
job_id: job_id.to_string(),
details: None,
})
.await;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct ResumeOptions {
pub force_resume: bool,
pub max_additional_retries: u32,
pub skip_validation: bool,
pub from_checkpoint: Option<u32>,
}
impl Default for ResumeOptions {
fn default() -> Self {
Self {
force_resume: false,
max_additional_retries: 2,
skip_validation: false,
from_checkpoint: None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cook::execution::mapreduce::state::persistence::InMemoryStateStore;
use crate::cook::execution::mapreduce::state::PhaseType;
use crate::cook::execution::mapreduce::{AgentResult, AgentStatus, MapReduceConfig};
use std::collections::HashSet;
use std::sync::Arc;
use std::time::Duration;
#[tokio::test]
async fn test_recovery_plan_creation() {
let store = Arc::new(InMemoryStateStore::new());
let manager = StateManager::new(store);
let config = MapReduceConfig::default();
let job_id = "test-job-recovery".to_string();
let _state = manager.create_job(&config, job_id.clone()).await.unwrap();
manager
.update_state(&job_id, |state| {
state.total_items = 5;
state.processed_items.insert("item_0".to_string());
state.processed_items.insert("item_1".to_string());
state.agent_results.insert(
"item_0".to_string(),
AgentResult {
item_id: "item_0".to_string(),
status: AgentStatus::Success,
output: Some("output".to_string()),
commits: vec![],
duration: Duration::from_secs(1),
error: None,
worktree_path: None,
branch_name: None,
worktree_session_id: None,
files_modified: vec![],
json_log_location: None,
cleanup_status: None,
},
);
state.agent_results.insert(
"item_1".to_string(),
AgentResult {
item_id: "item_1".to_string(),
status: AgentStatus::Success,
output: Some("output".to_string()),
commits: vec![],
duration: Duration::from_secs(1),
error: None,
worktree_path: None,
branch_name: None,
worktree_session_id: None,
files_modified: vec![],
json_log_location: None,
cleanup_status: None,
},
);
state.failed_items.push("item_2".to_string());
Ok(())
})
.await
.unwrap();
manager.create_checkpoint(&job_id).await.unwrap();
let plan = manager
.recover_from_checkpoint(&job_id, None)
.await
.unwrap();
assert_eq!(plan.skip_items.len(), 2); assert!(plan.skip_items.contains("item_0"));
assert!(plan.skip_items.contains("item_1"));
assert_eq!(plan.pending_items.len(), 3);
}
#[tokio::test]
async fn test_calculate_pending_items() {
let store = Arc::new(InMemoryStateStore::new());
let manager = StateManager::new(store);
let mut state = JobState {
id: "test-job".to_string(),
phase: PhaseType::Map,
checkpoint: None,
processed_items: HashSet::new(),
failed_items: Vec::new(),
variables: Default::default(),
created_at: Utc::now(),
updated_at: Utc::now(),
config: MapReduceConfig::default(),
agent_results: Default::default(),
is_complete: false,
total_items: 5,
};
state.processed_items.insert("item_0".to_string());
state.processed_items.insert("item_1".to_string());
state.failed_items.push("item_2".to_string());
let pending = manager.calculate_pending_items(&state, 0).unwrap();
assert_eq!(pending.len(), 3);
}
#[tokio::test]
async fn test_mark_items_processed() {
let store = Arc::new(InMemoryStateStore::new());
let manager = StateManager::new(store);
let config = MapReduceConfig::default();
let job_id = "test-job-mark".to_string();
manager.create_job(&config, job_id.clone()).await.unwrap();
manager
.mark_items_processed(&job_id, vec!["item_0".to_string(), "item_1".to_string()])
.await
.unwrap();
let state = manager.get_state(&job_id).await.unwrap().unwrap();
assert_eq!(state.processed_items.len(), 2);
assert!(state.processed_items.contains("item_0"));
assert!(state.processed_items.contains("item_1"));
}
#[tokio::test]
async fn test_can_resume_job() {
let store = Arc::new(InMemoryStateStore::new());
let manager = StateManager::new(store);
let config = MapReduceConfig::default();
let job_id = "test-job-resume".to_string();
manager.create_job(&config, job_id.clone()).await.unwrap();
assert!(!manager.can_resume_job(&job_id).await);
manager.create_checkpoint(&job_id).await.unwrap();
assert!(manager.can_resume_job(&job_id).await);
manager
.update_state(&job_id, |state| {
state.is_complete = true;
Ok(())
})
.await
.unwrap();
assert!(!manager.can_resume_job(&job_id).await);
}
}