use std::collections::HashMap;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::config::RunnableConfig;
pub const CHECKPOINT_NS_SEPARATOR: &str = "|";
#[derive(Debug, thiserror::Error)]
pub enum CheckpointError {
#[error("Serialization failed: {0}")]
Serialize(#[source] Box<dyn std::error::Error + Send + Sync>),
#[error("Deserialization failed: {0}")]
Deserialize(#[source] Box<dyn std::error::Error + Send + Sync>),
#[error("Checkpoint not found: thread={thread_id}, id={checkpoint_id}")]
NotFound {
thread_id: String,
checkpoint_id: String,
},
#[error("Storage error: {0}")]
Storage(#[source] Box<dyn std::error::Error + Send + Sync>),
#[error("Checkpoint error: {0}")]
Other(String),
}
impl From<serde_json::Error> for CheckpointError {
fn from(err: serde_json::Error) -> Self {
Self::Serialize(Box::new(err))
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct NamespaceSegment {
pub node_name: String,
pub invocation_id: String,
}
impl NamespaceSegment {
#[must_use]
pub const fn new(node_name: String, invocation_id: String) -> Self {
Self {
node_name,
invocation_id,
}
}
#[must_use]
pub fn as_str(&self) -> String {
format!("{}:{}", self.node_name, self.invocation_id)
}
}
impl std::fmt::Display for NamespaceSegment {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
pub struct CheckpointNamespace {
pub segments: Vec<NamespaceSegment>,
}
impl CheckpointNamespace {
#[must_use]
pub const fn root() -> Self {
Self {
segments: Vec::new(),
}
}
#[must_use]
pub const fn new(segments: Vec<NamespaceSegment>) -> Self {
Self { segments }
}
#[must_use]
pub fn child(&self, node_name: &str, invocation_id: &str) -> Self {
let mut segments = self.segments.clone();
segments.push(NamespaceSegment {
node_name: node_name.to_string(),
invocation_id: invocation_id.to_string(),
});
Self { segments }
}
#[must_use]
pub fn parent(&self) -> Option<Self> {
if self.segments.is_empty() {
None
} else {
let segments = self.segments[..self.segments.len() - 1].to_vec();
Some(Self { segments })
}
}
#[must_use]
pub const fn is_root(&self) -> bool {
self.segments.is_empty()
}
#[must_use]
pub fn as_str(&self) -> String {
self.segments.iter().fold(String::new(), |mut acc, s| {
acc.push('|');
acc.push_str(&s.node_name);
acc.push(':');
acc.push_str(&s.invocation_id);
acc
})
}
#[allow(
clippy::should_implement_trait,
clippy::inherent_to_string_shadow_display,
reason = "required by design spec 04-027"
)]
#[must_use]
pub fn to_string(&self) -> String {
self.as_str()
}
#[must_use]
pub fn parse(s: &str) -> Self {
if s.is_empty() {
return Self::root();
}
let trimmed = s.trim_start_matches('|');
let segments = trimmed
.split('|')
.filter_map(|seg| {
let (node_name, invocation_id) = seg.split_once(':')?;
Some(NamespaceSegment {
node_name: node_name.to_string(),
invocation_id: invocation_id.to_string(),
})
})
.collect();
Self { segments }
}
}
impl std::fmt::Display for CheckpointNamespace {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
impl From<Vec<NamespaceSegment>> for CheckpointNamespace {
fn from(segments: Vec<NamespaceSegment>) -> Self {
Self::new(segments)
}
}
impl From<&str> for CheckpointNamespace {
fn from(s: &str) -> Self {
Self::parse(s)
}
}
#[async_trait]
pub trait CheckpointSaver: Send + Sync + 'static {
async fn get_tuple(
&self,
config: &RunnableConfig,
) -> Result<Option<CheckpointTuple>, CheckpointError>;
async fn list(
&self,
config: &RunnableConfig,
filter: Option<CheckpointFilter>,
) -> Result<Vec<CheckpointTuple>, CheckpointError>;
async fn put(
&self,
config: &RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
) -> Result<RunnableConfig, CheckpointError>;
async fn put_writes(
&self,
config: &RunnableConfig,
writes: Vec<PendingWrite>,
task_id: &str,
) -> Result<(), CheckpointError>;
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Checkpoint {
pub id: String,
pub channel_values: serde_json::Value,
pub channel_versions: HashMap<String, u64>,
pub versions_seen: HashMap<String, HashMap<String, u64>>,
pub pending_tasks: Vec<CheckpointPendingTask>,
pub pending_sends: Vec<SerializedSend>,
#[serde(default)]
pub pending_interrupts: Vec<crate::interrupt::InterruptSignal>,
pub schema_version: u32,
pub created_at: String,
pub v: u32,
pub new_versions: HashMap<String, u64>,
pub counters_since_delta_snapshot: HashMap<String, DeltaCounters>,
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct DeltaCounters {
pub updates: u64,
pub supersteps: u64,
}
impl DeltaCounters {
#[must_use]
pub const fn new() -> Self {
Self {
updates: 0,
supersteps: 0,
}
}
#[must_use]
pub fn exceeds_frequency(&self, snapshot_frequency: usize) -> bool {
if snapshot_frequency == 0 {
return true;
}
usize::try_from(self.updates).unwrap_or(usize::MAX) >= snapshot_frequency
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct CheckpointMetadata {
pub source: CheckpointSource,
pub step: i64,
pub writes: HashMap<String, serde_json::Value>,
pub parents: HashMap<String, String>,
pub run_id: String,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub enum CheckpointSource {
Input,
Loop,
Update,
Fork,
Interrupt { node: String },
}
#[derive(Clone, Debug)]
pub struct CheckpointTuple {
pub config: RunnableConfig,
pub checkpoint: Checkpoint,
pub metadata: CheckpointMetadata,
pub pending_writes: Vec<PendingWrite>,
pub parent_config: Option<RunnableConfig>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PendingWrite {
pub task_id: String,
pub channel: String,
pub value: serde_json::Value,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct CheckpointPendingTask {
pub id: String,
pub node: String,
pub triggers: Vec<String>,
pub state_override: Option<serde_json::Value>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SerializedSend {
pub node: String,
pub state: serde_json::Value,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum DeltaOp {
Append,
Replace,
}
#[derive(Clone, Debug, Default)]
pub struct CheckpointFilter {
pub source: Option<CheckpointSource>,
pub step_gte: Option<i64>,
pub step_lte: Option<i64>,
pub before: Option<String>,
pub after: Option<String>,
pub limit: Option<usize>,
}
#[derive(Clone, Debug)]
pub struct StateSnapshot<S: crate::State> {
pub values: S,
pub next: Vec<String>,
pub config: RunnableConfig,
pub metadata: CheckpointMetadata,
pub created_at: String,
pub parent_config: Option<RunnableConfig>,
pub tasks: Vec<PregelTaskInfo>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PregelTaskInfo {
pub id: String,
pub node_name: String,
pub error: Option<String>,
pub interrupts: Vec<serde_json::Value>,
}
#[must_use]
pub fn generate_checkpoint_id() -> String {
let node_id: [u8; 6] = rand::random();
uuid::Uuid::now_v6(&node_id).to_string()
}