use crate::error::CanoError;
use cano_macros::checkpoint_store;
#[cfg(feature = "recovery")]
mod redb;
#[cfg(feature = "recovery")]
pub use redb::RedbCheckpointStore;
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Serialize, serde::Deserialize)]
pub enum RowKind {
#[default]
StateEntry,
CompensationCompletion,
StepCursor,
}
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct CheckpointRow {
pub sequence: u64,
pub state: String,
pub task_id: String,
pub output_blob: Option<Vec<u8>>,
pub kind: RowKind,
}
impl CheckpointRow {
pub fn new(sequence: u64, state: impl Into<String>, task_id: impl Into<String>) -> Self {
Self {
sequence,
state: state.into(),
task_id: task_id.into(),
output_blob: None,
kind: RowKind::StateEntry,
}
}
pub fn with_output(mut self, output_blob: Vec<u8>) -> Self {
self.output_blob = Some(output_blob);
self.kind = RowKind::CompensationCompletion;
self
}
pub fn with_cursor(mut self, cursor_blob: Vec<u8>) -> Self {
self.output_blob = Some(cursor_blob);
self.kind = RowKind::StepCursor;
self
}
}
#[checkpoint_store]
pub trait CheckpointStore: Send + Sync + 'static {
async fn append(&self, workflow_id: &str, row: CheckpointRow) -> Result<(), CanoError>;
async fn load_run(&self, workflow_id: &str) -> Result<Vec<CheckpointRow>, CanoError>;
async fn clear(&self, workflow_id: &str) -> Result<(), CanoError>;
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
use std::sync::Mutex;
#[derive(Default)]
struct InMemoryStore(Mutex<HashMap<String, Vec<CheckpointRow>>>);
#[checkpoint_store]
impl CheckpointStore for InMemoryStore {
async fn append(&self, workflow_id: &str, row: CheckpointRow) -> Result<(), CanoError> {
let mut runs = self.0.lock().unwrap();
let rows = runs.entry(workflow_id.to_string()).or_default();
if rows.iter().any(|r| r.sequence == row.sequence) {
return Err(CanoError::checkpoint_store(format!(
"checkpoint conflict: {workflow_id:?} already has sequence {}",
row.sequence
)));
}
rows.push(row);
Ok(())
}
async fn load_run(&self, workflow_id: &str) -> Result<Vec<CheckpointRow>, CanoError> {
let mut rows = self
.0
.lock()
.unwrap()
.get(workflow_id)
.cloned()
.unwrap_or_default();
rows.sort_by_key(|r| r.sequence);
Ok(rows)
}
async fn clear(&self, workflow_id: &str) -> Result<(), CanoError> {
self.0.lock().unwrap().remove(workflow_id);
Ok(())
}
}
#[test]
fn checkpoint_store_is_dyn_compatible() {
let _erased: std::sync::Arc<dyn CheckpointStore> =
std::sync::Arc::new(InMemoryStore::default());
}
#[test]
fn checkpoint_row_constructors() {
let bare = CheckpointRow::new(3, "Process", "worker");
assert_eq!(bare.sequence, 3);
assert_eq!(bare.state, "Process");
assert_eq!(bare.task_id, "worker");
assert_eq!(bare.output_blob, None);
assert_eq!(bare.kind, RowKind::StateEntry);
let carried = CheckpointRow::new(4, "Done", "worker").with_output(vec![1, 2, 3]);
assert_eq!(carried.sequence, 4);
assert_eq!(carried.output_blob.as_deref(), Some(&[1u8, 2, 3][..]));
assert_eq!(carried.kind, RowKind::CompensationCompletion);
let cursor = CheckpointRow::new(5, "Step", "stepper").with_cursor(vec![9, 8, 7]);
assert_eq!(cursor.sequence, 5);
assert_eq!(cursor.output_blob.as_deref(), Some(&[9u8, 8, 7][..]));
assert_eq!(cursor.kind, RowKind::StepCursor);
}
#[tokio::test]
async fn trait_roundtrip_append_load_clear() {
let store = InMemoryStore::default();
store
.append("run", CheckpointRow::new(0, "A", "t0"))
.await
.unwrap();
store
.append("run", CheckpointRow::new(1, "B", "t1"))
.await
.unwrap();
store
.append("run", CheckpointRow::new(2, "C", "t2").with_output(vec![9]))
.await
.unwrap();
let rows = store.load_run("run").await.unwrap();
assert_eq!(rows.len(), 3);
assert_eq!(
rows.iter().map(|r| r.sequence).collect::<Vec<_>>(),
vec![0, 1, 2]
);
assert_eq!(rows[2].output_blob.as_deref(), Some(&[9u8][..]));
store.clear("run").await.unwrap();
assert!(store.load_run("run").await.unwrap().is_empty());
store.clear("never-existed").await.unwrap();
}
#[tokio::test]
async fn load_run_unknown_id_is_empty() {
let store = InMemoryStore::default();
assert!(store.load_run("nope").await.unwrap().is_empty());
}
#[tokio::test]
async fn append_rejects_duplicate_sequence() {
let store = InMemoryStore::default();
store
.append("run", CheckpointRow::new(0, "A", "t0"))
.await
.unwrap();
let err = store
.append("run", CheckpointRow::new(0, "A-again", "t0"))
.await
.expect_err("duplicate sequence must be rejected");
assert_eq!(err.category(), "checkpoint_store");
store
.append("run", CheckpointRow::new(1, "B", "t1"))
.await
.unwrap();
let rows = store.load_run("run").await.unwrap();
assert_eq!(
rows.iter()
.map(|r| (r.sequence, r.state.as_str()))
.collect::<Vec<_>>(),
vec![(0, "A"), (1, "B")]
);
store
.append("other", CheckpointRow::new(0, "A", "t0"))
.await
.unwrap();
}
}