use std::fmt::Debug;
use serde::{Deserialize, Serialize};
use crate::state::State;
use crate::state::workflow_state::WorkflowState;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct CheckpointId(pub uuid::Uuid);
impl std::fmt::Display for CheckpointId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct NodeId(pub String);
impl std::fmt::Display for NodeId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Checkpoint<S: WorkflowState = State> {
pub checkpoint_id: CheckpointId,
pub current_node: NodeId,
pub state: S::Checkpoint,
pub graph_hash: u64,
pub created_at: std::time::SystemTime,
}
impl<S: WorkflowState> Checkpoint<S> {
pub fn new(current_node: impl Into<String>, state: &S, graph_hash: u64) -> Self {
Self {
checkpoint_id: CheckpointId(uuid::Uuid::new_v4()),
current_node: NodeId(current_node.into()),
state: state.snapshot(),
graph_hash,
created_at: std::time::SystemTime::now(),
}
}
pub fn restore_state(self) -> S {
S::restore(self.state)
}
}
#[derive(Debug, Clone)]
pub struct CheckpointBlob {
pub id: CheckpointId,
pub data: Vec<u8>,
pub graph_hash: u64,
pub created_at: std::time::SystemTime,
}
impl CheckpointBlob {
pub fn new(
id: CheckpointId,
data: Vec<u8>,
graph_hash: u64,
created_at: std::time::SystemTime,
) -> Self {
Self {
id,
data,
graph_hash,
created_at,
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum CheckpointStoreError {
#[error("storage error: {0}")]
Storage(String),
#[error("checkpoint not found: {0}")]
NotFound(CheckpointId),
#[error("corrupted checkpoint: {0}")]
Corrupted(String),
#[error("serialization error: {0}")]
Serialization(String),
#[error("graph mismatch: expected hash {expected:#018x}, got {actual:#018x}")]
GraphMismatch { expected: u64, actual: u64 },
}
pub use crate::ids::TraceId;
#[allow(deprecated)]
#[derive(Clone, Serialize, Deserialize)]
pub struct Frame<S: WorkflowState = State> {
pub graph_id: String,
pub node_id: String,
pub state: S::Checkpoint,
pub cursor: usize,
}
impl<S: WorkflowState> Frame<S> {
pub fn new(graph_id: String, node_id: String, state: &S, cursor: usize) -> Self {
Self {
graph_id,
node_id,
state: state.snapshot(),
cursor,
}
}
}
impl<S: WorkflowState> Debug for Frame<S>
where
S::Checkpoint: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Frame")
.field("graph_id", &self.graph_id)
.field("node_id", &self.node_id)
.field("state", &self.state)
.field("cursor", &self.cursor)
.finish()
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct FrameStack<S: WorkflowState = State>
where
S::Checkpoint: Debug,
{
frames: Vec<Frame<S>>,
}
impl<S: WorkflowState> FrameStack<S>
where
S::Checkpoint: Debug,
{
pub fn new() -> Self {
Self { frames: Vec::new() }
}
pub fn push(&mut self, frame: Frame<S>) {
self.frames.push(frame);
}
pub fn pop(&mut self) -> Option<Frame<S>> {
self.frames.pop()
}
pub fn current(&self) -> Option<&Frame<S>> {
self.frames.last()
}
pub fn depth(&self) -> usize {
self.frames.len()
}
pub fn is_empty(&self) -> bool {
self.frames.is_empty()
}
pub fn frames(&self) -> &[Frame<S>] {
&self.frames
}
}
impl<S: WorkflowState> Default for FrameStack<S>
where
S::Checkpoint: Debug,
{
fn default() -> Self {
Self::new()
}
}
impl<S: WorkflowState> Debug for FrameStack<S>
where
S::Checkpoint: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FrameStack")
.field("frames", &self.frames)
.finish()
}
}
#[derive(Debug, Clone)]
pub struct FrameInfo {
pub node_id: String,
pub step: usize,
}
impl FrameInfo {
pub fn new(node_id: impl Into<String>, step: usize) -> Self {
Self {
node_id: node_id.into(),
step,
}
}
}
pub trait CheckpointSink<S: WorkflowState>: Send + Sync {
fn on_checkpoint(&mut self, state: &S, frame: &FrameInfo);
}
#[derive(Debug, Default)]
pub struct NoopCheckpointSink;
impl<S: WorkflowState> CheckpointSink<S> for NoopCheckpointSink {
fn on_checkpoint(&mut self, _state: &S, _frame: &FrameInfo) {
}
}
pub struct MemorySink<S: WorkflowState = State> {
pub frames: Vec<Frame<S>>,
}
impl<S: WorkflowState> Debug for MemorySink<S>
where
S::Checkpoint: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemorySink")
.field("frames", &self.frames)
.finish()
}
}
impl<S: WorkflowState> Default for MemorySink<S> {
fn default() -> Self {
Self { frames: Vec::new() }
}
}
impl<S: WorkflowState> MemorySink<S> {
pub fn new() -> Self {
Self::default()
}
pub fn into_frames(self) -> Vec<Frame<S>> {
self.frames
}
}
impl<S: WorkflowState> CheckpointSink<S> for MemorySink<S>
where
S::Checkpoint: Sync,
{
fn on_checkpoint(&mut self, state: &S, frame: &FrameInfo) {
self.frames.push(Frame {
graph_id: String::new(), node_id: frame.node_id.clone(),
state: state.snapshot(),
cursor: frame.step,
});
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::state::StateMerge;
use crate::{GraphBuilder, NodeKind, TaskNode};
use std::sync::Arc;
use tokio_util::sync::CancellationToken;
#[tokio::test]
async fn test_auto_checkpoint_via_memory_sink() {
let mut builder = GraphBuilder::<State, StateMerge>::new("test");
builder.start("a");
builder.node("a", NodeKind::Task(TaskNode::new("a", |_| Ok(()))));
builder.node("b", NodeKind::Task(TaskNode::new("b", |_| Ok(()))));
builder.end("b");
builder.edge("a", "b");
let graph = Arc::new(builder.build().unwrap());
let mut sink = MemorySink::<State>::new();
let mut state = State::new();
let mut engine: crate::ExecutionEngine<'_, State> = crate::ExecutionEngine::new(
&mut state,
None,
CancellationToken::new(),
Some(&mut sink),
None,
);
let mut cb = crate::graph::NoopStepCallback;
graph.run_inline(&mut engine, 100, &mut cb).await.unwrap();
assert_eq!(sink.frames.len(), 2);
assert_eq!(sink.frames[0].node_id, "a");
assert_eq!(sink.frames[1].node_id, "b");
assert_eq!(sink.frames[0].cursor, 1);
assert_eq!(sink.frames[1].cursor, 2);
}
#[tokio::test]
async fn test_noop_checkpoint_sink() {
let mut builder = GraphBuilder::<State, StateMerge>::new("test");
builder.start("a");
builder.node("a", NodeKind::Task(TaskNode::new("a", |_| Ok(()))));
builder.end("a");
let graph = Arc::new(builder.build().unwrap());
let mut sink = NoopCheckpointSink;
let mut state = State::new();
let mut engine: crate::ExecutionEngine<'_, State> = crate::ExecutionEngine::new(
&mut state,
None,
CancellationToken::new(),
Some(&mut sink),
None,
);
let mut cb = crate::graph::NoopStepCallback;
graph.run_inline(&mut engine, 100, &mut cb).await.unwrap();
}
#[test]
fn test_frame_info_minimal() {
let info = FrameInfo::new("test_node", 42);
assert_eq!(info.node_id, "test_node");
assert_eq!(info.step, 42);
}
}