erio_workflow/
checkpoint.rs1use std::collections::HashMap;
4use std::path::Path;
5
6use serde::{Deserialize, Serialize};
7
8use crate::WorkflowError;
9use crate::context::WorkflowContext;
10use crate::step::StepOutput;
11
12#[derive(Debug, Clone, Serialize, Deserialize, Default)]
17pub struct Checkpoint {
18 completed: HashMap<String, StepOutput>,
19}
20
21impl Checkpoint {
22 pub fn new() -> Self {
24 Self::default()
25 }
26
27 pub fn mark_completed(&mut self, step_id: &str, output: StepOutput) {
29 self.completed.insert(step_id.into(), output);
30 }
31
32 pub fn is_completed(&self, step_id: &str) -> bool {
34 self.completed.contains_key(step_id)
35 }
36
37 pub fn completed_ids(&self) -> Vec<&str> {
39 self.completed.keys().map(String::as_str).collect()
40 }
41
42 pub fn output(&self, step_id: &str) -> Option<&StepOutput> {
44 self.completed.get(step_id)
45 }
46
47 pub fn into_context(self) -> WorkflowContext {
49 let mut ctx = WorkflowContext::new();
50 for (id, output) in self.completed {
51 ctx.set_output(&id, output);
52 }
53 ctx
54 }
55
56 pub async fn save(&self, path: &Path) -> Result<(), WorkflowError> {
58 let json = serde_json::to_string_pretty(self).map_err(|e| WorkflowError::Checkpoint {
59 message: format!("serialize failed: {e}"),
60 })?;
61 tokio::fs::write(path, json)
62 .await
63 .map_err(|e| WorkflowError::Checkpoint {
64 message: format!("write failed: {e}"),
65 })
66 }
67
68 pub async fn load(path: &Path) -> Result<Self, WorkflowError> {
70 let json =
71 tokio::fs::read_to_string(path)
72 .await
73 .map_err(|e| WorkflowError::Checkpoint {
74 message: format!("read failed: {e}"),
75 })?;
76 serde_json::from_str(&json).map_err(|e| WorkflowError::Checkpoint {
77 message: format!("deserialize failed: {e}"),
78 })
79 }
80}
81
82#[cfg(test)]
83mod tests {
84 use super::*;
85
86 #[test]
89 fn checkpoint_stores_completed_steps() {
90 let mut checkpoint = Checkpoint::new();
91 checkpoint.mark_completed("a", StepOutput::new("A result"));
92 checkpoint.mark_completed("b", StepOutput::new("B result"));
93
94 assert!(checkpoint.is_completed("a"));
95 assert!(checkpoint.is_completed("b"));
96 assert!(!checkpoint.is_completed("c"));
97 }
98
99 #[test]
100 fn checkpoint_returns_completed_ids() {
101 let mut checkpoint = Checkpoint::new();
102 checkpoint.mark_completed("x", StepOutput::new("X"));
103 checkpoint.mark_completed("y", StepOutput::new("Y"));
104
105 let mut ids = checkpoint.completed_ids();
106 ids.sort_unstable();
107 assert_eq!(ids, vec!["x", "y"]);
108 }
109
110 #[test]
111 fn checkpoint_returns_output_for_completed_step() {
112 let mut checkpoint = Checkpoint::new();
113 checkpoint.mark_completed("a", StepOutput::new("hello"));
114
115 let output = checkpoint.output("a").unwrap();
116 assert_eq!(output.value(), "hello");
117 }
118
119 #[test]
122 fn checkpoint_serializes_to_json() {
123 let mut checkpoint = Checkpoint::new();
124 checkpoint.mark_completed("a", StepOutput::new("result_a"));
125
126 let json = serde_json::to_string(&checkpoint).unwrap();
127 assert!(json.contains("result_a"));
128 }
129
130 #[test]
131 fn checkpoint_roundtrips_through_json() {
132 let mut original = Checkpoint::new();
133 original.mark_completed("a", StepOutput::new("A"));
134 original.mark_completed("b", StepOutput::new("B"));
135
136 let json = serde_json::to_string(&original).unwrap();
137 let restored: Checkpoint = serde_json::from_str(&json).unwrap();
138
139 assert!(restored.is_completed("a"));
140 assert!(restored.is_completed("b"));
141 assert_eq!(restored.output("a").unwrap().value(), "A");
142 assert_eq!(restored.output("b").unwrap().value(), "B");
143 }
144
145 #[tokio::test]
148 async fn saves_and_loads_from_file() {
149 let dir = tempfile::tempdir().unwrap();
150 let path = dir.path().join("checkpoint.json");
151
152 let mut checkpoint = Checkpoint::new();
153 checkpoint.mark_completed("step_1", StepOutput::new("output_1"));
154 checkpoint.mark_completed("step_2", StepOutput::new("output_2"));
155
156 checkpoint.save(&path).await.unwrap();
157 assert!(path.exists());
158
159 let loaded = Checkpoint::load(&path).await.unwrap();
160 assert!(loaded.is_completed("step_1"));
161 assert!(loaded.is_completed("step_2"));
162 assert_eq!(loaded.output("step_1").unwrap().value(), "output_1");
163 }
164
165 #[tokio::test]
166 async fn load_returns_error_for_missing_file() {
167 let result = Checkpoint::load(Path::new("/tmp/nonexistent_ckpt.json")).await;
168 assert!(result.is_err());
169 }
170
171 #[test]
174 fn converts_to_workflow_context() {
175 let mut checkpoint = Checkpoint::new();
176 checkpoint.mark_completed("a", StepOutput::new("A"));
177 checkpoint.mark_completed("b", StepOutput::new("B"));
178
179 let ctx = checkpoint.into_context();
180
181 assert!(ctx.is_completed("a"));
182 assert!(ctx.is_completed("b"));
183 assert_eq!(ctx.output("a").unwrap().value(), "A");
184 }
185}