use serde::{Deserialize, Serialize};
use crate::state::State;
use crate::workflow_state::WorkflowState;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct CheckpointId(pub uuid::Uuid);
impl std::fmt::Display for CheckpointId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct NodeId(pub String);
impl std::fmt::Display for NodeId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Checkpoint<S = State> {
pub checkpoint_id: CheckpointId,
pub current_node: NodeId,
pub state: S,
pub created_at: std::time::SystemTime,
}
impl<S: WorkflowState> Checkpoint<S> {
pub fn new(current_node: impl Into<String>, state: S) -> Self {
Self {
checkpoint_id: CheckpointId(uuid::Uuid::new_v4()),
current_node: NodeId(current_node.into()),
state,
created_at: std::time::SystemTime::now(),
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum CheckpointStoreError {
#[error("storage error: {0}")]
Storage(String),
#[error("checkpoint not found: {0}")]
NotFound(CheckpointId),
#[error("corrupted checkpoint: {0}")]
Corrupted(String),
}
#[async_trait::async_trait]
pub trait CheckpointStore: Send + Sync {
async fn save_with_trace(
&self,
trace_id: &TraceId,
checkpoint: &Checkpoint,
) -> Result<(), CheckpointStoreError>;
async fn load(&self, id: &CheckpointId) -> Result<Option<Checkpoint>, CheckpointStoreError>;
async fn load_latest(
&self,
trace_id: &TraceId,
) -> Result<Option<Checkpoint>, CheckpointStoreError>;
async fn list(&self, trace_id: &TraceId) -> Result<Vec<CheckpointId>, CheckpointStoreError>;
async fn delete(&self, id: &CheckpointId) -> Result<bool, CheckpointStoreError>;
async fn prune(&self, trace_id: &TraceId, keep: usize) -> Result<usize, CheckpointStoreError>;
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum CheckpointPolicy {
#[default]
EveryNode,
BarrierOnly,
Manual,
}
pub use crate::ids::TraceId;