1use std::fmt::Debug;
43
44use serde::{Deserialize, Serialize};
45
46use crate::state::State;
47use crate::state::workflow_state::WorkflowState;
48
49#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
53pub struct CheckpointId(pub uuid::Uuid);
54
55impl std::fmt::Display for CheckpointId {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 write!(f, "{}", self.0)
58 }
59}
60
61#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
65pub struct NodeId(pub String);
66
67impl std::fmt::Display for NodeId {
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 write!(f, "{}", self.0)
70 }
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct Checkpoint<S: WorkflowState = State> {
94 pub checkpoint_id: CheckpointId,
96 pub current_node: NodeId,
98 pub state: S::Checkpoint,
100 pub graph_hash: u64,
102 pub created_at: std::time::SystemTime,
104}
105
106impl<S: WorkflowState> Checkpoint<S> {
107 pub fn new(current_node: impl Into<String>, state: &S, graph_hash: u64) -> Self {
109 Self {
110 checkpoint_id: CheckpointId(uuid::Uuid::new_v4()),
111 current_node: NodeId(current_node.into()),
112 state: state.snapshot(),
113 graph_hash,
114 created_at: std::time::SystemTime::now(),
115 }
116 }
117
118 pub fn restore_state(self) -> S {
120 S::restore(self.state)
121 }
122}
123
124#[derive(Debug, Clone)]
134pub struct CheckpointBlob {
135 pub id: CheckpointId,
137 pub data: Vec<u8>,
139 pub graph_hash: u64,
141 pub created_at: std::time::SystemTime,
143}
144
145impl CheckpointBlob {
146 pub fn new(
147 id: CheckpointId,
148 data: Vec<u8>,
149 graph_hash: u64,
150 created_at: std::time::SystemTime,
151 ) -> Self {
152 Self {
153 id,
154 data,
155 graph_hash,
156 created_at,
157 }
158 }
159}
160
161#[derive(Debug, thiserror::Error)]
165pub enum CheckpointStoreError {
166 #[error("storage error: {0}")]
167 Storage(String),
168 #[error("checkpoint not found: {0}")]
169 NotFound(CheckpointId),
170 #[error("corrupted checkpoint: {0}")]
171 Corrupted(String),
172 #[error("serialization error: {0}")]
173 Serialization(String),
174 #[error("graph mismatch: expected hash {expected:#018x}, got {actual:#018x}")]
175 GraphMismatch { expected: u64, actual: u64 },
176}
177
178pub use crate::ids::TraceId;
185
186#[allow(deprecated)]
191#[derive(Clone, Serialize, Deserialize)]
197pub struct Frame<S: WorkflowState = State> {
198 pub graph_id: String,
200
201 pub node_id: String,
203
204 pub state: S::Checkpoint,
206
207 pub cursor: usize,
209}
210
211impl<S: WorkflowState> Frame<S> {
212 pub fn new(graph_id: String, node_id: String, state: &S, cursor: usize) -> Self {
214 Self {
215 graph_id,
216 node_id,
217 state: state.snapshot(),
218 cursor,
219 }
220 }
221}
222
223impl<S: WorkflowState> Debug for Frame<S>
224where
225 S::Checkpoint: Debug,
226{
227 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
228 f.debug_struct("Frame")
229 .field("graph_id", &self.graph_id)
230 .field("node_id", &self.node_id)
231 .field("state", &self.state)
232 .field("cursor", &self.cursor)
233 .finish()
234 }
235}
236
237#[derive(Clone, Serialize, Deserialize)]
241pub struct FrameStack<S: WorkflowState = State>
242where
243 S::Checkpoint: Debug,
244{
245 frames: Vec<Frame<S>>,
247}
248
249impl<S: WorkflowState> FrameStack<S>
250where
251 S::Checkpoint: Debug,
252{
253 pub fn new() -> Self {
255 Self { frames: Vec::new() }
256 }
257
258 pub fn push(&mut self, frame: Frame<S>) {
260 self.frames.push(frame);
261 }
262
263 pub fn pop(&mut self) -> Option<Frame<S>> {
265 self.frames.pop()
266 }
267
268 pub fn current(&self) -> Option<&Frame<S>> {
270 self.frames.last()
271 }
272
273 pub fn depth(&self) -> usize {
275 self.frames.len()
276 }
277
278 pub fn is_empty(&self) -> bool {
280 self.frames.is_empty()
281 }
282
283 pub fn frames(&self) -> &[Frame<S>] {
285 &self.frames
286 }
287}
288
289impl<S: WorkflowState> Default for FrameStack<S>
290where
291 S::Checkpoint: Debug,
292{
293 fn default() -> Self {
294 Self::new()
295 }
296}
297
298impl<S: WorkflowState> Debug for FrameStack<S>
299where
300 S::Checkpoint: Debug,
301{
302 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
303 f.debug_struct("FrameStack")
304 .field("frames", &self.frames)
305 .finish()
306 }
307}
308
309#[derive(Debug, Clone)]
320pub struct FrameInfo {
321 pub node_id: String,
323 pub step: usize,
325}
326
327impl FrameInfo {
328 pub fn new(node_id: impl Into<String>, step: usize) -> Self {
330 Self {
331 node_id: node_id.into(),
332 step,
333 }
334 }
335}
336
337pub trait CheckpointSink<S: WorkflowState>: Send + Sync {
356 fn on_checkpoint(&mut self, state: &S, frame: &FrameInfo);
361}
362
363#[derive(Debug, Default)]
367pub struct NoopCheckpointSink;
368
369impl<S: WorkflowState> CheckpointSink<S> for NoopCheckpointSink {
370 fn on_checkpoint(&mut self, _state: &S, _frame: &FrameInfo) {
371 }
373}
374
375pub struct MemorySink<S: WorkflowState = State> {
381 pub frames: Vec<Frame<S>>,
382}
383
384impl<S: WorkflowState> Debug for MemorySink<S>
385where
386 S::Checkpoint: Debug,
387{
388 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
389 f.debug_struct("MemorySink")
390 .field("frames", &self.frames)
391 .finish()
392 }
393}
394
395impl<S: WorkflowState> Default for MemorySink<S> {
396 fn default() -> Self {
397 Self { frames: Vec::new() }
398 }
399}
400
401impl<S: WorkflowState> MemorySink<S> {
402 pub fn new() -> Self {
403 Self::default()
404 }
405
406 pub fn into_frames(self) -> Vec<Frame<S>> {
407 self.frames
408 }
409}
410
411impl<S: WorkflowState> CheckpointSink<S> for MemorySink<S>
412where
413 S::Checkpoint: Sync,
414{
415 fn on_checkpoint(&mut self, state: &S, frame: &FrameInfo) {
416 self.frames.push(Frame {
417 graph_id: String::new(), node_id: frame.node_id.clone(),
419 state: state.snapshot(),
420 cursor: frame.step,
421 });
422 }
423}
424
425#[cfg(test)]
428mod tests {
429 use super::*;
430 use crate::state::StateMerge;
431 use crate::{GraphBuilder, NodeKind, TaskNode};
432 use std::sync::Arc;
433 use tokio_util::sync::CancellationToken;
434
435 #[tokio::test]
436 async fn test_auto_checkpoint_via_memory_sink() {
437 let mut builder = GraphBuilder::<State, StateMerge>::new("test");
439 builder.start("a");
440 builder.node("a", NodeKind::Task(TaskNode::new("a", |_| Ok(()))));
441 builder.node("b", NodeKind::Task(TaskNode::new("b", |_| Ok(()))));
442 builder.end("b");
443 builder.edge("a", "b");
444 let graph = Arc::new(builder.build().unwrap());
445
446 let mut sink = MemorySink::<State>::new();
448
449 let mut state = State::new();
451 let mut engine: crate::ExecutionEngine<'_, State> = crate::ExecutionEngine::new(
452 &mut state,
453 None,
454 CancellationToken::new(),
455 Some(&mut sink),
456 None,
457 );
458
459 let mut cb = crate::graph::NoopStepCallback;
461 graph.run_inline(&mut engine, 100, &mut cb).await.unwrap();
462
463 assert_eq!(sink.frames.len(), 2);
465 assert_eq!(sink.frames[0].node_id, "a");
466 assert_eq!(sink.frames[1].node_id, "b");
467 assert_eq!(sink.frames[0].cursor, 1);
468 assert_eq!(sink.frames[1].cursor, 2);
469 }
470
471 #[tokio::test]
472 async fn test_noop_checkpoint_sink() {
473 let mut builder = GraphBuilder::<State, StateMerge>::new("test");
475 builder.start("a");
476 builder.node("a", NodeKind::Task(TaskNode::new("a", |_| Ok(()))));
477 builder.end("a");
478 let graph = Arc::new(builder.build().unwrap());
479
480 let mut sink = NoopCheckpointSink;
481 let mut state = State::new();
482 let mut engine: crate::ExecutionEngine<'_, State> = crate::ExecutionEngine::new(
483 &mut state,
484 None,
485 CancellationToken::new(),
486 Some(&mut sink),
487 None,
488 );
489
490 let mut cb = crate::graph::NoopStepCallback;
491 graph.run_inline(&mut engine, 100, &mut cb).await.unwrap();
492 }
494
495 #[test]
496 fn test_frame_info_minimal() {
497 let info = FrameInfo::new("test_node", 42);
498 assert_eq!(info.node_id, "test_node");
499 assert_eq!(info.step, 42);
500 }
501}