use std::collections::HashMap;
use std::path::Path;
use serde::{Deserialize, Serialize};
use crate::WorkflowError;
use crate::context::WorkflowContext;
use crate::step::StepOutput;
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Checkpoint {
completed: HashMap<String, StepOutput>,
}
impl Checkpoint {
pub fn new() -> Self {
Self::default()
}
pub fn mark_completed(&mut self, step_id: &str, output: StepOutput) {
self.completed.insert(step_id.into(), output);
}
pub fn is_completed(&self, step_id: &str) -> bool {
self.completed.contains_key(step_id)
}
pub fn completed_ids(&self) -> Vec<&str> {
self.completed.keys().map(String::as_str).collect()
}
pub fn output(&self, step_id: &str) -> Option<&StepOutput> {
self.completed.get(step_id)
}
pub fn into_context(self) -> WorkflowContext {
let mut ctx = WorkflowContext::new();
for (id, output) in self.completed {
ctx.set_output(&id, output);
}
ctx
}
pub async fn save(&self, path: &Path) -> Result<(), WorkflowError> {
let json = serde_json::to_string_pretty(self).map_err(|e| WorkflowError::Checkpoint {
message: format!("serialize failed: {e}"),
})?;
tokio::fs::write(path, json)
.await
.map_err(|e| WorkflowError::Checkpoint {
message: format!("write failed: {e}"),
})
}
pub async fn load(path: &Path) -> Result<Self, WorkflowError> {
let json =
tokio::fs::read_to_string(path)
.await
.map_err(|e| WorkflowError::Checkpoint {
message: format!("read failed: {e}"),
})?;
serde_json::from_str(&json).map_err(|e| WorkflowError::Checkpoint {
message: format!("deserialize failed: {e}"),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn checkpoint_stores_completed_steps() {
let mut checkpoint = Checkpoint::new();
checkpoint.mark_completed("a", StepOutput::new("A result"));
checkpoint.mark_completed("b", StepOutput::new("B result"));
assert!(checkpoint.is_completed("a"));
assert!(checkpoint.is_completed("b"));
assert!(!checkpoint.is_completed("c"));
}
#[test]
fn checkpoint_returns_completed_ids() {
let mut checkpoint = Checkpoint::new();
checkpoint.mark_completed("x", StepOutput::new("X"));
checkpoint.mark_completed("y", StepOutput::new("Y"));
let mut ids = checkpoint.completed_ids();
ids.sort_unstable();
assert_eq!(ids, vec!["x", "y"]);
}
#[test]
fn checkpoint_returns_output_for_completed_step() {
let mut checkpoint = Checkpoint::new();
checkpoint.mark_completed("a", StepOutput::new("hello"));
let output = checkpoint.output("a").unwrap();
assert_eq!(output.value(), "hello");
}
#[test]
fn checkpoint_serializes_to_json() {
let mut checkpoint = Checkpoint::new();
checkpoint.mark_completed("a", StepOutput::new("result_a"));
let json = serde_json::to_string(&checkpoint).unwrap();
assert!(json.contains("result_a"));
}
#[test]
fn checkpoint_roundtrips_through_json() {
let mut original = Checkpoint::new();
original.mark_completed("a", StepOutput::new("A"));
original.mark_completed("b", StepOutput::new("B"));
let json = serde_json::to_string(&original).unwrap();
let restored: Checkpoint = serde_json::from_str(&json).unwrap();
assert!(restored.is_completed("a"));
assert!(restored.is_completed("b"));
assert_eq!(restored.output("a").unwrap().value(), "A");
assert_eq!(restored.output("b").unwrap().value(), "B");
}
#[tokio::test]
async fn saves_and_loads_from_file() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("checkpoint.json");
let mut checkpoint = Checkpoint::new();
checkpoint.mark_completed("step_1", StepOutput::new("output_1"));
checkpoint.mark_completed("step_2", StepOutput::new("output_2"));
checkpoint.save(&path).await.unwrap();
assert!(path.exists());
let loaded = Checkpoint::load(&path).await.unwrap();
assert!(loaded.is_completed("step_1"));
assert!(loaded.is_completed("step_2"));
assert_eq!(loaded.output("step_1").unwrap().value(), "output_1");
}
#[tokio::test]
async fn load_returns_error_for_missing_file() {
let result = Checkpoint::load(Path::new("/tmp/nonexistent_ckpt.json")).await;
assert!(result.is_err());
}
#[test]
fn converts_to_workflow_context() {
let mut checkpoint = Checkpoint::new();
checkpoint.mark_completed("a", StepOutput::new("A"));
checkpoint.mark_completed("b", StepOutput::new("B"));
let ctx = checkpoint.into_context();
assert!(ctx.is_completed("a"));
assert!(ctx.is_completed("b"));
assert_eq!(ctx.output("a").unwrap().value(), "A");
}
}