use std::collections::HashMap;
use std::time::{SystemTime, UNIX_EPOCH};
use serde::{Deserialize, Serialize};
use crate::error::{ClientError, Result};
use crate::ConnectionTrait;
pub type RunId = String;
pub type NodeId = String;
pub type SeqNo = u64;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RunMetadata {
pub run_id: RunId,
pub workflow: String,
pub status: RunStatus,
pub created_at: u64,
pub updated_at: u64,
pub params: HashMap<String, serde_json::Value>,
pub latest_checkpoint_seq: u64,
pub latest_event_seq: u64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RunStatus {
Running,
Completed,
Failed,
Cancelled,
Paused,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Checkpoint {
pub run_id: RunId,
pub node_id: NodeId,
pub seq: SeqNo,
#[serde(with = "hex_serde")]
pub state: Vec<u8>,
pub timestamp: u64,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CheckpointMeta {
pub run_id: RunId,
pub node_id: NodeId,
pub seq: SeqNo,
pub timestamp: u64,
pub state_size: usize,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkflowEvent {
pub run_id: RunId,
pub seq: SeqNo,
pub event_type: String,
pub timestamp: u64,
pub node_id: Option<NodeId>,
pub payload: serde_json::Value,
}
pub trait CheckpointStore {
fn save_checkpoint(
&self,
run_id: &RunId,
node_id: &NodeId,
state: &[u8],
metadata: Option<HashMap<String, String>>,
) -> Result<CheckpointMeta>;
fn load_checkpoint(
&self,
run_id: &RunId,
node_id: &NodeId,
) -> Result<Option<Checkpoint>>;
fn load_checkpoint_at(
&self,
run_id: &RunId,
node_id: &NodeId,
seq: SeqNo,
) -> Result<Option<Checkpoint>>;
fn list_checkpoints(&self, run_id: &RunId) -> Result<Vec<CheckpointMeta>>;
fn list_node_checkpoints(
&self,
run_id: &RunId,
node_id: &NodeId,
) -> Result<Vec<CheckpointMeta>>;
fn create_run(
&self,
run_id: &RunId,
workflow: &str,
params: HashMap<String, serde_json::Value>,
) -> Result<RunMetadata>;
fn get_run(&self, run_id: &RunId) -> Result<Option<RunMetadata>>;
fn update_run_status(&self, run_id: &RunId, status: RunStatus) -> Result<()>;
fn append_event(&self, event: WorkflowEvent) -> Result<SeqNo>;
fn get_events(
&self,
run_id: &RunId,
since_seq: Option<SeqNo>,
limit: usize,
) -> Result<Vec<WorkflowEvent>>;
fn delete_run(&self, run_id: &RunId) -> Result<bool>;
}
const CHECKPOINT_PREFIX: &str = "_checkpoints/";
pub struct DefaultCheckpointStore<C: ConnectionTrait> {
conn: C,
}
impl<C: ConnectionTrait> DefaultCheckpointStore<C> {
pub fn new(conn: C) -> Self {
Self { conn }
}
fn run_meta_key(run_id: &RunId) -> Vec<u8> {
format!("{}{}/meta", CHECKPOINT_PREFIX, run_id).into_bytes()
}
fn checkpoint_key(run_id: &RunId, node_id: &NodeId, seq: SeqNo) -> Vec<u8> {
format!(
"{}{}/nodes/{}/{:016x}",
CHECKPOINT_PREFIX, run_id, node_id, seq
).into_bytes()
}
fn checkpoint_prefix(run_id: &RunId, node_id: &NodeId) -> Vec<u8> {
format!(
"{}{}/nodes/{}/",
CHECKPOINT_PREFIX, run_id, node_id
).into_bytes()
}
fn all_checkpoints_prefix(run_id: &RunId) -> Vec<u8> {
format!("{}{}/nodes/", CHECKPOINT_PREFIX, run_id).into_bytes()
}
fn event_key(run_id: &RunId, seq: SeqNo) -> Vec<u8> {
format!(
"{}{}/events/{:016x}",
CHECKPOINT_PREFIX, run_id, seq
).into_bytes()
}
fn events_prefix(run_id: &RunId) -> Vec<u8> {
format!("{}{}/events/", CHECKPOINT_PREFIX, run_id).into_bytes()
}
fn now_millis() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as u64
}
fn get_and_increment_checkpoint_seq(&self, run_id: &RunId) -> Result<SeqNo> {
let meta = self.get_run(run_id)?
.ok_or_else(|| ClientError::NotFound(format!("Run {} not found", run_id)))?;
let new_seq = meta.latest_checkpoint_seq + 1;
self.update_checkpoint_seq(run_id, new_seq)?;
Ok(new_seq)
}
fn get_and_increment_event_seq(&self, run_id: &RunId) -> Result<SeqNo> {
let meta = self.get_run(run_id)?
.ok_or_else(|| ClientError::NotFound(format!("Run {} not found", run_id)))?;
let new_seq = meta.latest_event_seq + 1;
self.update_event_seq(run_id, new_seq)?;
Ok(new_seq)
}
fn update_checkpoint_seq(&self, run_id: &RunId, seq: SeqNo) -> Result<()> {
let key = Self::run_meta_key(run_id);
if let Some(data) = self.conn.get(&key)? {
let mut meta: RunMetadata = serde_json::from_slice(&data)
.map_err(|e| ClientError::Serialization(e.to_string()))?;
meta.latest_checkpoint_seq = seq;
meta.updated_at = Self::now_millis();
let value = serde_json::to_vec(&meta)
.map_err(|e| ClientError::Serialization(e.to_string()))?;
self.conn.put(&key, &value)?;
}
Ok(())
}
fn update_event_seq(&self, run_id: &RunId, seq: SeqNo) -> Result<()> {
let key = Self::run_meta_key(run_id);
if let Some(data) = self.conn.get(&key)? {
let mut meta: RunMetadata = serde_json::from_slice(&data)
.map_err(|e| ClientError::Serialization(e.to_string()))?;
meta.latest_event_seq = seq;
meta.updated_at = Self::now_millis();
let value = serde_json::to_vec(&meta)
.map_err(|e| ClientError::Serialization(e.to_string()))?;
self.conn.put(&key, &value)?;
}
Ok(())
}
}
impl<C: ConnectionTrait> CheckpointStore for DefaultCheckpointStore<C> {
fn save_checkpoint(
&self,
run_id: &RunId,
node_id: &NodeId,
state: &[u8],
metadata: Option<HashMap<String, String>>,
) -> Result<CheckpointMeta> {
let seq = self.get_and_increment_checkpoint_seq(run_id)?;
let timestamp = Self::now_millis();
let checkpoint = Checkpoint {
run_id: run_id.clone(),
node_id: node_id.clone(),
seq,
state: state.to_vec(),
timestamp,
metadata: metadata.unwrap_or_default(),
};
let key = Self::checkpoint_key(run_id, node_id, seq);
let value = serde_json::to_vec(&checkpoint)
.map_err(|e| ClientError::Serialization(e.to_string()))?;
self.conn.put(&key, &value)?;
Ok(CheckpointMeta {
run_id: run_id.clone(),
node_id: node_id.clone(),
seq,
timestamp,
state_size: state.len(),
metadata: checkpoint.metadata,
})
}
fn load_checkpoint(
&self,
run_id: &RunId,
node_id: &NodeId,
) -> Result<Option<Checkpoint>> {
let prefix = Self::checkpoint_prefix(run_id, node_id);
let results = self.conn.scan(&prefix)?;
if let Some((_, value)) = results.into_iter().last() {
let checkpoint: Checkpoint = serde_json::from_slice(&value)
.map_err(|e| ClientError::Serialization(e.to_string()))?;
Ok(Some(checkpoint))
} else {
Ok(None)
}
}
fn load_checkpoint_at(
&self,
run_id: &RunId,
node_id: &NodeId,
seq: SeqNo,
) -> Result<Option<Checkpoint>> {
let key = Self::checkpoint_key(run_id, node_id, seq);
if let Some(data) = self.conn.get(&key)? {
let checkpoint: Checkpoint = serde_json::from_slice(&data)
.map_err(|e| ClientError::Serialization(e.to_string()))?;
Ok(Some(checkpoint))
} else {
Ok(None)
}
}
fn list_checkpoints(&self, run_id: &RunId) -> Result<Vec<CheckpointMeta>> {
let prefix = Self::all_checkpoints_prefix(run_id);
let results = self.conn.scan(&prefix)?;
let mut metas = Vec::new();
for (_, value) in results {
let cp: Checkpoint = serde_json::from_slice(&value)
.map_err(|e| ClientError::Serialization(e.to_string()))?;
metas.push(CheckpointMeta {
run_id: cp.run_id,
node_id: cp.node_id,
seq: cp.seq,
timestamp: cp.timestamp,
state_size: cp.state.len(),
metadata: cp.metadata,
});
}
Ok(metas)
}
fn list_node_checkpoints(
&self,
run_id: &RunId,
node_id: &NodeId,
) -> Result<Vec<CheckpointMeta>> {
let prefix = Self::checkpoint_prefix(run_id, node_id);
let results = self.conn.scan(&prefix)?;
let mut metas = Vec::new();
for (_, value) in results {
let cp: Checkpoint = serde_json::from_slice(&value)
.map_err(|e| ClientError::Serialization(e.to_string()))?;
metas.push(CheckpointMeta {
run_id: cp.run_id,
node_id: cp.node_id,
seq: cp.seq,
timestamp: cp.timestamp,
state_size: cp.state.len(),
metadata: cp.metadata,
});
}
Ok(metas)
}
fn create_run(
&self,
run_id: &RunId,
workflow: &str,
params: HashMap<String, serde_json::Value>,
) -> Result<RunMetadata> {
let now = Self::now_millis();
let meta = RunMetadata {
run_id: run_id.clone(),
workflow: workflow.to_string(),
status: RunStatus::Running,
created_at: now,
updated_at: now,
params,
latest_checkpoint_seq: 0,
latest_event_seq: 0,
};
let key = Self::run_meta_key(run_id);
let value = serde_json::to_vec(&meta)
.map_err(|e| ClientError::Serialization(e.to_string()))?;
self.conn.put(&key, &value)?;
Ok(meta)
}
fn get_run(&self, run_id: &RunId) -> Result<Option<RunMetadata>> {
let key = Self::run_meta_key(run_id);
if let Some(data) = self.conn.get(&key)? {
let meta: RunMetadata = serde_json::from_slice(&data)
.map_err(|e| ClientError::Serialization(e.to_string()))?;
Ok(Some(meta))
} else {
Ok(None)
}
}
fn update_run_status(&self, run_id: &RunId, status: RunStatus) -> Result<()> {
let key = Self::run_meta_key(run_id);
if let Some(data) = self.conn.get(&key)? {
let mut meta: RunMetadata = serde_json::from_slice(&data)
.map_err(|e| ClientError::Serialization(e.to_string()))?;
meta.status = status;
meta.updated_at = Self::now_millis();
let value = serde_json::to_vec(&meta)
.map_err(|e| ClientError::Serialization(e.to_string()))?;
self.conn.put(&key, &value)?;
}
Ok(())
}
fn append_event(&self, mut event: WorkflowEvent) -> Result<SeqNo> {
let seq = self.get_and_increment_event_seq(&event.run_id)?;
event.seq = seq;
event.timestamp = Self::now_millis();
let key = Self::event_key(&event.run_id, seq);
let value = serde_json::to_vec(&event)
.map_err(|e| ClientError::Serialization(e.to_string()))?;
self.conn.put(&key, &value)?;
Ok(seq)
}
fn get_events(
&self,
run_id: &RunId,
since_seq: Option<SeqNo>,
limit: usize,
) -> Result<Vec<WorkflowEvent>> {
let prefix = Self::events_prefix(run_id);
let results = self.conn.scan(&prefix)?;
let since = since_seq.unwrap_or(0);
let mut events = Vec::new();
for (_, value) in results {
let event: WorkflowEvent = serde_json::from_slice(&value)
.map_err(|e| ClientError::Serialization(e.to_string()))?;
if event.seq > since {
events.push(event);
if events.len() >= limit {
break;
}
}
}
Ok(events)
}
fn delete_run(&self, run_id: &RunId) -> Result<bool> {
if self.get_run(run_id)?.is_none() {
return Ok(false);
}
let cp_prefix = Self::all_checkpoints_prefix(run_id);
for (key, _) in self.conn.scan(&cp_prefix)? {
self.conn.delete(&key)?;
}
let ev_prefix = Self::events_prefix(run_id);
for (key, _) in self.conn.scan(&ev_prefix)? {
self.conn.delete(&key)?;
}
self.conn.delete(&Self::run_meta_key(run_id))?;
Ok(true)
}
}
mod hex_serde {
use serde::{Deserialize, Deserializer, Serializer};
pub fn serialize<S>(bytes: &[u8], serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let hex: String = bytes.iter().map(|b| format!("{:02x}", b)).collect();
serializer.serialize_str(&hex)
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
(0..s.len())
.step_by(2)
.map(|i| {
u8::from_str_radix(&s[i..i + 2], 16)
.map_err(serde::de::Error::custom)
})
.collect()
}
}
#[cfg(test)]
mod tests {
}