use std::collections::HashMap;
use std::hash::Hash;
use std::panic::AssertUnwindSafe;
use std::sync::Arc;
use futures_util::FutureExt;
use crate::error::CanoError;
use crate::recovery::RowKind;
use crate::saga::{CompensationEntry, ErasedCompensatable};
use crate::task::{TaskResult, run_with_retries};
use super::{Workflow, notify_observers, panic_payload_message};
#[cfg(feature = "tracing")]
use tracing::{debug, info, info_span};
impl<TState, TResourceKey> Workflow<TState, TResourceKey>
where
TState: Clone + std::fmt::Debug + std::hash::Hash + Eq + Send + Sync + 'static,
TResourceKey: Hash + Eq + Send + Sync + 'static,
{
pub(super) async fn run_compensations(
&self,
workflow_id: Option<&str>,
mut stack: Vec<CompensationEntry>,
original: CanoError,
) -> Result<TState, CanoError> {
if stack.is_empty() {
return Err(original);
}
let mut errors = vec![original];
while let Some(entry) = stack.pop() {
match self.compensators.get(&entry.task_id) {
None => errors.push(CanoError::workflow(format!(
"no compensator registered for task {:?} — cannot roll it back",
entry.task_id
))),
Some(compensator) => {
#[cfg(feature = "tracing")]
debug!(task_id = %entry.task_id, "compensating");
let attempt_timeout = compensator.config().attempt_timeout;
let compensate_fut =
compensator.compensate(&self.resources, &entry.output_blob);
let bounded = async {
match attempt_timeout {
Some(d) => tokio::time::timeout(d, compensate_fut)
.await
.unwrap_or_else(|_| {
Err(CanoError::timeout(format!(
"compensate for {:?} exceeded attempt_timeout {d:?}",
entry.task_id
)))
}),
None => compensate_fut.await,
}
};
#[cfg_attr(not(feature = "metrics"), allow(unused_variables))]
let compensate_ok = match AssertUnwindSafe(bounded).catch_unwind().await {
Ok(Ok(())) => true,
Ok(Err(e)) => {
#[cfg(feature = "tracing")]
tracing::error!(task_id = %entry.task_id, error = %e, "compensation failed");
errors.push(e);
false
}
Err(payload) => {
let msg = panic_payload_message(&*payload);
#[cfg(feature = "tracing")]
tracing::error!(task_id = %entry.task_id, panic = %msg, "compensation panicked");
errors.push(CanoError::task_execution(format!(
"compensate for {:?} panicked: {msg}",
entry.task_id
)));
false
}
};
#[cfg(feature = "metrics")]
crate::metrics::compensation_run(compensate_ok);
}
}
}
if errors.len() == 1 {
#[cfg(feature = "metrics")]
crate::metrics::compensation_drain(true);
self.clear_checkpoint_log(workflow_id).await;
Err(errors
.into_iter()
.next()
.expect("errors has exactly one element"))
} else {
#[cfg(feature = "metrics")]
crate::metrics::compensation_drain(false);
Err(CanoError::compensation_failed(errors))
}
}
pub(super) async fn clear_checkpoint_log(&self, workflow_id: Option<&str>) {
let Some((store, wf_id)) = self.checkpoint_store.as_ref().zip(workflow_id) else {
return;
};
let clear_result = store.clear(wf_id).await;
#[cfg(feature = "metrics")]
crate::metrics::checkpoint_clear(clear_result.is_ok());
#[cfg(feature = "tracing")]
if let Err(e) = clear_result {
tracing::warn!(workflow_id = %wf_id, error = %e, "failed to clear checkpoint log");
}
#[cfg(not(feature = "tracing"))]
let _ = clear_result;
}
pub(super) async fn execute_compensatable_task(
&self,
task: Arc<dyn ErasedCompensatable<TState, TResourceKey>>,
config: Arc<crate::task::TaskConfig>,
) -> Result<(TState, Vec<u8>), CanoError> {
let observers = self.observer_slice();
let task_name = task.name();
if let Some(ref slice) = observers {
notify_observers(slice, |o| o.on_task_start(task_name.as_ref()));
}
let config = Self::config_with_observers(&config, &observers, &task_name);
#[cfg(feature = "tracing")]
let task_span = if tracing::enabled!(tracing::Level::INFO) {
info_span!("compensatable_task_execution")
} else {
tracing::Span::none()
};
let run_future = async {
run_with_retries(&config, || {
let task_clone = task.clone();
let resources_clone = Arc::clone(&self.resources);
async move { task_clone.run(&*resources_clone).await }
})
.await
};
#[cfg(feature = "tracing")]
let unwind_result = {
let _enter = task_span.enter();
AssertUnwindSafe(run_future).catch_unwind().await
};
#[cfg(not(feature = "tracing"))]
let unwind_result = AssertUnwindSafe(run_future).catch_unwind().await;
let result: Result<(TaskResult<TState>, Vec<u8>), CanoError> = match unwind_result {
Ok(inner) => inner,
Err(payload) => {
let payload_str = panic_payload_message(&*payload);
#[cfg(feature = "tracing")]
tracing::error!(panic = %payload_str, "Compensatable task panicked");
Err(CanoError::task_execution(format!("panic: {payload_str}")))
}
};
let outcome: Result<(TState, Vec<u8>), CanoError> = match result {
Ok((TaskResult::Single(next_state), blob)) => Ok((next_state, blob)),
Ok((TaskResult::Split(_), _)) => Err(CanoError::workflow(
"Compensatable task returned a split result — split states cannot be compensatable",
)),
Err(e) => Err(e),
};
if let Some(ref slice) = observers {
match &outcome {
Ok(_) => notify_observers(slice, |o| o.on_task_success(task_name.as_ref())),
Err(e) => notify_observers(slice, |o| o.on_task_failure(task_name.as_ref(), e)),
}
}
outcome
}
fn state_from_label(&self, label: &str) -> Option<TState> {
self.states
.keys()
.chain(self.exit_states.iter())
.find(|s| format!("{s:?}") == label)
.cloned()
}
pub async fn resume_from(&self, workflow_id: impl Into<String>) -> Result<TState, CanoError> {
let workflow_id = workflow_id.into();
#[cfg(feature = "tracing")]
let workflow_span = self.tracing_span.clone().unwrap_or_else(|| {
if tracing::enabled!(tracing::Level::INFO) {
info_span!("workflow_resume", workflow_id = workflow_id.as_str())
} else {
tracing::Span::none()
}
});
#[cfg(feature = "tracing")]
let _enter = workflow_span.enter();
let store = self.checkpoint_store.clone().ok_or_else(|| {
CanoError::configuration(
"resume_from requires a checkpoint store (call with_checkpoint_store)",
)
})?;
let cached_validation = self.validated.get_or_init(|| self.validate());
if let Err(e) = cached_validation {
return Err(e.clone());
}
let rows = store.load_run(&workflow_id).await.map_err(|e| {
CanoError::checkpoint_store(format!("load checkpoint run {workflow_id:?}: {e}"))
})?;
let last = rows.iter().max_by_key(|r| r.sequence).ok_or_else(|| {
CanoError::checkpoint_store(format!(
"no checkpoint rows for workflow id {workflow_id:?}"
))
})?;
if last.workflow_version != self.workflow_version {
return Err(CanoError::workflow_version_mismatch(
last.workflow_version,
self.workflow_version,
));
}
let resume_state = self.state_from_label(&last.state).ok_or_else(|| {
CanoError::workflow(format!(
"checkpoint state {:?} is not a registered or exit state of this workflow",
last.state
))
})?;
let start_sequence = last.sequence + 1;
let resume_sequence = last.sequence;
let compensation_stack: Vec<CompensationEntry> = rows
.iter()
.filter(|r| r.sequence != resume_sequence)
.filter(|r| r.kind == RowKind::CompensationCompletion)
.filter_map(|r| {
r.output_blob.as_ref().map(|blob| CompensationEntry {
task_id: r.task_id.clone(),
output_blob: blob.clone(),
})
})
.collect();
let resume_cursors: HashMap<String, Vec<u8>> = rows
.iter()
.filter(|r| r.kind == RowKind::StepCursor)
.filter_map(|r| {
r.output_blob
.as_ref()
.map(|blob| (r.state.clone(), blob.clone()))
})
.fold(HashMap::new(), |mut acc, (state, blob)| {
acc.insert(state, blob);
acc
});
#[cfg(feature = "tracing")]
info!(workflow_id = %workflow_id, resume_state = ?resume_state, last_sequence = last.sequence, compensation_entries = compensation_stack.len(), cursor_states = resume_cursors.len(), "Resuming workflow from checkpoint");
notify_observers(&self.observers, |o| {
o.on_resume(&workflow_id, last.sequence)
});
self.resources.setup_all().await?;
#[cfg(feature = "metrics")]
let _active = crate::metrics::WorkflowActiveGuard::new();
#[cfg(feature = "metrics")]
let _started = std::time::Instant::now();
let exec = self.execute_workflow_from(
resume_state,
start_sequence,
Some(workflow_id),
compensation_stack,
resume_cursors,
);
let result = if let Some(timeout_duration) = self.workflow_timeout {
match tokio::time::timeout(timeout_duration, exec).await {
Ok(inner) => inner,
Err(_) => {
#[cfg(feature = "metrics")]
crate::metrics::workflow_run("timeout", _started.elapsed());
self.resources
.teardown_range(0..self.resources.lifecycle_len())
.await;
return Err(CanoError::workflow("Workflow timeout exceeded"));
}
}
} else {
exec.await
};
#[cfg(feature = "metrics")]
crate::metrics::workflow_run(
if result.is_ok() {
"completed"
} else {
"failed"
},
_started.elapsed(),
);
self.resources
.teardown_range(0..self.resources.lifecycle_len())
.await;
result
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::PollErrorPolicy;
use crate::observer::WorkflowObserver;
use crate::resource::Resources;
use crate::saga;
use crate::saga::CompensatableTask;
use crate::task as task_mod;
use crate::task::Task;
use crate::workflow::test_support::*;
use crate::workflow::{JoinConfig, JoinStrategy};
use cano_macros::task;
use std::borrow::Cow;
use std::collections::HashMap;
use std::sync::atomic::Ordering;
use std::time::Duration;
use crate::recovery::{CheckpointRow, CheckpointStore};
use std::sync::Mutex;
#[derive(Default)]
struct MemCheckpoints {
live: Mutex<HashMap<String, Vec<CheckpointRow>>>,
audit: Mutex<Vec<(String, CheckpointRow)>>,
}
#[cano_macros::checkpoint_store]
impl CheckpointStore for MemCheckpoints {
async fn append(&self, workflow_id: &str, row: CheckpointRow) -> Result<(), CanoError> {
let mut live = self.live.lock().unwrap();
let rows = live.entry(workflow_id.to_string()).or_default();
if rows.iter().any(|r| r.sequence == row.sequence) {
return Err(CanoError::checkpoint_store(format!(
"checkpoint conflict: {workflow_id:?} already has sequence {}",
row.sequence
)));
}
self.audit
.lock()
.unwrap()
.push((workflow_id.to_string(), row.clone()));
rows.push(row);
Ok(())
}
async fn load_run(&self, workflow_id: &str) -> Result<Vec<CheckpointRow>, CanoError> {
Ok(self.rows(workflow_id))
}
async fn clear(&self, workflow_id: &str) -> Result<(), CanoError> {
self.live.lock().unwrap().remove(workflow_id);
Ok(())
}
}
impl MemCheckpoints {
fn rows(&self, workflow_id: &str) -> Vec<CheckpointRow> {
let mut rows = self
.live
.lock()
.unwrap()
.get(workflow_id)
.cloned()
.unwrap_or_default();
rows.sort_by_key(|r| r.sequence);
rows
}
fn audit_rows(&self, workflow_id: &str) -> Vec<CheckpointRow> {
self.audit
.lock()
.unwrap()
.iter()
.filter(|(id, _)| id == workflow_id)
.map(|(_, r)| r.clone())
.collect()
}
fn audit_states(&self, workflow_id: &str) -> Vec<(u64, String)> {
self.audit_rows(workflow_id)
.into_iter()
.map(|r| (r.sequence, r.state))
.collect()
}
}
#[derive(Default)]
struct CkptObserver(Mutex<Vec<(&'static str, String, u64)>>);
impl WorkflowObserver for CkptObserver {
fn on_checkpoint(&self, workflow_id: &str, sequence: u64) {
self.0
.lock()
.unwrap()
.push(("checkpoint", workflow_id.to_string(), sequence));
}
fn on_resume(&self, workflow_id: &str, sequence: u64) {
self.0
.lock()
.unwrap()
.push(("resume", workflow_id.to_string(), sequence));
}
}
impl CkptObserver {
fn events(&self) -> Vec<(&'static str, String, u64)> {
self.0.lock().unwrap().clone()
}
}
#[tokio::test]
async fn checkpoint_row_written_for_each_state_entered() {
let store = Arc::new(MemCheckpoints::default());
let workflow = Workflow::bare()
.register(TestState::Start, SimpleTask::new(TestState::Process))
.register(TestState::Process, SimpleTask::new(TestState::Complete))
.add_exit_state(TestState::Complete)
.with_checkpoint_store(store.clone())
.with_workflow_id("run-1");
assert_eq!(
workflow.orchestrate(TestState::Start).await.unwrap(),
TestState::Complete
);
assert_eq!(
store.audit_states("run-1"),
vec![
(0, "Start".to_string()),
(1, "Process".to_string()),
(2, "Complete".to_string()),
]
);
let rows = store.audit_rows("run-1");
assert!(
!rows[0].task_id.is_empty() && !rows[1].task_id.is_empty(),
"rows for registered states carry the task name"
);
assert!(
rows[2].task_id.is_empty(),
"the exit state has no task, so its row's task_id is empty"
);
assert!(store.rows("run-1").is_empty());
}
#[tokio::test]
async fn no_checkpoint_store_means_no_rows_and_resume_is_rejected() {
let workflow = Workflow::bare()
.register(TestState::Start, SimpleTask::new(TestState::Complete))
.add_exit_state(TestState::Complete);
assert_eq!(
workflow.orchestrate(TestState::Start).await.unwrap(),
TestState::Complete
);
let err = workflow.resume_from("whatever").await.unwrap_err();
assert_eq!(err.category(), "configuration");
assert!(err.message().contains("checkpoint store"));
}
#[tokio::test]
async fn resume_continues_from_last_checkpointed_state() {
let store = Arc::new(MemCheckpoints::default());
store
.append("run-2", CheckpointRow::new(0, "Start", ""))
.await
.unwrap();
store
.append("run-2", CheckpointRow::new(1, "Process", ""))
.await
.unwrap();
let start_task = SimpleTask::new(TestState::Process);
let process_task = SimpleTask::new(TestState::Complete);
let observer = Arc::new(CkptObserver::default());
let workflow = Workflow::bare()
.register(TestState::Start, start_task.clone())
.register(TestState::Process, process_task.clone())
.add_exit_state(TestState::Complete)
.with_checkpoint_store(store.clone())
.with_observer(observer.clone());
assert_eq!(
workflow.resume_from("run-2").await.unwrap(),
TestState::Complete
);
assert_eq!(
start_task.count(),
0,
"state before the resume point never runs"
);
assert_eq!(
process_task.count(),
1,
"the resumed state's task re-runs once"
);
assert_eq!(
store.audit_states("run-2"),
vec![
(0, "Start".to_string()),
(1, "Process".to_string()),
(2, "Process".to_string()),
(3, "Complete".to_string()),
]
);
assert_eq!(
observer.events(),
vec![
("resume", "run-2".to_string(), 1),
("checkpoint", "run-2".to_string(), 2),
("checkpoint", "run-2".to_string(), 3),
]
);
}
#[tokio::test]
async fn resume_from_exit_state_returns_immediately() {
let store = Arc::new(MemCheckpoints::default());
store
.append("done-run", CheckpointRow::new(0, "Start", ""))
.await
.unwrap();
store
.append("done-run", CheckpointRow::new(1, "Complete", ""))
.await
.unwrap();
let work = SimpleTask::new(TestState::Complete);
let workflow = Workflow::bare()
.register(TestState::Start, work.clone())
.add_exit_state(TestState::Complete)
.with_checkpoint_store(store.clone());
assert_eq!(
workflow.resume_from("done-run").await.unwrap(),
TestState::Complete
);
assert_eq!(
work.count(),
0,
"no task runs when resuming into an exit state"
);
}
#[tokio::test]
async fn split_state_writes_a_single_checkpoint_row() {
let store = Arc::new(MemCheckpoints::default());
let tasks: Vec<SimpleTask> = (0..5)
.map(|_| SimpleTask::new(TestState::Complete))
.collect();
let join_config = JoinConfig::new(JoinStrategy::All, TestState::Complete);
let workflow = Workflow::bare()
.register_split(TestState::Start, tasks, join_config)
.add_exit_state(TestState::Complete)
.with_checkpoint_store(store.clone())
.with_workflow_id("split-run");
assert_eq!(
workflow.orchestrate(TestState::Start).await.unwrap(),
TestState::Complete
);
let rows = store.audit_rows("split-run");
assert_eq!(
rows.iter().filter(|r| r.state == "Start").count(),
1,
"the split state is checkpointed exactly once"
);
assert_eq!(
store.audit_states("split-run"),
vec![(0, "Start".to_string()), (1, "Complete".to_string())]
);
assert!(
rows[0].task_id.is_empty(),
"a split state's checkpoint row has no single task id"
);
}
#[tokio::test]
async fn checkpoint_requires_workflow_id_on_orchestrate() {
let store = Arc::new(MemCheckpoints::default());
let workflow = Workflow::bare()
.register(TestState::Start, SimpleTask::new(TestState::Complete))
.add_exit_state(TestState::Complete)
.with_checkpoint_store(store.clone());
let err = workflow.orchestrate(TestState::Start).await.unwrap_err();
assert_eq!(err.category(), "checkpoint_store");
assert!(err.message().contains("workflow id"));
}
#[tokio::test]
async fn resume_unknown_workflow_id_errors() {
let store = Arc::new(MemCheckpoints::default());
let workflow = Workflow::bare()
.register(TestState::Start, SimpleTask::new(TestState::Complete))
.add_exit_state(TestState::Complete)
.with_checkpoint_store(store.clone());
let err = workflow.resume_from("never-ran").await.unwrap_err();
assert_eq!(err.category(), "checkpoint_store");
assert!(err.message().contains("no checkpoint rows"));
}
#[tokio::test]
async fn resume_with_unrecognized_state_label_errors() {
let store = Arc::new(MemCheckpoints::default());
store
.append(
"wrong-defn",
CheckpointRow::new(0, "NotAStateOfThisWorkflow", ""),
)
.await
.unwrap();
let workflow = Workflow::bare()
.register(TestState::Start, SimpleTask::new(TestState::Complete))
.add_exit_state(TestState::Complete)
.with_checkpoint_store(store.clone());
let err = workflow.resume_from("wrong-defn").await.unwrap_err();
assert_eq!(err.category(), "workflow");
assert!(err.message().contains("is not a registered or exit state"));
}
#[tokio::test]
async fn router_state_produces_no_checkpoint_row_and_sequences_are_dense() {
use crate::observer::WorkflowObserver;
use crate::task::{RouterTask, TaskConfig};
struct RouteToWork;
#[task_mod::router]
impl RouterTask<TestState> for RouteToWork {
fn config(&self) -> TaskConfig {
TaskConfig::minimal()
}
async fn route(&self, _res: &Resources) -> Result<TaskResult<TestState>, CanoError> {
Ok(TaskResult::Single(TestState::Process))
}
}
#[derive(Default)]
struct StateEnterObserver(Mutex<Vec<String>>);
impl WorkflowObserver for StateEnterObserver {
fn on_state_enter(&self, state: &str) {
self.0.lock().unwrap().push(state.to_string());
}
}
let store = Arc::new(MemCheckpoints::default());
let observer = Arc::new(StateEnterObserver::default());
let ckpt_obs = Arc::new(CkptObserver::default());
let workflow = Workflow::bare()
.register_router(TestState::Start, RouteToWork)
.register(TestState::Process, SimpleTask::new(TestState::Complete))
.add_exit_state(TestState::Complete)
.with_checkpoint_store(store.clone())
.with_workflow_id("router-run")
.with_observer(observer.clone())
.with_observer(ckpt_obs.clone());
assert_eq!(
workflow.orchestrate(TestState::Start).await.unwrap(),
TestState::Complete
);
assert_eq!(
*observer.0.lock().unwrap(),
vec!["Start", "Process", "Complete"],
"on_state_enter fires for every state including the router"
);
assert_eq!(
store.audit_states("router-run"),
vec![(0, "Process".to_string()), (1, "Complete".to_string())],
"router state leaves no checkpoint row; sequences are dense from 0"
);
let ckpt_events = ckpt_obs.events();
assert_eq!(
ckpt_events,
vec![
("checkpoint", "router-run".to_string(), 0),
("checkpoint", "router-run".to_string(), 1),
],
"on_checkpoint fires only for non-router states"
);
}
#[tokio::test]
async fn resume_from_skips_router_rows_not_present_in_checkpoint_log() {
use crate::task::{RouterTask, TaskConfig};
struct RouteToWork;
#[task_mod::router]
impl RouterTask<TestState> for RouteToWork {
fn config(&self) -> TaskConfig {
TaskConfig::minimal()
}
async fn route(&self, _res: &Resources) -> Result<TaskResult<TestState>, CanoError> {
Ok(TaskResult::Single(TestState::Process))
}
}
let store = Arc::new(MemCheckpoints::default());
store
.append("router-resume", CheckpointRow::new(0, "Process", ""))
.await
.unwrap();
let process_task = SimpleTask::new(TestState::Complete);
let workflow = Workflow::bare()
.register_router(TestState::Start, RouteToWork)
.register(TestState::Process, process_task.clone())
.add_exit_state(TestState::Complete)
.with_checkpoint_store(store.clone());
assert_eq!(
workflow.resume_from("router-resume").await.unwrap(),
TestState::Complete
);
assert_eq!(
process_task.count(),
1,
"Process task re-runs once on resume"
);
assert_eq!(
store.audit_states("router-resume"),
vec![
(0, "Process".to_string()),
(1, "Process".to_string()),
(2, "Complete".to_string()),
]
);
}
use crate::task::TaskConfig;
type CompLog = Arc<Mutex<Vec<(String, u32)>>>;
#[derive(Clone)]
struct CompTask {
name: &'static str,
value: u32,
next_state: TestState,
log: CompLog,
fail_forward: bool,
fail_compensate: bool,
}
impl CompTask {
fn ok(name: &'static str, value: u32, next_state: TestState, log: &CompLog) -> Self {
Self {
name,
value,
next_state,
log: log.clone(),
fail_forward: false,
fail_compensate: false,
}
}
}
#[saga::task]
impl CompensatableTask<TestState> for CompTask {
type Output = u32;
fn config(&self) -> TaskConfig {
TaskConfig::minimal()
}
fn name(&self) -> Cow<'static, str> {
Cow::Borrowed(self.name)
}
async fn run(&self, _res: &Resources) -> Result<(TaskResult<TestState>, u32), CanoError> {
if self.fail_forward {
return Err(CanoError::task_execution(format!(
"{} forward failed",
self.name
)));
}
Ok((TaskResult::Single(self.next_state.clone()), self.value))
}
async fn compensate(&self, _res: &Resources, output: u32) -> Result<(), CanoError> {
self.log
.lock()
.unwrap()
.push((self.name.to_string(), output));
if self.fail_compensate {
return Err(CanoError::generic(format!(
"{} compensate failed",
self.name
)));
}
Ok(())
}
}
#[tokio::test]
async fn compensations_run_in_reverse_on_terminal_failure() {
let log: CompLog = Arc::new(Mutex::new(Vec::new()));
let workflow = Workflow::bare()
.register_with_compensation(
TestState::Start,
CompTask::ok("A", 1, TestState::Process, &log),
)
.register_with_compensation(
TestState::Process,
CompTask::ok("B", 2, TestState::Split, &log),
)
.register_with_compensation(
TestState::Split,
CompTask::ok("C", 3, TestState::Join, &log),
)
.register_with_compensation(
TestState::Join,
CompTask {
name: "D",
value: 4,
next_state: TestState::Complete,
log: log.clone(),
fail_forward: true,
fail_compensate: false,
},
)
.add_exit_state(TestState::Complete);
let err = workflow.orchestrate(TestState::Start).await.unwrap_err();
assert_eq!(err.category(), "task_execution");
assert_eq!(err.message(), "D forward failed");
assert_eq!(
*log.lock().unwrap(),
vec![
("C".to_string(), 3),
("B".to_string(), 2),
("A".to_string(), 1),
]
);
}
#[tokio::test]
async fn only_compensatable_tasks_are_compensated() {
let log: CompLog = Arc::new(Mutex::new(Vec::new()));
let workflow = Workflow::bare()
.register_with_compensation(
TestState::Start,
CompTask::ok("A", 10, TestState::Process, &log),
)
.register(TestState::Process, SimpleTask::new(TestState::Split))
.register_with_compensation(
TestState::Split,
CompTask::ok("C", 30, TestState::Join, &log),
)
.register_with_compensation(
TestState::Join,
CompTask {
name: "D",
value: 40,
next_state: TestState::Complete,
log: log.clone(),
fail_forward: true,
fail_compensate: false,
},
)
.add_exit_state(TestState::Complete);
let err = workflow.orchestrate(TestState::Start).await.unwrap_err();
assert_eq!(err.message(), "D forward failed");
assert_eq!(
*log.lock().unwrap(),
vec![("C".to_string(), 30), ("A".to_string(), 10)]
);
}
#[tokio::test]
async fn compensation_failure_aggregates_errors_and_keeps_going() {
let log: CompLog = Arc::new(Mutex::new(Vec::new()));
let workflow = Workflow::bare()
.register_with_compensation(
TestState::Start,
CompTask::ok("A", 1, TestState::Process, &log),
)
.register_with_compensation(
TestState::Process,
CompTask {
name: "B",
value: 2,
next_state: TestState::Split,
log: log.clone(),
fail_forward: false,
fail_compensate: true, },
)
.register_with_compensation(
TestState::Split,
CompTask::ok("C", 3, TestState::Join, &log),
)
.register_with_compensation(
TestState::Join,
CompTask {
name: "D",
value: 4,
next_state: TestState::Complete,
log: log.clone(),
fail_forward: true,
fail_compensate: false,
},
)
.add_exit_state(TestState::Complete);
let err = workflow.orchestrate(TestState::Start).await.unwrap_err();
match err {
CanoError::CompensationFailed { errors } => {
assert_eq!(errors.len(), 2);
assert_eq!(errors[0].message(), "D forward failed");
assert!(errors[1].message().contains("B compensate failed"));
}
other => panic!("expected CompensationFailed, got {other:?}"),
}
assert_eq!(
*log.lock().unwrap(),
vec![
("C".to_string(), 3),
("B".to_string(), 2),
("A".to_string(), 1),
]
);
}
#[tokio::test]
async fn resume_rehydrates_compensation_stack_and_rolls_back() {
let store = Arc::new(MemCheckpoints::default());
let log: CompLog = Arc::new(Mutex::new(Vec::new()));
store
.append("saga-run", CheckpointRow::new(0, "Start", "A"))
.await
.unwrap();
store
.append(
"saga-run",
CheckpointRow::new(1, "Start", "A").with_output(serde_json::to_vec(&7u32).unwrap()),
)
.await
.unwrap();
store
.append("saga-run", CheckpointRow::new(2, "Process", "B"))
.await
.unwrap();
store
.append(
"saga-run",
CheckpointRow::new(3, "Process", "B")
.with_output(serde_json::to_vec(&8u32).unwrap()),
)
.await
.unwrap();
store
.append("saga-run", CheckpointRow::new(4, "Split", "C"))
.await
.unwrap();
let workflow = Workflow::bare()
.register_with_compensation(
TestState::Start,
CompTask::ok("A", 7, TestState::Process, &log),
)
.register_with_compensation(
TestState::Process,
CompTask::ok("B", 8, TestState::Split, &log),
)
.register_with_compensation(
TestState::Split,
CompTask {
name: "C",
value: 9,
next_state: TestState::Join,
log: log.clone(),
fail_forward: true, fail_compensate: false,
},
)
.add_exit_state(TestState::Complete)
.with_checkpoint_store(store.clone());
let err = workflow.resume_from("saga-run").await.unwrap_err();
assert_eq!(err.message(), "C forward failed");
assert_eq!(
*log.lock().unwrap(),
vec![("B".to_string(), 8), ("A".to_string(), 7)]
);
}
#[tokio::test]
async fn resume_step_cursor_row_is_not_added_to_compensation_stack() {
use crate::recovery::RowKind;
let store = Arc::new(MemCheckpoints::default());
let log: CompLog = Arc::new(Mutex::new(Vec::new()));
store
.append("mixed-run", CheckpointRow::new(0, "Start", "A"))
.await
.unwrap();
store
.append(
"mixed-run",
CheckpointRow::new(1, "Start", "A")
.with_output(serde_json::to_vec(&101u32).unwrap()),
)
.await
.unwrap();
store
.append("mixed-run", CheckpointRow::new(2, "Process", "B"))
.await
.unwrap();
store
.append(
"mixed-run",
CheckpointRow::new(3, "Process", "B")
.with_output(serde_json::to_vec(&202u32).unwrap()),
)
.await
.unwrap();
let cursor_row = CheckpointRow::new(4, "Split", "C").with_cursor(vec![9, 9]);
assert_eq!(
cursor_row.kind,
RowKind::StepCursor,
"sanity: with_cursor sets RowKind::StepCursor"
);
store.append("mixed-run", cursor_row).await.unwrap();
let workflow = Workflow::bare()
.register_with_compensation(
TestState::Start,
CompTask::ok("A", 101, TestState::Process, &log),
)
.register_with_compensation(
TestState::Process,
CompTask::ok("B", 202, TestState::Split, &log),
)
.register_with_compensation(
TestState::Split,
CompTask {
name: "C",
value: 999,
next_state: TestState::Complete,
log: log.clone(),
fail_forward: true, fail_compensate: false,
},
)
.add_exit_state(TestState::Complete)
.with_checkpoint_store(store.clone());
let err = workflow.resume_from("mixed-run").await.unwrap_err();
assert_eq!(err.message(), "C forward failed");
assert_eq!(
*log.lock().unwrap(),
vec![("B".to_string(), 202), ("A".to_string(), 101)],
"StepCursor and StateEntry rows must not become compensation entries"
);
}
#[test]
fn re_registering_a_state_drops_the_stale_compensator() {
let log: CompLog = Arc::new(Mutex::new(Vec::new()));
let workflow = Workflow::bare()
.register_with_compensation(
TestState::Start,
CompTask::ok("A", 1, TestState::Process, &log),
)
.register_with_compensation(
TestState::Start,
CompTask::ok("A2", 2, TestState::Process, &log),
);
assert!(
!workflow.compensators.contains_key("A"),
"the replaced task's compensator must not linger"
);
assert!(workflow.compensators.contains_key("A2"));
let workflow = Workflow::bare()
.register_with_compensation(
TestState::Start,
CompTask::ok("B", 1, TestState::Process, &log),
)
.register(TestState::Start, SimpleTask::new(TestState::Process));
assert!(!workflow.compensators.contains_key("B"));
}
fn three_state_checkpointed(
store: Arc<MemCheckpoints>,
id: impl Into<String>,
) -> Workflow<TestState> {
Workflow::bare()
.register(TestState::Start, SimpleTask::new(TestState::Process))
.register(TestState::Process, SimpleTask::new(TestState::Complete))
.add_exit_state(TestState::Complete)
.with_checkpoint_store(store)
.with_workflow_id(id)
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn many_workflows_share_one_store_without_cross_talk() {
let store = Arc::new(MemCheckpoints::default());
const RUNS: usize = 16;
let mut handles = Vec::new();
for i in 0..RUNS {
let s = Arc::clone(&store);
handles.push(tokio::spawn(async move {
three_state_checkpointed(s, format!("run-{i}"))
.orchestrate(TestState::Start)
.await
}));
}
for h in handles {
assert_eq!(h.await.unwrap().unwrap(), TestState::Complete);
}
for i in 0..RUNS {
let id = format!("run-{i}");
assert_eq!(
store.audit_states(&id),
vec![
(0, "Start".to_string()),
(1, "Process".to_string()),
(2, "Complete".to_string()),
],
"{id}: exactly its own three rows, in order"
);
assert!(
store.rows(&id).is_empty(),
"{id}: a successful run clears its live log"
);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn racing_runs_of_one_id_never_corrupt_and_at_least_one_completes() {
let store = Arc::new(MemCheckpoints::default());
let mut handles = Vec::new();
for _ in 0..2 {
let s = Arc::clone(&store);
handles.push(tokio::spawn(async move {
three_state_checkpointed(s, "dup")
.orchestrate(TestState::Start)
.await
}));
}
let mut completed = 0;
for h in handles {
match h.await.unwrap() {
Ok(TestState::Complete) => completed += 1,
Ok(other) => panic!("unexpected success state {other:?}"),
Err(e) => assert_eq!(e.category(), "checkpoint_store", "unexpected error: {e}"),
}
}
assert!(
completed >= 1,
"whichever run wins sequence 0 must run to completion"
);
}
#[tokio::test]
async fn fresh_run_over_an_uncleared_log_is_rejected() {
let store = Arc::new(MemCheckpoints::default());
store
.append("run", CheckpointRow::new(0, "Start", "SimpleTask"))
.await
.unwrap();
store
.append("run", CheckpointRow::new(1, "Process", "SimpleTask"))
.await
.unwrap();
let err = three_state_checkpointed(store.clone(), "run")
.orchestrate(TestState::Start)
.await
.unwrap_err();
assert_eq!(err.category(), "checkpoint_store");
assert!(err.message().contains("conflict"), "got: {err}");
assert_eq!(store.rows("run").len(), 2);
}
#[tokio::test]
async fn append_failure_mid_run_rolls_back_and_keeps_the_log() {
struct FailAfter {
inner: MemCheckpoints,
ok_appends: std::sync::atomic::AtomicUsize,
}
#[cano_macros::checkpoint_store]
impl CheckpointStore for FailAfter {
async fn append(&self, workflow_id: &str, row: CheckpointRow) -> Result<(), CanoError> {
if self.ok_appends.fetch_sub(1, Ordering::SeqCst) == 0 {
return Err(CanoError::checkpoint_store("simulated disk failure"));
}
self.inner.append(workflow_id, row).await
}
async fn load_run(&self, id: &str) -> Result<Vec<CheckpointRow>, CanoError> {
self.inner.load_run(id).await
}
async fn clear(&self, id: &str) -> Result<(), CanoError> {
self.inner.clear(id).await
}
}
let log: CompLog = Arc::new(Mutex::new(Vec::new()));
let store = Arc::new(FailAfter {
inner: MemCheckpoints::default(),
ok_appends: std::sync::atomic::AtomicUsize::new(2),
});
let workflow = Workflow::bare()
.register_with_compensation(
TestState::Start,
CompTask {
name: "A",
value: 1,
next_state: TestState::Process,
log: log.clone(),
fail_forward: false,
fail_compensate: true,
},
)
.register_with_compensation(
TestState::Process,
CompTask::ok("B", 2, TestState::Complete, &log),
)
.add_exit_state(TestState::Complete)
.with_checkpoint_store(store.clone())
.with_workflow_id("disk");
let err = workflow.orchestrate(TestState::Start).await.unwrap_err();
match err {
CanoError::CompensationFailed { errors } => {
assert_eq!(errors[0].category(), "checkpoint_store");
assert!(errors[1].message().contains("A compensate failed"));
}
other => panic!("expected CompensationFailed, got {other:?}"),
}
assert_eq!(*log.lock().unwrap(), vec![("A".to_string(), 1)]);
assert_eq!(store.inner.rows("disk").len(), 2);
}
#[derive(Clone)]
struct CompFault {
name: &'static str,
value: u32,
next: TestState,
on_compensate: CompFaultKind,
attempt_timeout: Option<Duration>,
}
#[derive(Clone, Copy)]
enum CompFaultKind {
Panic,
Hang,
}
#[saga::task]
impl CompensatableTask<TestState> for CompFault {
type Output = u32;
fn config(&self) -> TaskConfig {
let cfg = TaskConfig::minimal();
match self.attempt_timeout {
Some(d) => cfg.with_attempt_timeout(d),
None => cfg,
}
}
fn name(&self) -> Cow<'static, str> {
Cow::Borrowed(self.name)
}
async fn run(&self, _res: &Resources) -> Result<(TaskResult<TestState>, u32), CanoError> {
Ok((TaskResult::Single(self.next.clone()), self.value))
}
async fn compensate(&self, _res: &Resources, _output: u32) -> Result<(), CanoError> {
match self.on_compensate {
CompFaultKind::Panic => panic!("{} compensate exploded", self.name),
CompFaultKind::Hang => {
std::future::pending::<()>().await;
unreachable!()
}
}
}
}
#[tokio::test]
async fn panicking_compensator_is_caught_and_the_drain_continues() {
let log: CompLog = Arc::new(Mutex::new(Vec::new()));
let workflow = Workflow::bare()
.register_with_compensation(
TestState::Start,
CompTask::ok("A", 1, TestState::Process, &log),
)
.register_with_compensation(
TestState::Process,
CompFault {
name: "B",
value: 2,
next: TestState::Split,
on_compensate: CompFaultKind::Panic,
attempt_timeout: None,
},
)
.register_with_compensation(
TestState::Split,
CompTask {
name: "C",
value: 3,
next_state: TestState::Complete,
log: log.clone(),
fail_forward: true,
fail_compensate: false,
},
)
.add_exit_state(TestState::Complete);
let err = workflow.orchestrate(TestState::Start).await.unwrap_err();
match err {
CanoError::CompensationFailed { errors } => {
assert_eq!(errors[0].message(), "C forward failed");
assert!(
errors[1..].iter().any(|e| e.message().contains("panicked")),
"the caught panic must be one of the collected errors: {errors:?}"
);
}
other => panic!("expected CompensationFailed, got {other:?}"),
}
assert_eq!(*log.lock().unwrap(), vec![("A".to_string(), 1)]);
}
#[tokio::test]
async fn hanging_compensator_is_bounded_by_attempt_timeout() {
let log: CompLog = Arc::new(Mutex::new(Vec::new()));
let workflow = Workflow::bare()
.register_with_compensation(
TestState::Start,
CompTask::ok("A", 1, TestState::Process, &log),
)
.register_with_compensation(
TestState::Process,
CompFault {
name: "H",
value: 2,
next: TestState::Split,
on_compensate: CompFaultKind::Hang,
attempt_timeout: Some(Duration::from_millis(50)),
},
)
.register_with_compensation(
TestState::Split,
CompTask {
name: "C",
value: 3,
next_state: TestState::Complete,
log: log.clone(),
fail_forward: true,
fail_compensate: false,
},
)
.add_exit_state(TestState::Complete);
let started = std::time::Instant::now();
let err = workflow.orchestrate(TestState::Start).await.unwrap_err();
assert!(
started.elapsed() < Duration::from_secs(5),
"a hanging compensator must be bounded, not block the drain forever"
);
match err {
CanoError::CompensationFailed { errors } => {
assert_eq!(errors[0].message(), "C forward failed");
assert!(
errors[1..]
.iter()
.any(|e| matches!(e, CanoError::Timeout(_))),
"H's compensate must time out: {errors:?}"
);
}
other => panic!("expected CompensationFailed, got {other:?}"),
}
assert_eq!(*log.lock().unwrap(), vec![("A".to_string(), 1)]);
}
#[tokio::test]
async fn double_compensate_on_resume_is_avoided() {
let store = Arc::new(MemCheckpoints::default());
let log: CompLog = Arc::new(Mutex::new(Vec::new()));
store
.append("crash-after-b", CheckpointRow::new(0, "Start", "A"))
.await
.unwrap();
store
.append(
"crash-after-b",
CheckpointRow::new(1, "Start", "A").with_output(serde_json::to_vec(&1u32).unwrap()),
)
.await
.unwrap();
store
.append("crash-after-b", CheckpointRow::new(2, "Process", "B"))
.await
.unwrap();
store
.append(
"crash-after-b",
CheckpointRow::new(3, "Process", "B")
.with_output(serde_json::to_vec(&2u32).unwrap()),
)
.await
.unwrap();
let workflow = Workflow::bare()
.register_with_compensation(
TestState::Start,
CompTask::ok("A", 1, TestState::Process, &log),
)
.register_with_compensation(
TestState::Process,
CompTask::ok("B", 2, TestState::Split, &log),
)
.register_with_compensation(
TestState::Split,
CompTask {
name: "C",
value: 3,
next_state: TestState::Complete,
log: log.clone(),
fail_forward: true,
fail_compensate: false,
},
)
.add_exit_state(TestState::Complete)
.with_checkpoint_store(store.clone());
let err = workflow.resume_from("crash-after-b").await.unwrap_err();
assert_eq!(err.message(), "C forward failed");
assert_eq!(
*log.lock().unwrap(),
vec![("B".to_string(), 2), ("A".to_string(), 1)],
"B compensated exactly once"
);
}
#[tokio::test]
async fn cross_model_chain_router_poll_batch_stepped_interop() {
use crate::PollOutcome;
use crate::recovery::RowKind;
use crate::task::{BatchTask, PollTask, RouterTask, StepOutcome, SteppedTask, TaskConfig};
use serde::{Deserialize, Serialize};
use std::sync::atomic::AtomicU32;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum TourStage {
Route,
Wait,
Crunch,
Grind,
Done,
}
struct TourRouter;
#[task_mod::router]
impl RouterTask<TourStage> for TourRouter {
fn config(&self) -> TaskConfig {
TaskConfig::minimal()
}
async fn route(&self, _res: &Resources) -> Result<TaskResult<TourStage>, CanoError> {
Ok(TaskResult::Single(TourStage::Wait))
}
}
struct TourPoller {
counter: Arc<AtomicU32>,
}
#[task_mod::poll]
impl PollTask<TourStage> for TourPoller {
fn config(&self) -> TaskConfig {
TaskConfig::minimal()
}
async fn poll(&self, _res: &Resources) -> Result<PollOutcome<TourStage>, CanoError> {
let n = self.counter.fetch_add(1, Ordering::Relaxed) + 1;
if n >= 2 {
Ok(PollOutcome::Ready(TaskResult::Single(TourStage::Crunch)))
} else {
Ok(PollOutcome::Pending { delay_ms: 0 })
}
}
}
struct TourBatch;
#[task_mod::batch]
impl BatchTask<TourStage> for TourBatch {
type Item = u32;
type ItemOutput = u32;
fn config(&self) -> TaskConfig {
TaskConfig::minimal()
}
fn concurrency(&self) -> usize {
4
}
async fn load(&self, _res: &Resources) -> Result<Vec<u32>, CanoError> {
Ok(vec![1, 2, 3])
}
async fn process_item(&self, item: &u32) -> Result<u32, CanoError> {
Ok(item * item)
}
async fn finish(
&self,
_res: &Resources,
outputs: Vec<Result<u32, CanoError>>,
) -> Result<TaskResult<TourStage>, CanoError> {
let sum: u32 = outputs.into_iter().filter_map(|r| r.ok()).sum();
assert_eq!(sum, 1 + 4 + 9, "1²+2²+3² = 14");
Ok(TaskResult::Single(TourStage::Grind))
}
}
struct TourStepper;
#[derive(Serialize, Deserialize)]
struct GrindCursor(u32);
#[task_mod::stepped]
impl SteppedTask<TourStage> for TourStepper {
type Cursor = GrindCursor;
fn config(&self) -> TaskConfig {
TaskConfig::minimal()
}
async fn step(
&self,
_res: &Resources,
cursor: Option<GrindCursor>,
) -> Result<StepOutcome<GrindCursor, TourStage>, CanoError> {
let n = cursor.map(|c| c.0).unwrap_or(0);
if n >= 3 {
Ok(StepOutcome::Done(TaskResult::Single(TourStage::Done)))
} else {
Ok(StepOutcome::More(GrindCursor(n + 1)))
}
}
}
let store = Arc::new(MemCheckpoints::default());
let poll_counter = Arc::new(AtomicU32::new(0));
let workflow = Workflow::bare()
.register_router(TourStage::Route, TourRouter)
.register(
TourStage::Wait,
TourPoller {
counter: Arc::clone(&poll_counter),
},
)
.register(TourStage::Crunch, TourBatch)
.register_stepped(TourStage::Grind, TourStepper)
.add_exit_state(TourStage::Done)
.with_checkpoint_store(store.clone())
.with_workflow_id("tour-interop");
let result = workflow.orchestrate(TourStage::Route).await.unwrap();
assert_eq!(result, TourStage::Done);
let audit = store.audit_rows("tour-interop");
assert!(
audit.iter().all(|r| r.state != "Route"),
"router state must leave no checkpoint row; got: {audit:?}"
);
for (idx, row) in audit.iter().enumerate() {
assert_eq!(
row.sequence, idx as u64,
"sequence gap at index {idx}: got {} expected {idx}",
row.sequence
);
}
let state_entry_states: Vec<&str> = audit
.iter()
.filter(|r| r.kind == RowKind::StateEntry)
.map(|r| r.state.as_str())
.collect();
assert_eq!(
state_entry_states,
vec!["Wait", "Crunch", "Grind", "Done"],
"StateEntry rows must cover all non-router states in order"
);
assert!(
audit
.iter()
.any(|r| r.kind == RowKind::StepCursor && r.state == "Grind"),
"Grind stepped task must produce at least one StepCursor row"
);
}
#[tokio::test]
async fn resume_then_advance_then_fail_rolls_back_mixed_outputs_and_clears() {
let store = Arc::new(MemCheckpoints::default());
let log: CompLog = Arc::new(Mutex::new(Vec::new()));
store
.append("mid-a", CheckpointRow::new(0, "Start", "A"))
.await
.unwrap();
let workflow = Workflow::bare()
.register_with_compensation(
TestState::Start,
CompTask::ok("A", 11, TestState::Process, &log),
)
.register_with_compensation(
TestState::Process,
CompTask::ok("B", 22, TestState::Split, &log),
)
.register_with_compensation(
TestState::Split,
CompTask {
name: "C",
value: 33,
next_state: TestState::Complete,
log: log.clone(),
fail_forward: true,
fail_compensate: false,
},
)
.add_exit_state(TestState::Complete)
.with_checkpoint_store(store.clone());
let err = workflow.resume_from("mid-a").await.unwrap_err();
assert_eq!(err.message(), "C forward failed");
assert_eq!(
*log.lock().unwrap(),
vec![("B".to_string(), 22), ("A".to_string(), 11)]
);
assert!(store.rows("mid-a").is_empty());
}
#[tokio::test]
async fn empty_compensation_stack_failure_keeps_the_log_for_resume() {
#[derive(Clone)]
struct FailFast;
#[task]
impl Task<TestState> for FailFast {
fn config(&self) -> TaskConfig {
TaskConfig::minimal()
}
async fn run_bare(&self) -> Result<TaskResult<TestState>, CanoError> {
Err(CanoError::task_execution("boom"))
}
}
let store = Arc::new(MemCheckpoints::default());
let workflow = Workflow::bare()
.register(TestState::Start, FailFast)
.add_exit_state(TestState::Complete)
.with_checkpoint_store(store.clone())
.with_workflow_id("nope");
let err = workflow.orchestrate(TestState::Start).await.unwrap_err();
assert_eq!(err.category(), "task_execution");
assert_eq!(store.rows("nope").len(), 1);
}
#[tokio::test]
async fn deep_compensation_stack_drains_all_in_reverse() {
const N: u32 = 60;
let log = Arc::new(Mutex::new(Vec::<u32>::new()));
#[derive(Clone)]
struct NumComp {
idx: u32,
fail: bool,
log: Arc<Mutex<Vec<u32>>>,
}
#[saga::task]
impl CompensatableTask<u32> for NumComp {
type Output = u32;
fn config(&self) -> TaskConfig {
TaskConfig::minimal()
}
fn name(&self) -> Cow<'static, str> {
Cow::Owned(format!("n{}", self.idx))
}
async fn run(&self, _res: &Resources) -> Result<(TaskResult<u32>, u32), CanoError> {
if self.fail {
return Err(CanoError::task_execution(format!("n{} failed", self.idx)));
}
Ok((TaskResult::Single(self.idx + 1), self.idx))
}
async fn compensate(&self, _res: &Resources, output: u32) -> Result<(), CanoError> {
self.log.lock().unwrap().push(output);
Ok(())
}
}
let mut workflow = Workflow::<u32>::bare().add_exit_state(N);
for i in 0..N {
workflow = workflow.register_with_compensation(
i,
NumComp {
idx: i,
fail: i == N - 1,
log: Arc::clone(&log),
},
);
}
let err = workflow.orchestrate(0).await.unwrap_err();
assert_eq!(err.message(), format!("n{} failed", N - 1));
let expected: Vec<u32> = (0..N - 1).rev().collect();
assert_eq!(*log.lock().unwrap(), expected);
}
#[tokio::test]
async fn failure_at_every_step_compensates_exactly_the_completed_prefix() {
const N: u32 = 5;
#[derive(Clone)]
struct Step {
idx: u32,
fail_at: u32,
log: Arc<Mutex<Vec<u32>>>,
}
#[saga::task]
impl CompensatableTask<u32> for Step {
type Output = u32;
fn config(&self) -> TaskConfig {
TaskConfig::minimal()
}
fn name(&self) -> Cow<'static, str> {
Cow::Owned(format!("s{}", self.idx))
}
async fn run(&self, _res: &Resources) -> Result<(TaskResult<u32>, u32), CanoError> {
if self.idx == self.fail_at {
return Err(CanoError::task_execution(format!("s{} failed", self.idx)));
}
Ok((TaskResult::Single(self.idx + 1), self.idx))
}
async fn compensate(&self, _res: &Resources, output: u32) -> Result<(), CanoError> {
self.log.lock().unwrap().push(output);
Ok(())
}
}
for fail_at in 0..=N {
let log = Arc::new(Mutex::new(Vec::<u32>::new()));
let mut workflow = Workflow::<u32>::bare().add_exit_state(N);
for i in 0..N {
workflow = workflow.register_with_compensation(
i,
Step {
idx: i,
fail_at,
log: Arc::clone(&log),
},
);
}
let result = workflow.orchestrate(0).await;
if fail_at == N {
assert_eq!(result.unwrap(), N, "no failure ⇒ run completes");
assert!(
log.lock().unwrap().is_empty(),
"completed run compensates nothing"
);
} else {
let err = result.unwrap_err();
assert_eq!(err.message(), format!("s{fail_at} failed"));
let expected: Vec<u32> = (0..fail_at).rev().collect();
assert_eq!(
*log.lock().unwrap(),
expected,
"fail at {fail_at}: states 0..{fail_at} compensate in reverse"
);
}
}
}
use crate::recovery::RowKind;
use crate::task::stepped::{StepOutcome, SteppedTask};
use std::sync::atomic::{AtomicU32, Ordering as AtomicOrdering};
struct CountStepper {
target: u32,
calls: Arc<AtomicU32>,
}
impl CountStepper {
fn new(target: u32) -> (Self, Arc<AtomicU32>) {
let calls = Arc::new(AtomicU32::new(0));
(
Self {
target,
calls: Arc::clone(&calls),
},
calls,
)
}
}
#[task_mod::stepped]
impl SteppedTask<TestState> for CountStepper {
type Cursor = u32;
fn config(&self) -> TaskConfig {
TaskConfig::minimal()
}
async fn step(
&self,
_res: &Resources,
cursor: Option<u32>,
) -> Result<StepOutcome<u32, TestState>, CanoError> {
self.calls.fetch_add(1, AtomicOrdering::Relaxed);
let n = cursor.unwrap_or(0) + 1;
if n >= self.target {
Ok(StepOutcome::Done(TaskResult::Single(TestState::Complete)))
} else {
Ok(StepOutcome::More(n))
}
}
}
#[tokio::test]
async fn stepped_forward_run_writes_cursor_rows_and_clears_on_success() {
let store = Arc::new(MemCheckpoints::default());
let (stepper, calls) = CountStepper::new(4);
let workflow = Workflow::bare()
.register_stepped(TestState::Start, stepper)
.add_exit_state(TestState::Complete)
.with_checkpoint_store(store.clone())
.with_workflow_id("step-fwd");
assert_eq!(
workflow.orchestrate(TestState::Start).await.unwrap(),
TestState::Complete
);
assert_eq!(calls.load(AtomicOrdering::Relaxed), 4);
let audit = store.audit_rows("step-fwd");
assert_eq!(audit.len(), 5, "start + 3 cursors + exit = 5 rows");
assert_eq!(audit[0].kind, RowKind::StateEntry);
assert_eq!(audit[0].state, "Start");
assert_eq!(audit[1].kind, RowKind::StepCursor);
assert_eq!(audit[2].kind, RowKind::StepCursor);
assert_eq!(audit[3].kind, RowKind::StepCursor);
assert_eq!(audit[4].kind, RowKind::StateEntry);
assert_eq!(audit[4].state, "Complete");
assert_eq!(
serde_json::from_slice::<u32>(audit[1].output_blob.as_ref().unwrap()).unwrap(),
1
);
assert_eq!(
serde_json::from_slice::<u32>(audit[3].output_blob.as_ref().unwrap()).unwrap(),
3
);
assert!(
store.rows("step-fwd").is_empty(),
"live log cleared on success"
);
}
#[tokio::test]
async fn stepped_resume_continues_from_last_cursor() {
let store = Arc::new(MemCheckpoints::default());
store
.append(
"step-resume",
CheckpointRow::new(0, "Start", "CountStepper"),
)
.await
.unwrap();
store
.append(
"step-resume",
CheckpointRow::new(1, "Start", "CountStepper")
.with_cursor(serde_json::to_vec(&1u32).unwrap()),
)
.await
.unwrap();
store
.append(
"step-resume",
CheckpointRow::new(2, "Start", "CountStepper")
.with_cursor(serde_json::to_vec(&2u32).unwrap()),
)
.await
.unwrap();
let (stepper, calls) = CountStepper::new(4);
let workflow = Workflow::bare()
.register_stepped(TestState::Start, stepper)
.add_exit_state(TestState::Complete)
.with_checkpoint_store(store.clone());
assert_eq!(
workflow.resume_from("step-resume").await.unwrap(),
TestState::Complete
);
assert_eq!(
calls.load(AtomicOrdering::Relaxed),
2,
"resumed run must not restart from None"
);
}
#[tokio::test]
async fn stepped_sequences_are_dense_after_cursors() {
let store = Arc::new(MemCheckpoints::default());
let (stepper, _) = CountStepper::new(3);
let workflow = Workflow::bare()
.register_stepped(TestState::Start, stepper)
.add_exit_state(TestState::Complete)
.with_checkpoint_store(store.clone())
.with_workflow_id("dense");
workflow.orchestrate(TestState::Start).await.unwrap();
let audit = store.audit_rows("dense");
assert_eq!(audit[0].sequence, 0);
assert_eq!(audit[1].sequence, 1);
assert_eq!(audit[2].sequence, 2);
assert_eq!(audit[3].sequence, 3);
for (i, row) in audit.iter().enumerate() {
assert_eq!(
row.sequence, i as u64,
"sequences must be contiguous, gap at {i}"
);
}
}
#[tokio::test]
async fn stepped_resume_without_cursor_rows_restarts_from_none() {
let store = Arc::new(MemCheckpoints::default());
store
.append(
"step-fresh-resume",
CheckpointRow::new(0, "Start", "CountStepper"),
)
.await
.unwrap();
let (stepper, calls) = CountStepper::new(3);
let workflow = Workflow::bare()
.register_stepped(TestState::Start, stepper)
.add_exit_state(TestState::Complete)
.with_checkpoint_store(store.clone());
assert_eq!(
workflow.resume_from("step-fresh-resume").await.unwrap(),
TestState::Complete
);
assert_eq!(calls.load(AtomicOrdering::Relaxed), 3);
}
#[tokio::test]
async fn stepped_no_store_path_unchanged() {
let (stepper, calls) = CountStepper::new(3);
let result = Workflow::bare()
.register_stepped(TestState::Start, stepper)
.add_exit_state(TestState::Complete)
.orchestrate(TestState::Start)
.await
.unwrap();
assert_eq!(result, TestState::Complete);
assert_eq!(calls.load(AtomicOrdering::Relaxed), 3);
}
#[tokio::test]
async fn stepped_cursor_rows_do_not_become_compensation_entries() {
let store = Arc::new(MemCheckpoints::default());
let log: CompLog = Arc::new(Mutex::new(Vec::new()));
store
.append("mixed-stepped", CheckpointRow::new(0, "Start", "A"))
.await
.unwrap();
store
.append(
"mixed-stepped",
CheckpointRow::new(1, "Start", "A")
.with_output(serde_json::to_vec(&42u32).unwrap()),
)
.await
.unwrap();
store
.append(
"mixed-stepped",
CheckpointRow::new(2, "Process", "CountStepper")
.with_cursor(serde_json::to_vec(&1u32).unwrap()),
)
.await
.unwrap();
struct FailingStepper;
#[task_mod::stepped]
impl SteppedTask<TestState> for FailingStepper {
type Cursor = u32;
fn config(&self) -> TaskConfig {
TaskConfig::minimal()
}
async fn step(
&self,
_res: &Resources,
_cursor: Option<u32>,
) -> Result<StepOutcome<u32, TestState>, CanoError> {
Err(CanoError::task_execution("stepper failed"))
}
}
let workflow = Workflow::bare()
.register_with_compensation(
TestState::Start,
CompTask::ok("A", 42, TestState::Process, &log),
)
.register_stepped(TestState::Process, FailingStepper)
.add_exit_state(TestState::Complete)
.with_checkpoint_store(store.clone());
let err = workflow.resume_from("mixed-stepped").await.unwrap_err();
assert_eq!(err.message(), "stepper failed");
assert_eq!(
*log.lock().unwrap(),
vec![("A".to_string(), 42u32)],
"StepCursor row must not become a compensation entry"
);
}
#[tokio::test]
async fn workflow_version_is_stamped_on_every_appended_row() {
let store = Arc::new(MemCheckpoints::default());
let workflow = Workflow::bare()
.register(TestState::Start, SimpleTask::new(TestState::Process))
.register(TestState::Process, SimpleTask::new(TestState::Complete))
.add_exit_state(TestState::Complete)
.with_checkpoint_store(store.clone())
.with_workflow_id("ver-run")
.with_workflow_version(7);
workflow.orchestrate(TestState::Start).await.unwrap();
let rows = store.audit_rows("ver-run");
assert!(rows.iter().all(|r| r.workflow_version == 7));
assert!(!rows.is_empty(), "expected at least one appended row");
}
#[tokio::test]
async fn resume_from_rejects_workflow_version_mismatch() {
let store = Arc::new(MemCheckpoints::default());
store
.append(
"ver-mismatch",
CheckpointRow::new(0, "Start", "").with_workflow_version(1),
)
.await
.unwrap();
let workflow = Workflow::bare()
.register(TestState::Start, SimpleTask::new(TestState::Complete))
.add_exit_state(TestState::Complete)
.with_checkpoint_store(store.clone())
.with_workflow_version(2);
let err = workflow.resume_from("ver-mismatch").await.unwrap_err();
assert_eq!(err, CanoError::workflow_version_mismatch(1, 2));
}
}
#[cfg(all(test, feature = "metrics"))]
mod metrics_tests {
use crate::metrics::test_support::*;
use crate::prelude::*;
use crate::recovery::{CheckpointRow, CheckpointStore};
use crate::saga;
use crate::saga::CompensatableTask;
use crate::task::TaskConfig;
use crate::workflow::test_support::{SimpleTask, TestState};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
#[derive(Default)]
struct MemCheckpoints {
live: Mutex<HashMap<String, Vec<CheckpointRow>>>,
audit: Mutex<Vec<(String, CheckpointRow)>>,
}
#[cano_macros::checkpoint_store]
impl CheckpointStore for MemCheckpoints {
async fn append(&self, workflow_id: &str, row: CheckpointRow) -> Result<(), CanoError> {
let mut live = self.live.lock().unwrap();
let rows = live.entry(workflow_id.to_string()).or_default();
if rows.iter().any(|r| r.sequence == row.sequence) {
return Err(CanoError::checkpoint_store(format!(
"checkpoint conflict: {workflow_id:?} already has sequence {}",
row.sequence
)));
}
self.audit
.lock()
.unwrap()
.push((workflow_id.to_string(), row.clone()));
rows.push(row);
Ok(())
}
async fn load_run(&self, workflow_id: &str) -> Result<Vec<CheckpointRow>, CanoError> {
let mut rows = self
.live
.lock()
.unwrap()
.get(workflow_id)
.cloned()
.unwrap_or_default();
rows.sort_by_key(|r| r.sequence);
Ok(rows)
}
async fn clear(&self, workflow_id: &str) -> Result<(), CanoError> {
self.live.lock().unwrap().remove(workflow_id);
Ok(())
}
}
#[test]
fn checkpoint_append_and_clear_counters_on_successful_run() {
let (res, rows) = run_with_recorder(|| async {
let store = Arc::new(MemCheckpoints::default());
let workflow = Workflow::bare()
.with_checkpoint_store(store.clone())
.with_workflow_id("metrics-wf")
.register(TestState::Start, SimpleTask::new(TestState::Process))
.register(TestState::Process, SimpleTask::new(TestState::Complete))
.add_exit_state(TestState::Complete);
workflow.orchestrate(TestState::Start).await
});
assert_eq!(res.unwrap(), TestState::Complete);
assert!(
counter(&rows, "cano_checkpoint_appends_total", &[("result", "ok")]) >= 1,
"expected at least one successful checkpoint append"
);
assert_eq!(
counter(&rows, "cano_checkpoint_clears_total", &[("result", "ok")]),
1,
"expected exactly one checkpoint clear on successful run"
);
}
type CompLog = Arc<Mutex<Vec<(String, u32)>>>;
#[derive(Clone)]
struct CompTask {
value: u32,
next_state: TestState,
log: CompLog,
}
#[saga::task]
impl CompensatableTask<TestState> for CompTask {
type Output = u32;
fn config(&self) -> TaskConfig {
TaskConfig::minimal()
}
async fn run(&self, _res: &Resources) -> Result<(TaskResult<TestState>, u32), CanoError> {
Ok((TaskResult::Single(self.next_state.clone()), self.value))
}
async fn compensate(&self, _res: &Resources, output: u32) -> Result<(), CanoError> {
self.log.lock().unwrap().push(("comp".to_string(), output));
Ok(())
}
}
struct AlwaysFailTask;
#[crate::task]
impl Task<TestState> for AlwaysFailTask {
fn config(&self) -> TaskConfig {
TaskConfig::minimal()
}
async fn run_bare(&self) -> Result<TaskResult<TestState>, CanoError> {
Err(CanoError::task_execution(
"intentional failure for saga test",
))
}
}
#[test]
fn compensation_run_and_drain_counters_on_clean_rollback() {
let (res, rows) = run_with_recorder(|| async {
let log: CompLog = Arc::new(Mutex::new(Vec::new()));
let workflow = Workflow::bare()
.register_with_compensation(
TestState::Start,
CompTask {
value: 42,
next_state: TestState::Process,
log: log.clone(),
},
)
.register(TestState::Process, AlwaysFailTask)
.add_exit_state(TestState::Complete);
workflow.orchestrate(TestState::Start).await
});
assert!(res.is_err(), "expected workflow to fail");
assert_eq!(
counter(&rows, "cano_compensations_run_total", &[("result", "ok")]),
1,
"expected one successful compensate() call"
);
assert_eq!(
counter(
&rows,
"cano_compensation_drains_total",
&[("outcome", "clean")]
),
1,
"expected a clean compensation drain"
);
}
}