use std::borrow::Cow;
use std::collections::HashMap;
use std::hash::Hash;
use std::sync::{Arc, OnceLock};
use std::time::Duration;
use crate::error::CanoError;
use crate::observer::WorkflowObserver;
use crate::recovery::CheckpointStore;
use crate::resource::Resources;
use crate::saga::{CompensatableTask, ErasedCompensatable};
use crate::task::stepped::{ErasedSteppedTask, SteppedAdapter, SteppedTask};
use crate::task::{RouterTask, Task};
#[cfg(feature = "tracing")]
use tracing::{Span, info_span};
mod compensation;
mod execution;
mod join;
#[cfg(test)]
mod test_support;
pub use execution::StateEntry;
pub use join::{JoinConfig, JoinStrategy, SplitResult, SplitTaskResult};
#[inline]
fn notify_observers(observers: &[Arc<dyn WorkflowObserver>], f: impl Fn(&dyn WorkflowObserver)) {
for observer in observers {
let observer_ref = observer.as_ref();
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(observer_ref)));
if let Err(payload) = result {
let msg = panic_payload_message(&*payload);
#[cfg(feature = "tracing")]
tracing::error!(panic = %msg, "workflow observer panicked");
#[cfg(not(feature = "tracing"))]
observer_panic_notice(&msg);
}
}
}
#[cfg(not(feature = "tracing"))]
fn observer_panic_notice(msg: &str) {
use std::sync::atomic::{AtomicBool, Ordering};
static EMITTED: AtomicBool = AtomicBool::new(false);
if EMITTED
.compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
.is_ok()
{
eprintln!(
"cano: workflow observer panicked (further panics will be silent until process restart): {msg}"
);
}
}
pub(crate) fn panic_payload_message(payload: &(dyn std::any::Any + Send)) -> String {
if let Some(s) = payload.downcast_ref::<&'static str>() {
(*s).to_string()
} else if let Some(s) = payload.downcast_ref::<String>() {
s.clone()
} else {
"<non-string panic payload>".to_string()
}
}
pub(crate) async fn catch_panic_to_error<T, F>(
fut: F,
panic_label: &'static str,
) -> Result<T, CanoError>
where
F: std::future::Future<Output = Result<T, CanoError>>,
{
use futures_util::FutureExt;
use std::panic::AssertUnwindSafe;
match AssertUnwindSafe(fut).catch_unwind().await {
Ok(inner) => inner,
Err(payload) => {
let msg = panic_payload_message(&*payload);
#[cfg(feature = "tracing")]
tracing::error!(panic = %msg, "{} panicked", panic_label);
#[cfg(not(feature = "tracing"))]
let _ = panic_label;
Err(CanoError::task_execution(format!("panic: {msg}")))
}
}
}
#[must_use]
pub struct Workflow<TState, TResourceKey = Cow<'static, str>>
where
TState: Clone + std::fmt::Debug + std::hash::Hash + Eq + Send + Sync + 'static,
TResourceKey: Hash + Eq + Send + Sync + 'static,
{
states: HashMap<TState, Arc<StateEntry<TState, TResourceKey>>>,
pub(crate) resources: Arc<Resources<TResourceKey>>,
workflow_timeout: Option<Duration>,
pub(crate) total_timeout: Option<Duration>,
pub(crate) compensation_timeout: Option<Duration>,
exit_states: Vec<TState>,
validated: OnceLock<Result<(), CanoError>>,
observers: Vec<Arc<dyn WorkflowObserver>>,
compensators: HashMap<Arc<str>, Arc<dyn ErasedCompensatable<TState, TResourceKey>>>,
checkpoint_store: Option<Arc<dyn CheckpointStore>>,
workflow_id: Option<Arc<str>>,
workflow_version: u32,
#[cfg(feature = "tracing")]
tracing_span: Option<Span>,
}
impl<TState, TResourceKey> Workflow<TState, TResourceKey>
where
TState: Clone + std::fmt::Debug + std::hash::Hash + Eq + Send + Sync + 'static,
TResourceKey: Hash + Eq + Send + Sync + 'static,
{
pub fn new(resources: Resources<TResourceKey>) -> Self {
Self {
states: HashMap::new(),
resources: Arc::new(resources),
workflow_timeout: None,
total_timeout: None,
compensation_timeout: None,
exit_states: Vec::new(),
validated: OnceLock::new(),
observers: Vec::new(),
compensators: HashMap::new(),
checkpoint_store: None,
workflow_id: None,
workflow_version: 0,
#[cfg(feature = "tracing")]
tracing_span: None,
}
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.workflow_timeout = Some(timeout);
self
}
pub fn with_total_timeout(mut self, timeout: Duration) -> Self {
self.total_timeout = Some(timeout);
self
}
pub fn with_compensation_timeout(mut self, timeout: Duration) -> Self {
self.compensation_timeout = Some(timeout);
self
}
pub fn register<T>(mut self, state: TState, task: T) -> Self
where
T: Task<TState, TResourceKey> + Send + Sync + 'static,
{
self.forget_compensator_for(&state);
let config = Arc::new(task.config());
self.states.insert(
state,
Arc::new(StateEntry::Single {
task: Arc::new(task),
config,
}),
);
self
}
pub fn register_router<T>(mut self, state: TState, task: T) -> Self
where
T: RouterTask<TState, TResourceKey> + Task<TState, TResourceKey> + 'static,
{
self.forget_compensator_for(&state);
let config = Arc::new(Task::config(&task));
self.states.insert(
state,
Arc::new(StateEntry::Router {
task: Arc::new(task),
config,
}),
);
self
}
pub fn register_split<T>(
mut self,
state: TState,
tasks: Vec<T>,
join_config: JoinConfig<TState>,
) -> Self
where
T: Task<TState, TResourceKey> + Send + Sync + 'static,
{
self.forget_compensator_for(&state);
let configs: Vec<Arc<crate::task::TaskConfig>> =
tasks.iter().map(|t| Arc::new(t.config())).collect();
let arc_tasks: Vec<Arc<dyn Task<TState, TResourceKey> + Send + Sync>> =
tasks.into_iter().map(|t| Arc::new(t) as Arc<_>).collect();
self.states.insert(
state,
Arc::new(StateEntry::Split {
tasks: arc_tasks,
configs: Arc::new(configs),
join_config: Arc::new(join_config),
}),
);
self
}
fn forget_compensator_for(&mut self, state: &TState) {
let stale_name: Option<Arc<str>> =
if let Some(StateEntry::CompensatableSingle { task, .. }) =
self.states.get(state).map(|e| e.as_ref())
{
Some(Arc::from(task.name()))
} else {
None
};
if let Some(name) = stale_name {
self.compensators.remove(&*name);
}
}
pub fn register_with_compensation<T>(mut self, state: TState, task: T) -> Self
where
T: CompensatableTask<TState, TResourceKey> + 'static,
{
self.forget_compensator_for(&state);
let config = Arc::new(task.config());
let name: Arc<str> = Arc::from(task.name());
let erased: Arc<dyn ErasedCompensatable<TState, TResourceKey>> =
Arc::new(crate::saga::CompensatableAdapter(Arc::new(task)));
self.compensators.insert(name, Arc::clone(&erased));
self.states.insert(
state,
Arc::new(StateEntry::CompensatableSingle {
task: erased,
config,
}),
);
self
}
pub fn register_stepped<T>(mut self, state: TState, task: T) -> Self
where
T: SteppedTask<TState, TResourceKey> + 'static,
{
self.forget_compensator_for(&state);
let config = Arc::new(SteppedTask::config(&task));
let erased: Arc<dyn ErasedSteppedTask<TState, TResourceKey>> =
Arc::new(SteppedAdapter(Arc::new(task)));
self.states.insert(
state,
Arc::new(StateEntry::Stepped {
task: erased,
config,
}),
);
self
}
pub fn add_exit_state(mut self, state: TState) -> Self {
if !self.exit_states.contains(&state) {
self.exit_states.push(state);
}
self
}
pub fn add_exit_states(mut self, states: Vec<TState>) -> Self {
for state in states {
if !self.exit_states.contains(&state) {
self.exit_states.push(state);
}
}
self
}
pub fn with_observer(mut self, observer: Arc<dyn WorkflowObserver>) -> Self {
self.observers.push(observer);
self
}
pub fn with_checkpoint_store(mut self, checkpoint_store: Arc<dyn CheckpointStore>) -> Self {
self.checkpoint_store = Some(checkpoint_store);
self
}
pub fn with_workflow_id(mut self, workflow_id: impl Into<Arc<str>>) -> Self {
self.workflow_id = Some(workflow_id.into());
self
}
pub fn with_workflow_version(mut self, version: u32) -> Self {
self.workflow_version = version;
self
}
#[cfg(feature = "tracing")]
pub fn with_tracing_span(mut self, span: Span) -> Self {
self.tracing_span = Some(span);
self
}
fn observer_slice(&self) -> Option<Arc<[Arc<dyn WorkflowObserver>]>> {
if self.observers.is_empty() {
None
} else {
Some(Arc::from(self.observers.as_slice()))
}
}
fn config_with_observers(
base: &Arc<crate::task::TaskConfig>,
observers: &Option<Arc<[Arc<dyn WorkflowObserver>]>>,
task_name: &str,
) -> Arc<crate::task::TaskConfig> {
match observers {
None => Arc::clone(base),
Some(slice) => {
let mut cfg = (**base).clone();
cfg.observers = Some(Arc::clone(slice));
cfg.task_name = Some(Cow::Owned(task_name.to_owned()));
Arc::new(cfg)
}
}
}
fn validate_join_config(
join_config: &JoinConfig<TState>,
_total_tasks: usize,
) -> Result<(), CanoError> {
if matches!(join_config.strategy, JoinStrategy::PartialTimeout)
&& join_config.timeout.is_none()
{
return Err(CanoError::configuration(
"PartialTimeout strategy requires a timeout to be configured",
));
}
if let JoinStrategy::Percentage(p) = join_config.strategy
&& (!p.is_finite() || p <= 0.0 || p > 1.0)
{
return Err(CanoError::configuration(format!(
"Percentage strategy requires a finite value in (0.0, 1.0], got {p}"
)));
}
if let Some(0) = join_config.bulkhead {
return Err(CanoError::configuration(
"bulkhead requires a positive permit count, got 0",
));
}
Ok(())
}
pub fn validate(&self) -> Result<(), CanoError> {
if self.states.is_empty() {
return Err(CanoError::configuration(
"Workflow has no registered state handlers",
));
}
if self.exit_states.is_empty() {
return Err(CanoError::configuration(
"Workflow has no exit states defined — orchestration may loop forever",
));
}
for entry in self.states.values() {
if let StateEntry::Split {
tasks, join_config, ..
} = entry.as_ref()
{
Self::validate_join_config(join_config, tasks.len())?;
let js = &join_config.join_state;
if !self.states.contains_key(js) && !self.exit_states.contains(js) {
return Err(CanoError::configuration(format!(
"Split join_state {:?} is neither registered nor an exit state",
js
)));
}
}
}
Ok(())
}
pub fn validate_initial_state(&self, state: &TState) -> Result<(), CanoError> {
if !self.states.contains_key(state) && !self.exit_states.contains(state) {
return Err(CanoError::configuration(format!(
"Initial state {:?} is neither registered nor an exit state",
state
)));
}
Ok(())
}
pub async fn orchestrate(&self, initial_state: TState) -> Result<TState, CanoError> {
#[cfg(feature = "tracing")]
let workflow_span = self.tracing_span.clone().unwrap_or_else(|| {
if tracing::enabled!(tracing::Level::INFO) {
info_span!(
"workflow_orchestrate",
workflow_id = self.workflow_id.as_deref()
)
} else {
tracing::Span::none()
}
});
#[cfg(feature = "tracing")]
let _enter = workflow_span.enter();
let cached_validation = self.validated.get_or_init(|| self.validate());
if let Err(e) = cached_validation {
return Err(e.clone());
}
self.validate_initial_state(&initial_state)?;
self.resources.setup_all().await?;
let result = self.run_workflow(initial_state).await;
self.resources
.teardown_range(0..self.resources.lifecycle_len())
.await;
result
}
async fn run_workflow(&self, initial_state: TState) -> Result<TState, CanoError> {
#[cfg(feature = "metrics")]
let _active = crate::metrics::WorkflowActiveGuard::new();
let started = std::time::Instant::now();
let total_budget = self.resolve_total_budget(started);
let workflow_future = self.execute_workflow(initial_state, total_budget);
self.await_with_outer_timeout(workflow_future, total_budget, started)
.await
}
pub(crate) fn resolve_total_budget(
&self,
started: std::time::Instant,
) -> Option<(std::time::Instant, Duration)> {
let effective = match (self.workflow_timeout, self.total_timeout) {
(Some(w), Some(t)) => Some(w.min(t)),
(_, Some(t)) => Some(t),
_ => None,
};
effective.map(|d| (started, d))
}
pub(crate) async fn await_with_outer_timeout<F, T>(
&self,
fut: F,
total_budget: Option<(std::time::Instant, Duration)>,
#[allow(unused_variables)] started: std::time::Instant,
) -> Result<T, CanoError>
where
F: std::future::Future<Output = Result<T, CanoError>>,
{
let result = match (self.workflow_timeout, total_budget) {
(Some(timeout_duration), None) => {
match tokio::time::timeout(timeout_duration, fut).await {
Ok(inner) => inner,
Err(_) => {
#[cfg(feature = "metrics")]
crate::metrics::workflow_run("timeout", started.elapsed());
return Err(CanoError::workflow("Workflow timeout exceeded"));
}
}
}
_ => fut.await,
};
#[cfg(feature = "metrics")]
crate::metrics::workflow_run(
if result.is_ok() {
"completed"
} else {
"failed"
},
started.elapsed(),
);
result
}
}
impl<TState, TResourceKey> Clone for Workflow<TState, TResourceKey>
where
TState: Clone + std::fmt::Debug + std::hash::Hash + Eq + Send + Sync + 'static,
TResourceKey: Hash + Eq + Send + Sync + 'static,
{
fn clone(&self) -> Self {
Self {
states: self.states.clone(),
resources: Arc::clone(&self.resources),
workflow_timeout: self.workflow_timeout,
total_timeout: self.total_timeout,
compensation_timeout: self.compensation_timeout,
exit_states: self.exit_states.clone(),
validated: OnceLock::new(),
observers: self.observers.clone(),
compensators: self.compensators.clone(),
checkpoint_store: self.checkpoint_store.clone(),
workflow_id: self.workflow_id.clone(),
workflow_version: self.workflow_version,
#[cfg(feature = "tracing")]
tracing_span: self.tracing_span.clone(),
}
}
}
impl<TState> Workflow<TState, Cow<'static, str>>
where
TState: Clone + std::fmt::Debug + std::hash::Hash + Eq + Send + Sync + 'static,
{
pub fn bare() -> Self {
Self::new(Resources::new())
}
}
impl<TState, TResourceKey> std::fmt::Debug for Workflow<TState, TResourceKey>
where
TState: Clone + std::fmt::Debug + std::hash::Hash + Eq + Send + Sync + 'static,
TResourceKey: Hash + Eq + Send + Sync + 'static,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Workflow")
.field("states", &format!("{} states", self.states.len()))
.field("exit_states", &self.exit_states)
.field("workflow_timeout", &self.workflow_timeout)
.field("total_timeout", &self.total_timeout)
.field("compensation_timeout", &self.compensation_timeout)
.field("workflow_id", &self.workflow_id)
.field("workflow_version", &self.workflow_version)
.field("checkpoint_store", &self.checkpoint_store.is_some())
.field(
"compensators",
&format!("{} compensators", self.compensators.len()),
)
.finish()
}
}
#[cfg(all(test, feature = "metrics"))]
mod metrics_tests {
use crate::metrics::test_support::*;
use crate::prelude::*;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum S {
Start,
Mid,
Done,
}
struct GoTo(S);
#[crate::task]
impl Task<S> for GoTo {
async fn run_bare(&self) -> Result<TaskResult<S>, CanoError> {
Ok(TaskResult::Single(self.0.clone()))
}
}
struct Boom;
#[crate::task]
impl Task<S> for Boom {
fn config(&self) -> TaskConfig {
TaskConfig::minimal()
}
async fn run_bare(&self) -> Result<TaskResult<S>, CanoError> {
Err(CanoError::task_execution("boom"))
}
}
fn ok_workflow() -> Workflow<S> {
Workflow::bare()
.register(S::Start, GoTo(S::Mid))
.register(S::Mid, GoTo(S::Done))
.add_exit_state(S::Done)
}
#[test]
fn successful_run_records_outcome_duration_and_clears_active_gauge() {
let (res, rows) = run_with_recorder(|| async { ok_workflow().orchestrate(S::Start).await });
assert_eq!(res.unwrap(), S::Done);
assert_eq!(
counter(
&rows,
"cano_workflow_runs_total",
&[("outcome", "completed")]
),
1
);
assert_eq!(
histogram_count(
&rows,
"cano_workflow_duration_seconds",
&[("outcome", "completed")]
),
1
);
assert_eq!(gauge(&rows, "cano_workflow_active", &[]), 0.0);
}
#[test]
fn failed_run_records_failed_outcome() {
let (res, rows) = run_with_recorder(|| async {
Workflow::bare()
.register(S::Start, Boom)
.add_exit_state(S::Done)
.orchestrate(S::Start)
.await
});
assert!(res.is_err());
assert_eq!(
counter(&rows, "cano_workflow_runs_total", &[("outcome", "failed")]),
1
);
assert_eq!(gauge(&rows, "cano_workflow_active", &[]), 0.0);
}
#[test]
fn legacy_timeout_on_orchestrate_only_increments_timeout_counter() {
struct Slow;
#[crate::task]
impl Task<S> for Slow {
fn config(&self) -> TaskConfig {
TaskConfig::minimal()
}
async fn run_bare(&self) -> Result<TaskResult<S>, CanoError> {
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
Ok(TaskResult::Single(S::Done))
}
}
let (res, rows) = run_with_recorder(|| async {
Workflow::bare()
.with_timeout(std::time::Duration::from_millis(20))
.register(S::Start, Slow)
.add_exit_state(S::Done)
.orchestrate(S::Start)
.await
});
assert!(res.is_err());
assert_eq!(
counter(&rows, "cano_workflow_runs_total", &[("outcome", "timeout")]),
1
);
assert_eq!(
counter_opt(&rows, "cano_workflow_runs_total", &[("outcome", "failed")]).unwrap_or(0),
0,
"legacy timeout must not double-count as both `timeout` and `failed`"
);
}
#[test]
fn per_state_task_durations_are_recorded_single_and_split() {
let (res, rows) = run_with_recorder(|| async { ok_workflow().orchestrate(S::Start).await });
assert_eq!(res.unwrap(), S::Done);
assert_eq!(
histogram_count(
&rows,
"cano_task_duration_seconds",
&[("state", "Start"), ("kind", "single")]
),
1
);
assert_eq!(
histogram_count(
&rows,
"cano_task_duration_seconds",
&[("state", "Mid"), ("kind", "single")]
),
1
);
}
#[derive(Clone)]
struct Branch {
fail: bool,
}
#[crate::task]
impl Task<S> for Branch {
fn config(&self) -> TaskConfig {
TaskConfig::minimal()
}
async fn run_bare(&self) -> Result<TaskResult<S>, CanoError> {
if self.fail {
Err(CanoError::task_execution("nope"))
} else {
Ok(TaskResult::Single(S::Done))
}
}
}
#[test]
fn split_records_branch_results_and_a_split_kind_duration() {
let (res, rows) = run_with_recorder(|| async {
Workflow::bare()
.register_split(
S::Start,
vec![
Branch { fail: false },
Branch { fail: true },
Branch { fail: false },
],
JoinConfig::new(JoinStrategy::PartialResults(2), S::Done),
)
.add_exit_state(S::Done)
.orchestrate(S::Start)
.await
});
assert_eq!(res.unwrap(), S::Done);
assert_eq!(
counter(
&rows,
"cano_split_branch_results_total",
&[("result", "success")]
),
2
);
assert_eq!(
counter(
&rows,
"cano_split_branch_results_total",
&[("result", "failure")]
),
1
);
assert_eq!(
counter_opt(
&rows,
"cano_split_branch_results_total",
&[("result", "cancelled")]
),
None
);
assert_eq!(
histogram_count(
&rows,
"cano_task_duration_seconds",
&[("state", "Start"), ("kind", "split")]
),
1
);
}
}
#[cfg(test)]
mod tests {
use super::test_support::*;
use super::*;
use crate::resource::Resources;
use crate::task;
use crate::task::{Task, TaskResult};
use cano_macros::task as task_macro;
use tokio;
#[tokio::test]
async fn test_workflow_creation() {
let workflow = Workflow::<TestState>::bare();
assert_eq!(workflow.states.len(), 0);
assert_eq!(workflow.exit_states.len(), 0);
}
#[tokio::test]
async fn test_simple_workflow() {
let workflow = Workflow::bare()
.register(TestState::Start, SimpleTask::new(TestState::Complete))
.add_exit_state(TestState::Complete);
let result = workflow.orchestrate(TestState::Start).await.unwrap();
assert_eq!(result, TestState::Complete);
}
#[tokio::test]
async fn test_multi_step_workflow() {
let workflow = Workflow::bare()
.register(TestState::Start, SimpleTask::new(TestState::Process))
.register(TestState::Process, SimpleTask::new(TestState::Complete))
.add_exit_state(TestState::Complete);
let result = workflow.orchestrate(TestState::Start).await.unwrap();
assert_eq!(result, TestState::Complete);
}
#[tokio::test]
async fn test_workflow_with_data() {
let store = crate::store::MemoryStore::new();
let resources = Resources::new().insert("store", store.clone());
let workflow = Workflow::new(resources)
.register(
TestState::Start,
DataTask::new("test_key", "test_value", TestState::Complete),
)
.add_exit_state(TestState::Complete);
let result = workflow.orchestrate(TestState::Start).await.unwrap();
assert_eq!(result, TestState::Complete);
let data: String = store.get("test_key").unwrap();
assert_eq!(data, "test_value");
}
#[tokio::test]
async fn test_unregistered_state_error() {
let workflow = Workflow::<TestState>::bare().add_exit_state(TestState::Complete);
let result = workflow.orchestrate(TestState::Start).await;
let err = result.unwrap_err();
assert_eq!(err.category(), "configuration");
assert!(err.to_string().contains("no registered state handlers"));
}
#[test]
fn test_validate_empty_workflow() {
let workflow = Workflow::<TestState>::bare();
let result = workflow.validate();
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("no registered state handlers")
);
}
#[test]
fn test_validate_no_exit_states() {
let workflow =
Workflow::bare().register(TestState::Start, SimpleTask::new(TestState::Complete));
let result = workflow.validate();
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("no exit states defined")
);
}
#[test]
fn test_validate_valid_workflow() {
let workflow = Workflow::bare()
.register(TestState::Start, SimpleTask::new(TestState::Complete))
.add_exit_state(TestState::Complete);
assert!(workflow.validate().is_ok());
}
#[test]
fn test_validate_split_join_state_unregistered() {
let workflow = Workflow::bare()
.register_split(
TestState::Start,
vec![SimpleTask::new(TestState::Join)],
JoinConfig::new(JoinStrategy::All, TestState::Process), )
.add_exit_state(TestState::Complete);
let result = workflow.validate();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("join_state"));
}
#[test]
fn test_validate_split_join_state_as_exit_state() {
let workflow = Workflow::bare()
.register_split(
TestState::Start,
vec![SimpleTask::new(TestState::Complete)],
JoinConfig::new(JoinStrategy::All, TestState::Complete),
)
.add_exit_state(TestState::Complete);
assert!(workflow.validate().is_ok());
}
#[test]
fn test_validate_rejects_partial_timeout_without_timeout() {
let workflow = Workflow::bare()
.register_split(
TestState::Start,
vec![SimpleTask::new(TestState::Complete)],
JoinConfig::new(JoinStrategy::PartialTimeout, TestState::Complete),
)
.add_exit_state(TestState::Complete);
let err = workflow
.validate()
.expect_err("PartialTimeout without timeout must fail validation");
assert!(matches!(err, CanoError::Configuration(_)), "got {err:?}");
assert!(err.to_string().contains("requires a timeout"));
}
#[test]
fn test_validate_rejects_invalid_percentage() {
for value in [0.0, 1.5, f64::NAN] {
let workflow = Workflow::bare()
.register_split(
TestState::Start,
vec![SimpleTask::new(TestState::Complete)],
JoinConfig::new(JoinStrategy::Percentage(value), TestState::Complete),
)
.add_exit_state(TestState::Complete);
let err = workflow
.validate()
.expect_err("invalid Percentage strategy must fail validation");
assert!(matches!(err, CanoError::Configuration(_)), "got {err:?}");
assert!(err.to_string().contains("Percentage strategy"));
}
}
#[test]
fn test_validate_rejects_zero_bulkhead() {
let workflow = Workflow::bare()
.register_split(
TestState::Start,
vec![SimpleTask::new(TestState::Complete)],
JoinConfig::new(JoinStrategy::All, TestState::Complete).with_bulkhead(0),
)
.add_exit_state(TestState::Complete);
let err = workflow
.validate()
.expect_err("bulkhead=0 must fail validation");
assert!(matches!(err, CanoError::Configuration(_)), "got {err:?}");
assert!(err.to_string().contains("bulkhead"));
}
#[test]
fn test_validate_initial_state() {
let workflow = Workflow::bare()
.register(TestState::Start, SimpleTask::new(TestState::Complete))
.add_exit_state(TestState::Complete);
assert!(workflow.validate_initial_state(&TestState::Start).is_ok());
assert!(
workflow
.validate_initial_state(&TestState::Complete)
.is_ok()
);
let result = workflow.validate_initial_state(&TestState::Process);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("neither registered nor an exit state")
);
}
struct BareWorkflowTask;
#[task_macro]
impl Task<TestState> for BareWorkflowTask {
async fn run_bare(&self) -> Result<TaskResult<TestState>, CanoError> {
Ok(TaskResult::Single(TestState::Complete))
}
}
#[tokio::test]
async fn test_workflow_bare_runs_task_with_run_bare() {
let result = Workflow::bare()
.register(TestState::Start, BareWorkflowTask)
.add_exit_state(TestState::Complete)
.orchestrate(TestState::Start)
.await
.unwrap();
assert_eq!(result, TestState::Complete);
}
#[tokio::test]
async fn test_register_router_orchestration_round_trip() {
use crate::task::{RouterTask, TaskConfig};
struct RouteToProcess;
#[task::router]
impl RouterTask<TestState> for RouteToProcess {
fn config(&self) -> TaskConfig {
TaskConfig::minimal()
}
async fn route(&self, _res: &Resources) -> Result<TaskResult<TestState>, CanoError> {
Ok(TaskResult::Single(TestState::Process))
}
}
let result = Workflow::bare()
.register_router(TestState::Start, RouteToProcess)
.register(TestState::Process, SimpleTask::new(TestState::Complete))
.add_exit_state(TestState::Complete)
.orchestrate(TestState::Start)
.await
.unwrap();
assert_eq!(result, TestState::Complete);
}
#[test]
fn test_validate_passes_with_router_state() {
use crate::task::{RouterTask, TaskConfig};
struct RouteToComplete;
#[task::router]
impl RouterTask<TestState> for RouteToComplete {
fn config(&self) -> TaskConfig {
TaskConfig::minimal()
}
async fn route(&self, _res: &Resources) -> Result<TaskResult<TestState>, CanoError> {
Ok(TaskResult::Single(TestState::Complete))
}
}
let workflow = Workflow::bare()
.register_router(TestState::Start, RouteToComplete)
.add_exit_state(TestState::Complete);
assert!(workflow.validate().is_ok());
}
#[test]
fn test_with_total_timeout_stores_value_and_clones() {
use std::time::Duration;
let wf = Workflow::<TestState>::bare()
.with_total_timeout(Duration::from_secs(5))
.with_compensation_timeout(Duration::from_millis(750))
.register(TestState::Start, SimpleTask::new(TestState::Complete))
.add_exit_state(TestState::Complete);
assert_eq!(wf.total_timeout, Some(Duration::from_secs(5)));
assert_eq!(wf.compensation_timeout, Some(Duration::from_millis(750)));
let cloned = wf.clone();
assert_eq!(cloned.total_timeout, Some(Duration::from_secs(5)));
assert_eq!(
cloned.compensation_timeout,
Some(Duration::from_millis(750))
);
}
#[test]
fn test_with_total_timeout_defaults_to_none() {
let wf = Workflow::<TestState>::bare();
assert_eq!(wf.total_timeout, None);
assert_eq!(wf.compensation_timeout, None);
}
#[test]
fn test_workflow_debug_includes_total_and_compensation_timeouts() {
use std::time::Duration;
let wf = Workflow::<TestState>::bare()
.with_total_timeout(Duration::from_secs(5))
.with_compensation_timeout(Duration::from_millis(750));
let debug_str = format!("{wf:?}");
assert!(debug_str.contains("total_timeout"), "got: {debug_str}");
assert!(
debug_str.contains("compensation_timeout"),
"got: {debug_str}"
);
}
#[tokio::test]
async fn orchestrate_survives_observer_panic_via_notify_observers_catch_unwind() {
use crate::observer::WorkflowObserver;
struct PanickyObserver;
impl WorkflowObserver for PanickyObserver {
fn on_state_enter(&self, _state: &str) {
panic!("observer panic — must not abort the workflow");
}
}
let result = Workflow::bare()
.register(TestState::Start, SimpleTask::new(TestState::Complete))
.add_exit_state(TestState::Complete)
.with_observer(Arc::new(PanickyObserver))
.orchestrate(TestState::Start)
.await
.expect("orchestrate must complete despite observer panic");
assert_eq!(result, TestState::Complete);
}
#[tokio::test]
async fn repeated_observer_panics_do_not_block_workflow_progress() {
use crate::observer::WorkflowObserver;
use std::sync::atomic::{AtomicUsize, Ordering};
struct CountThenPanic(Arc<AtomicUsize>);
impl WorkflowObserver for CountThenPanic {
fn on_state_enter(&self, _state: &str) {
self.0.fetch_add(1, Ordering::SeqCst);
panic!("repeated panic — must not block the FSM loop");
}
}
let count = Arc::new(AtomicUsize::new(0));
let result = Workflow::bare()
.register(TestState::Start, SimpleTask::new(TestState::Process))
.register(TestState::Process, SimpleTask::new(TestState::Complete))
.add_exit_state(TestState::Complete)
.with_observer(Arc::new(CountThenPanic(Arc::clone(&count))))
.orchestrate(TestState::Start)
.await
.expect("orchestrate must complete despite repeated observer panics");
assert_eq!(result, TestState::Complete);
assert_eq!(count.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn orchestrate_with_exit_state_as_initial_returns_immediately() {
let start = SimpleTask::new(TestState::Complete);
let wf = Workflow::bare()
.register(TestState::Start, start.clone())
.add_exit_state(TestState::Complete);
let result = wf.orchestrate(TestState::Complete).await.unwrap();
assert_eq!(result, TestState::Complete);
assert_eq!(
start.count(),
0,
"no handler runs when the initial state is an exit state"
);
}
#[tokio::test]
async fn exit_state_takes_precedence_over_registered_handler() {
let process = SimpleTask::new(TestState::Start); let wf = Workflow::bare()
.register(TestState::Start, SimpleTask::new(TestState::Process))
.register(TestState::Process, process.clone())
.add_exit_state(TestState::Process);
let result = wf.orchestrate(TestState::Start).await.unwrap();
assert_eq!(result, TestState::Process);
assert_eq!(
process.count(),
0,
"an exit state short-circuits before its handler runs"
);
}
#[tokio::test]
async fn transition_to_unregistered_state_errors_at_runtime() {
let wf = Workflow::bare()
.register(TestState::Start, SimpleTask::new(TestState::Process))
.add_exit_state(TestState::Complete);
let err = wf.orchestrate(TestState::Start).await.unwrap_err();
assert!(err.to_string().contains("No task registered"), "got: {err}");
}
#[tokio::test]
async fn orchestrate_from_unknown_initial_state_errors() {
let wf = Workflow::bare()
.register(TestState::Start, SimpleTask::new(TestState::Complete))
.add_exit_state(TestState::Complete);
let err = wf.orchestrate(TestState::Process).await.unwrap_err();
assert_eq!(err.category(), "configuration");
assert!(
err.to_string()
.contains("neither registered nor an exit state"),
"got: {err}"
);
}
struct ReturnsSplit;
#[task_macro]
impl Task<TestState> for ReturnsSplit {
async fn run_bare(&self) -> Result<TaskResult<TestState>, CanoError> {
Ok(TaskResult::Split(vec![TestState::Complete]))
}
}
#[tokio::test]
async fn single_task_returning_split_errors() {
let wf = Workflow::bare()
.register(TestState::Start, ReturnsSplit)
.add_exit_state(TestState::Complete);
let err = wf.orchestrate(TestState::Start).await.unwrap_err();
assert!(err.to_string().contains("use register_split"), "got: {err}");
}
#[tokio::test]
async fn register_replaces_a_prior_handler_for_the_same_state() {
let first = SimpleTask::new(TestState::Process); let wf = Workflow::bare()
.register(TestState::Start, first.clone())
.register(TestState::Start, SimpleTask::new(TestState::Complete)) .add_exit_state(TestState::Complete);
let result = wf.orchestrate(TestState::Start).await.unwrap();
assert_eq!(result, TestState::Complete); assert_eq!(first.count(), 0, "the replaced handler must not run");
}
struct LoopUntil {
limit: u32,
count: Arc<std::sync::atomic::AtomicU32>,
}
#[task_macro]
impl Task<TestState> for LoopUntil {
async fn run_bare(&self) -> Result<TaskResult<TestState>, CanoError> {
use std::sync::atomic::Ordering;
let n = self.count.fetch_add(1, Ordering::SeqCst) + 1;
if n >= self.limit {
Ok(TaskResult::Single(TestState::Complete))
} else {
Ok(TaskResult::Single(TestState::Start)) }
}
}
#[tokio::test]
async fn self_looping_state_runs_until_it_routes_to_exit() {
let count = Arc::new(std::sync::atomic::AtomicU32::new(0));
let wf = Workflow::bare()
.register(
TestState::Start,
LoopUntil {
limit: 5,
count: Arc::clone(&count),
},
)
.add_exit_state(TestState::Complete);
let result = wf.orchestrate(TestState::Start).await.unwrap();
assert_eq!(result, TestState::Complete);
assert_eq!(count.load(std::sync::atomic::Ordering::SeqCst), 5);
}
#[tokio::test]
async fn empty_split_joins_to_its_join_state() {
let wf = Workflow::bare()
.register_split(
TestState::Start,
Vec::<SimpleTask>::new(),
JoinConfig::new(JoinStrategy::All, TestState::Complete),
)
.add_exit_state(TestState::Complete);
let result = tokio::time::timeout(
std::time::Duration::from_secs(5),
wf.orchestrate(TestState::Start),
)
.await
.expect("orchestrate of an empty split must not hang");
assert_eq!(result.unwrap(), TestState::Complete);
}
#[test]
fn add_exit_state_and_add_exit_states_deduplicate() {
let wf = Workflow::<TestState>::bare()
.add_exit_state(TestState::Complete)
.add_exit_state(TestState::Complete) .add_exit_states(vec![
TestState::Complete, TestState::Error,
TestState::Error, ]);
assert_eq!(wf.exit_states.len(), 2);
assert!(wf.exit_states.contains(&TestState::Complete));
assert!(wf.exit_states.contains(&TestState::Error));
}
#[test]
fn workflow_id_and_version_are_stored_defaulted_and_cloned() {
let wf = Workflow::<TestState>::bare();
assert_eq!(wf.workflow_version, 0); assert!(wf.workflow_id.is_none());
let wf = wf.with_workflow_id("run-42").with_workflow_version(7);
assert_eq!(wf.workflow_version, 7);
assert_eq!(wf.workflow_id.as_deref(), Some("run-42"));
let cloned = wf.clone();
assert_eq!(cloned.workflow_version, 7);
assert_eq!(cloned.workflow_id.as_deref(), Some("run-42"));
}
}
#[cfg(test)]
mod await_with_outer_timeout_tests {
use super::test_support::TestState;
use super::*;
use std::time::Duration;
fn workflow_with(
workflow_timeout: Option<Duration>,
total_timeout: Option<Duration>,
) -> Workflow<TestState> {
let mut w = Workflow::<TestState>::bare();
if let Some(d) = workflow_timeout {
w = w.with_timeout(d);
}
if let Some(d) = total_timeout {
w = w.with_total_timeout(d);
}
w
}
#[tokio::test]
async fn neither_timeout_just_awaits_future() {
let w = workflow_with(None, None);
let started = std::time::Instant::now();
let out = w
.await_with_outer_timeout(
async { Ok::<TestState, CanoError>(TestState::Complete) },
None,
started,
)
.await
.unwrap();
assert_eq!(out, TestState::Complete);
}
#[tokio::test]
async fn only_with_timeout_passes_through_when_future_is_fast() {
let w = workflow_with(Some(Duration::from_secs(60)), None);
let started = std::time::Instant::now();
let out = w
.await_with_outer_timeout(
async { Ok::<TestState, CanoError>(TestState::Complete) },
None,
started,
)
.await
.unwrap();
assert_eq!(out, TestState::Complete);
}
#[tokio::test]
async fn only_with_timeout_fires_legacy_timeout_on_slow_future() {
let w = workflow_with(Some(Duration::from_millis(10)), None);
let started = std::time::Instant::now();
let err = w
.await_with_outer_timeout(
async {
tokio::time::sleep(Duration::from_secs(1)).await;
Ok::<TestState, CanoError>(TestState::Complete)
},
None,
started,
)
.await
.expect_err("legacy timeout must fire");
assert!(
matches!(err, CanoError::Workflow(ref m) if m.contains("Workflow timeout exceeded")),
"expected legacy shape, got: {err}"
);
assert!(
started.elapsed() < Duration::from_millis(500),
"must bound to the legacy timeout, not the inner sleep"
);
}
#[tokio::test]
async fn only_total_budget_skips_legacy_path() {
let w = workflow_with(None, Some(Duration::from_millis(10)));
let total_budget = Some((std::time::Instant::now(), Duration::from_millis(10)));
let started = std::time::Instant::now();
let out = w
.await_with_outer_timeout(
async { Ok::<TestState, CanoError>(TestState::Complete) },
total_budget,
started,
)
.await
.unwrap();
assert_eq!(out, TestState::Complete);
}
#[tokio::test]
async fn both_timeouts_set_skips_legacy_path() {
let w = workflow_with(
Some(Duration::from_millis(5)),
Some(Duration::from_secs(60)),
);
let total_budget = Some((std::time::Instant::now(), Duration::from_secs(60)));
let started = std::time::Instant::now();
let out = w
.await_with_outer_timeout(
async {
tokio::time::sleep(Duration::from_millis(20)).await;
Ok::<TestState, CanoError>(TestState::Complete)
},
total_budget,
started,
)
.await
.unwrap();
assert_eq!(out, TestState::Complete);
assert!(
started.elapsed() >= Duration::from_millis(20),
"future must run to completion; legacy wrapper must NOT be applied"
);
}
#[tokio::test]
async fn legacy_path_propagates_inner_errors_unchanged() {
let w = workflow_with(Some(Duration::from_secs(60)), None);
let started = std::time::Instant::now();
let err = w
.await_with_outer_timeout(
async { Err::<TestState, _>(CanoError::task_execution("inner boom")) },
None,
started,
)
.await
.expect_err("inner err must propagate");
assert!(
matches!(err, CanoError::TaskExecution(ref m) if m == "inner boom"),
"must propagate verbatim, got: {err}"
);
}
}
#[cfg(test)]
mod resolve_total_budget_tests {
use super::test_support::TestState;
use super::*;
use std::time::Duration;
#[test]
fn neither_set_returns_none() {
let w = Workflow::<TestState>::bare();
assert!(w.resolve_total_budget(std::time::Instant::now()).is_none());
}
#[test]
fn only_with_timeout_set_returns_none() {
let w = Workflow::<TestState>::bare().with_timeout(Duration::from_secs(1));
assert!(
w.resolve_total_budget(std::time::Instant::now()).is_none(),
"with_timeout alone goes through the legacy wrapper; FSM gets no budget"
);
}
#[test]
fn only_total_timeout_set_returns_total() {
let w = Workflow::<TestState>::bare().with_total_timeout(Duration::from_secs(7));
let now = std::time::Instant::now();
let (start, limit) = w.resolve_total_budget(now).unwrap();
assert_eq!(start, now);
assert_eq!(limit, Duration::from_secs(7));
}
#[test]
fn both_set_returns_min_via_with_timeout_as_floor() {
let w = Workflow::<TestState>::bare()
.with_timeout(Duration::from_millis(50))
.with_total_timeout(Duration::from_secs(60));
let now = std::time::Instant::now();
let (_, limit) = w.resolve_total_budget(now).unwrap();
assert_eq!(limit, Duration::from_millis(50));
}
#[test]
fn both_set_total_smaller_returns_total() {
let w = Workflow::<TestState>::bare()
.with_timeout(Duration::from_secs(60))
.with_total_timeout(Duration::from_millis(50));
let (_, limit) = w.resolve_total_budget(std::time::Instant::now()).unwrap();
assert_eq!(limit, Duration::from_millis(50));
}
}
#[cfg(test)]
mod catch_panic_to_error_tests {
use super::*;
#[tokio::test]
async fn passes_through_ok_result_unchanged() {
let result =
catch_panic_to_error(async { Ok::<u32, CanoError>(42) }, "non-panicking").await;
assert_eq!(result.unwrap(), 42);
}
#[tokio::test]
async fn passes_through_err_result_unchanged() {
let result = catch_panic_to_error::<u32, _>(
async { Err(CanoError::task_execution("explicit failure")) },
"non-panicking",
)
.await;
match result {
Err(CanoError::TaskExecution(m)) => assert_eq!(m, "explicit failure"),
other => panic!("expected explicit TaskExecution err, got: {other:?}"),
}
}
#[tokio::test]
async fn str_literal_panic_payload_yields_panic_prefixed_task_execution_err() {
let result = catch_panic_to_error::<u32, _>(
async {
panic!("static literal");
},
"labelled",
)
.await;
match result {
Err(CanoError::TaskExecution(m)) => assert_eq!(m, "panic: static literal"),
other => panic!("expected TaskExecution(\"panic: ...\"), got: {other:?}"),
}
}
#[tokio::test]
async fn formatted_string_panic_payload_preserves_message() {
let result = catch_panic_to_error::<u32, _>(
async {
let detail = 99;
panic!("formatted message {detail}");
},
"labelled",
)
.await;
match result {
Err(CanoError::TaskExecution(m)) => assert_eq!(m, "panic: formatted message 99"),
other => panic!("expected formatted TaskExecution, got: {other:?}"),
}
}
#[tokio::test]
async fn non_string_panic_payload_yields_fallback_message() {
let result = catch_panic_to_error::<u32, _>(
async {
std::panic::panic_any(42i32);
},
"labelled",
)
.await;
match result {
Err(CanoError::TaskExecution(m)) => {
assert_eq!(m, "panic: <non-string panic payload>")
}
other => panic!("expected fallback TaskExecution, got: {other:?}"),
}
}
}