use std::borrow::Cow;
use std::collections::HashMap;
use std::hash::Hash;
use std::sync::{Arc, OnceLock};
use std::time::Duration;
use crate::error::CanoError;
use crate::observer::WorkflowObserver;
use crate::recovery::CheckpointStore;
use crate::resource::Resources;
use crate::saga::{CompensatableTask, ErasedCompensatable};
use crate::task::stepped::{ErasedSteppedTask, SteppedAdapter, SteppedTask};
use crate::task::{RouterTask, Task};
#[cfg(feature = "tracing")]
use tracing::{Span, info_span};
mod compensation;
mod execution;
mod join;
#[cfg(test)]
mod test_support;
pub use execution::StateEntry;
pub use join::{JoinConfig, JoinStrategy, SplitResult, SplitTaskResult};
#[inline]
fn notify_observers(observers: &[Arc<dyn WorkflowObserver>], f: impl Fn(&dyn WorkflowObserver)) {
for observer in observers {
f(observer.as_ref());
}
}
fn panic_payload_message(payload: &(dyn std::any::Any + Send)) -> String {
if let Some(s) = payload.downcast_ref::<&'static str>() {
(*s).to_string()
} else if let Some(s) = payload.downcast_ref::<String>() {
s.clone()
} else {
"<non-string panic payload>".to_string()
}
}
#[must_use]
pub struct Workflow<TState, TResourceKey = Cow<'static, str>>
where
TState: Clone + std::fmt::Debug + std::hash::Hash + Eq + Send + Sync + 'static,
TResourceKey: Hash + Eq + Send + Sync + 'static,
{
states: HashMap<TState, Arc<StateEntry<TState, TResourceKey>>>,
pub(crate) resources: Arc<Resources<TResourceKey>>,
workflow_timeout: Option<Duration>,
exit_states: Vec<TState>,
validated: OnceLock<Result<(), CanoError>>,
observers: Vec<Arc<dyn WorkflowObserver>>,
compensators: HashMap<String, Arc<dyn ErasedCompensatable<TState, TResourceKey>>>,
checkpoint_store: Option<Arc<dyn CheckpointStore>>,
workflow_id: Option<String>,
workflow_version: u32,
#[cfg(feature = "tracing")]
tracing_span: Option<Span>,
}
impl<TState, TResourceKey> Workflow<TState, TResourceKey>
where
TState: Clone + std::fmt::Debug + std::hash::Hash + Eq + Send + Sync + 'static,
TResourceKey: Hash + Eq + Send + Sync + 'static,
{
pub fn new(resources: Resources<TResourceKey>) -> Self {
Self {
states: HashMap::new(),
resources: Arc::new(resources),
workflow_timeout: None,
exit_states: Vec::new(),
validated: OnceLock::new(),
observers: Vec::new(),
compensators: HashMap::new(),
checkpoint_store: None,
workflow_id: None,
workflow_version: 0,
#[cfg(feature = "tracing")]
tracing_span: None,
}
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.workflow_timeout = Some(timeout);
self
}
pub fn register<T>(mut self, state: TState, task: T) -> Self
where
T: Task<TState, TResourceKey> + Send + Sync + 'static,
{
self.forget_compensator_for(&state);
let config = Arc::new(task.config());
self.states.insert(
state,
Arc::new(StateEntry::Single {
task: Arc::new(task),
config,
}),
);
self
}
pub fn register_router<T>(mut self, state: TState, task: T) -> Self
where
T: RouterTask<TState, TResourceKey> + Task<TState, TResourceKey> + 'static,
{
self.forget_compensator_for(&state);
let config = Arc::new(Task::config(&task));
self.states.insert(
state,
Arc::new(StateEntry::Router {
task: Arc::new(task),
config,
}),
);
self
}
pub fn register_split<T>(
mut self,
state: TState,
tasks: Vec<T>,
join_config: JoinConfig<TState>,
) -> Self
where
T: Task<TState, TResourceKey> + Send + Sync + 'static,
{
self.forget_compensator_for(&state);
let configs: Vec<Arc<crate::task::TaskConfig>> =
tasks.iter().map(|t| Arc::new(t.config())).collect();
let arc_tasks: Vec<Arc<dyn Task<TState, TResourceKey> + Send + Sync>> =
tasks.into_iter().map(|t| Arc::new(t) as Arc<_>).collect();
self.states.insert(
state,
Arc::new(StateEntry::Split {
tasks: arc_tasks,
configs: Arc::new(configs),
join_config: Arc::new(join_config),
}),
);
self
}
fn forget_compensator_for(&mut self, state: &TState) {
let stale_name = if let Some(StateEntry::CompensatableSingle { task, .. }) =
self.states.get(state).map(|e| e.as_ref())
{
Some(task.name().into_owned())
} else {
None
};
if let Some(name) = stale_name {
self.compensators.remove(&name);
}
}
pub fn register_with_compensation<T>(mut self, state: TState, task: T) -> Self
where
T: CompensatableTask<TState, TResourceKey> + 'static,
{
self.forget_compensator_for(&state);
let config = Arc::new(task.config());
let name = task.name().into_owned();
let erased: Arc<dyn ErasedCompensatable<TState, TResourceKey>> =
Arc::new(crate::saga::CompensatableAdapter(Arc::new(task)));
self.compensators.insert(name, Arc::clone(&erased));
self.states.insert(
state,
Arc::new(StateEntry::CompensatableSingle {
task: erased,
config,
}),
);
self
}
pub fn register_stepped<T>(mut self, state: TState, task: T) -> Self
where
T: SteppedTask<TState, TResourceKey> + 'static,
{
self.forget_compensator_for(&state);
let config = Arc::new(SteppedTask::config(&task));
let erased: Arc<dyn ErasedSteppedTask<TState, TResourceKey>> =
Arc::new(SteppedAdapter(Arc::new(task)));
self.states.insert(
state,
Arc::new(StateEntry::Stepped {
task: erased,
config,
}),
);
self
}
pub fn add_exit_state(mut self, state: TState) -> Self {
if !self.exit_states.contains(&state) {
self.exit_states.push(state);
}
self
}
pub fn add_exit_states(mut self, states: Vec<TState>) -> Self {
for state in states {
if !self.exit_states.contains(&state) {
self.exit_states.push(state);
}
}
self
}
pub fn with_observer(mut self, observer: Arc<dyn WorkflowObserver>) -> Self {
self.observers.push(observer);
self
}
pub fn with_checkpoint_store(mut self, checkpoint_store: Arc<dyn CheckpointStore>) -> Self {
self.checkpoint_store = Some(checkpoint_store);
self
}
pub fn with_workflow_id(mut self, workflow_id: impl Into<String>) -> Self {
self.workflow_id = Some(workflow_id.into());
self
}
pub fn with_workflow_version(mut self, version: u32) -> Self {
self.workflow_version = version;
self
}
#[cfg(feature = "tracing")]
pub fn with_tracing_span(mut self, span: Span) -> Self {
self.tracing_span = Some(span);
self
}
fn observer_slice(&self) -> Option<Arc<[Arc<dyn WorkflowObserver>]>> {
if self.observers.is_empty() {
None
} else {
Some(Arc::from(self.observers.as_slice()))
}
}
fn config_with_observers(
base: &Arc<crate::task::TaskConfig>,
observers: &Option<Arc<[Arc<dyn WorkflowObserver>]>>,
task_name: &str,
) -> Arc<crate::task::TaskConfig> {
match observers {
None => Arc::clone(base),
Some(slice) => {
let mut cfg = (**base).clone();
cfg.observers = Some(Arc::clone(slice));
cfg.task_name = Some(Cow::Owned(task_name.to_owned()));
Arc::new(cfg)
}
}
}
fn validate_join_config(
join_config: &JoinConfig<TState>,
_total_tasks: usize,
) -> Result<(), CanoError> {
if matches!(join_config.strategy, JoinStrategy::PartialTimeout)
&& join_config.timeout.is_none()
{
return Err(CanoError::configuration(
"PartialTimeout strategy requires a timeout to be configured",
));
}
if let JoinStrategy::Percentage(p) = join_config.strategy
&& (!p.is_finite() || p <= 0.0 || p > 1.0)
{
return Err(CanoError::configuration(format!(
"Percentage strategy requires a finite value in (0.0, 1.0], got {p}"
)));
}
if let Some(0) = join_config.bulkhead {
return Err(CanoError::configuration(
"bulkhead requires a positive permit count, got 0",
));
}
Ok(())
}
pub fn validate(&self) -> Result<(), CanoError> {
if self.states.is_empty() {
return Err(CanoError::configuration(
"Workflow has no registered state handlers",
));
}
if self.exit_states.is_empty() {
return Err(CanoError::configuration(
"Workflow has no exit states defined — orchestration may loop forever",
));
}
for entry in self.states.values() {
if let StateEntry::Split {
tasks, join_config, ..
} = entry.as_ref()
{
Self::validate_join_config(join_config, tasks.len())?;
let js = &join_config.join_state;
if !self.states.contains_key(js) && !self.exit_states.contains(js) {
return Err(CanoError::configuration(format!(
"Split join_state {:?} is neither registered nor an exit state",
js
)));
}
}
}
Ok(())
}
pub fn validate_initial_state(&self, state: &TState) -> Result<(), CanoError> {
if !self.states.contains_key(state) && !self.exit_states.contains(state) {
return Err(CanoError::configuration(format!(
"Initial state {:?} is neither registered nor an exit state",
state
)));
}
Ok(())
}
pub async fn orchestrate(&self, initial_state: TState) -> Result<TState, CanoError> {
#[cfg(feature = "tracing")]
let workflow_span = self.tracing_span.clone().unwrap_or_else(|| {
if tracing::enabled!(tracing::Level::INFO) {
info_span!(
"workflow_orchestrate",
workflow_id = self.workflow_id.as_deref()
)
} else {
tracing::Span::none()
}
});
#[cfg(feature = "tracing")]
let _enter = workflow_span.enter();
let cached_validation = self.validated.get_or_init(|| self.validate());
if let Err(e) = cached_validation {
return Err(e.clone());
}
self.validate_initial_state(&initial_state)?;
self.resources.setup_all().await?;
let result = self.run_workflow(initial_state).await;
self.resources
.teardown_range(0..self.resources.lifecycle_len())
.await;
result
}
async fn run_workflow(&self, initial_state: TState) -> Result<TState, CanoError> {
#[cfg(feature = "metrics")]
let _active = crate::metrics::WorkflowActiveGuard::new();
#[cfg(feature = "metrics")]
let _started = std::time::Instant::now();
let workflow_future = self.execute_workflow(initial_state);
let result = if let Some(timeout_duration) = self.workflow_timeout {
match tokio::time::timeout(timeout_duration, workflow_future).await {
Ok(inner) => inner,
Err(_) => {
#[cfg(feature = "metrics")]
crate::metrics::workflow_run("timeout", _started.elapsed());
return Err(CanoError::workflow("Workflow timeout exceeded"));
}
}
} else {
workflow_future.await
};
#[cfg(feature = "metrics")]
crate::metrics::workflow_run(
if result.is_ok() {
"completed"
} else {
"failed"
},
_started.elapsed(),
);
result
}
}
impl<TState, TResourceKey> Clone for Workflow<TState, TResourceKey>
where
TState: Clone + std::fmt::Debug + std::hash::Hash + Eq + Send + Sync + 'static,
TResourceKey: Hash + Eq + Send + Sync + 'static,
{
fn clone(&self) -> Self {
Self {
states: self.states.clone(),
resources: Arc::clone(&self.resources),
workflow_timeout: self.workflow_timeout,
exit_states: self.exit_states.clone(),
validated: OnceLock::new(),
observers: self.observers.clone(),
compensators: self.compensators.clone(),
checkpoint_store: self.checkpoint_store.clone(),
workflow_id: self.workflow_id.clone(),
workflow_version: self.workflow_version,
#[cfg(feature = "tracing")]
tracing_span: self.tracing_span.clone(),
}
}
}
impl<TState> Workflow<TState, Cow<'static, str>>
where
TState: Clone + std::fmt::Debug + std::hash::Hash + Eq + Send + Sync + 'static,
{
pub fn bare() -> Self {
Self::new(Resources::new())
}
}
impl<TState, TResourceKey> std::fmt::Debug for Workflow<TState, TResourceKey>
where
TState: Clone + std::fmt::Debug + std::hash::Hash + Eq + Send + Sync + 'static,
TResourceKey: Hash + Eq + Send + Sync + 'static,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Workflow")
.field("states", &format!("{} states", self.states.len()))
.field("exit_states", &self.exit_states)
.field("workflow_timeout", &self.workflow_timeout)
.field("workflow_id", &self.workflow_id)
.field("workflow_version", &self.workflow_version)
.field("checkpoint_store", &self.checkpoint_store.is_some())
.field(
"compensators",
&format!("{} compensators", self.compensators.len()),
)
.finish()
}
}
#[cfg(all(test, feature = "metrics"))]
mod metrics_tests {
use crate::metrics::test_support::*;
use crate::prelude::*;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum S {
Start,
Mid,
Done,
}
struct GoTo(S);
#[crate::task]
impl Task<S> for GoTo {
async fn run_bare(&self) -> Result<TaskResult<S>, CanoError> {
Ok(TaskResult::Single(self.0.clone()))
}
}
struct Boom;
#[crate::task]
impl Task<S> for Boom {
fn config(&self) -> TaskConfig {
TaskConfig::minimal()
}
async fn run_bare(&self) -> Result<TaskResult<S>, CanoError> {
Err(CanoError::task_execution("boom"))
}
}
fn ok_workflow() -> Workflow<S> {
Workflow::bare()
.register(S::Start, GoTo(S::Mid))
.register(S::Mid, GoTo(S::Done))
.add_exit_state(S::Done)
}
#[test]
fn successful_run_records_outcome_duration_and_clears_active_gauge() {
let (res, rows) = run_with_recorder(|| async { ok_workflow().orchestrate(S::Start).await });
assert_eq!(res.unwrap(), S::Done);
assert_eq!(
counter(
&rows,
"cano_workflow_runs_total",
&[("outcome", "completed")]
),
1
);
assert_eq!(
histogram_count(
&rows,
"cano_workflow_duration_seconds",
&[("outcome", "completed")]
),
1
);
assert_eq!(gauge(&rows, "cano_workflow_active", &[]), 0.0);
}
#[test]
fn failed_run_records_failed_outcome() {
let (res, rows) = run_with_recorder(|| async {
Workflow::bare()
.register(S::Start, Boom)
.add_exit_state(S::Done)
.orchestrate(S::Start)
.await
});
assert!(res.is_err());
assert_eq!(
counter(&rows, "cano_workflow_runs_total", &[("outcome", "failed")]),
1
);
assert_eq!(gauge(&rows, "cano_workflow_active", &[]), 0.0);
}
#[test]
fn per_state_task_durations_are_recorded_single_and_split() {
let (res, rows) = run_with_recorder(|| async { ok_workflow().orchestrate(S::Start).await });
assert_eq!(res.unwrap(), S::Done);
assert_eq!(
histogram_count(
&rows,
"cano_task_duration_seconds",
&[("state", "Start"), ("kind", "single")]
),
1
);
assert_eq!(
histogram_count(
&rows,
"cano_task_duration_seconds",
&[("state", "Mid"), ("kind", "single")]
),
1
);
}
#[derive(Clone)]
struct Branch {
fail: bool,
}
#[crate::task]
impl Task<S> for Branch {
fn config(&self) -> TaskConfig {
TaskConfig::minimal()
}
async fn run_bare(&self) -> Result<TaskResult<S>, CanoError> {
if self.fail {
Err(CanoError::task_execution("nope"))
} else {
Ok(TaskResult::Single(S::Done))
}
}
}
#[test]
fn split_records_branch_results_and_a_split_kind_duration() {
let (res, rows) = run_with_recorder(|| async {
Workflow::bare()
.register_split(
S::Start,
vec![
Branch { fail: false },
Branch { fail: true },
Branch { fail: false },
],
JoinConfig::new(JoinStrategy::PartialResults(2), S::Done),
)
.add_exit_state(S::Done)
.orchestrate(S::Start)
.await
});
assert_eq!(res.unwrap(), S::Done);
assert_eq!(
counter(
&rows,
"cano_split_branch_results_total",
&[("result", "success")]
),
2
);
assert_eq!(
counter(
&rows,
"cano_split_branch_results_total",
&[("result", "failure")]
),
1
);
assert_eq!(
counter_opt(
&rows,
"cano_split_branch_results_total",
&[("result", "cancelled")]
),
None
);
assert_eq!(
histogram_count(
&rows,
"cano_task_duration_seconds",
&[("state", "Start"), ("kind", "split")]
),
1
);
}
}
#[cfg(test)]
mod tests {
use super::test_support::*;
use super::*;
use crate::resource::Resources;
use crate::task;
use crate::task::{Task, TaskResult};
use cano_macros::task as task_macro;
use tokio;
#[tokio::test]
async fn test_workflow_creation() {
let workflow = Workflow::<TestState>::bare();
assert_eq!(workflow.states.len(), 0);
assert_eq!(workflow.exit_states.len(), 0);
}
#[tokio::test]
async fn test_simple_workflow() {
let workflow = Workflow::bare()
.register(TestState::Start, SimpleTask::new(TestState::Complete))
.add_exit_state(TestState::Complete);
let result = workflow.orchestrate(TestState::Start).await.unwrap();
assert_eq!(result, TestState::Complete);
}
#[tokio::test]
async fn test_multi_step_workflow() {
let workflow = Workflow::bare()
.register(TestState::Start, SimpleTask::new(TestState::Process))
.register(TestState::Process, SimpleTask::new(TestState::Complete))
.add_exit_state(TestState::Complete);
let result = workflow.orchestrate(TestState::Start).await.unwrap();
assert_eq!(result, TestState::Complete);
}
#[tokio::test]
async fn test_workflow_with_data() {
let store = crate::store::MemoryStore::new();
let resources = Resources::new().insert("store", store.clone());
let workflow = Workflow::new(resources)
.register(
TestState::Start,
DataTask::new("test_key", "test_value", TestState::Complete),
)
.add_exit_state(TestState::Complete);
let result = workflow.orchestrate(TestState::Start).await.unwrap();
assert_eq!(result, TestState::Complete);
let data: String = store.get("test_key").unwrap();
assert_eq!(data, "test_value");
}
#[tokio::test]
async fn test_unregistered_state_error() {
let workflow = Workflow::<TestState>::bare().add_exit_state(TestState::Complete);
let result = workflow.orchestrate(TestState::Start).await;
let err = result.unwrap_err();
assert_eq!(err.category(), "configuration");
assert!(err.to_string().contains("no registered state handlers"));
}
#[test]
fn test_validate_empty_workflow() {
let workflow = Workflow::<TestState>::bare();
let result = workflow.validate();
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("no registered state handlers")
);
}
#[test]
fn test_validate_no_exit_states() {
let workflow =
Workflow::bare().register(TestState::Start, SimpleTask::new(TestState::Complete));
let result = workflow.validate();
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("no exit states defined")
);
}
#[test]
fn test_validate_valid_workflow() {
let workflow = Workflow::bare()
.register(TestState::Start, SimpleTask::new(TestState::Complete))
.add_exit_state(TestState::Complete);
assert!(workflow.validate().is_ok());
}
#[test]
fn test_validate_split_join_state_unregistered() {
let workflow = Workflow::bare()
.register_split(
TestState::Start,
vec![SimpleTask::new(TestState::Join)],
JoinConfig::new(JoinStrategy::All, TestState::Process), )
.add_exit_state(TestState::Complete);
let result = workflow.validate();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("join_state"));
}
#[test]
fn test_validate_split_join_state_as_exit_state() {
let workflow = Workflow::bare()
.register_split(
TestState::Start,
vec![SimpleTask::new(TestState::Complete)],
JoinConfig::new(JoinStrategy::All, TestState::Complete),
)
.add_exit_state(TestState::Complete);
assert!(workflow.validate().is_ok());
}
#[test]
fn test_validate_rejects_partial_timeout_without_timeout() {
let workflow = Workflow::bare()
.register_split(
TestState::Start,
vec![SimpleTask::new(TestState::Complete)],
JoinConfig::new(JoinStrategy::PartialTimeout, TestState::Complete),
)
.add_exit_state(TestState::Complete);
let err = workflow
.validate()
.expect_err("PartialTimeout without timeout must fail validation");
assert!(matches!(err, CanoError::Configuration(_)), "got {err:?}");
assert!(err.to_string().contains("requires a timeout"));
}
#[test]
fn test_validate_rejects_invalid_percentage() {
for value in [0.0, 1.5, f64::NAN] {
let workflow = Workflow::bare()
.register_split(
TestState::Start,
vec![SimpleTask::new(TestState::Complete)],
JoinConfig::new(JoinStrategy::Percentage(value), TestState::Complete),
)
.add_exit_state(TestState::Complete);
let err = workflow
.validate()
.expect_err("invalid Percentage strategy must fail validation");
assert!(matches!(err, CanoError::Configuration(_)), "got {err:?}");
assert!(err.to_string().contains("Percentage strategy"));
}
}
#[test]
fn test_validate_rejects_zero_bulkhead() {
let workflow = Workflow::bare()
.register_split(
TestState::Start,
vec![SimpleTask::new(TestState::Complete)],
JoinConfig::new(JoinStrategy::All, TestState::Complete).with_bulkhead(0),
)
.add_exit_state(TestState::Complete);
let err = workflow
.validate()
.expect_err("bulkhead=0 must fail validation");
assert!(matches!(err, CanoError::Configuration(_)), "got {err:?}");
assert!(err.to_string().contains("bulkhead"));
}
#[test]
fn test_validate_initial_state() {
let workflow = Workflow::bare()
.register(TestState::Start, SimpleTask::new(TestState::Complete))
.add_exit_state(TestState::Complete);
assert!(workflow.validate_initial_state(&TestState::Start).is_ok());
assert!(
workflow
.validate_initial_state(&TestState::Complete)
.is_ok()
);
let result = workflow.validate_initial_state(&TestState::Process);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("neither registered nor an exit state")
);
}
struct BareWorkflowTask;
#[task_macro]
impl Task<TestState> for BareWorkflowTask {
async fn run_bare(&self) -> Result<TaskResult<TestState>, CanoError> {
Ok(TaskResult::Single(TestState::Complete))
}
}
#[tokio::test]
async fn test_workflow_bare_runs_task_with_run_bare() {
let result = Workflow::bare()
.register(TestState::Start, BareWorkflowTask)
.add_exit_state(TestState::Complete)
.orchestrate(TestState::Start)
.await
.unwrap();
assert_eq!(result, TestState::Complete);
}
#[tokio::test]
async fn test_register_router_orchestration_round_trip() {
use crate::task::{RouterTask, TaskConfig};
struct RouteToProcess;
#[task::router]
impl RouterTask<TestState> for RouteToProcess {
fn config(&self) -> TaskConfig {
TaskConfig::minimal()
}
async fn route(&self, _res: &Resources) -> Result<TaskResult<TestState>, CanoError> {
Ok(TaskResult::Single(TestState::Process))
}
}
let result = Workflow::bare()
.register_router(TestState::Start, RouteToProcess)
.register(TestState::Process, SimpleTask::new(TestState::Complete))
.add_exit_state(TestState::Complete)
.orchestrate(TestState::Start)
.await
.unwrap();
assert_eq!(result, TestState::Complete);
}
#[test]
fn test_validate_passes_with_router_state() {
use crate::task::{RouterTask, TaskConfig};
struct RouteToComplete;
#[task::router]
impl RouterTask<TestState> for RouteToComplete {
fn config(&self) -> TaskConfig {
TaskConfig::minimal()
}
async fn route(&self, _res: &Resources) -> Result<TaskResult<TestState>, CanoError> {
Ok(TaskResult::Single(TestState::Complete))
}
}
let workflow = Workflow::bare()
.register_router(TestState::Start, RouteToComplete)
.add_exit_state(TestState::Complete);
assert!(workflow.validate().is_ok());
}
}