use std::ops::RangeInclusive;
use std::path::Path;
use std::sync::Arc;
use redb::{Database, ReadableDatabase, TableDefinition};
use super::{CheckpointRow, CheckpointStore};
use crate::error::CanoError;
use cano_macros::checkpoint_store;
const CHECKPOINTS: TableDefinition<(&str, u64), &[u8]> = TableDefinition::new("cano_checkpoints");
#[derive(serde::Serialize, serde::Deserialize)]
struct StoredRow {
state: String,
task_id: String,
output_blob: Option<Vec<u8>>,
kind: super::RowKind,
workflow_version: u32,
}
#[derive(serde::Deserialize)]
struct StoredRowV0 {
state: String,
task_id: String,
output_blob: Option<Vec<u8>>,
kind: super::RowKind,
}
fn workflow_range(workflow_id: &str) -> RangeInclusive<(&str, u64)> {
(workflow_id, u64::MIN)..=(workflow_id, u64::MAX)
}
fn redb_err(e: impl std::fmt::Display) -> CanoError {
CanoError::CheckpointStore(format!("redb: {e}"))
}
#[derive(Clone)]
pub struct RedbCheckpointStore {
db: Arc<Database>,
}
impl RedbCheckpointStore {
pub fn new(path: impl AsRef<Path>) -> Result<Self, CanoError> {
let db = Database::create(path).map_err(redb_err)?;
let tx = db.begin_write().map_err(redb_err)?;
{
let _ = tx.open_table(CHECKPOINTS).map_err(redb_err)?;
}
tx.commit().map_err(redb_err)?;
Ok(Self { db: Arc::new(db) })
}
}
#[checkpoint_store]
impl CheckpointStore for RedbCheckpointStore {
async fn append(&self, workflow_id: &str, row: CheckpointRow) -> Result<(), CanoError> {
let sequence = row.sequence;
let payload = StoredRow {
state: row.state,
task_id: row.task_id,
output_blob: row.output_blob,
kind: row.kind,
workflow_version: row.workflow_version,
};
let bytes = postcard::to_stdvec(&payload)
.map_err(|e| CanoError::CheckpointStore(format!("encode checkpoint row: {e}")))?;
let tx = self.db.begin_write().map_err(redb_err)?;
{
let mut table = tx.open_table(CHECKPOINTS).map_err(redb_err)?;
if table
.insert((workflow_id, sequence), bytes.as_slice())
.map_err(redb_err)?
.is_some()
{
return Err(CanoError::CheckpointStore(format!(
"checkpoint conflict: workflow {workflow_id:?} already has a row at \
sequence {sequence}; resume the existing run or clear it before starting \
a new one"
)));
}
}
tx.commit().map_err(redb_err)?;
Ok(())
}
async fn load_run(&self, workflow_id: &str) -> Result<Vec<CheckpointRow>, CanoError> {
let tx = self.db.begin_read().map_err(redb_err)?;
let table = tx.open_table(CHECKPOINTS).map_err(redb_err)?;
let mut rows = Vec::new();
for entry in table.range(workflow_range(workflow_id)).map_err(redb_err)? {
let (key, value) = entry.map_err(redb_err)?;
let sequence = key.value().1;
let bytes = value.value();
let row = match postcard::from_bytes::<StoredRow>(bytes) {
Ok(payload) => CheckpointRow {
sequence,
state: payload.state,
task_id: payload.task_id,
output_blob: payload.output_blob,
kind: payload.kind,
workflow_version: payload.workflow_version,
},
Err(new_err) => match postcard::from_bytes::<StoredRowV0>(bytes) {
Ok(legacy) => CheckpointRow {
sequence,
state: legacy.state,
task_id: legacy.task_id,
output_blob: legacy.output_blob,
kind: legacy.kind,
workflow_version: 0,
},
Err(_) => {
return Err(CanoError::CheckpointStore(format!(
"decode checkpoint row: {new_err}"
)));
}
},
};
rows.push(row);
}
Ok(rows)
}
async fn clear(&self, workflow_id: &str) -> Result<(), CanoError> {
let tx = self.db.begin_write().map_err(redb_err)?;
{
let mut table = tx.open_table(CHECKPOINTS).map_err(redb_err)?;
table
.retain_in(workflow_range(workflow_id), |_, _| false)
.map_err(redb_err)?;
}
tx.commit().map_err(redb_err)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[tokio::test]
async fn append_load_clear_roundtrip() {
let dir = tempdir().unwrap();
let store = RedbCheckpointStore::new(dir.path().join("ckpt.redb")).unwrap();
store
.append("run", CheckpointRow::new(0, "A", "t0"))
.await
.unwrap();
store
.append("run", CheckpointRow::new(1, "B", "t1"))
.await
.unwrap();
store
.append(
"run",
CheckpointRow::new(2, "C", "t2").with_output(vec![7, 8, 9]),
)
.await
.unwrap();
let rows = store.load_run("run").await.unwrap();
assert_eq!(
rows.iter().map(|r| r.sequence).collect::<Vec<_>>(),
vec![0, 1, 2]
);
assert_eq!(rows[0].state, "A");
assert_eq!(rows[2].task_id, "t2");
assert_eq!(rows[2].output_blob.as_deref(), Some(&[7u8, 8, 9][..]));
store.clear("run").await.unwrap();
assert!(store.load_run("run").await.unwrap().is_empty());
store.clear("missing").await.unwrap();
}
#[tokio::test]
async fn rows_ordered_by_sequence_regardless_of_insert_order() {
let dir = tempdir().unwrap();
let store = RedbCheckpointStore::new(dir.path().join("ckpt.redb")).unwrap();
for seq in [5u64, 1, 9, 0, 3] {
store
.append("run", CheckpointRow::new(seq, format!("S{seq}"), "t"))
.await
.unwrap();
}
let rows = store.load_run("run").await.unwrap();
assert_eq!(
rows.iter().map(|r| r.sequence).collect::<Vec<_>>(),
vec![0, 1, 3, 5, 9]
);
}
#[tokio::test]
async fn rows_survive_reopen() {
let dir = tempdir().unwrap();
let path = dir.path().join("ckpt.redb");
{
let store = RedbCheckpointStore::new(&path).unwrap();
store
.append("run", CheckpointRow::new(0, "A", "t0"))
.await
.unwrap();
store
.append("run", CheckpointRow::new(1, "B", "t1"))
.await
.unwrap();
}
let reopened = RedbCheckpointStore::new(&path).unwrap();
let rows = reopened.load_run("run").await.unwrap();
assert_eq!(
rows.iter().map(|r| r.sequence).collect::<Vec<_>>(),
vec![0, 1]
);
assert_eq!(rows[1].state, "B");
}
#[tokio::test]
async fn clear_isolates_workflows() {
let dir = tempdir().unwrap();
let store = RedbCheckpointStore::new(dir.path().join("ckpt.redb")).unwrap();
store
.append("a", CheckpointRow::new(0, "A0", "t"))
.await
.unwrap();
store
.append("a", CheckpointRow::new(1, "A1", "t"))
.await
.unwrap();
store
.append("b", CheckpointRow::new(0, "B0", "t"))
.await
.unwrap();
store.clear("a").await.unwrap();
assert!(store.load_run("a").await.unwrap().is_empty());
let b = store.load_run("b").await.unwrap();
assert_eq!(b.len(), 1);
assert_eq!(b[0].state, "B0");
}
#[tokio::test]
async fn load_run_unknown_id_is_empty() {
let dir = tempdir().unwrap();
let store = RedbCheckpointStore::new(dir.path().join("ckpt.redb")).unwrap();
assert!(store.load_run("nope").await.unwrap().is_empty());
}
#[tokio::test]
async fn append_rejects_duplicate_sequence() {
let dir = tempdir().unwrap();
let store = RedbCheckpointStore::new(dir.path().join("ckpt.redb")).unwrap();
store
.append("run", CheckpointRow::new(0, "A", "t0"))
.await
.unwrap();
let err = store
.append("run", CheckpointRow::new(0, "A-again", "t0"))
.await
.expect_err("duplicate (workflow_id, sequence) must be rejected");
assert_eq!(err.category(), "checkpoint_store");
assert!(
err.message().contains("conflict"),
"unexpected message: {err}"
);
store
.append("run", CheckpointRow::new(1, "B", "t1"))
.await
.unwrap();
let rows = store.load_run("run").await.unwrap();
assert_eq!(
rows.iter()
.map(|r| (r.sequence, r.state.clone()))
.collect::<Vec<_>>(),
vec![(0, "A".to_string()), (1, "B".to_string())]
);
store
.append("other", CheckpointRow::new(0, "A", "t0"))
.await
.unwrap();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn concurrent_appends_distinct_ids_stay_isolated_and_monotonic() {
let dir = tempdir().unwrap();
let store = Arc::new(RedbCheckpointStore::new(dir.path().join("ckpt.redb")).unwrap());
const RUNS: u64 = 16;
const ROWS_PER_RUN: u64 = 12;
let mut handles = Vec::new();
for r in 0..RUNS {
let store = Arc::clone(&store);
handles.push(tokio::spawn(async move {
let id = format!("run-{r}");
for s in 0..ROWS_PER_RUN {
store
.append(&id, CheckpointRow::new(s, format!("S{r}-{s}"), "t"))
.await
.unwrap();
tokio::task::yield_now().await;
}
}));
}
for h in handles {
h.await.unwrap();
}
for r in 0..RUNS {
let rows = store.load_run(&format!("run-{r}")).await.unwrap();
assert_eq!(rows.len() as u64, ROWS_PER_RUN, "run {r} row count");
assert_eq!(
rows.iter().map(|r| r.sequence).collect::<Vec<_>>(),
(0..ROWS_PER_RUN).collect::<Vec<_>>(),
"run {r} sequences"
);
assert!(
rows.iter()
.enumerate()
.all(|(i, row)| row.state == format!("S{r}-{i}")),
"run {r} rows belong only to that run"
);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn concurrent_appends_same_id_distinct_sequences_all_land() {
let dir = tempdir().unwrap();
let store = Arc::new(RedbCheckpointStore::new(dir.path().join("ckpt.redb")).unwrap());
const N: u64 = 32;
let mut handles = Vec::new();
for s in 0..N {
let store = Arc::clone(&store);
handles.push(tokio::spawn(async move {
store
.append("run", CheckpointRow::new(s, format!("S{s}"), "t"))
.await
.unwrap();
}));
}
for h in handles {
h.await.unwrap();
}
let rows = store.load_run("run").await.unwrap();
assert_eq!(
rows.iter().map(|r| r.sequence).collect::<Vec<_>>(),
(0..N).collect::<Vec<_>>(),
"every distinct sequence landed exactly once, in order"
);
}
#[tokio::test]
async fn large_output_blob_roundtrips() {
let dir = tempdir().unwrap();
let store = RedbCheckpointStore::new(dir.path().join("ckpt.redb")).unwrap();
let blob: Vec<u8> = (0..5 * 1024 * 1024usize).map(|i| (i % 251) as u8).collect();
store
.append(
"run",
CheckpointRow::new(0, "Big", "t").with_output(blob.clone()),
)
.await
.unwrap();
let rows = store.load_run("run").await.unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].output_blob.as_deref(), Some(blob.as_slice()));
}
#[tokio::test]
async fn corrupted_stored_row_is_a_decode_error_not_a_panic() {
let dir = tempdir().unwrap();
let store = RedbCheckpointStore::new(dir.path().join("ckpt.redb")).unwrap();
store
.append("run", CheckpointRow::new(0, "A", "t0"))
.await
.unwrap();
{
let tx = store.db.begin_write().unwrap();
{
let mut table = tx.open_table(CHECKPOINTS).unwrap();
table
.insert(("run", 1u64), [0xFFu8, 0xFF, 0xFF, 0xFF].as_slice())
.unwrap();
}
tx.commit().unwrap();
}
let err = store
.load_run("run")
.await
.expect_err("a corrupted row must surface as an error, not panic");
assert_eq!(err.category(), "checkpoint_store");
assert!(
err.message().contains("decode"),
"unexpected message: {err}"
);
}
#[tokio::test]
async fn all_row_kinds_roundtrip() {
let dir = tempdir().unwrap();
let store = RedbCheckpointStore::new(dir.path().join("ckpt.redb")).unwrap();
store
.append("run", CheckpointRow::new(0, "A", "task-a"))
.await
.unwrap();
store
.append(
"run",
CheckpointRow::new(1, "B", "task-b").with_output(vec![1, 2, 3]),
)
.await
.unwrap();
store
.append(
"run",
CheckpointRow::new(2, "C", "task-c").with_cursor(vec![4, 5]),
)
.await
.unwrap();
let rows = store.load_run("run").await.unwrap();
assert_eq!(rows.len(), 3);
assert_eq!(rows[0].sequence, 0);
assert_eq!(rows[0].state, "A");
assert_eq!(rows[0].kind, super::super::RowKind::StateEntry);
assert_eq!(rows[0].output_blob, None);
assert_eq!(rows[1].sequence, 1);
assert_eq!(rows[1].state, "B");
assert_eq!(rows[1].kind, super::super::RowKind::CompensationCompletion);
assert_eq!(rows[1].output_blob.as_deref(), Some(&[1u8, 2, 3][..]));
assert_eq!(rows[2].sequence, 2);
assert_eq!(rows[2].state, "C");
assert_eq!(rows[2].kind, super::super::RowKind::StepCursor);
assert_eq!(rows[2].output_blob.as_deref(), Some(&[4u8, 5][..]));
}
#[tokio::test]
async fn workflow_version_roundtrips_through_redb_store() {
let dir = tempdir().unwrap();
let store = RedbCheckpointStore::new(dir.path().join("ckpt.redb")).unwrap();
let row = CheckpointRow::new(0, "Start", "task").with_workflow_version(7);
store.append("wf-roundtrip", row).await.unwrap();
let loaded = store.load_run("wf-roundtrip").await.unwrap();
assert_eq!(loaded.len(), 1);
assert_eq!(loaded[0].workflow_version, 7);
}
#[tokio::test]
async fn legacy_stored_row_decodes_as_workflow_version_zero() {
#[derive(serde::Serialize)]
struct LegacyWrite {
state: String,
task_id: String,
output_blob: Option<Vec<u8>>,
kind: super::super::RowKind,
}
let dir = tempdir().unwrap();
let path = dir.path().join("ckpt.redb");
{
let db = Database::create(&path).unwrap();
let tx = db.begin_write().unwrap();
{
let mut table = tx.open_table(CHECKPOINTS).unwrap();
let legacy = LegacyWrite {
state: "Start".into(),
task_id: "task".into(),
output_blob: None,
kind: super::super::RowKind::StateEntry,
};
let bytes = postcard::to_stdvec(&legacy).unwrap();
table.insert(("wf-legacy", 0u64), bytes.as_slice()).unwrap();
}
tx.commit().unwrap();
}
let store = RedbCheckpointStore::new(&path).unwrap();
let rows = store.load_run("wf-legacy").await.unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].sequence, 0);
assert_eq!(rows[0].state, "Start");
assert_eq!(rows[0].task_id, "task");
assert_eq!(rows[0].kind, super::super::RowKind::StateEntry);
assert_eq!(rows[0].workflow_version, 0);
}
}