use crate::error::CanoError;
use cano_macros::checkpoint_store;
#[cfg(feature = "recovery")]
mod redb;
#[cfg(feature = "recovery")]
pub use redb::RedbCheckpointStore;
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Serialize, serde::Deserialize)]
pub enum RowKind {
#[default]
StateEntry,
CompensationCompletion,
StepCursor,
}
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct CheckpointRow {
pub sequence: u64,
pub state: String,
pub task_id: String,
pub output_blob: Option<Vec<u8>>,
pub kind: RowKind,
pub workflow_version: u32,
}
impl CheckpointRow {
pub fn new(sequence: u64, state: impl Into<String>, task_id: impl Into<String>) -> Self {
Self {
sequence,
state: state.into(),
task_id: task_id.into(),
output_blob: None,
kind: RowKind::StateEntry,
workflow_version: 0,
}
}
pub fn with_output(mut self, output_blob: Vec<u8>) -> Self {
self.output_blob = Some(output_blob);
self.kind = RowKind::CompensationCompletion;
self
}
pub fn with_cursor(mut self, cursor_blob: Vec<u8>) -> Self {
self.output_blob = Some(cursor_blob);
self.kind = RowKind::StepCursor;
self
}
pub fn with_workflow_version(mut self, version: u32) -> Self {
self.workflow_version = version;
self
}
}
#[checkpoint_store]
pub trait CheckpointStore: Send + Sync + 'static {
async fn append(&self, workflow_id: &str, row: CheckpointRow) -> Result<(), CanoError>;
async fn load_run(&self, workflow_id: &str) -> Result<Vec<CheckpointRow>, CanoError>;
async fn clear(&self, workflow_id: &str) -> Result<(), CanoError>;
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
use std::sync::Mutex;
#[derive(Default)]
struct InMemoryStore(Mutex<HashMap<String, Vec<CheckpointRow>>>);
#[checkpoint_store]
impl CheckpointStore for InMemoryStore {
async fn append(&self, workflow_id: &str, row: CheckpointRow) -> Result<(), CanoError> {
let mut runs = self.0.lock().unwrap();
let rows = runs.entry(workflow_id.to_string()).or_default();
if rows.iter().any(|r| r.sequence == row.sequence) {
return Err(CanoError::checkpoint_store(format!(
"checkpoint conflict: {workflow_id:?} already has sequence {}",
row.sequence
)));
}
rows.push(row);
Ok(())
}
async fn load_run(&self, workflow_id: &str) -> Result<Vec<CheckpointRow>, CanoError> {
let mut rows = self
.0
.lock()
.unwrap()
.get(workflow_id)
.cloned()
.unwrap_or_default();
rows.sort_by_key(|r| r.sequence);
Ok(rows)
}
async fn clear(&self, workflow_id: &str) -> Result<(), CanoError> {
self.0.lock().unwrap().remove(workflow_id);
Ok(())
}
}
#[test]
fn checkpoint_store_is_dyn_compatible() {
let _erased: std::sync::Arc<dyn CheckpointStore> =
std::sync::Arc::new(InMemoryStore::default());
}
#[test]
fn checkpoint_row_constructors() {
let bare = CheckpointRow::new(3, "Process", "worker");
assert_eq!(bare.sequence, 3);
assert_eq!(bare.state, "Process");
assert_eq!(bare.task_id, "worker");
assert_eq!(bare.output_blob, None);
assert_eq!(bare.kind, RowKind::StateEntry);
let carried = CheckpointRow::new(4, "Done", "worker").with_output(vec![1, 2, 3]);
assert_eq!(carried.sequence, 4);
assert_eq!(carried.output_blob.as_deref(), Some(&[1u8, 2, 3][..]));
assert_eq!(carried.kind, RowKind::CompensationCompletion);
let cursor = CheckpointRow::new(5, "Step", "stepper").with_cursor(vec![9, 8, 7]);
assert_eq!(cursor.sequence, 5);
assert_eq!(cursor.output_blob.as_deref(), Some(&[9u8, 8, 7][..]));
assert_eq!(cursor.kind, RowKind::StepCursor);
}
#[tokio::test]
async fn trait_roundtrip_append_load_clear() {
let store = InMemoryStore::default();
store
.append("run", CheckpointRow::new(0, "A", "t0"))
.await
.unwrap();
store
.append("run", CheckpointRow::new(1, "B", "t1"))
.await
.unwrap();
store
.append("run", CheckpointRow::new(2, "C", "t2").with_output(vec![9]))
.await
.unwrap();
let rows = store.load_run("run").await.unwrap();
assert_eq!(rows.len(), 3);
assert_eq!(
rows.iter().map(|r| r.sequence).collect::<Vec<_>>(),
vec![0, 1, 2]
);
assert_eq!(rows[2].output_blob.as_deref(), Some(&[9u8][..]));
store.clear("run").await.unwrap();
assert!(store.load_run("run").await.unwrap().is_empty());
store.clear("never-existed").await.unwrap();
}
#[tokio::test]
async fn load_run_unknown_id_is_empty() {
let store = InMemoryStore::default();
assert!(store.load_run("nope").await.unwrap().is_empty());
}
#[test]
fn checkpoint_row_default_workflow_version_is_zero() {
let row = CheckpointRow::new(0, "Start", "task");
assert_eq!(row.workflow_version, 0);
}
#[test]
fn checkpoint_row_with_workflow_version_sets_field() {
let row = CheckpointRow::new(0, "Start", "task").with_workflow_version(42);
assert_eq!(row.workflow_version, 42);
}
#[tokio::test]
async fn append_rejects_duplicate_sequence() {
let store = InMemoryStore::default();
store
.append("run", CheckpointRow::new(0, "A", "t0"))
.await
.unwrap();
let err = store
.append("run", CheckpointRow::new(0, "A-again", "t0"))
.await
.expect_err("duplicate sequence must be rejected");
assert_eq!(err.category(), "checkpoint_store");
store
.append("run", CheckpointRow::new(1, "B", "t1"))
.await
.unwrap();
let rows = store.load_run("run").await.unwrap();
assert_eq!(
rows.iter()
.map(|r| (r.sequence, r.state.as_str()))
.collect::<Vec<_>>(),
vec![(0, "A"), (1, "B")]
);
store
.append("other", CheckpointRow::new(0, "A", "t0"))
.await
.unwrap();
}
#[test]
fn checkpoint_row_json_roundtrip_preserves_all_fields() {
let rows = [
CheckpointRow::new(0, "Start", "t0"), CheckpointRow::new(1, "Pay", "charge").with_output(vec![1, 2, 3]), CheckpointRow::new(2, "Step", "stepper").with_cursor(vec![]), CheckpointRow::new(3, "Start", "t0").with_workflow_version(99), ];
for row in rows {
let bytes = serde_json::to_vec(&row).expect("serialize");
let back: CheckpointRow = serde_json::from_slice(&bytes).expect("deserialize");
assert_eq!(back, row, "JSON round-trip must preserve every field");
}
}
#[cfg(feature = "recovery")]
#[test]
fn checkpoint_row_postcard_roundtrip_preserves_all_fields() {
let rows = [
CheckpointRow::new(0, "Start", "t0"),
CheckpointRow::new(1, "Pay", "charge").with_output(vec![1, 2, 3]),
CheckpointRow::new(2, "Step", "stepper").with_cursor(vec![9, 8]),
CheckpointRow::new(3, "Start", "t0").with_workflow_version(99),
];
for row in rows {
let bytes = postcard::to_stdvec(&row).expect("serialize");
let back: CheckpointRow = postcard::from_bytes(&bytes).expect("deserialize");
assert_eq!(back, row, "postcard round-trip must preserve every field");
}
}
#[test]
fn rowkind_default_is_state_entry() {
assert_eq!(RowKind::default(), RowKind::StateEntry);
}
#[test]
fn with_output_then_with_workflow_version_is_order_independent() {
let a = CheckpointRow::new(0, "S", "t")
.with_output(vec![7])
.with_workflow_version(5);
let b = CheckpointRow::new(0, "S", "t")
.with_workflow_version(5)
.with_output(vec![7]);
assert_eq!(a, b);
assert_eq!(a.kind, RowKind::CompensationCompletion);
assert_eq!(a.workflow_version, 5);
assert_eq!(a.output_blob.as_deref(), Some(&[7u8][..]));
}
#[test]
fn last_blob_builder_wins() {
let cursor_wins = CheckpointRow::new(0, "S", "t")
.with_output(vec![1])
.with_cursor(vec![2]);
assert_eq!(cursor_wins.kind, RowKind::StepCursor);
assert_eq!(cursor_wins.output_blob.as_deref(), Some(&[2u8][..]));
let output_wins = CheckpointRow::new(0, "S", "t")
.with_cursor(vec![1])
.with_output(vec![2]);
assert_eq!(output_wins.kind, RowKind::CompensationCompletion);
assert_eq!(output_wins.output_blob.as_deref(), Some(&[2u8][..]));
}
#[test]
fn with_output_empty_blob_is_some_not_none() {
let row = CheckpointRow::new(0, "S", "t").with_output(vec![]);
assert_eq!(row.output_blob, Some(vec![]));
assert_eq!(row.kind, RowKind::CompensationCompletion);
}
#[tokio::test]
async fn load_run_returns_rows_sorted_even_when_appended_out_of_order() {
let store = InMemoryStore::default();
for seq in [2u64, 0, 1] {
store
.append("run", CheckpointRow::new(seq, "S", "t"))
.await
.unwrap();
}
let rows = store.load_run("run").await.unwrap();
assert_eq!(
rows.iter().map(|r| r.sequence).collect::<Vec<_>>(),
vec![0, 1, 2],
"load_run must return rows sorted ascending by sequence"
);
}
#[tokio::test]
async fn clear_isolates_by_workflow_id() {
let store = InMemoryStore::default();
store
.append("a", CheckpointRow::new(0, "S", "t"))
.await
.unwrap();
store
.append("b", CheckpointRow::new(0, "S", "t"))
.await
.unwrap();
store.clear("a").await.unwrap();
assert!(store.load_run("a").await.unwrap().is_empty());
assert_eq!(
store.load_run("b").await.unwrap().len(),
1,
"clearing one id must not affect another"
);
}
#[tokio::test]
async fn shared_store_accepts_appends_from_many_tasks() {
let store: std::sync::Arc<dyn CheckpointStore> =
std::sync::Arc::new(InMemoryStore::default());
let mut handles = Vec::new();
for seq in 0..20u64 {
let s = std::sync::Arc::clone(&store);
handles.push(tokio::spawn(async move {
s.append("run", CheckpointRow::new(seq, "S", "t"))
.await
.unwrap();
}));
}
for h in handles {
h.await.unwrap();
}
let rows = store.load_run("run").await.unwrap();
assert_eq!(
rows.iter().map(|r| r.sequence).collect::<Vec<_>>(),
(0..20).collect::<Vec<_>>()
);
}
}