use super::storage::CheckpointStorage;
use super::types::*;
use anyhow::{anyhow, Context, Result};
use chrono::{DateTime, Utc};
use std::path::PathBuf;
use tokio::fs;
use tracing::{debug, info, warn};
pub struct CheckpointManager {
storage: Box<dyn CheckpointStorage>,
config: CheckpointConfig,
job_id: String,
}
impl CheckpointManager {
pub fn new(
storage: Box<dyn CheckpointStorage>,
config: CheckpointConfig,
job_id: String,
) -> Self {
Self {
storage,
config,
job_id,
}
}
pub async fn create_checkpoint(
&self,
state: &MapReduceCheckpoint,
reason: CheckpointReason,
) -> Result<CheckpointId> {
let checkpoint_id = CheckpointId::new();
let mut checkpoint = state.clone();
checkpoint.metadata.checkpoint_id = checkpoint_id.to_string();
checkpoint.metadata.created_at = Utc::now();
checkpoint.metadata.checkpoint_reason = reason;
checkpoint.metadata.integrity_hash = self.calculate_integrity_hash(&checkpoint);
if self.config.validate_on_save {
self.validate_checkpoint(&checkpoint)?;
}
self.storage.save_checkpoint(&checkpoint).await?;
self.cleanup_old_checkpoints().await?;
info!(
"Created checkpoint {} for job {}",
checkpoint_id, self.job_id
);
Ok(checkpoint_id)
}
pub async fn resume_from_checkpoint(
&self,
checkpoint_id: Option<CheckpointId>,
) -> Result<ResumeState> {
let checkpoint_id = match checkpoint_id {
Some(id) => id,
None => self
.find_latest_checkpoint()
.await?
.ok_or_else(|| anyhow!("No checkpoint found for job {}", self.job_id))?,
};
let checkpoint = self.storage.load_checkpoint(&checkpoint_id).await?;
if self.config.validate_on_load {
self.validate_checkpoint(&checkpoint)?;
self.validate_integrity(&checkpoint)?;
}
let resume_state = self.build_resume_state(checkpoint)?;
Ok(resume_state)
}
pub async fn resume_from_checkpoint_with_strategy(
&self,
checkpoint_id: Option<CheckpointId>,
strategy: ResumeStrategy,
) -> Result<ResumeState> {
let checkpoint_id = match checkpoint_id {
Some(id) => id,
None => self
.find_latest_checkpoint()
.await?
.ok_or_else(|| anyhow!("No checkpoint found for job {}", self.job_id))?,
};
let checkpoint = self.storage.load_checkpoint(&checkpoint_id).await?;
if self.config.validate_on_load {
self.validate_checkpoint(&checkpoint)?;
self.validate_integrity(&checkpoint)?;
}
let work_items = self.prepare_work_items_for_resume(&checkpoint, &strategy)?;
Ok(ResumeState {
execution_state: checkpoint.execution_state.clone(),
work_items,
agents: checkpoint.agent_state.clone(),
variables: checkpoint.variable_state.clone(),
resources: checkpoint.resource_state.clone(),
resume_strategy: strategy,
checkpoint,
})
}
pub async fn list_checkpoints(&self) -> Result<Vec<CheckpointInfo>> {
self.storage.list_checkpoints(&self.job_id).await
}
pub async fn delete_checkpoint(&self, checkpoint_id: &CheckpointId) -> Result<()> {
self.storage.delete_checkpoint(checkpoint_id).await
}
fn validate_checkpoint(&self, checkpoint: &MapReduceCheckpoint) -> Result<()> {
let total_accounted = [
checkpoint.work_item_state.completed_items.len(),
checkpoint.work_item_state.failed_items.len(),
checkpoint.work_item_state.pending_items.len(),
checkpoint.work_item_state.in_progress_items.len(),
]
.into_iter()
.sum::<usize>();
if total_accounted != checkpoint.metadata.total_work_items {
warn!(
"Work item count mismatch: {} accounted vs {} total",
total_accounted, checkpoint.metadata.total_work_items
);
}
checkpoint
.agent_state
.agent_assignments
.keys()
.find(|agent_id| !checkpoint.agent_state.active_agents.contains_key(*agent_id))
.map_or(Ok(()), |agent_id| {
Err(anyhow!(
"Agent {} has assignments but is not active",
agent_id
))
})
}
fn validate_integrity(&self, checkpoint: &MapReduceCheckpoint) -> Result<()> {
let calculated_hash = self.calculate_integrity_hash(checkpoint);
if calculated_hash != checkpoint.metadata.integrity_hash {
return Err(anyhow!("Checkpoint integrity check failed: hash mismatch"));
}
Ok(())
}
fn calculate_integrity_hash(&self, checkpoint: &MapReduceCheckpoint) -> String {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(checkpoint.metadata.job_id.as_bytes());
hasher.update(checkpoint.metadata.version.to_string().as_bytes());
hasher.update(format!("{:?}", checkpoint.metadata.phase).as_bytes());
hasher.update(checkpoint.metadata.total_work_items.to_string().as_bytes());
hasher.update(checkpoint.metadata.completed_items.to_string().as_bytes());
hasher.update(
checkpoint
.work_item_state
.completed_items
.len()
.to_string()
.as_bytes(),
);
hasher.update(
checkpoint
.work_item_state
.failed_items
.len()
.to_string()
.as_bytes(),
);
format!("{:x}", hasher.finalize())
}
async fn find_latest_checkpoint(&self) -> Result<Option<CheckpointId>> {
let checkpoints = self.list_checkpoints().await?;
if checkpoints.is_empty() {
return Ok(None);
}
let latest = checkpoints
.into_iter()
.max_by_key(|cp| cp.created_at)
.map(|cp| CheckpointId::from_string(cp.id));
Ok(latest)
}
async fn cleanup_old_checkpoints(&self) -> Result<()> {
if let Some(ref policy) = self.config.retention_policy {
let checkpoints = self.list_checkpoints().await?;
let to_delete = self.select_checkpoints_for_deletion(&checkpoints, policy);
for checkpoint_id in to_delete {
self.delete_checkpoint(&checkpoint_id).await?;
debug!("Deleted old checkpoint {}", checkpoint_id);
}
}
Ok(())
}
fn select_checkpoints_for_deletion(
&self,
checkpoints: &[CheckpointInfo],
policy: &RetentionPolicy,
) -> Vec<CheckpointId> {
let mut sorted = checkpoints.to_vec();
sorted.sort_by_key(|c| c.created_at);
let should_delete = |cp: &CheckpointInfo| !policy.keep_final || !cp.is_final;
let excess_deletions: Vec<CheckpointId> = policy
.max_checkpoints
.filter(|&max| sorted.len() > max)
.map(|max| {
sorted
.iter()
.take(sorted.len() - max)
.filter(|cp| should_delete(cp))
.map(|cp| CheckpointId::from_string(cp.id.clone()))
.collect()
})
.unwrap_or_default();
let age_deletions: Vec<CheckpointId> = policy
.max_age
.map(|max_age| {
let cutoff = Utc::now() - chrono::Duration::from_std(max_age).unwrap_or_default();
sorted
.iter()
.filter(|cp| cp.created_at < cutoff && should_delete(cp))
.map(|cp| CheckpointId::from_string(cp.id.clone()))
.collect()
})
.unwrap_or_default();
excess_deletions.into_iter().chain(age_deletions).collect()
}
fn build_resume_state(&self, checkpoint: MapReduceCheckpoint) -> Result<ResumeState> {
let strategy = self.determine_resume_strategy(&checkpoint);
let work_items = self.prepare_work_items_for_resume(&checkpoint, &strategy)?;
Ok(ResumeState {
execution_state: checkpoint.execution_state.clone(),
work_items,
agents: checkpoint.agent_state.clone(),
variables: checkpoint.variable_state.clone(),
resources: checkpoint.resource_state.clone(),
resume_strategy: strategy,
checkpoint,
})
}
fn determine_resume_strategy(&self, checkpoint: &MapReduceCheckpoint) -> ResumeStrategy {
match checkpoint.metadata.phase {
PhaseType::Setup => ResumeStrategy::RestartCurrentPhase,
PhaseType::Map => {
if checkpoint.work_item_state.in_progress_items.is_empty() {
ResumeStrategy::ContinueFromCheckpoint
} else {
ResumeStrategy::ValidateAndContinue
}
}
PhaseType::Reduce => ResumeStrategy::ContinueFromCheckpoint,
PhaseType::Complete => ResumeStrategy::ContinueFromCheckpoint,
}
}
fn prepare_work_items_for_resume(
&self,
checkpoint: &MapReduceCheckpoint,
strategy: &ResumeStrategy,
) -> Result<WorkItemState> {
let mut work_items = checkpoint.work_item_state.clone();
match strategy {
ResumeStrategy::ContinueFromCheckpoint => {
Ok(work_items)
}
ResumeStrategy::ValidateAndContinue => {
for (_, progress) in work_items.in_progress_items.drain() {
work_items.pending_items.push(progress.work_item);
}
Ok(work_items)
}
ResumeStrategy::RestartCurrentPhase => {
work_items.pending_items.extend(
work_items
.in_progress_items
.drain()
.map(|(_, p)| p.work_item),
);
work_items.completed_items.clear();
Ok(work_items)
}
ResumeStrategy::RestartFromMapPhase => {
let all_items: Vec<WorkItem> = checkpoint
.work_item_state
.completed_items
.iter()
.map(|c| c.work_item.clone())
.chain(
checkpoint
.work_item_state
.in_progress_items
.values()
.map(|p| p.work_item.clone()),
)
.chain(checkpoint.work_item_state.pending_items.clone())
.collect();
work_items.pending_items = all_items;
work_items.in_progress_items.clear();
work_items.completed_items.clear();
Ok(work_items)
}
}
}
pub fn should_checkpoint(
&self,
items_processed: usize,
last_checkpoint_time: DateTime<Utc>,
) -> bool {
if let Some(interval) = self.config.interval_items {
if items_processed >= interval {
return true;
}
}
if let Some(interval) = self.config.interval_duration {
let elapsed = Utc::now().signed_duration_since(last_checkpoint_time);
if elapsed >= chrono::Duration::from_std(interval).unwrap_or_default() {
return true;
}
}
false
}
pub async fn export_checkpoint(
&self,
checkpoint_id: &CheckpointId,
export_path: PathBuf,
) -> Result<()> {
info!(
"Exporting checkpoint {} to {:?}",
checkpoint_id, export_path
);
let checkpoint = self
.storage
.load_checkpoint(checkpoint_id)
.await
.context("Failed to load checkpoint for export")?;
if let Some(parent) = export_path.parent() {
fs::create_dir_all(parent)
.await
.context("Failed to create export directory")?;
}
let json = serde_json::to_vec_pretty(&checkpoint)
.context("Failed to serialize checkpoint for export")?;
fs::write(&export_path, json)
.await
.context("Failed to write exported checkpoint")?;
info!("Successfully exported checkpoint to {:?}", export_path);
Ok(())
}
pub async fn import_checkpoint(&self, import_path: PathBuf) -> Result<CheckpointId> {
info!("Importing checkpoint from {:?}", import_path);
if !import_path.exists() {
return Err(anyhow!("Import file does not exist: {:?}", import_path));
}
let data = fs::read(&import_path)
.await
.context("Failed to read import file")?;
let mut checkpoint: MapReduceCheckpoint =
serde_json::from_slice(&data).context("Failed to parse imported checkpoint")?;
let new_id = CheckpointId::new();
checkpoint.metadata.checkpoint_id = new_id.to_string();
checkpoint.metadata.job_id = self.job_id.clone();
self.storage
.save_checkpoint(&checkpoint)
.await
.context("Failed to save imported checkpoint")?;
info!("Successfully imported checkpoint with ID {}", new_id);
Ok(new_id)
}
pub async fn save_reduce_checkpoint(
&self,
reduce_checkpoint: &super::reduce::ReducePhaseCheckpoint,
) -> Result<PathBuf> {
let checkpoint_dir = self.get_reduce_checkpoint_dir().await?;
let checkpoint_file = checkpoint_dir.join(format!(
"reduce-checkpoint-v{}-{}.json",
reduce_checkpoint.version,
reduce_checkpoint.timestamp.format("%Y%m%d_%H%M%S")
));
fs::create_dir_all(&checkpoint_dir).await?;
let json = serde_json::to_vec_pretty(reduce_checkpoint)
.context("Failed to serialize reduce checkpoint")?;
let temp_file = checkpoint_file.with_extension("tmp");
fs::write(&temp_file, &json).await?;
fs::rename(&temp_file, &checkpoint_file).await?;
info!("Saved reduce checkpoint to {:?}", checkpoint_file);
Ok(checkpoint_file)
}
pub async fn load_reduce_checkpoint(
&self,
) -> Result<Option<super::reduce::ReducePhaseCheckpoint>> {
let checkpoint_dir = self.get_reduce_checkpoint_dir().await?;
if !checkpoint_dir.exists() {
return Ok(None);
}
let mut entries = fs::read_dir(&checkpoint_dir).await?;
let mut latest_checkpoint: Option<(PathBuf, std::fs::Metadata)> = None;
while let Some(entry) = entries.next_entry().await? {
let path = entry.path();
if path
.file_name()
.and_then(|s| s.to_str())
.map(|s| s.starts_with("reduce-checkpoint-") && s.ends_with(".json"))
.unwrap_or(false)
{
if let Ok(metadata) = tokio::fs::metadata(&path).await {
if latest_checkpoint.is_none()
|| metadata.modified().ok()
> latest_checkpoint
.as_ref()
.and_then(|(_, m)| m.modified().ok())
{
latest_checkpoint = Some((path.clone(), metadata));
}
}
}
}
if let Some((checkpoint_file, _)) = latest_checkpoint {
let data = fs::read(&checkpoint_file).await?;
let checkpoint: super::reduce::ReducePhaseCheckpoint =
serde_json::from_slice(&data).context("Failed to deserialize reduce checkpoint")?;
info!("Loaded reduce checkpoint from {:?}", checkpoint_file);
Ok(Some(checkpoint))
} else {
Ok(None)
}
}
pub async fn can_resume_reduce(&self) -> Result<bool> {
let checkpoint = self.load_reduce_checkpoint().await?;
Ok(checkpoint.map(|c| c.can_resume()).unwrap_or(false))
}
async fn get_reduce_checkpoint_dir(&self) -> Result<PathBuf> {
let storage_dir = crate::storage::get_default_storage_dir()
.context("Failed to determine storage directory")?;
let checkpoint_dir = storage_dir
.join("state")
.join("reduce_checkpoints")
.join(&self.job_id);
Ok(checkpoint_dir)
}
}
#[cfg(test)]
mod tests {
use super::super::storage::{CheckpointStorage, CompressionAlgorithm, FileCheckpointStorage};
use super::*;
use crate::cook::execution::mapreduce::{AgentResult, AgentStatus};
use serde_json::Value;
use std::collections::HashMap;
use std::time::Duration;
fn create_test_checkpoint(job_id: &str) -> MapReduceCheckpoint {
MapReduceCheckpoint {
metadata: CheckpointMetadata {
checkpoint_id: "test-checkpoint".to_string(),
job_id: job_id.to_string(),
version: 1,
created_at: Utc::now(),
phase: PhaseType::Map,
total_work_items: 10,
completed_items: 5,
checkpoint_reason: CheckpointReason::Manual,
integrity_hash: String::new(),
},
execution_state: ExecutionState {
current_phase: PhaseType::Map,
phase_start_time: Utc::now(),
setup_results: None,
map_results: None,
reduce_results: None,
workflow_variables: HashMap::new(),
},
work_item_state: WorkItemState {
pending_items: vec![],
in_progress_items: HashMap::new(),
completed_items: vec![],
failed_items: vec![],
current_batch: None,
},
agent_state: AgentState {
active_agents: HashMap::new(),
agent_assignments: HashMap::new(),
agent_results: HashMap::new(),
resource_allocation: HashMap::new(),
},
variable_state: VariableState {
workflow_variables: HashMap::new(),
captured_outputs: HashMap::new(),
environment_variables: HashMap::new(),
item_variables: HashMap::new(),
},
resource_state: ResourceState {
total_agents_allowed: 10,
current_agents_active: 0,
worktrees_created: vec![],
worktrees_cleaned: vec![],
disk_usage_bytes: None,
},
error_state: ErrorState {
error_count: 0,
dlq_items: vec![],
error_threshold_reached: false,
last_error: None,
},
}
}
fn create_test_checkpoint_with_work_items(job_id: &str) -> MapReduceCheckpoint {
let mut checkpoint = create_test_checkpoint(job_id);
let items = vec![
WorkItem {
id: "item-1".to_string(),
data: Value::String("test1".to_string()),
},
WorkItem {
id: "item-2".to_string(),
data: Value::String("test2".to_string()),
},
WorkItem {
id: "item-3".to_string(),
data: Value::String("test3".to_string()),
},
WorkItem {
id: "item-4".to_string(),
data: Value::String("test4".to_string()),
},
WorkItem {
id: "item-5".to_string(),
data: Value::String("test5".to_string()),
},
];
checkpoint.work_item_state.pending_items = items[2..].to_vec();
checkpoint.work_item_state.completed_items = items[..2]
.iter()
.map(|item| CompletedWorkItem {
work_item: item.clone(),
result: crate::cook::execution::mapreduce::agent::types::AgentResult {
item_id: item.id.clone(),
status: AgentStatus::Success,
output: None,
commits: vec![],
files_modified: vec![],
duration: Duration::from_secs(1),
error: None,
worktree_path: None,
branch_name: None,
worktree_session_id: None,
json_log_location: None,
cleanup_status: None,
},
completed_at: Utc::now(),
})
.collect();
checkpoint
}
#[tokio::test]
async fn test_checkpoint_creation() {
let temp_dir = tempfile::tempdir().unwrap();
let storage = Box::new(FileCheckpointStorage::new(
temp_dir.path().to_path_buf(),
false,
));
let config = CheckpointConfig::default();
let job_id = "test-job".to_string();
let manager = CheckpointManager::new(storage, config, job_id.clone());
let checkpoint = MapReduceCheckpoint {
metadata: CheckpointMetadata {
checkpoint_id: "test-cp".to_string(),
job_id: job_id.clone(),
version: 1,
created_at: Utc::now(),
phase: PhaseType::Map,
total_work_items: 10,
completed_items: 5,
checkpoint_reason: CheckpointReason::Interval,
integrity_hash: String::new(),
},
execution_state: ExecutionState {
current_phase: PhaseType::Map,
phase_start_time: Utc::now(),
setup_results: None,
map_results: None,
reduce_results: None,
workflow_variables: HashMap::new(),
},
work_item_state: WorkItemState {
pending_items: vec![],
in_progress_items: HashMap::new(),
completed_items: vec![],
failed_items: vec![],
current_batch: None,
},
agent_state: AgentState {
active_agents: HashMap::new(),
agent_assignments: HashMap::new(),
agent_results: HashMap::new(),
resource_allocation: HashMap::new(),
},
variable_state: VariableState {
workflow_variables: HashMap::new(),
captured_outputs: HashMap::new(),
environment_variables: HashMap::new(),
item_variables: HashMap::new(),
},
resource_state: ResourceState {
total_agents_allowed: 10,
current_agents_active: 0,
worktrees_created: vec![],
worktrees_cleaned: vec![],
disk_usage_bytes: None,
},
error_state: ErrorState {
error_count: 0,
dlq_items: vec![],
error_threshold_reached: false,
last_error: None,
},
};
let checkpoint_id = manager
.create_checkpoint(&checkpoint, CheckpointReason::Interval)
.await
.unwrap();
assert!(!checkpoint_id.as_str().is_empty());
let checkpoints = manager.list_checkpoints().await.unwrap();
assert_eq!(checkpoints.len(), 1);
assert_eq!(checkpoints[0].job_id, job_id);
}
#[tokio::test]
async fn test_compression_algorithms() {
let temp_dir = tempfile::tempdir().unwrap();
let algorithms = vec![
CompressionAlgorithm::None,
CompressionAlgorithm::Gzip,
CompressionAlgorithm::Zstd,
CompressionAlgorithm::Lz4,
];
for algo in algorithms {
let storage = Box::new(FileCheckpointStorage::with_compression(
temp_dir.path().join(format!("{:?}", algo)).to_path_buf(),
algo,
));
let config = CheckpointConfig::default();
let manager = CheckpointManager::new(storage, config, format!("test-{:?}", algo));
let checkpoint = create_test_checkpoint(&format!("test-{:?}", algo));
let id = manager
.create_checkpoint(&checkpoint, CheckpointReason::Interval)
.await
.unwrap_or_else(|_| panic!("Failed to create checkpoint with {:?}", algo));
let loaded = manager
.resume_from_checkpoint(Some(id))
.await
.unwrap_or_else(|_| panic!("Failed to resume checkpoint with {:?}", algo));
assert_eq!(
loaded.checkpoint.metadata.job_id,
checkpoint.metadata.job_id
);
}
}
#[tokio::test]
async fn test_checkpoint_integrity_validation() {
let temp_dir = tempfile::tempdir().unwrap();
let storage = Box::new(FileCheckpointStorage::new(
temp_dir.path().to_path_buf(),
true,
));
let config = CheckpointConfig::default();
let manager = CheckpointManager::new(storage, config, "test-job".to_string());
let checkpoint = create_test_checkpoint("test-job");
let id = manager
.create_checkpoint(&checkpoint, CheckpointReason::Interval)
.await
.expect("Failed to create checkpoint");
let loaded = manager
.resume_from_checkpoint(Some(id.clone()))
.await
.expect("Failed to load checkpoint");
assert!(!loaded.checkpoint.metadata.integrity_hash.is_empty());
}
#[tokio::test]
async fn test_checkpoint_retention_policies() {
let temp_dir = tempfile::tempdir().unwrap();
let storage = Box::new(FileCheckpointStorage::new(
temp_dir.path().to_path_buf(),
false,
));
let config = CheckpointConfig {
retention_policy: Some(RetentionPolicy {
max_checkpoints: Some(3),
max_age: None,
keep_final: true,
}),
..Default::default()
};
let manager = CheckpointManager::new(storage, config, "test-job".to_string());
let mut checkpoint_ids = Vec::new();
for i in 0..5 {
let mut checkpoint = create_test_checkpoint("test-job");
checkpoint.metadata.checkpoint_id = format!("cp-{}", i);
let id = manager
.create_checkpoint(&checkpoint, CheckpointReason::Interval)
.await
.expect("Failed to create checkpoint");
checkpoint_ids.push(id);
}
let checkpoints = manager.list_checkpoints().await.unwrap();
assert_eq!(
checkpoints.len(),
3,
"Should have only 3 checkpoints due to retention policy"
);
}
#[tokio::test]
async fn test_checkpoint_export_import() {
let temp_dir = tempfile::tempdir().unwrap();
let storage = Box::new(FileCheckpointStorage::new(
temp_dir.path().join("storage").to_path_buf(),
true,
));
let config = CheckpointConfig::default();
let manager = CheckpointManager::new(storage, config, "test-job".to_string());
let checkpoint = create_test_checkpoint("test-job");
let id = manager
.create_checkpoint(&checkpoint, CheckpointReason::Manual)
.await
.expect("Failed to create checkpoint");
let export_path = temp_dir.path().join("exported.json");
manager
.export_checkpoint(&id, export_path.clone())
.await
.expect("Failed to export checkpoint");
assert!(export_path.exists());
let imported_id = manager
.import_checkpoint(export_path)
.await
.expect("Failed to import checkpoint");
let loaded = manager
.resume_from_checkpoint(Some(imported_id))
.await
.expect("Failed to load imported checkpoint");
assert_eq!(loaded.checkpoint.metadata.job_id, "test-job");
}
#[tokio::test]
async fn test_concurrent_checkpoint_operations() {
use futures::future::join_all;
let temp_dir = tempfile::tempdir().unwrap();
let storage = Box::new(FileCheckpointStorage::new(
temp_dir.path().to_path_buf(),
false,
));
let config = CheckpointConfig::default();
let manager = std::sync::Arc::new(CheckpointManager::new(
storage,
config,
"test-job".to_string(),
));
let tasks: Vec<_> = (0..10)
.map(|i| {
let manager = manager.clone();
tokio::spawn(async move {
let mut checkpoint = create_test_checkpoint("test-job");
checkpoint.metadata.checkpoint_id = format!("concurrent-{}", i);
manager
.create_checkpoint(&checkpoint, CheckpointReason::Interval)
.await
})
})
.collect();
let results = join_all(tasks).await;
for result in results {
assert!(result.unwrap().is_ok());
}
let checkpoints = manager.list_checkpoints().await.unwrap();
assert_eq!(checkpoints.len(), 10);
}
#[tokio::test]
async fn test_checkpoint_resume_strategies() {
let temp_dir = tempfile::tempdir().unwrap();
let storage = Box::new(FileCheckpointStorage::new(
temp_dir.path().to_path_buf(),
false,
));
let config = CheckpointConfig::default();
let manager = CheckpointManager::new(storage, config, "test-job".to_string());
let checkpoint = create_test_checkpoint_with_work_items("test-job");
let id = manager
.create_checkpoint(&checkpoint, CheckpointReason::Manual)
.await
.expect("Failed to create checkpoint");
let strategies = vec![
ResumeStrategy::ContinueFromCheckpoint,
ResumeStrategy::RestartCurrentPhase,
ResumeStrategy::RestartFromMapPhase,
ResumeStrategy::ValidateAndContinue,
];
for strategy in strategies {
let resume_state = manager
.resume_from_checkpoint_with_strategy(Some(id.clone()), strategy.clone())
.await
.unwrap_or_else(|_| panic!("Failed with strategy {:?}", strategy));
match strategy {
ResumeStrategy::ContinueFromCheckpoint => {
assert_eq!(resume_state.work_items.pending_items.len(), 3);
assert_eq!(resume_state.work_items.completed_items.len(), 2);
}
ResumeStrategy::RestartCurrentPhase => {
assert_eq!(resume_state.work_items.pending_items.len(), 3);
assert!(resume_state.work_items.completed_items.is_empty());
}
ResumeStrategy::RestartFromMapPhase => {
assert_eq!(resume_state.work_items.pending_items.len(), 5);
assert!(resume_state.work_items.completed_items.is_empty());
}
ResumeStrategy::ValidateAndContinue => {
assert!(resume_state.work_items.in_progress_items.is_empty());
}
}
}
}
#[tokio::test]
async fn test_checkpoint_resume() {
let temp_dir = tempfile::tempdir().unwrap();
let storage = Box::new(FileCheckpointStorage::new(
temp_dir.path().to_path_buf(),
true,
));
let config = CheckpointConfig::default();
let job_id = "test-job".to_string();
let manager = CheckpointManager::new(storage, config, job_id.clone());
let checkpoint = MapReduceCheckpoint {
metadata: CheckpointMetadata {
checkpoint_id: "test-cp".to_string(),
job_id: job_id.clone(),
version: 1,
created_at: Utc::now(),
phase: PhaseType::Map,
total_work_items: 10,
completed_items: 3,
checkpoint_reason: CheckpointReason::Interval,
integrity_hash: String::new(),
},
execution_state: ExecutionState {
current_phase: PhaseType::Map,
phase_start_time: Utc::now(),
setup_results: None,
map_results: None,
reduce_results: None,
workflow_variables: HashMap::new(),
},
work_item_state: WorkItemState {
pending_items: vec![
WorkItem {
id: "item_4".to_string(),
data: Value::String("data4".to_string()),
},
WorkItem {
id: "item_5".to_string(),
data: Value::String("data5".to_string()),
},
],
in_progress_items: {
let mut map = HashMap::new();
map.insert(
"item_3".to_string(),
WorkItemProgress {
work_item: WorkItem {
id: "item_3".to_string(),
data: Value::String("data3".to_string()),
},
agent_id: "agent_1".to_string(),
started_at: Utc::now(),
last_update: Utc::now(),
},
);
map
},
completed_items: vec![CompletedWorkItem {
work_item: WorkItem {
id: "item_1".to_string(),
data: Value::String("data1".to_string()),
},
result: AgentResult {
item_id: "item_1".to_string(),
status: AgentStatus::Success,
output: Some("output1".to_string()),
commits: vec![],
duration: Duration::from_secs(10),
error: None,
worktree_path: None,
branch_name: None,
worktree_session_id: None,
files_modified: vec![],
json_log_location: None,
cleanup_status: None,
},
completed_at: Utc::now(),
}],
failed_items: vec![],
current_batch: None,
},
agent_state: AgentState {
active_agents: HashMap::new(),
agent_assignments: HashMap::new(),
agent_results: HashMap::new(),
resource_allocation: HashMap::new(),
},
variable_state: VariableState {
workflow_variables: HashMap::new(),
captured_outputs: HashMap::new(),
environment_variables: HashMap::new(),
item_variables: HashMap::new(),
},
resource_state: ResourceState {
total_agents_allowed: 10,
current_agents_active: 1,
worktrees_created: vec!["wt1".to_string()],
worktrees_cleaned: vec![],
disk_usage_bytes: None,
},
error_state: ErrorState {
error_count: 0,
dlq_items: vec![],
error_threshold_reached: false,
last_error: None,
},
};
let checkpoint_id = manager
.create_checkpoint(&checkpoint, CheckpointReason::Interval)
.await
.unwrap();
let resume_state = manager
.resume_from_checkpoint(Some(checkpoint_id))
.await
.unwrap();
assert_eq!(resume_state.execution_state.current_phase, PhaseType::Map);
assert_eq!(resume_state.work_items.pending_items.len(), 3); assert!(resume_state.work_items.in_progress_items.is_empty());
}
#[tokio::test]
async fn test_checkpoint_interval_check() {
let temp_dir = tempfile::tempdir().unwrap();
let storage = Box::new(FileCheckpointStorage::new(
temp_dir.path().to_path_buf(),
false,
));
let config = CheckpointConfig {
interval_items: Some(5),
interval_duration: Some(Duration::from_secs(60)),
..Default::default()
};
let manager = CheckpointManager::new(storage, config, "test-job".to_string());
assert!(!manager.should_checkpoint(3, Utc::now()));
assert!(manager.should_checkpoint(5, Utc::now()));
assert!(manager.should_checkpoint(10, Utc::now()));
let old_time = Utc::now() - chrono::Duration::seconds(61);
assert!(manager.should_checkpoint(2, old_time));
}
#[tokio::test]
async fn test_checkpoint_id_generation() {
let id1 = CheckpointId::new();
let id2 = CheckpointId::new();
assert_ne!(id1.as_str(), id2.as_str());
assert!(id1.as_str().starts_with("cp-"));
let id_str = "custom-checkpoint-id".to_string();
let custom_id = CheckpointId::from_string(id_str.clone());
assert_eq!(custom_id.as_str(), "custom-checkpoint-id");
assert_eq!(format!("{}", custom_id), "custom-checkpoint-id");
}
#[tokio::test]
async fn test_checkpoint_validation() {
let temp_dir = tempfile::tempdir().unwrap();
let storage = Box::new(FileCheckpointStorage::new(
temp_dir.path().to_path_buf(),
false,
));
let config = CheckpointConfig {
validate_on_save: true,
validate_on_load: true,
..Default::default()
};
let manager = CheckpointManager::new(storage, config, "test-job".to_string());
let mut checkpoint = create_test_checkpoint("test-job");
checkpoint
.agent_state
.agent_assignments
.insert("non-existent-agent".to_string(), vec!["item1".to_string()]);
let result = manager
.create_checkpoint(&checkpoint, CheckpointReason::Manual)
.await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Agent"));
}
#[tokio::test]
async fn test_checkpoint_integrity() {
let temp_dir = tempfile::tempdir().unwrap();
let storage = Box::new(FileCheckpointStorage::new(
temp_dir.path().to_path_buf(),
false,
));
let config = CheckpointConfig {
validate_on_save: true,
validate_on_load: true,
..Default::default()
};
let manager = CheckpointManager::new(storage, config, "test-job".to_string());
let checkpoint = create_test_checkpoint("test-job");
let id = manager
.create_checkpoint(&checkpoint, CheckpointReason::Manual)
.await
.unwrap();
let loaded = manager
.resume_from_checkpoint(Some(id.clone()))
.await
.unwrap();
assert!(!loaded.checkpoint.metadata.integrity_hash.is_empty());
}
#[tokio::test]
async fn test_work_item_state_manipulation() {
let temp_dir = tempfile::tempdir().unwrap();
let storage = Box::new(FileCheckpointStorage::new(
temp_dir.path().to_path_buf(),
false,
));
let config = CheckpointConfig::default();
let manager = CheckpointManager::new(storage, config, "test-job".to_string());
let mut checkpoint = create_test_checkpoint_with_work_items("test-job");
checkpoint.work_item_state.in_progress_items.insert(
"item-3".to_string(),
WorkItemProgress {
work_item: WorkItem {
id: "item-3".to_string(),
data: Value::String("test3".to_string()),
},
agent_id: "agent-1".to_string(),
started_at: Utc::now(),
last_update: Utc::now(),
},
);
let id = manager
.create_checkpoint(&checkpoint, CheckpointReason::Manual)
.await
.unwrap();
let resume_state = manager
.resume_from_checkpoint_with_strategy(
Some(id.clone()),
ResumeStrategy::ValidateAndContinue,
)
.await
.unwrap();
assert!(resume_state.work_items.in_progress_items.is_empty());
assert_eq!(resume_state.work_items.pending_items.len(), 4); }
#[tokio::test]
async fn test_phase_transition_checkpoint() {
let temp_dir = tempfile::tempdir().unwrap();
let storage = Box::new(FileCheckpointStorage::new(
temp_dir.path().to_path_buf(),
false,
));
let config = CheckpointConfig::default();
let manager = CheckpointManager::new(storage, config, "test-job".to_string());
let phases = vec![
PhaseType::Setup,
PhaseType::Map,
PhaseType::Reduce,
PhaseType::Complete,
];
for phase in phases {
let mut checkpoint = create_test_checkpoint("test-job");
checkpoint.metadata.phase = phase;
checkpoint.execution_state.current_phase = phase;
let id = manager
.create_checkpoint(&checkpoint, CheckpointReason::PhaseTransition)
.await
.unwrap();
let resume = manager.resume_from_checkpoint(Some(id)).await.unwrap();
match phase {
PhaseType::Setup => {
assert!(matches!(
resume.resume_strategy,
ResumeStrategy::RestartCurrentPhase
));
}
PhaseType::Map => {
assert!(matches!(
resume.resume_strategy,
ResumeStrategy::ContinueFromCheckpoint
));
}
PhaseType::Reduce | PhaseType::Complete => {
assert!(matches!(
resume.resume_strategy,
ResumeStrategy::ContinueFromCheckpoint
));
}
}
}
}
#[tokio::test]
async fn test_failed_work_items() {
let temp_dir = tempfile::tempdir().unwrap();
let storage = Box::new(FileCheckpointStorage::new(
temp_dir.path().to_path_buf(),
false,
));
let config = CheckpointConfig::default();
let manager = CheckpointManager::new(storage, config, "test-job".to_string());
let mut checkpoint = create_test_checkpoint("test-job");
checkpoint.work_item_state.failed_items = vec![FailedWorkItem {
work_item: WorkItem {
id: "failed-1".to_string(),
data: Value::String("data".to_string()),
},
error: "Processing failed".to_string(),
failed_at: Utc::now(),
retry_count: 2,
}];
checkpoint.error_state.dlq_items = vec![DlqItem {
item_id: "dlq-1".to_string(),
error: "DLQ error".to_string(),
timestamp: Utc::now(),
retry_count: 1,
}];
checkpoint.error_state.error_count = 2;
let id = manager
.create_checkpoint(&checkpoint, CheckpointReason::ErrorRecovery)
.await
.unwrap();
let resume = manager.resume_from_checkpoint(Some(id)).await.unwrap();
assert_eq!(resume.checkpoint.work_item_state.failed_items.len(), 1);
assert_eq!(resume.checkpoint.error_state.dlq_items.len(), 1);
assert_eq!(resume.checkpoint.error_state.error_count, 2);
}
#[tokio::test]
async fn test_retention_policy_max_age() {
let temp_dir = tempfile::tempdir().unwrap();
let storage = Box::new(FileCheckpointStorage::new(
temp_dir.path().to_path_buf(),
false,
));
let config = CheckpointConfig::default();
let manager = CheckpointManager::new(storage, config, "test-job".to_string());
let checkpoints = vec![
CheckpointInfo {
id: "old-1".to_string(),
job_id: "test-job".to_string(),
created_at: Utc::now() - chrono::Duration::days(10),
phase: PhaseType::Map,
completed_items: 5,
total_items: 10,
is_final: false,
},
CheckpointInfo {
id: "recent-1".to_string(),
job_id: "test-job".to_string(),
created_at: Utc::now() - chrono::Duration::days(2),
phase: PhaseType::Map,
completed_items: 7,
total_items: 10,
is_final: false,
},
CheckpointInfo {
id: "final-old".to_string(),
job_id: "test-job".to_string(),
created_at: Utc::now() - chrono::Duration::days(15),
phase: PhaseType::Complete,
completed_items: 10,
total_items: 10,
is_final: true,
},
];
let policy = RetentionPolicy {
max_checkpoints: None,
max_age: Some(Duration::from_secs(5 * 24 * 3600)), keep_final: true,
};
let to_delete = manager.select_checkpoints_for_deletion(&checkpoints, &policy);
assert_eq!(to_delete.len(), 1);
assert_eq!(to_delete[0].as_str(), "old-1");
}
#[tokio::test]
async fn test_checkpoint_storage_not_found() {
let temp_dir = tempfile::tempdir().unwrap();
let storage = FileCheckpointStorage::new(temp_dir.path().to_path_buf(), false);
let non_existent_id = CheckpointId::from_string("non-existent".to_string());
let result = storage.load_checkpoint(&non_existent_id).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("not found"));
}
#[tokio::test]
async fn test_checkpoint_exists() {
let temp_dir = tempfile::tempdir().unwrap();
let storage = Box::new(FileCheckpointStorage::new(
temp_dir.path().to_path_buf(),
false,
));
let checkpoint = create_test_checkpoint("test-job");
let id = CheckpointId::from_string(checkpoint.metadata.checkpoint_id.clone());
assert!(!storage.checkpoint_exists(&id).await.unwrap());
storage.save_checkpoint(&checkpoint).await.unwrap();
assert!(storage.checkpoint_exists(&id).await.unwrap());
storage.delete_checkpoint(&id).await.unwrap();
assert!(!storage.checkpoint_exists(&id).await.unwrap());
}
#[tokio::test]
async fn test_restart_from_map_phase_strategy() {
let temp_dir = tempfile::tempdir().unwrap();
let storage = Box::new(FileCheckpointStorage::new(
temp_dir.path().to_path_buf(),
false,
));
let config = CheckpointConfig::default();
let manager = CheckpointManager::new(storage, config, "test-job".to_string());
let mut checkpoint = create_test_checkpoint_with_work_items("test-job");
checkpoint
.work_item_state
.failed_items
.push(FailedWorkItem {
work_item: WorkItem {
id: "failed-item".to_string(),
data: Value::String("failed".to_string()),
},
error: "Test error".to_string(),
failed_at: Utc::now(),
retry_count: 1,
});
let id = manager
.create_checkpoint(&checkpoint, CheckpointReason::Manual)
.await
.unwrap();
let resume_state = manager
.resume_from_checkpoint_with_strategy(Some(id), ResumeStrategy::RestartFromMapPhase)
.await
.unwrap();
assert_eq!(resume_state.work_items.pending_items.len(), 5); assert!(resume_state.work_items.completed_items.is_empty());
assert!(resume_state.work_items.in_progress_items.is_empty());
}
#[tokio::test]
async fn test_resource_state_tracking() {
let temp_dir = tempfile::tempdir().unwrap();
let storage = Box::new(FileCheckpointStorage::new(
temp_dir.path().to_path_buf(),
false,
));
let config = CheckpointConfig::default();
let manager = CheckpointManager::new(storage, config, "test-job".to_string());
let mut checkpoint = create_test_checkpoint("test-job");
checkpoint.resource_state = ResourceState {
total_agents_allowed: 20,
current_agents_active: 5,
worktrees_created: vec!["wt1".to_string(), "wt2".to_string(), "wt3".to_string()],
worktrees_cleaned: vec!["wt1".to_string()],
disk_usage_bytes: Some(1024 * 1024 * 100), };
let id = manager
.create_checkpoint(&checkpoint, CheckpointReason::Interval)
.await
.unwrap();
let resume = manager.resume_from_checkpoint(Some(id)).await.unwrap();
assert_eq!(resume.resources.total_agents_allowed, 20);
assert_eq!(resume.resources.current_agents_active, 5);
assert_eq!(resume.resources.worktrees_created.len(), 3);
assert_eq!(resume.resources.worktrees_cleaned.len(), 1);
assert_eq!(resume.resources.disk_usage_bytes, Some(1024 * 1024 * 100));
}
#[tokio::test]
async fn test_variable_state_preservation() {
let temp_dir = tempfile::tempdir().unwrap();
let storage = Box::new(FileCheckpointStorage::new(
temp_dir.path().to_path_buf(),
false,
));
let config = CheckpointConfig::default();
let manager = CheckpointManager::new(storage, config, "test-job".to_string());
let mut checkpoint = create_test_checkpoint("test-job");
checkpoint
.variable_state
.workflow_variables
.insert("output_dir".to_string(), "/tmp/output".to_string());
checkpoint
.variable_state
.captured_outputs
.insert("command_1".to_string(), "Success".to_string());
checkpoint
.variable_state
.environment_variables
.insert("PRODIGY_MODE".to_string(), "test".to_string());
let mut item_vars = HashMap::new();
item_vars.insert("path".to_string(), "/src/file.rs".to_string());
checkpoint
.variable_state
.item_variables
.insert("item-1".to_string(), item_vars);
let id = manager
.create_checkpoint(&checkpoint, CheckpointReason::Manual)
.await
.unwrap();
let resume = manager.resume_from_checkpoint(Some(id)).await.unwrap();
assert_eq!(
resume.variables.workflow_variables.get("output_dir"),
Some(&"/tmp/output".to_string())
);
assert_eq!(
resume.variables.captured_outputs.get("command_1"),
Some(&"Success".to_string())
);
assert_eq!(
resume.variables.environment_variables.get("PRODIGY_MODE"),
Some(&"test".to_string())
);
assert!(resume.variables.item_variables.contains_key("item-1"));
}
#[tokio::test]
async fn test_batch_tracking() {
let temp_dir = tempfile::tempdir().unwrap();
let storage = Box::new(FileCheckpointStorage::new(
temp_dir.path().to_path_buf(),
false,
));
let config = CheckpointConfig::default();
let manager = CheckpointManager::new(storage, config, "test-job".to_string());
let mut checkpoint = create_test_checkpoint("test-job");
checkpoint.work_item_state.current_batch = Some(WorkItemBatch {
batch_id: "batch-001".to_string(),
items: vec![
"item-1".to_string(),
"item-2".to_string(),
"item-3".to_string(),
],
started_at: Utc::now(),
});
let id = manager
.create_checkpoint(&checkpoint, CheckpointReason::BatchComplete)
.await
.unwrap();
let resume = manager.resume_from_checkpoint(Some(id)).await.unwrap();
let batch = resume.checkpoint.work_item_state.current_batch.unwrap();
assert_eq!(batch.batch_id, "batch-001");
assert_eq!(batch.items.len(), 3);
}
#[tokio::test]
async fn test_map_phase_results() {
let temp_dir = tempfile::tempdir().unwrap();
let storage = Box::new(FileCheckpointStorage::new(
temp_dir.path().to_path_buf(),
false,
));
let config = CheckpointConfig::default();
let manager = CheckpointManager::new(storage, config, "test-job".to_string());
let mut checkpoint = create_test_checkpoint("test-job");
checkpoint.execution_state.map_results = Some(MapPhaseResults {
successful_count: 42,
failed_count: 3,
total_duration: Duration::from_secs(120),
});
let id = manager
.create_checkpoint(&checkpoint, CheckpointReason::PhaseTransition)
.await
.unwrap();
let resume = manager.resume_from_checkpoint(Some(id)).await.unwrap();
let map_results = resume.checkpoint.execution_state.map_results.unwrap();
assert_eq!(map_results.successful_count, 42);
assert_eq!(map_results.failed_count, 3);
assert_eq!(map_results.total_duration, Duration::from_secs(120));
}
#[tokio::test]
async fn test_error_threshold_state() {
let temp_dir = tempfile::tempdir().unwrap();
let storage = Box::new(FileCheckpointStorage::new(
temp_dir.path().to_path_buf(),
false,
));
let config = CheckpointConfig::default();
let manager = CheckpointManager::new(storage, config, "test-job".to_string());
let mut checkpoint = create_test_checkpoint("test-job");
checkpoint.error_state = ErrorState {
error_count: 10,
dlq_items: vec![],
error_threshold_reached: true,
last_error: Some("Critical error occurred".to_string()),
};
let id = manager
.create_checkpoint(&checkpoint, CheckpointReason::ErrorRecovery)
.await
.unwrap();
let resume = manager.resume_from_checkpoint(Some(id)).await.unwrap();
assert_eq!(resume.checkpoint.error_state.error_count, 10);
assert!(resume.checkpoint.error_state.error_threshold_reached);
assert_eq!(
resume.checkpoint.error_state.last_error,
Some("Critical error occurred".to_string())
);
}
#[tokio::test]
async fn test_find_latest_checkpoint_empty() {
let temp_dir = tempfile::tempdir().unwrap();
let storage = Box::new(FileCheckpointStorage::new(
temp_dir.path().to_path_buf(),
false,
));
let config = CheckpointConfig::default();
let manager = CheckpointManager::new(storage, config, "empty-job".to_string());
let result = manager.find_latest_checkpoint().await.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn test_find_latest_checkpoint_multiple() {
let temp_dir = tempfile::tempdir().unwrap();
let storage = Box::new(FileCheckpointStorage::new(
temp_dir.path().to_path_buf(),
false,
));
let config = CheckpointConfig::default();
let manager = CheckpointManager::new(storage, config, "test-job".to_string());
let mut checkpoint_ids = Vec::new();
for i in 0..3 {
let mut checkpoint = create_test_checkpoint("test-job");
checkpoint.metadata.created_at = Utc::now() - chrono::Duration::seconds(10 - i);
let id = manager
.create_checkpoint(&checkpoint, CheckpointReason::Interval)
.await
.unwrap();
checkpoint_ids.push(id);
}
let latest = manager.find_latest_checkpoint().await.unwrap().unwrap();
assert!(!latest.as_str().is_empty());
}
}