use std::fmt;
use crate::graph::command::Interrupt;
use crate::harness::ids::NodeId;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum CheckpointSource {
Input,
Loop,
Update,
Fork,
}
impl CheckpointSource {
pub fn as_str(&self) -> &'static str {
match self {
CheckpointSource::Input => "input",
CheckpointSource::Loop => "loop",
CheckpointSource::Update => "update",
CheckpointSource::Fork => "fork",
}
}
pub fn parse(s: &str) -> Option<Self> {
match s {
"input" => Some(CheckpointSource::Input),
"loop" => Some(CheckpointSource::Loop),
"update" => Some(CheckpointSource::Update),
"fork" => Some(CheckpointSource::Fork),
_ => None,
}
}
}
impl fmt::Display for CheckpointSource {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum DurabilityMode {
#[default]
Sync,
Async,
Exit,
}
#[derive(Clone, Debug, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct CheckpointConfig {
pub thread_id: String,
pub checkpoint_id: Option<String>,
pub namespace: Vec<String>,
}
impl CheckpointConfig {
pub fn latest(thread_id: impl Into<String>) -> Self {
Self {
thread_id: thread_id.into(),
checkpoint_id: None,
namespace: Vec::new(),
}
}
}
#[derive(Clone, Debug)]
pub struct CheckpointTuple<State> {
pub config: CheckpointConfig,
pub checkpoint: Checkpoint<State>,
pub parent_config: Option<CheckpointConfig>,
pub pending_writes: Vec<PendingWrite>,
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct Checkpoint<State> {
pub thread_id: String,
pub checkpoint_id: String,
pub run_id: Option<String>,
pub parent_checkpoint_id: Option<String>,
pub namespace: Vec<String>,
pub state: State,
pub next_nodes: Vec<NodeId>,
pub completed_tasks: Vec<NodeId>,
pub pending_writes: Vec<PendingWrite>,
pub interrupts: Vec<Interrupt>,
pub metadata: serde_json::Value,
}
impl<State> Checkpoint<State> {
pub fn to_metadata(&self) -> CheckpointMetadata {
let source = self
.metadata
.get("source")
.and_then(|v| v.as_str())
.and_then(CheckpointSource::parse)
.unwrap_or(CheckpointSource::Loop);
let step = self
.metadata
.get("step")
.and_then(|v| v.as_u64())
.unwrap_or(0) as usize;
CheckpointMetadata {
thread_id: self.thread_id.clone(),
checkpoint_id: self.checkpoint_id.clone(),
run_id: self.run_id.clone(),
parent_checkpoint_id: self.parent_checkpoint_id.clone(),
namespace: self.namespace.clone(),
next_nodes: self.next_nodes.clone(),
has_interrupts: !self.interrupts.is_empty(),
source,
step,
}
}
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct PendingWrite {
pub node: NodeId,
pub payload: serde_json::Value,
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct CheckpointMetadata {
pub thread_id: String,
pub checkpoint_id: String,
pub run_id: Option<String>,
pub parent_checkpoint_id: Option<String>,
pub namespace: Vec<String>,
pub next_nodes: Vec<NodeId>,
pub has_interrupts: bool,
pub source: CheckpointSource,
pub step: usize,
}