use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use crate::error::CanoError;
use crate::store::{KeyValueStore, MemoryStore};
use crate::task::{DefaultTaskParams, Task, TaskResult};
#[cfg(feature = "tracing")]
use tracing::{Span, debug, info, info_span, warn};
use futures_util::stream::{FuturesUnordered, StreamExt};
struct AbortOnDrop<T>(tokio::task::JoinHandle<T>);
impl<T> Drop for AbortOnDrop<T> {
fn drop(&mut self) {
self.0.abort();
}
}
impl<T> std::future::Future for AbortOnDrop<T> {
type Output = Result<T, tokio::task::JoinError>;
fn poll(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
std::pin::Pin::new(&mut self.0).poll(cx)
}
}
#[derive(Clone, Debug, PartialEq)]
pub enum JoinStrategy {
All,
Any,
Quorum(usize),
Percentage(f64),
PartialResults(usize),
PartialTimeout,
}
impl JoinStrategy {
pub fn is_satisfied(&self, completed: usize, total: usize) -> bool {
match self {
JoinStrategy::All => completed >= total,
JoinStrategy::Any => completed >= 1,
JoinStrategy::Quorum(n) => completed >= *n,
JoinStrategy::Percentage(p) => {
let required_f = (total as f64 * p).ceil();
let required = if required_f >= usize::MAX as f64 {
usize::MAX
} else {
required_f as usize
};
completed >= required
}
JoinStrategy::PartialResults(min) => completed >= *min,
JoinStrategy::PartialTimeout => completed >= 1, }
}
}
#[derive(Clone, Debug)]
pub struct SplitTaskResult<TState> {
pub task_index: usize,
pub result: Result<TaskResult<TState>, CanoError>,
}
#[derive(Clone, Debug)]
pub struct SplitResult<TState> {
pub successes: Vec<SplitTaskResult<TState>>,
pub errors: Vec<SplitTaskResult<TState>>,
pub cancelled: Vec<usize>,
}
impl<TState> SplitResult<TState> {
pub fn new() -> Self {
Self {
successes: Vec::new(),
errors: Vec::new(),
cancelled: Vec::new(),
}
}
pub fn completed_count(&self) -> usize {
self.successes.len() + self.errors.len()
}
pub fn total_count(&self) -> usize {
self.successes.len() + self.errors.len() + self.cancelled.len()
}
}
impl<TState> Default for SplitResult<TState> {
fn default() -> Self {
Self::new()
}
}
#[must_use]
#[derive(Clone)]
pub struct JoinConfig<TState> {
pub strategy: JoinStrategy,
pub timeout: Option<Duration>,
pub join_state: TState,
pub store_partial_results: bool,
}
impl<TState> JoinConfig<TState>
where
TState: Clone,
{
pub fn new(strategy: JoinStrategy, join_state: TState) -> Self {
Self {
strategy,
timeout: None,
join_state,
store_partial_results: false,
}
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
pub fn with_store_partial_results(mut self, store: bool) -> Self {
self.store_partial_results = store;
self
}
}
pub enum StateEntry<TState, TStore = MemoryStore>
where
TState: Clone + Send + Sync + 'static,
TStore: Send + Sync + 'static,
{
Single {
task: Arc<dyn Task<TState, TStore, DefaultTaskParams> + Send + Sync>,
},
Split {
tasks: Vec<Arc<dyn Task<TState, TStore, DefaultTaskParams> + Send + Sync>>,
join_config: Arc<JoinConfig<TState>>,
},
}
impl<TState, TStore> Clone for StateEntry<TState, TStore>
where
TState: Clone + Send + Sync + 'static,
TStore: Send + Sync + 'static,
{
fn clone(&self) -> Self {
match self {
StateEntry::Single { task } => StateEntry::Single { task: task.clone() },
StateEntry::Split { tasks, join_config } => StateEntry::Split {
tasks: tasks.clone(),
join_config: join_config.clone(),
},
}
}
}
#[must_use]
pub struct Workflow<TState, TStore = MemoryStore>
where
TState: Clone + std::fmt::Debug + std::hash::Hash + Eq + Send + Sync + 'static,
TStore: KeyValueStore + 'static,
{
states: HashMap<TState, Arc<StateEntry<TState, TStore>>>,
store: Arc<TStore>,
workflow_timeout: Option<Duration>,
exit_states: std::collections::HashSet<TState>,
#[cfg(feature = "tracing")]
tracing_span: Option<Span>,
}
impl<TState, TStore> Workflow<TState, TStore>
where
TState: Clone + std::fmt::Debug + std::hash::Hash + Eq + Send + Sync + 'static,
TStore: KeyValueStore + 'static,
{
pub fn new(store: TStore) -> Self {
Self {
states: HashMap::new(),
store: Arc::new(store),
workflow_timeout: None,
exit_states: std::collections::HashSet::new(),
#[cfg(feature = "tracing")]
tracing_span: None,
}
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.workflow_timeout = Some(timeout);
self
}
pub fn register<T>(mut self, state: TState, task: T) -> Self
where
T: Task<TState, TStore, DefaultTaskParams> + Send + Sync + 'static,
{
self.states.insert(
state,
Arc::new(StateEntry::Single {
task: Arc::new(task),
}),
);
self
}
pub fn register_split<T>(
mut self,
state: TState,
tasks: Vec<T>,
join_config: JoinConfig<TState>,
) -> Self
where
T: Task<TState, TStore, DefaultTaskParams> + Send + Sync + 'static,
{
let arc_tasks: Vec<Arc<dyn Task<TState, TStore, DefaultTaskParams> + Send + Sync>> =
tasks.into_iter().map(|t| Arc::new(t) as Arc<_>).collect();
self.states.insert(
state,
Arc::new(StateEntry::Split {
tasks: arc_tasks,
join_config: Arc::new(join_config),
}),
);
self
}
pub fn add_exit_state(mut self, state: TState) -> Self {
self.exit_states.insert(state);
self
}
pub fn add_exit_states(mut self, states: Vec<TState>) -> Self {
self.exit_states.extend(states);
self
}
#[cfg(feature = "tracing")]
pub fn with_tracing_span(mut self, span: Span) -> Self {
self.tracing_span = Some(span);
self
}
pub fn validate(&self) -> Result<(), CanoError> {
if self.states.is_empty() {
return Err(CanoError::configuration(
"Workflow has no registered state handlers",
));
}
if self.exit_states.is_empty() {
return Err(CanoError::configuration(
"Workflow has no exit states defined — orchestration may loop forever",
));
}
for entry in self.states.values() {
if let StateEntry::Split { join_config, .. } = entry.as_ref() {
let js = &join_config.join_state;
if !self.states.contains_key(js) && !self.exit_states.contains(js) {
return Err(CanoError::configuration(format!(
"Split join_state {:?} is neither registered nor an exit state",
js
)));
}
}
}
Ok(())
}
pub fn validate_initial_state(&self, state: &TState) -> Result<(), CanoError> {
if !self.states.contains_key(state) && !self.exit_states.contains(state) {
return Err(CanoError::configuration(format!(
"Initial state {:?} is neither registered nor an exit state",
state
)));
}
Ok(())
}
pub async fn orchestrate(&self, initial_state: TState) -> Result<TState, CanoError> {
#[cfg(feature = "tracing")]
let workflow_span = self
.tracing_span
.clone()
.unwrap_or_else(|| info_span!("workflow_orchestrate"));
#[cfg(feature = "tracing")]
let _enter = workflow_span.enter();
let workflow_future = self.execute_workflow(initial_state);
if let Some(timeout_duration) = self.workflow_timeout {
match tokio::time::timeout(timeout_duration, workflow_future).await {
Ok(result) => result,
Err(_) => Err(CanoError::workflow("Workflow timeout exceeded")),
}
} else {
workflow_future.await
}
}
async fn execute_workflow(&self, initial_state: TState) -> Result<TState, CanoError> {
let mut current_state = initial_state;
#[cfg(feature = "tracing")]
info!(initial_state = ?current_state, "Starting workflow execution");
loop {
if self.exit_states.contains(¤t_state) {
#[cfg(feature = "tracing")]
info!(final_state = ?current_state, "Workflow completed successfully");
return Ok(current_state);
}
let state_entry = self.states.get(¤t_state).ok_or_else(|| {
CanoError::workflow(format!("No task registered for state: {:?}", current_state))
})?;
#[cfg(feature = "tracing")]
debug!(current_state = ?current_state, "Executing state");
current_state = match state_entry.as_ref() {
StateEntry::Single { task } => self.execute_single_task(task.clone()).await?,
StateEntry::Split { tasks, join_config } => {
self.execute_split_join(tasks.clone(), join_config.clone())
.await?
}
};
}
}
async fn execute_single_task(
&self,
task: Arc<dyn Task<TState, TStore, DefaultTaskParams> + Send + Sync>,
) -> Result<TState, CanoError> {
use crate::task::run_with_retries;
#[cfg(feature = "tracing")]
let task_span = info_span!("single_task_execution");
let config = task.config();
#[cfg(feature = "tracing")]
let result = {
let _enter = task_span.enter();
run_with_retries(&config, || {
let task_clone = task.clone();
let store_clone = self.store.clone();
async move { task_clone.run(&*store_clone).await }
})
.await
};
#[cfg(not(feature = "tracing"))]
let result = run_with_retries(&config, || {
let task_clone = task.clone();
let store_clone = self.store.clone();
async move { task_clone.run(&*store_clone).await }
})
.await;
match result? {
TaskResult::Single(next_state) => {
#[cfg(feature = "tracing")]
debug!(next_state = ?next_state, "Single task completed");
Ok(next_state)
}
TaskResult::Split(_) => Err(CanoError::workflow(
"Single task returned split result - use register_split() for split tasks",
)),
}
}
async fn execute_split_join(
&self,
tasks: Vec<Arc<dyn Task<TState, TStore, DefaultTaskParams> + Send + Sync>>,
join_config: Arc<JoinConfig<TState>>,
) -> Result<TState, CanoError> {
let store = self.store.clone();
let total_tasks = tasks.len();
#[cfg(feature = "tracing")]
info!(
total_tasks = total_tasks,
strategy = ?join_config.strategy,
"Starting split execution"
);
if matches!(join_config.strategy, JoinStrategy::PartialTimeout)
&& join_config.timeout.is_none()
{
return Err(CanoError::configuration(
"PartialTimeout strategy requires a timeout to be configured",
));
}
if let JoinStrategy::Percentage(p) = join_config.strategy
&& (p <= 0.0 || p > 1.0)
{
return Err(CanoError::configuration(format!(
"Percentage strategy requires a value in (0.0, 1.0], got {p}"
)));
}
let mut handles = Vec::new();
#[cfg_attr(not(feature = "tracing"), allow(unused_variables))]
for (idx, task) in tasks.into_iter().enumerate() {
use crate::task::run_with_retries;
let config = task.config();
let store_clone = store.clone();
#[cfg(feature = "tracing")]
let task_span = info_span!("split_task", task_id = idx);
let handle = tokio::spawn(async move {
#[cfg(feature = "tracing")]
let _enter = task_span.enter();
#[cfg(feature = "tracing")]
debug!(task_id = idx, "Executing split task");
let result = run_with_retries(&config, || {
let t = task.clone();
let s = store_clone.clone();
async move { t.run(&*s).await }
})
.await;
#[cfg(feature = "tracing")]
match &result {
Ok(_) => debug!(task_id = idx, "Split task completed successfully"),
Err(e) => warn!(task_id = idx, error = %e, "Split task failed"),
}
result
});
handles.push(handle);
}
let split_result = self
.collect_results(handles, &join_config, total_tasks)
.await?;
let successful = split_result.successes.len();
let failed = split_result.errors.len();
let cancelled = split_result.cancelled.len();
#[cfg(feature = "tracing")]
info!(
successful = successful,
failed = failed,
cancelled = cancelled,
total = total_tasks,
"Split execution completed"
);
if join_config.store_partial_results {
let summary = format!(
"Split results: {} succeeded, {} failed, {} cancelled",
successful, failed, cancelled
);
self.store.put("split_results_summary", summary)?;
self.store.put("split_successes_count", successful)?;
self.store.put("split_errors_count", failed)?;
self.store.put("split_cancelled_count", cancelled)?;
}
match &join_config.strategy {
JoinStrategy::PartialResults(_) => {
if join_config.strategy.is_satisfied(successful, total_tasks) {
Ok(join_config.join_state.clone())
} else {
Err(CanoError::workflow(format!(
"Partial results condition not met: {} completed successfully, {} required",
successful,
match &join_config.strategy {
JoinStrategy::PartialResults(min) => *min,
_ => 0,
}
)))
}
}
JoinStrategy::PartialTimeout => {
if split_result.completed_count() >= 1 {
Ok(join_config.join_state.clone())
} else {
Err(CanoError::workflow(
"PartialTimeout: No tasks completed before timeout",
))
}
}
_ => {
if join_config.strategy.is_satisfied(successful, total_tasks) {
Ok(join_config.join_state.clone())
} else {
Err(CanoError::workflow(format!(
"Join condition not met: {} of {} tasks completed successfully, strategy: {:?}",
successful, total_tasks, join_config.strategy
)))
}
}
}
}
async fn collect_results(
&self,
handles: Vec<tokio::task::JoinHandle<Result<TaskResult<TState>, CanoError>>>,
join_config: &JoinConfig<TState>,
total_tasks: usize,
) -> Result<SplitResult<TState>, CanoError> {
let mut split_result = SplitResult::new();
let mut completed_indices = std::collections::HashSet::new();
let mut futures = FuturesUnordered::new();
for (idx, handle) in handles.into_iter().enumerate() {
futures.push(async move {
let handle = AbortOnDrop(handle);
let result = handle.await;
(idx, result)
});
}
let deadline = join_config.timeout.map(|d| tokio::time::Instant::now() + d);
loop {
let next_result = if let Some(d) = deadline {
match tokio::time::timeout_at(d, futures.next()).await {
Ok(res) => res,
Err(_) => {
if matches!(join_config.strategy, JoinStrategy::PartialTimeout) {
break;
} else {
return Err(CanoError::workflow("Split task timeout exceeded"));
}
}
}
} else {
futures.next().await
};
match next_result {
Some((index, result)) => {
completed_indices.insert(index);
match result {
Ok(Ok(task_result)) => {
split_result.successes.push(SplitTaskResult {
task_index: index,
result: Ok(task_result),
});
}
Ok(Err(e)) => {
split_result.errors.push(SplitTaskResult {
task_index: index,
result: Err(e),
});
}
Err(e) => {
split_result.errors.push(SplitTaskResult {
task_index: index,
result: Err(CanoError::workflow(format!("Task panic: {:?}", e))),
});
}
}
match &join_config.strategy {
JoinStrategy::Any => {
if !split_result.successes.is_empty() {
break;
}
}
JoinStrategy::PartialResults(min) => {
if split_result.successes.len() >= *min {
break;
}
}
_ => {} }
}
None => break, }
}
for idx in 0..total_tasks {
if !completed_indices.contains(&idx) {
split_result.cancelled.push(idx);
}
}
Ok(split_result)
}
}
impl<TState, TStore> Clone for Workflow<TState, TStore>
where
TState: Clone + std::fmt::Debug + std::hash::Hash + Eq + Send + Sync + 'static,
TStore: KeyValueStore + 'static,
{
fn clone(&self) -> Self {
Self {
states: self.states.clone(),
store: self.store.clone(),
workflow_timeout: self.workflow_timeout,
exit_states: self.exit_states.clone(),
#[cfg(feature = "tracing")]
tracing_span: self.tracing_span.clone(),
}
}
}
impl<TState, TStore> std::fmt::Debug for Workflow<TState, TStore>
where
TState: Clone + std::fmt::Debug + std::hash::Hash + Eq + Send + Sync + 'static,
TStore: KeyValueStore + 'static,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Workflow")
.field("states", &format!("{} states", self.states.len()))
.field("exit_states", &self.exit_states)
.field("workflow_timeout", &self.workflow_timeout)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::store::KeyValueStore;
use crate::task::Task;
use async_trait::async_trait;
use std::sync::atomic::{AtomicU32, Ordering};
use tokio;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum TestState {
Start,
Process,
Split,
Join,
Complete,
#[allow(dead_code)]
Error,
}
#[derive(Clone)]
struct SimpleTask {
next_state: TestState,
counter: Arc<AtomicU32>,
}
impl SimpleTask {
fn new(next_state: TestState) -> Self {
Self {
next_state,
counter: Arc::new(AtomicU32::new(0)),
}
}
#[allow(dead_code)]
fn count(&self) -> u32 {
self.counter.load(Ordering::SeqCst)
}
}
#[async_trait]
impl Task<TestState> for SimpleTask {
async fn run(&self, _store: &MemoryStore) -> Result<TaskResult<TestState>, CanoError> {
self.counter.fetch_add(1, Ordering::SeqCst);
Ok(TaskResult::Single(self.next_state.clone()))
}
}
#[derive(Clone)]
struct DataTask {
key: String,
value: String,
next_state: TestState,
}
impl DataTask {
fn new(key: &str, value: &str, next_state: TestState) -> Self {
Self {
key: key.to_string(),
value: value.to_string(),
next_state,
}
}
}
#[async_trait]
impl Task<TestState> for DataTask {
async fn run(&self, store: &MemoryStore) -> Result<TaskResult<TestState>, CanoError> {
store.put(&self.key, self.value.clone())?;
Ok(TaskResult::Single(self.next_state.clone()))
}
}
#[derive(Clone)]
struct FailTask {
should_fail: bool,
}
impl FailTask {
fn new(should_fail: bool) -> Self {
Self { should_fail }
}
}
#[async_trait]
impl Task<TestState> for FailTask {
async fn run(&self, _store: &MemoryStore) -> Result<TaskResult<TestState>, CanoError> {
if self.should_fail {
Err(CanoError::task_execution("Task intentionally failed"))
} else {
Ok(TaskResult::Single(TestState::Complete))
}
}
}
#[tokio::test]
async fn test_workflow_creation() {
let store = MemoryStore::new();
let workflow = Workflow::<TestState>::new(store);
assert_eq!(workflow.states.len(), 0);
assert_eq!(workflow.exit_states.len(), 0);
}
#[tokio::test]
async fn test_simple_workflow() {
let store = MemoryStore::new();
let workflow = Workflow::new(store)
.register(TestState::Start, SimpleTask::new(TestState::Complete))
.add_exit_state(TestState::Complete);
let result = workflow.orchestrate(TestState::Start).await.unwrap();
assert_eq!(result, TestState::Complete);
}
#[tokio::test]
async fn test_multi_step_workflow() {
let store = MemoryStore::new();
let workflow = Workflow::new(store)
.register(TestState::Start, SimpleTask::new(TestState::Process))
.register(TestState::Process, SimpleTask::new(TestState::Complete))
.add_exit_state(TestState::Complete);
let result = workflow.orchestrate(TestState::Start).await.unwrap();
assert_eq!(result, TestState::Complete);
}
#[tokio::test]
async fn test_workflow_with_data() {
let store = MemoryStore::new();
let workflow = Workflow::new(store.clone())
.register(
TestState::Start,
DataTask::new("test_key", "test_value", TestState::Complete),
)
.add_exit_state(TestState::Complete);
let result = workflow.orchestrate(TestState::Start).await.unwrap();
assert_eq!(result, TestState::Complete);
let data: String = store.get("test_key").unwrap();
assert_eq!(data, "test_value");
}
#[tokio::test]
async fn test_split_all_strategy() {
let store = MemoryStore::new();
let tasks = vec![
SimpleTask::new(TestState::Join),
SimpleTask::new(TestState::Join),
SimpleTask::new(TestState::Join),
];
let join_config = JoinConfig::new(JoinStrategy::All, TestState::Complete);
let workflow = Workflow::new(store)
.register_split(TestState::Start, tasks, join_config)
.add_exit_state(TestState::Complete);
let result = workflow.orchestrate(TestState::Start).await.unwrap();
assert_eq!(result, TestState::Complete);
}
#[tokio::test]
async fn test_split_any_strategy() {
let store = MemoryStore::new();
let tasks = vec![
SimpleTask::new(TestState::Join),
SimpleTask::new(TestState::Join),
SimpleTask::new(TestState::Join),
];
let join_config = JoinConfig::new(JoinStrategy::Any, TestState::Complete);
let workflow = Workflow::new(store)
.register_split(TestState::Start, tasks, join_config)
.add_exit_state(TestState::Complete);
let result = workflow.orchestrate(TestState::Start).await.unwrap();
assert_eq!(result, TestState::Complete);
}
#[tokio::test]
async fn test_split_quorum_strategy() {
let store = MemoryStore::new();
let tasks = vec![
SimpleTask::new(TestState::Join),
SimpleTask::new(TestState::Join),
SimpleTask::new(TestState::Join),
SimpleTask::new(TestState::Join),
];
let join_config = JoinConfig::new(JoinStrategy::Quorum(3), TestState::Complete);
let workflow = Workflow::new(store)
.register_split(TestState::Start, tasks, join_config)
.add_exit_state(TestState::Complete);
let result = workflow.orchestrate(TestState::Start).await.unwrap();
assert_eq!(result, TestState::Complete);
}
#[tokio::test]
async fn test_split_percentage_strategy() {
let store = MemoryStore::new();
let tasks = vec![
SimpleTask::new(TestState::Join),
SimpleTask::new(TestState::Join),
SimpleTask::new(TestState::Join),
SimpleTask::new(TestState::Join),
];
let join_config = JoinConfig::new(JoinStrategy::Percentage(0.75), TestState::Complete);
let workflow = Workflow::new(store)
.register_split(TestState::Start, tasks, join_config)
.add_exit_state(TestState::Complete);
let result = workflow.orchestrate(TestState::Start).await.unwrap();
assert_eq!(result, TestState::Complete);
}
#[tokio::test]
async fn test_split_with_failures_all_strategy() {
let store = MemoryStore::new();
let tasks = vec![
FailTask::new(false),
FailTask::new(true), FailTask::new(false),
];
let join_config = JoinConfig::new(JoinStrategy::All, TestState::Complete);
let workflow = Workflow::new(store)
.register_split(TestState::Start, tasks, join_config)
.add_exit_state(TestState::Complete);
let result = workflow.orchestrate(TestState::Start).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_split_with_failures_quorum_strategy() {
let store = MemoryStore::new();
let tasks = vec![
FailTask::new(false),
FailTask::new(false),
FailTask::new(true), ];
let join_config = JoinConfig::new(JoinStrategy::Quorum(2), TestState::Complete);
let workflow = Workflow::new(store)
.register_split(TestState::Start, tasks, join_config)
.add_exit_state(TestState::Complete);
let result = workflow.orchestrate(TestState::Start).await.unwrap();
assert_eq!(result, TestState::Complete);
}
#[tokio::test]
async fn test_split_with_timeout() {
let store = MemoryStore::new();
#[derive(Clone)]
struct SlowTask;
#[async_trait]
impl Task<TestState> for SlowTask {
async fn run(&self, _store: &MemoryStore) -> Result<TaskResult<TestState>, CanoError> {
tokio::time::sleep(Duration::from_millis(200)).await;
Ok(TaskResult::Single(TestState::Complete))
}
}
let tasks = vec![SlowTask, SlowTask];
let join_config = JoinConfig::new(JoinStrategy::All, TestState::Complete)
.with_timeout(Duration::from_millis(50));
let workflow = Workflow::new(store)
.register_split(TestState::Start, tasks, join_config)
.add_exit_state(TestState::Complete);
let result = workflow.orchestrate(TestState::Start).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("timeout"));
}
#[tokio::test]
async fn test_workflow_timeout() {
let store = MemoryStore::new();
#[derive(Clone)]
struct SlowTask;
#[async_trait]
impl Task<TestState> for SlowTask {
async fn run(&self, _store: &MemoryStore) -> Result<TaskResult<TestState>, CanoError> {
tokio::time::sleep(Duration::from_millis(200)).await;
Ok(TaskResult::Single(TestState::Complete))
}
}
let workflow = Workflow::new(store)
.with_timeout(Duration::from_millis(50))
.register(TestState::Start, SlowTask)
.add_exit_state(TestState::Complete);
let result = workflow.orchestrate(TestState::Start).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Workflow timeout"));
}
#[tokio::test]
async fn test_split_with_data_sharing() {
let store = MemoryStore::new();
let tasks = vec![
DataTask::new("task1", "value1", TestState::Join),
DataTask::new("task2", "value2", TestState::Join),
DataTask::new("task3", "value3", TestState::Join),
];
let join_config = JoinConfig::new(JoinStrategy::All, TestState::Complete);
let workflow = Workflow::new(store.clone())
.register_split(TestState::Start, tasks, join_config)
.add_exit_state(TestState::Complete);
let result = workflow.orchestrate(TestState::Start).await.unwrap();
assert_eq!(result, TestState::Complete);
let data1: String = store.get("task1").unwrap();
let data2: String = store.get("task2").unwrap();
let data3: String = store.get("task3").unwrap();
assert_eq!(data1, "value1");
assert_eq!(data2, "value2");
assert_eq!(data3, "value3");
}
#[tokio::test]
async fn test_complex_workflow_with_split_join() {
let store = MemoryStore::new();
let split_tasks = vec![
DataTask::new("parallel1", "data1", TestState::Join),
DataTask::new("parallel2", "data2", TestState::Join),
];
let join_config = JoinConfig::new(JoinStrategy::All, TestState::Process);
let workflow = Workflow::new(store.clone())
.register(
TestState::Start,
DataTask::new("init", "initialized", TestState::Split),
)
.register_split(TestState::Split, split_tasks, join_config)
.register(TestState::Process, SimpleTask::new(TestState::Complete))
.add_exit_state(TestState::Complete);
let result = workflow.orchestrate(TestState::Start).await.unwrap();
assert_eq!(result, TestState::Complete);
let init: String = store.get("init").unwrap();
let parallel1: String = store.get("parallel1").unwrap();
let parallel2: String = store.get("parallel2").unwrap();
assert_eq!(init, "initialized");
assert_eq!(parallel1, "data1");
assert_eq!(parallel2, "data2");
}
#[tokio::test]
async fn test_join_strategy_is_satisfied() {
assert!(JoinStrategy::All.is_satisfied(3, 3));
assert!(!JoinStrategy::All.is_satisfied(2, 3));
assert!(JoinStrategy::Any.is_satisfied(1, 3));
assert!(!JoinStrategy::Any.is_satisfied(0, 3));
assert!(JoinStrategy::Quorum(2).is_satisfied(2, 3));
assert!(JoinStrategy::Quorum(2).is_satisfied(3, 3));
assert!(!JoinStrategy::Quorum(2).is_satisfied(1, 3));
assert!(JoinStrategy::Percentage(0.5).is_satisfied(2, 4));
assert!(JoinStrategy::Percentage(0.75).is_satisfied(3, 4));
assert!(!JoinStrategy::Percentage(0.75).is_satisfied(2, 4));
assert!(JoinStrategy::PartialResults(2).is_satisfied(2, 4));
assert!(JoinStrategy::PartialResults(2).is_satisfied(3, 4));
assert!(!JoinStrategy::PartialResults(2).is_satisfied(1, 4));
assert!(JoinStrategy::PartialTimeout.is_satisfied(1, 4));
assert!(JoinStrategy::PartialTimeout.is_satisfied(3, 4));
assert!(!JoinStrategy::PartialTimeout.is_satisfied(0, 4));
}
#[tokio::test]
async fn test_partial_results_strategy() {
let store = MemoryStore::new();
#[derive(Clone)]
struct DelayedTask {
delay_ms: u64,
#[allow(dead_code)]
task_id: usize,
}
#[async_trait]
impl Task<TestState> for DelayedTask {
async fn run(&self, _store: &MemoryStore) -> Result<TaskResult<TestState>, CanoError> {
tokio::time::sleep(Duration::from_millis(self.delay_ms)).await;
Ok(TaskResult::Single(TestState::Complete))
}
}
let tasks = vec![
DelayedTask {
delay_ms: 50,
task_id: 1,
},
DelayedTask {
delay_ms: 100,
task_id: 2,
},
DelayedTask {
delay_ms: 500,
task_id: 3,
}, DelayedTask {
delay_ms: 600,
task_id: 4,
}, ];
let join_config = JoinConfig::new(JoinStrategy::PartialResults(2), TestState::Complete)
.with_store_partial_results(true);
let workflow = Workflow::new(store.clone())
.register_split(TestState::Start, tasks, join_config)
.add_exit_state(TestState::Complete);
let result = workflow.orchestrate(TestState::Start).await.unwrap();
assert_eq!(result, TestState::Complete);
let successes: usize = store.get("split_successes_count").unwrap();
let cancelled: usize = store.get("split_cancelled_count").unwrap();
assert_eq!(successes, 2);
assert_eq!(cancelled, 2);
}
#[tokio::test]
async fn test_partial_results_with_failures() {
let store = MemoryStore::new();
#[derive(Clone)]
struct MixedTask {
delay_ms: u64,
should_fail: bool,
}
#[async_trait]
impl Task<TestState> for MixedTask {
fn config(&self) -> crate::task::TaskConfig {
crate::task::TaskConfig::minimal()
}
async fn run(&self, _store: &MemoryStore) -> Result<TaskResult<TestState>, CanoError> {
tokio::time::sleep(Duration::from_millis(self.delay_ms)).await;
if self.should_fail {
Err(CanoError::task_execution("Task failed"))
} else {
Ok(TaskResult::Single(TestState::Complete))
}
}
}
let tasks = vec![
MixedTask {
delay_ms: 50,
should_fail: false,
}, MixedTask {
delay_ms: 100,
should_fail: true,
}, MixedTask {
delay_ms: 500,
should_fail: false,
}, MixedTask {
delay_ms: 600,
should_fail: false,
}, ];
let join_config = JoinConfig::new(JoinStrategy::PartialResults(2), TestState::Complete)
.with_store_partial_results(true);
let workflow = Workflow::new(store.clone())
.register_split(TestState::Start, tasks, join_config)
.add_exit_state(TestState::Complete);
let result = workflow.orchestrate(TestState::Start).await.unwrap();
assert_eq!(result, TestState::Complete);
let successes: usize = store.get("split_successes_count").unwrap();
let errors: usize = store.get("split_errors_count").unwrap();
let cancelled: usize = store.get("split_cancelled_count").unwrap();
assert_eq!(successes, 2);
assert_eq!(errors, 1);
assert_eq!(cancelled, 1);
}
#[tokio::test]
async fn test_partial_results_minimum_not_met() {
let store = MemoryStore::new();
#[derive(Clone)]
struct SlowTask;
#[async_trait]
impl Task<TestState> for SlowTask {
async fn run(&self, _store: &MemoryStore) -> Result<TaskResult<TestState>, CanoError> {
tokio::time::sleep(Duration::from_millis(500)).await;
Ok(TaskResult::Single(TestState::Complete))
}
}
let tasks = vec![SlowTask, SlowTask, SlowTask];
let join_config = JoinConfig::new(JoinStrategy::PartialResults(3), TestState::Complete)
.with_timeout(Duration::from_millis(100));
let workflow = Workflow::new(store)
.register_split(TestState::Start, tasks, join_config)
.add_exit_state(TestState::Complete);
let result = workflow.orchestrate(TestState::Start).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("timeout"));
}
#[tokio::test]
async fn test_partial_timeout_strategy() {
let store = MemoryStore::new();
#[derive(Clone)]
struct DelayedTask {
delay_ms: u64,
#[allow(dead_code)]
task_id: usize,
}
#[async_trait]
impl Task<TestState> for DelayedTask {
async fn run(&self, _store: &MemoryStore) -> Result<TaskResult<TestState>, CanoError> {
tokio::time::sleep(Duration::from_millis(self.delay_ms)).await;
Ok(TaskResult::Single(TestState::Complete))
}
}
let tasks = vec![
DelayedTask {
delay_ms: 50,
task_id: 1,
},
DelayedTask {
delay_ms: 100,
task_id: 2,
},
DelayedTask {
delay_ms: 500,
task_id: 3,
}, DelayedTask {
delay_ms: 600,
task_id: 4,
}, ];
let join_config = JoinConfig::new(JoinStrategy::PartialTimeout, TestState::Complete)
.with_timeout(Duration::from_millis(200))
.with_store_partial_results(true);
let workflow = Workflow::new(store.clone())
.register_split(TestState::Start, tasks, join_config)
.add_exit_state(TestState::Complete);
let result = workflow.orchestrate(TestState::Start).await.unwrap();
assert_eq!(result, TestState::Complete);
let successes: usize = store.get("split_successes_count").unwrap();
let cancelled: usize = store.get("split_cancelled_count").unwrap();
assert_eq!(successes, 2);
assert_eq!(cancelled, 2);
}
#[tokio::test]
async fn test_partial_timeout_with_failures() {
let store = MemoryStore::new();
#[derive(Clone)]
struct MixedTask {
delay_ms: u64,
should_fail: bool,
}
#[async_trait]
impl Task<TestState> for MixedTask {
fn config(&self) -> crate::task::TaskConfig {
crate::task::TaskConfig::minimal()
}
async fn run(&self, _store: &MemoryStore) -> Result<TaskResult<TestState>, CanoError> {
tokio::time::sleep(Duration::from_millis(self.delay_ms)).await;
if self.should_fail {
Err(CanoError::task_execution("Task failed"))
} else {
Ok(TaskResult::Single(TestState::Complete))
}
}
}
let tasks = vec![
MixedTask {
delay_ms: 50,
should_fail: false,
}, MixedTask {
delay_ms: 100,
should_fail: true,
}, MixedTask {
delay_ms: 150,
should_fail: false,
}, MixedTask {
delay_ms: 500,
should_fail: false,
}, ];
let join_config = JoinConfig::new(JoinStrategy::PartialTimeout, TestState::Complete)
.with_timeout(Duration::from_millis(200))
.with_store_partial_results(true);
let workflow = Workflow::new(store.clone())
.register_split(TestState::Start, tasks, join_config)
.add_exit_state(TestState::Complete);
let result = workflow.orchestrate(TestState::Start).await.unwrap();
assert_eq!(result, TestState::Complete);
let successes: usize = store.get("split_successes_count").unwrap();
let errors: usize = store.get("split_errors_count").unwrap();
let cancelled: usize = store.get("split_cancelled_count").unwrap();
assert_eq!(successes, 2);
assert_eq!(errors, 1);
assert_eq!(cancelled, 1);
}
#[tokio::test]
async fn test_partial_timeout_all_complete() {
let store = MemoryStore::new();
#[derive(Clone)]
struct FastTask {
delay_ms: u64,
}
#[async_trait]
impl Task<TestState> for FastTask {
async fn run(&self, _store: &MemoryStore) -> Result<TaskResult<TestState>, CanoError> {
tokio::time::sleep(Duration::from_millis(self.delay_ms)).await;
Ok(TaskResult::Single(TestState::Complete))
}
}
let tasks = vec![
FastTask { delay_ms: 20 },
FastTask { delay_ms: 30 },
FastTask { delay_ms: 40 },
];
let join_config = JoinConfig::new(JoinStrategy::PartialTimeout, TestState::Complete)
.with_timeout(Duration::from_millis(500))
.with_store_partial_results(true);
let workflow = Workflow::new(store.clone())
.register_split(TestState::Start, tasks, join_config)
.add_exit_state(TestState::Complete);
let result = workflow.orchestrate(TestState::Start).await.unwrap();
assert_eq!(result, TestState::Complete);
let successes: usize = store.get("split_successes_count").unwrap();
let cancelled: usize = store.get("split_cancelled_count").unwrap();
assert_eq!(successes, 3);
assert_eq!(cancelled, 0);
}
#[tokio::test]
async fn test_partial_timeout_no_timeout_configured() {
let store = MemoryStore::new();
#[derive(Clone)]
struct SimpleTask;
#[async_trait]
impl Task<TestState> for SimpleTask {
async fn run(&self, _store: &MemoryStore) -> Result<TaskResult<TestState>, CanoError> {
Ok(TaskResult::Single(TestState::Complete))
}
}
let tasks = vec![SimpleTask, SimpleTask];
let join_config = JoinConfig::new(JoinStrategy::PartialTimeout, TestState::Complete);
let workflow = Workflow::new(store)
.register_split(TestState::Start, tasks, join_config)
.add_exit_state(TestState::Complete);
let result = workflow.orchestrate(TestState::Start).await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("requires a timeout")
);
}
#[tokio::test]
async fn test_unregistered_state_error() {
let store = MemoryStore::new();
let workflow = Workflow::<TestState>::new(store).add_exit_state(TestState::Complete);
let result = workflow.orchestrate(TestState::Start).await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("No task registered")
);
}
#[test]
fn test_validate_empty_workflow() {
let workflow = Workflow::<TestState>::new(MemoryStore::new());
let result = workflow.validate();
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("no registered state handlers")
);
}
#[test]
fn test_validate_no_exit_states() {
let workflow = Workflow::new(MemoryStore::new())
.register(TestState::Start, SimpleTask::new(TestState::Complete));
let result = workflow.validate();
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("no exit states defined")
);
}
#[test]
fn test_validate_valid_workflow() {
let workflow = Workflow::new(MemoryStore::new())
.register(TestState::Start, SimpleTask::new(TestState::Complete))
.add_exit_state(TestState::Complete);
assert!(workflow.validate().is_ok());
}
#[test]
fn test_validate_split_join_state_unregistered() {
let workflow = Workflow::new(MemoryStore::new())
.register_split(
TestState::Start,
vec![SimpleTask::new(TestState::Join)],
JoinConfig::new(JoinStrategy::All, TestState::Process), )
.add_exit_state(TestState::Complete);
let result = workflow.validate();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("join_state"));
}
#[test]
fn test_validate_split_join_state_as_exit_state() {
let workflow = Workflow::new(MemoryStore::new())
.register_split(
TestState::Start,
vec![SimpleTask::new(TestState::Complete)],
JoinConfig::new(JoinStrategy::All, TestState::Complete),
)
.add_exit_state(TestState::Complete);
assert!(workflow.validate().is_ok());
}
#[test]
fn test_validate_initial_state() {
let workflow = Workflow::new(MemoryStore::new())
.register(TestState::Start, SimpleTask::new(TestState::Complete))
.add_exit_state(TestState::Complete);
assert!(workflow.validate_initial_state(&TestState::Start).is_ok());
assert!(
workflow
.validate_initial_state(&TestState::Complete)
.is_ok()
);
let result = workflow.validate_initial_state(&TestState::Process);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("neither registered nor an exit state")
);
}
#[tokio::test]
async fn test_empty_split_task_list() {
let store = MemoryStore::new();
let tasks: Vec<SimpleTask> = vec![];
let join_config = JoinConfig::new(JoinStrategy::All, TestState::Complete);
let workflow = Workflow::new(store)
.register_split(TestState::Start, tasks, join_config)
.add_exit_state(TestState::Complete);
let result = workflow.orchestrate(TestState::Start).await.unwrap();
assert_eq!(result, TestState::Complete);
}
#[tokio::test]
async fn test_percentage_zero() {
let store = MemoryStore::new();
let tasks = vec![FailTask::new(true), FailTask::new(true)];
let join_config = JoinConfig::new(JoinStrategy::Percentage(0.0), TestState::Complete);
let workflow = Workflow::new(store)
.register_split(TestState::Start, tasks, join_config)
.add_exit_state(TestState::Complete);
let err = workflow.orchestrate(TestState::Start).await.unwrap_err();
assert!(
matches!(err, CanoError::Configuration(_)),
"expected Configuration error, got {err:?}"
);
}
#[tokio::test]
async fn test_percentage_one() {
let store = MemoryStore::new();
let tasks = vec![
SimpleTask::new(TestState::Join),
SimpleTask::new(TestState::Join),
];
let join_config = JoinConfig::new(JoinStrategy::Percentage(1.0), TestState::Complete);
let workflow = Workflow::new(store)
.register_split(TestState::Start, tasks, join_config)
.add_exit_state(TestState::Complete);
assert_eq!(
workflow.orchestrate(TestState::Start).await.unwrap(),
TestState::Complete
);
let store2 = MemoryStore::new();
let tasks_fail = vec![FailTask::new(false), FailTask::new(true)];
let join_config2 = JoinConfig::new(JoinStrategy::Percentage(1.0), TestState::Complete);
let workflow2 = Workflow::new(store2)
.register_split(TestState::Start, tasks_fail, join_config2)
.add_exit_state(TestState::Complete);
assert!(workflow2.orchestrate(TestState::Start).await.is_err());
}
#[tokio::test]
async fn test_percentage_over_one() {
let store = MemoryStore::new();
let tasks = vec![
FailTask::new(false), FailTask::new(true), ];
let join_config = JoinConfig::new(JoinStrategy::Percentage(1.5), TestState::Complete);
let workflow = Workflow::new(store)
.register_split(TestState::Start, tasks, join_config)
.add_exit_state(TestState::Complete);
let err = workflow.orchestrate(TestState::Start).await.unwrap_err();
assert!(
matches!(err, CanoError::Configuration(_)),
"expected Configuration error, got {err:?}"
);
}
#[tokio::test]
async fn test_quorum_zero() {
let store = MemoryStore::new();
let tasks = vec![FailTask::new(true), FailTask::new(true)];
let join_config = JoinConfig::new(JoinStrategy::Quorum(0), TestState::Complete);
let workflow = Workflow::new(store)
.register_split(TestState::Start, tasks, join_config)
.add_exit_state(TestState::Complete);
let result = workflow.orchestrate(TestState::Start).await.unwrap();
assert_eq!(result, TestState::Complete);
}
#[tokio::test]
async fn test_single_task_register_split() {
let store = MemoryStore::new();
let tasks = vec![SimpleTask::new(TestState::Complete)];
let join_config = JoinConfig::new(JoinStrategy::All, TestState::Complete);
let workflow = Workflow::new(store)
.register_split(TestState::Start, tasks, join_config)
.add_exit_state(TestState::Complete);
let result = workflow.orchestrate(TestState::Start).await.unwrap();
assert_eq!(result, TestState::Complete);
}
#[tokio::test]
async fn test_workflow_no_exit_states() {
let store = MemoryStore::new();
let workflow =
Workflow::new(store).register(TestState::Start, SimpleTask::new(TestState::Complete));
let result = workflow.orchestrate(TestState::Start).await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("No task registered")
);
}
#[tokio::test]
async fn test_split_task_from_single_register() {
#[derive(Clone)]
struct SplitReturningTask;
#[async_trait]
impl Task<TestState> for SplitReturningTask {
async fn run(&self, _store: &MemoryStore) -> Result<TaskResult<TestState>, CanoError> {
Ok(TaskResult::Split(vec![TestState::Complete]))
}
}
let store = MemoryStore::new();
let workflow = Workflow::new(store)
.register(TestState::Start, SplitReturningTask)
.add_exit_state(TestState::Complete);
let result = workflow.orchestrate(TestState::Start).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("register_split"));
}
#[tokio::test]
async fn test_node_in_workflow_no_double_retry() {
use crate::node::Node;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
struct CountingNode {
call_count: Arc<std::sync::atomic::AtomicUsize>,
}
#[async_trait]
impl Node<TestState> for CountingNode {
type PrepResult = ();
type ExecResult = ();
fn config(&self) -> crate::task::TaskConfig {
crate::task::TaskConfig::new().with_fixed_retry(2, Duration::from_millis(1))
}
async fn prep(&self, _store: &MemoryStore) -> Result<(), CanoError> {
self.call_count.fetch_add(1, Ordering::SeqCst);
Err(CanoError::preparation("always fails"))
}
async fn exec(&self, _: ()) -> () {}
async fn post(&self, _store: &MemoryStore, _: ()) -> Result<TestState, CanoError> {
Ok(TestState::Complete)
}
}
let call_count = Arc::new(AtomicUsize::new(0));
let node = CountingNode {
call_count: Arc::clone(&call_count),
};
let store = MemoryStore::new();
let workflow = Workflow::new(store)
.register(TestState::Start, node)
.add_exit_state(TestState::Complete);
let result = workflow.orchestrate(TestState::Start).await;
assert!(result.is_err());
assert_eq!(
call_count.load(Ordering::SeqCst),
3,
"Node should be called exactly 3 times (1 + 2 retries), not double-retried"
);
}
#[tokio::test]
async fn test_split_task_retry_config_honoured() {
use std::sync::atomic::{AtomicUsize, Ordering};
#[derive(Clone)]
struct RetryCountingTask {
call_count: Arc<AtomicUsize>,
succeed_after: usize,
}
#[async_trait]
impl Task<TestState> for RetryCountingTask {
fn config(&self) -> crate::task::TaskConfig {
crate::task::TaskConfig::new()
.with_fixed_retry(4, std::time::Duration::from_millis(1))
}
async fn run(&self, _store: &MemoryStore) -> Result<TaskResult<TestState>, CanoError> {
let count = self.call_count.fetch_add(1, Ordering::SeqCst) + 1;
if count >= self.succeed_after {
Ok(TaskResult::Single(TestState::Complete))
} else {
Err(CanoError::task_execution("not ready yet"))
}
}
}
let call_count = Arc::new(AtomicUsize::new(0));
let tasks = vec![RetryCountingTask {
call_count: Arc::clone(&call_count),
succeed_after: 3, }];
let join_config = JoinConfig::new(JoinStrategy::All, TestState::Complete);
let store = MemoryStore::new();
let workflow = Workflow::new(store)
.register_split(TestState::Start, tasks, join_config)
.add_exit_state(TestState::Complete);
let result = workflow.orchestrate(TestState::Start).await;
assert!(result.is_ok(), "workflow should succeed after retries");
assert_eq!(
call_count.load(Ordering::SeqCst),
3,
"task should have been called exactly 3 times (2 failures + 1 success)"
);
}
}