Skip to main content

erio_workflow/
checkpoint.rs

1//! Workflow checkpointing and recovery.
2
3use std::collections::HashMap;
4use std::path::Path;
5
6use serde::{Deserialize, Serialize};
7
8use crate::WorkflowError;
9use crate::context::WorkflowContext;
10use crate::step::StepOutput;
11
12/// A serializable snapshot of workflow progress.
13///
14/// Stores which steps have completed and their outputs, allowing
15/// a workflow to be resumed from the last checkpoint after a crash.
16#[derive(Debug, Clone, Serialize, Deserialize, Default)]
17pub struct Checkpoint {
18    completed: HashMap<String, StepOutput>,
19}
20
21impl Checkpoint {
22    /// Creates an empty checkpoint.
23    pub fn new() -> Self {
24        Self::default()
25    }
26
27    /// Marks a step as completed with its output.
28    pub fn mark_completed(&mut self, step_id: &str, output: StepOutput) {
29        self.completed.insert(step_id.into(), output);
30    }
31
32    /// Returns `true` if the step has been completed.
33    pub fn is_completed(&self, step_id: &str) -> bool {
34        self.completed.contains_key(step_id)
35    }
36
37    /// Returns the IDs of all completed steps.
38    pub fn completed_ids(&self) -> Vec<&str> {
39        self.completed.keys().map(String::as_str).collect()
40    }
41
42    /// Returns the output of a completed step.
43    pub fn output(&self, step_id: &str) -> Option<&StepOutput> {
44        self.completed.get(step_id)
45    }
46
47    /// Converts this checkpoint into a `WorkflowContext` for resuming execution.
48    pub fn into_context(self) -> WorkflowContext {
49        let mut ctx = WorkflowContext::new();
50        for (id, output) in self.completed {
51            ctx.set_output(&id, output);
52        }
53        ctx
54    }
55
56    /// Saves the checkpoint to a JSON file.
57    pub async fn save(&self, path: &Path) -> Result<(), WorkflowError> {
58        let json = serde_json::to_string_pretty(self).map_err(|e| WorkflowError::Checkpoint {
59            message: format!("serialize failed: {e}"),
60        })?;
61        tokio::fs::write(path, json)
62            .await
63            .map_err(|e| WorkflowError::Checkpoint {
64                message: format!("write failed: {e}"),
65            })
66    }
67
68    /// Loads a checkpoint from a JSON file.
69    pub async fn load(path: &Path) -> Result<Self, WorkflowError> {
70        let json =
71            tokio::fs::read_to_string(path)
72                .await
73                .map_err(|e| WorkflowError::Checkpoint {
74                    message: format!("read failed: {e}"),
75                })?;
76        serde_json::from_str(&json).map_err(|e| WorkflowError::Checkpoint {
77            message: format!("deserialize failed: {e}"),
78        })
79    }
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85
86    // === Checkpoint Data Tests ===
87
88    #[test]
89    fn checkpoint_stores_completed_steps() {
90        let mut checkpoint = Checkpoint::new();
91        checkpoint.mark_completed("a", StepOutput::new("A result"));
92        checkpoint.mark_completed("b", StepOutput::new("B result"));
93
94        assert!(checkpoint.is_completed("a"));
95        assert!(checkpoint.is_completed("b"));
96        assert!(!checkpoint.is_completed("c"));
97    }
98
99    #[test]
100    fn checkpoint_returns_completed_ids() {
101        let mut checkpoint = Checkpoint::new();
102        checkpoint.mark_completed("x", StepOutput::new("X"));
103        checkpoint.mark_completed("y", StepOutput::new("Y"));
104
105        let mut ids = checkpoint.completed_ids();
106        ids.sort_unstable();
107        assert_eq!(ids, vec!["x", "y"]);
108    }
109
110    #[test]
111    fn checkpoint_returns_output_for_completed_step() {
112        let mut checkpoint = Checkpoint::new();
113        checkpoint.mark_completed("a", StepOutput::new("hello"));
114
115        let output = checkpoint.output("a").unwrap();
116        assert_eq!(output.value(), "hello");
117    }
118
119    // === Serialization Tests ===
120
121    #[test]
122    fn checkpoint_serializes_to_json() {
123        let mut checkpoint = Checkpoint::new();
124        checkpoint.mark_completed("a", StepOutput::new("result_a"));
125
126        let json = serde_json::to_string(&checkpoint).unwrap();
127        assert!(json.contains("result_a"));
128    }
129
130    #[test]
131    fn checkpoint_roundtrips_through_json() {
132        let mut original = Checkpoint::new();
133        original.mark_completed("a", StepOutput::new("A"));
134        original.mark_completed("b", StepOutput::new("B"));
135
136        let json = serde_json::to_string(&original).unwrap();
137        let restored: Checkpoint = serde_json::from_str(&json).unwrap();
138
139        assert!(restored.is_completed("a"));
140        assert!(restored.is_completed("b"));
141        assert_eq!(restored.output("a").unwrap().value(), "A");
142        assert_eq!(restored.output("b").unwrap().value(), "B");
143    }
144
145    // === File Persistence Tests ===
146
147    #[tokio::test]
148    async fn saves_and_loads_from_file() {
149        let dir = tempfile::tempdir().unwrap();
150        let path = dir.path().join("checkpoint.json");
151
152        let mut checkpoint = Checkpoint::new();
153        checkpoint.mark_completed("step_1", StepOutput::new("output_1"));
154        checkpoint.mark_completed("step_2", StepOutput::new("output_2"));
155
156        checkpoint.save(&path).await.unwrap();
157        assert!(path.exists());
158
159        let loaded = Checkpoint::load(&path).await.unwrap();
160        assert!(loaded.is_completed("step_1"));
161        assert!(loaded.is_completed("step_2"));
162        assert_eq!(loaded.output("step_1").unwrap().value(), "output_1");
163    }
164
165    #[tokio::test]
166    async fn load_returns_error_for_missing_file() {
167        let result = Checkpoint::load(Path::new("/tmp/nonexistent_ckpt.json")).await;
168        assert!(result.is_err());
169    }
170
171    // === Integration with WorkflowContext ===
172
173    #[test]
174    fn converts_to_workflow_context() {
175        let mut checkpoint = Checkpoint::new();
176        checkpoint.mark_completed("a", StepOutput::new("A"));
177        checkpoint.mark_completed("b", StepOutput::new("B"));
178
179        let ctx = checkpoint.into_context();
180
181        assert!(ctx.is_completed("a"));
182        assert!(ctx.is_completed("b"));
183        assert_eq!(ctx.output("a").unwrap().value(), "A");
184    }
185}