use super::{Checkpoint, JobState, StateError, StateEvent, StateEventType, StateManager};
use crate::cook::execution::mapreduce::AgentResult;
use chrono::Utc;
use sha2::{Digest, Sha256};
use std::collections::HashSet;
use tracing::{debug, info, warn};
impl StateManager {
pub async fn create_checkpoint(&self, job_id: &str) -> Result<Checkpoint, StateError> {
let state = self
.get_state(job_id)
.await?
.ok_or_else(|| StateError::NotFound(job_id.to_string()))?;
let agent_results: Vec<AgentResult> = state.agent_results.values().cloned().collect();
let version = state
.checkpoint
.as_ref()
.map(|c| c.version + 1)
.unwrap_or(1);
let checkpoint = Checkpoint {
phase: state.phase,
items_processed: state.processed_items.clone().into_iter().collect(),
agent_results: agent_results.clone(),
timestamp: Utc::now(),
checksum: calculate_checksum(&state),
version,
};
self.update_state(job_id, |state| {
state.checkpoint = Some(checkpoint.clone());
Ok(())
})
.await?;
self.log_event(StateEvent {
timestamp: Utc::now(),
event_type: StateEventType::CheckpointCreated { version },
job_id: job_id.to_string(),
details: Some(format!(
"Phase: {:?}, Items processed: {}",
state.phase,
agent_results.len()
)),
})
.await;
info!("Created checkpoint v{} for job {}", version, job_id);
Ok(checkpoint)
}
pub fn validate_checkpoint(&self, state: &JobState) -> Result<(), StateError> {
if state.id.is_empty() {
return Err(StateError::ValidationError(
"Empty job ID in state".to_string(),
));
}
if state.total_items == 0 && !state.is_complete {
warn!("State has 0 total items but is not marked complete");
}
let total_processed = state.processed_items.len();
if total_processed > state.total_items {
return Err(StateError::ValidationError(format!(
"Processed count ({}) exceeds total items ({})",
total_processed, state.total_items
)));
}
for item_id in &state.processed_items {
if !state.agent_results.contains_key(item_id) {
return Err(StateError::ValidationError(format!(
"Processed item {} has no result",
item_id
)));
}
}
if let Some(ref checkpoint) = state.checkpoint {
self.validate_checkpoint_integrity(checkpoint, state)?;
}
debug!("Checkpoint validation passed for job {}", state.id);
Ok(())
}
fn validate_checkpoint_integrity(
&self,
checkpoint: &Checkpoint,
state: &JobState,
) -> Result<(), StateError> {
let expected_checksum = calculate_checksum(state);
if checkpoint.checksum != expected_checksum {
warn!(
"Checksum mismatch for checkpoint v{}: expected {}, got {}",
checkpoint.version, expected_checksum, checkpoint.checksum
);
}
let checkpoint_items: HashSet<_> = checkpoint.items_processed.iter().cloned().collect();
if checkpoint_items != state.processed_items {
return Err(StateError::ValidationError(format!(
"Checkpoint items mismatch: checkpoint has {} items, state has {}",
checkpoint_items.len(),
state.processed_items.len()
)));
}
if checkpoint.agent_results.len() != state.agent_results.len() {
return Err(StateError::ValidationError(format!(
"Agent results mismatch: checkpoint has {}, state has {}",
checkpoint.agent_results.len(),
state.agent_results.len()
)));
}
Ok(())
}
pub async fn get_checkpoint(
&self,
job_id: &str,
version: Option<u32>,
) -> Result<Option<Checkpoint>, StateError> {
let state = self.get_state(job_id).await?;
match state {
Some(s) => {
if let Some(checkpoint) = s.checkpoint {
if version.is_none() || version == Some(checkpoint.version) {
Ok(Some(checkpoint))
} else {
Err(StateError::ValidationError(format!(
"Checkpoint version {} not found (current: {})",
version.unwrap(),
checkpoint.version
)))
}
} else {
Ok(None)
}
}
None => Ok(None),
}
}
pub async fn clean_old_checkpoints(
&self,
job_id: &str,
keep_count: usize,
) -> Result<(), StateError> {
debug!(
"Cleaning old checkpoints for job {} (keeping {})",
job_id, keep_count
);
Ok(())
}
}
fn calculate_checksum(state: &JobState) -> String {
let mut hasher = Sha256::new();
hasher.update(state.id.as_bytes());
hasher.update(format!("{:?}", state.phase).as_bytes());
hasher.update(state.total_items.to_string().as_bytes());
let mut items: Vec<_> = state.processed_items.iter().cloned().collect();
items.sort();
for item in items {
hasher.update(item.as_bytes());
}
let mut results: Vec<_> = state.agent_results.keys().cloned().collect();
results.sort();
for key in results {
hasher.update(key.as_bytes());
if let Some(result) = state.agent_results.get(&key) {
hasher.update(format!("{:?}", result.status).as_bytes());
}
}
format!("{:x}", hasher.finalize())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cook::execution::mapreduce::state::PhaseType;
use crate::cook::execution::mapreduce::{AgentStatus, MapReduceConfig};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
#[tokio::test]
async fn test_checkpoint_creation() {
let store = Arc::new(super::super::persistence::InMemoryStateStore::new());
let manager = StateManager::new(store);
let config = MapReduceConfig::default();
let job_id = "test-job-checkpoint".to_string();
let _state = manager.create_job(&config, job_id.clone()).await.unwrap();
manager
.update_state(&job_id, |state| {
state.total_items = 5;
state.processed_items.insert("item_0".to_string());
state.processed_items.insert("item_1".to_string());
state.agent_results.insert(
"item_0".to_string(),
AgentResult {
item_id: "item_0".to_string(),
status: AgentStatus::Success,
output: Some("output".to_string()),
commits: vec![],
duration: Duration::from_secs(1),
error: None,
worktree_path: None,
branch_name: None,
worktree_session_id: None,
files_modified: vec![],
json_log_location: None,
cleanup_status: None,
},
);
state.agent_results.insert(
"item_1".to_string(),
AgentResult {
item_id: "item_1".to_string(),
status: AgentStatus::Success,
output: Some("output".to_string()),
commits: vec![],
duration: Duration::from_secs(1),
error: None,
worktree_path: None,
branch_name: None,
worktree_session_id: None,
files_modified: vec![],
json_log_location: None,
cleanup_status: None,
},
);
Ok(())
})
.await
.unwrap();
let checkpoint = manager.create_checkpoint(&job_id).await.unwrap();
assert_eq!(checkpoint.version, 1);
assert_eq!(checkpoint.items_processed.len(), 2);
assert_eq!(checkpoint.agent_results.len(), 2);
let state = manager.get_state(&job_id).await.unwrap().unwrap();
assert!(state.checkpoint.is_some());
assert_eq!(state.checkpoint.as_ref().unwrap().version, 1);
}
#[tokio::test]
async fn test_checkpoint_validation() {
let store = Arc::new(super::super::persistence::InMemoryStateStore::new());
let manager = StateManager::new(store);
let config = MapReduceConfig::default();
let job_id = "test-job-validation".to_string();
let _state = manager.create_job(&config, job_id.clone()).await.unwrap();
manager
.update_state(&job_id, |state| {
state.total_items = 5; state.processed_items.insert("item_0".to_string());
Ok(())
})
.await
.unwrap();
let state = manager.get_state(&job_id).await.unwrap().unwrap();
let result = manager.validate_checkpoint(&state);
assert!(
result.is_err(),
"Expected validation to fail but it succeeded"
);
let error_msg = result.unwrap_err().to_string();
assert!(
error_msg.contains("has no result"),
"Expected error message to contain 'has no result' but got: '{}'",
error_msg
);
}
#[tokio::test]
async fn test_checksum_calculation() {
let mut state1 = JobState {
id: "test-job".to_string(),
phase: PhaseType::Map,
checkpoint: None,
processed_items: HashSet::new(),
failed_items: Vec::new(),
variables: HashMap::new(),
created_at: Utc::now(),
updated_at: Utc::now(),
config: MapReduceConfig::default(),
agent_results: HashMap::new(),
is_complete: false,
total_items: 10,
};
let checksum1 = calculate_checksum(&state1);
let checksum2 = calculate_checksum(&state1);
assert_eq!(checksum1, checksum2);
state1.processed_items.insert("item_0".to_string());
let checksum3 = calculate_checksum(&state1);
assert_ne!(checksum1, checksum3);
}
}