use std::borrow::Cow;
use std::collections::HashMap;
use std::hash::Hash;
use std::sync::{Arc, OnceLock};
use std::time::Duration;
use crate::error::CanoError;
use crate::observer::WorkflowObserver;
use crate::recovery::CheckpointStore;
use crate::resource::Resources;
use crate::saga::{CompensatableTask, ErasedCompensatable};
use crate::task::stepped::{ErasedSteppedTask, SteppedAdapter, SteppedTask};
use crate::task::{RouterTask, Task};
#[cfg(feature = "tracing")]
use tracing::{Span, info_span};
mod compensation;
mod execution;
mod join;
#[cfg(test)]
mod test_support;
pub use execution::StateEntry;
pub use join::{JoinConfig, JoinStrategy, SplitResult, SplitTaskResult};
#[inline]
fn notify_observers(observers: &[Arc<dyn WorkflowObserver>], f: impl Fn(&dyn WorkflowObserver)) {
for observer in observers {
f(observer.as_ref());
}
}
fn panic_payload_message(payload: &(dyn std::any::Any + Send)) -> String {
if let Some(s) = payload.downcast_ref::<&'static str>() {
(*s).to_string()
} else if let Some(s) = payload.downcast_ref::<String>() {
s.clone()
} else {
"<non-string panic payload>".to_string()
}
}
#[must_use]
pub struct Workflow<TState, TResourceKey = Cow<'static, str>>
where
TState: Clone + std::fmt::Debug + std::hash::Hash + Eq + Send + Sync + 'static,
TResourceKey: Hash + Eq + Send + Sync + 'static,
{
states: HashMap<TState, Arc<StateEntry<TState, TResourceKey>>>,
pub(crate) resources: Arc<Resources<TResourceKey>>,
workflow_timeout: Option<Duration>,
exit_states: Vec<TState>,
validated: OnceLock<Result<(), CanoError>>,
observers: Vec<Arc<dyn WorkflowObserver>>,
compensators: HashMap<String, Arc<dyn ErasedCompensatable<TState, TResourceKey>>>,
checkpoint_store: Option<Arc<dyn CheckpointStore>>,
workflow_id: Option<String>,
#[cfg(feature = "tracing")]
tracing_span: Option<Span>,
}
impl<TState, TResourceKey> Workflow<TState, TResourceKey>
where
TState: Clone + std::fmt::Debug + std::hash::Hash + Eq + Send + Sync + 'static,
TResourceKey: Hash + Eq + Send + Sync + 'static,
{
pub fn new(resources: Resources<TResourceKey>) -> Self {
Self {
states: HashMap::new(),
resources: Arc::new(resources),
workflow_timeout: None,
exit_states: Vec::new(),
validated: OnceLock::new(),
observers: Vec::new(),
compensators: HashMap::new(),
checkpoint_store: None,
workflow_id: None,
#[cfg(feature = "tracing")]
tracing_span: None,
}
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.workflow_timeout = Some(timeout);
self
}
pub fn register<T>(mut self, state: TState, task: T) -> Self
where
T: Task<TState, TResourceKey> + Send + Sync + 'static,
{
self.forget_compensator_for(&state);
let config = Arc::new(task.config());
self.states.insert(
state,
Arc::new(StateEntry::Single {
task: Arc::new(task),
config,
}),
);
self
}
pub fn register_router<T>(mut self, state: TState, task: T) -> Self
where
T: RouterTask<TState, TResourceKey> + Task<TState, TResourceKey> + 'static,
{
self.forget_compensator_for(&state);
let config = Arc::new(Task::config(&task));
self.states.insert(
state,
Arc::new(StateEntry::Router {
task: Arc::new(task),
config,
}),
);
self
}
pub fn register_split<T>(
mut self,
state: TState,
tasks: Vec<T>,
join_config: JoinConfig<TState>,
) -> Self
where
T: Task<TState, TResourceKey> + Send + Sync + 'static,
{
self.forget_compensator_for(&state);
let configs: Vec<Arc<crate::task::TaskConfig>> =
tasks.iter().map(|t| Arc::new(t.config())).collect();
let arc_tasks: Vec<Arc<dyn Task<TState, TResourceKey> + Send + Sync>> =
tasks.into_iter().map(|t| Arc::new(t) as Arc<_>).collect();
self.states.insert(
state,
Arc::new(StateEntry::Split {
tasks: arc_tasks,
configs: Arc::new(configs),
join_config: Arc::new(join_config),
}),
);
self
}
fn forget_compensator_for(&mut self, state: &TState) {
let stale_name = if let Some(StateEntry::CompensatableSingle { task, .. }) =
self.states.get(state).map(|e| e.as_ref())
{
Some(task.name().into_owned())
} else {
None
};
if let Some(name) = stale_name {
self.compensators.remove(&name);
}
}
pub fn register_with_compensation<T>(mut self, state: TState, task: T) -> Self
where
T: CompensatableTask<TState, TResourceKey> + 'static,
{
self.forget_compensator_for(&state);
let config = Arc::new(task.config());
let name = task.name().into_owned();
let erased: Arc<dyn ErasedCompensatable<TState, TResourceKey>> =
Arc::new(crate::saga::CompensatableAdapter(Arc::new(task)));
self.compensators.insert(name, Arc::clone(&erased));
self.states.insert(
state,
Arc::new(StateEntry::CompensatableSingle {
task: erased,
config,
}),
);
self
}
pub fn register_stepped<T>(mut self, state: TState, task: T) -> Self
where
T: SteppedTask<TState, TResourceKey> + 'static,
{
self.forget_compensator_for(&state);
let config = Arc::new(SteppedTask::config(&task));
let erased: Arc<dyn ErasedSteppedTask<TState, TResourceKey>> =
Arc::new(SteppedAdapter(Arc::new(task)));
self.states.insert(
state,
Arc::new(StateEntry::Stepped {
task: erased,
config,
}),
);
self
}
pub fn add_exit_state(mut self, state: TState) -> Self {
if !self.exit_states.contains(&state) {
self.exit_states.push(state);
}
self
}
pub fn add_exit_states(mut self, states: Vec<TState>) -> Self {
for state in states {
if !self.exit_states.contains(&state) {
self.exit_states.push(state);
}
}
self
}
pub fn with_observer(mut self, observer: Arc<dyn WorkflowObserver>) -> Self {
self.observers.push(observer);
self
}
pub fn with_checkpoint_store(mut self, store: Arc<dyn CheckpointStore>) -> Self {
self.checkpoint_store = Some(store);
self
}
pub fn with_workflow_id(mut self, workflow_id: impl Into<String>) -> Self {
self.workflow_id = Some(workflow_id.into());
self
}
#[cfg(feature = "tracing")]
pub fn with_tracing_span(mut self, span: Span) -> Self {
self.tracing_span = Some(span);
self
}
fn observer_slice(&self) -> Option<Arc<[Arc<dyn WorkflowObserver>]>> {
if self.observers.is_empty() {
None
} else {
Some(Arc::from(self.observers.as_slice()))
}
}
fn config_with_observers(
base: &Arc<crate::task::TaskConfig>,
observers: &Option<Arc<[Arc<dyn WorkflowObserver>]>>,
task_name: &str,
) -> Arc<crate::task::TaskConfig> {
match observers {
None => Arc::clone(base),
Some(slice) => {
let mut cfg = (**base).clone();
cfg.observers = Some(Arc::clone(slice));
cfg.task_name = Some(Cow::Owned(task_name.to_owned()));
Arc::new(cfg)
}
}
}
fn validate_join_config(
join_config: &JoinConfig<TState>,
_total_tasks: usize,
) -> Result<(), CanoError> {
if matches!(join_config.strategy, JoinStrategy::PartialTimeout)
&& join_config.timeout.is_none()
{
return Err(CanoError::configuration(
"PartialTimeout strategy requires a timeout to be configured",
));
}
if let JoinStrategy::Percentage(p) = join_config.strategy
&& (!p.is_finite() || p <= 0.0 || p > 1.0)
{
return Err(CanoError::configuration(format!(
"Percentage strategy requires a finite value in (0.0, 1.0], got {p}"
)));
}
if let Some(0) = join_config.bulkhead {
return Err(CanoError::configuration(
"bulkhead requires a positive permit count, got 0",
));
}
Ok(())
}
pub fn validate(&self) -> Result<(), CanoError> {
if self.states.is_empty() {
return Err(CanoError::configuration(
"Workflow has no registered state handlers",
));
}
if self.exit_states.is_empty() {
return Err(CanoError::configuration(
"Workflow has no exit states defined — orchestration may loop forever",
));
}
for entry in self.states.values() {
if let StateEntry::Split {
tasks, join_config, ..
} = entry.as_ref()
{
Self::validate_join_config(join_config, tasks.len())?;
let js = &join_config.join_state;
if !self.states.contains_key(js) && !self.exit_states.contains(js) {
return Err(CanoError::configuration(format!(
"Split join_state {:?} is neither registered nor an exit state",
js
)));
}
}
}
Ok(())
}
pub fn validate_initial_state(&self, state: &TState) -> Result<(), CanoError> {
if !self.states.contains_key(state) && !self.exit_states.contains(state) {
return Err(CanoError::configuration(format!(
"Initial state {:?} is neither registered nor an exit state",
state
)));
}
Ok(())
}
pub async fn orchestrate(&self, initial_state: TState) -> Result<TState, CanoError> {
#[cfg(feature = "tracing")]
let workflow_span = self.tracing_span.clone().unwrap_or_else(|| {
if tracing::enabled!(tracing::Level::INFO) {
info_span!("workflow_orchestrate")
} else {
tracing::Span::none()
}
});
#[cfg(feature = "tracing")]
let _enter = workflow_span.enter();
let cached_validation = self.validated.get_or_init(|| self.validate());
if let Err(e) = cached_validation {
return Err(e.clone());
}
self.validate_initial_state(&initial_state)?;
self.resources.setup_all().await?;
let result = self.run_workflow(initial_state).await;
self.resources
.teardown_range(0..self.resources.lifecycle_len())
.await;
result
}
async fn run_workflow(&self, initial_state: TState) -> Result<TState, CanoError> {
let workflow_future = self.execute_workflow(initial_state);
if let Some(timeout_duration) = self.workflow_timeout {
match tokio::time::timeout(timeout_duration, workflow_future).await {
Ok(result) => result,
Err(_) => Err(CanoError::workflow("Workflow timeout exceeded")),
}
} else {
workflow_future.await
}
}
}
impl<TState, TResourceKey> Clone for Workflow<TState, TResourceKey>
where
TState: Clone + std::fmt::Debug + std::hash::Hash + Eq + Send + Sync + 'static,
TResourceKey: Hash + Eq + Send + Sync + 'static,
{
fn clone(&self) -> Self {
Self {
states: self.states.clone(),
resources: Arc::clone(&self.resources),
workflow_timeout: self.workflow_timeout,
exit_states: self.exit_states.clone(),
validated: OnceLock::new(),
observers: self.observers.clone(),
compensators: self.compensators.clone(),
checkpoint_store: self.checkpoint_store.clone(),
workflow_id: self.workflow_id.clone(),
#[cfg(feature = "tracing")]
tracing_span: self.tracing_span.clone(),
}
}
}
impl<TState> Workflow<TState, Cow<'static, str>>
where
TState: Clone + std::fmt::Debug + std::hash::Hash + Eq + Send + Sync + 'static,
{
pub fn bare() -> Self {
Self::new(Resources::new())
}
}
impl<TState, TResourceKey> std::fmt::Debug for Workflow<TState, TResourceKey>
where
TState: Clone + std::fmt::Debug + std::hash::Hash + Eq + Send + Sync + 'static,
TResourceKey: Hash + Eq + Send + Sync + 'static,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Workflow")
.field("states", &format!("{} states", self.states.len()))
.field("exit_states", &self.exit_states)
.field("workflow_timeout", &self.workflow_timeout)
.field("workflow_id", &self.workflow_id)
.field("checkpoint_store", &self.checkpoint_store.is_some())
.field(
"compensators",
&format!("{} compensators", self.compensators.len()),
)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::test_support::*;
use super::*;
use crate::resource::Resources;
use crate::task;
use crate::task::{Task, TaskResult};
use cano_macros::task as task_macro;
use tokio;
#[tokio::test]
async fn test_workflow_creation() {
let workflow = Workflow::<TestState>::bare();
assert_eq!(workflow.states.len(), 0);
assert_eq!(workflow.exit_states.len(), 0);
}
#[tokio::test]
async fn test_simple_workflow() {
let workflow = Workflow::bare()
.register(TestState::Start, SimpleTask::new(TestState::Complete))
.add_exit_state(TestState::Complete);
let result = workflow.orchestrate(TestState::Start).await.unwrap();
assert_eq!(result, TestState::Complete);
}
#[tokio::test]
async fn test_multi_step_workflow() {
let workflow = Workflow::bare()
.register(TestState::Start, SimpleTask::new(TestState::Process))
.register(TestState::Process, SimpleTask::new(TestState::Complete))
.add_exit_state(TestState::Complete);
let result = workflow.orchestrate(TestState::Start).await.unwrap();
assert_eq!(result, TestState::Complete);
}
#[tokio::test]
async fn test_workflow_with_data() {
let store = crate::store::MemoryStore::new();
let resources = Resources::new().insert("store", store.clone());
let workflow = Workflow::new(resources)
.register(
TestState::Start,
DataTask::new("test_key", "test_value", TestState::Complete),
)
.add_exit_state(TestState::Complete);
let result = workflow.orchestrate(TestState::Start).await.unwrap();
assert_eq!(result, TestState::Complete);
let data: String = store.get("test_key").unwrap();
assert_eq!(data, "test_value");
}
#[tokio::test]
async fn test_unregistered_state_error() {
let workflow = Workflow::<TestState>::bare().add_exit_state(TestState::Complete);
let result = workflow.orchestrate(TestState::Start).await;
let err = result.unwrap_err();
assert_eq!(err.category(), "configuration");
assert!(err.to_string().contains("no registered state handlers"));
}
#[test]
fn test_validate_empty_workflow() {
let workflow = Workflow::<TestState>::bare();
let result = workflow.validate();
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("no registered state handlers")
);
}
#[test]
fn test_validate_no_exit_states() {
let workflow =
Workflow::bare().register(TestState::Start, SimpleTask::new(TestState::Complete));
let result = workflow.validate();
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("no exit states defined")
);
}
#[test]
fn test_validate_valid_workflow() {
let workflow = Workflow::bare()
.register(TestState::Start, SimpleTask::new(TestState::Complete))
.add_exit_state(TestState::Complete);
assert!(workflow.validate().is_ok());
}
#[test]
fn test_validate_split_join_state_unregistered() {
let workflow = Workflow::bare()
.register_split(
TestState::Start,
vec![SimpleTask::new(TestState::Join)],
JoinConfig::new(JoinStrategy::All, TestState::Process), )
.add_exit_state(TestState::Complete);
let result = workflow.validate();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("join_state"));
}
#[test]
fn test_validate_split_join_state_as_exit_state() {
let workflow = Workflow::bare()
.register_split(
TestState::Start,
vec![SimpleTask::new(TestState::Complete)],
JoinConfig::new(JoinStrategy::All, TestState::Complete),
)
.add_exit_state(TestState::Complete);
assert!(workflow.validate().is_ok());
}
#[test]
fn test_validate_rejects_partial_timeout_without_timeout() {
let workflow = Workflow::bare()
.register_split(
TestState::Start,
vec![SimpleTask::new(TestState::Complete)],
JoinConfig::new(JoinStrategy::PartialTimeout, TestState::Complete),
)
.add_exit_state(TestState::Complete);
let err = workflow
.validate()
.expect_err("PartialTimeout without timeout must fail validation");
assert!(matches!(err, CanoError::Configuration(_)), "got {err:?}");
assert!(err.to_string().contains("requires a timeout"));
}
#[test]
fn test_validate_rejects_invalid_percentage() {
for value in [0.0, 1.5, f64::NAN] {
let workflow = Workflow::bare()
.register_split(
TestState::Start,
vec![SimpleTask::new(TestState::Complete)],
JoinConfig::new(JoinStrategy::Percentage(value), TestState::Complete),
)
.add_exit_state(TestState::Complete);
let err = workflow
.validate()
.expect_err("invalid Percentage strategy must fail validation");
assert!(matches!(err, CanoError::Configuration(_)), "got {err:?}");
assert!(err.to_string().contains("Percentage strategy"));
}
}
#[test]
fn test_validate_rejects_zero_bulkhead() {
let workflow = Workflow::bare()
.register_split(
TestState::Start,
vec![SimpleTask::new(TestState::Complete)],
JoinConfig::new(JoinStrategy::All, TestState::Complete).with_bulkhead(0),
)
.add_exit_state(TestState::Complete);
let err = workflow
.validate()
.expect_err("bulkhead=0 must fail validation");
assert!(matches!(err, CanoError::Configuration(_)), "got {err:?}");
assert!(err.to_string().contains("bulkhead"));
}
#[test]
fn test_validate_initial_state() {
let workflow = Workflow::bare()
.register(TestState::Start, SimpleTask::new(TestState::Complete))
.add_exit_state(TestState::Complete);
assert!(workflow.validate_initial_state(&TestState::Start).is_ok());
assert!(
workflow
.validate_initial_state(&TestState::Complete)
.is_ok()
);
let result = workflow.validate_initial_state(&TestState::Process);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("neither registered nor an exit state")
);
}
struct BareWorkflowTask;
#[task_macro]
impl Task<TestState> for BareWorkflowTask {
async fn run_bare(&self) -> Result<TaskResult<TestState>, CanoError> {
Ok(TaskResult::Single(TestState::Complete))
}
}
#[tokio::test]
async fn test_workflow_bare_runs_task_with_run_bare() {
let result = Workflow::bare()
.register(TestState::Start, BareWorkflowTask)
.add_exit_state(TestState::Complete)
.orchestrate(TestState::Start)
.await
.unwrap();
assert_eq!(result, TestState::Complete);
}
#[tokio::test]
async fn test_register_router_orchestration_round_trip() {
use crate::task::{RouterTask, TaskConfig};
struct RouteToProcess;
#[task::router]
impl RouterTask<TestState> for RouteToProcess {
fn config(&self) -> TaskConfig {
TaskConfig::minimal()
}
async fn route(&self, _res: &Resources) -> Result<TaskResult<TestState>, CanoError> {
Ok(TaskResult::Single(TestState::Process))
}
}
let result = Workflow::bare()
.register_router(TestState::Start, RouteToProcess)
.register(TestState::Process, SimpleTask::new(TestState::Complete))
.add_exit_state(TestState::Complete)
.orchestrate(TestState::Start)
.await
.unwrap();
assert_eq!(result, TestState::Complete);
}
#[test]
fn test_validate_passes_with_router_state() {
use crate::task::{RouterTask, TaskConfig};
struct RouteToComplete;
#[task::router]
impl RouterTask<TestState> for RouteToComplete {
fn config(&self) -> TaskConfig {
TaskConfig::minimal()
}
async fn route(&self, _res: &Resources) -> Result<TaskResult<TestState>, CanoError> {
Ok(TaskResult::Single(TestState::Complete))
}
}
let workflow = Workflow::bare()
.register_router(TestState::Start, RouteToComplete)
.add_exit_state(TestState::Complete);
assert!(workflow.validate().is_ok());
}
}