use crate::context::WorkflowContext;
use crate::task::{RetryPolicy, UntypedCoreTask};
use sha2::{Digest, Sha256};
use std::collections::{HashMap, HashSet};
use std::marker::PhantomData;
use std::ops::Deref;
use std::sync::Arc;
#[derive(
Debug,
Clone,
Copy,
PartialEq,
Eq,
serde::Serialize,
serde::Deserialize,
strum::EnumString,
strum::Display,
)]
pub enum MaxIterationsPolicy {
#[strum(serialize = "fail")]
Fail,
#[strum(serialize = "exit_with_last")]
ExitWithLast,
}
macro_rules! impl_find_duplicate_id {
($name:ident, task_fields: { $($task_extra:tt)* }, delay_extra: { $($delay_extra:tt)* }, deref_branch: $deref:expr, deref_branch_map: $deref_map:expr) => {
impl $name {
pub(crate) fn find_duplicate_id(&self) -> Option<String> {
fn collect(cont: &$name, seen: &mut HashSet<String>) -> Option<String> {
match cont {
$name::Task { id, next, $($task_extra)* } => {
if !seen.insert(id.clone()) {
return Some(id.clone());
}
next.as_ref().and_then(|n| collect(n, seen))
}
$name::Fork { id, branches, join } => {
if !seen.insert(id.clone()) {
return Some(id.clone());
}
let deref_fn: fn(&_) -> &$name = $deref;
branches
.iter()
.find_map(|b| collect(deref_fn(b), seen))
.or_else(|| join.as_ref().and_then(|j| collect(j, seen)))
}
$name::Branch { id, branches, default, next, .. } => {
if !seen.insert(id.clone()) {
return Some(id.clone());
}
let deref_map_fn: fn(&_) -> &$name = $deref_map;
branches
.values()
.find_map(|b| collect(deref_map_fn(b), seen))
.or_else(|| default.as_ref().and_then(|d| collect(d, seen)))
.or_else(|| next.as_ref().and_then(|n| collect(n, seen)))
}
$name::Delay { id, next, $($delay_extra)* }
| $name::AwaitSignal { id, next, $($delay_extra)* } => {
if !seen.insert(id.clone()) {
return Some(id.clone());
}
next.as_ref().and_then(|n| collect(n, seen))
}
$name::Loop { id, body, next, .. } => {
if !seen.insert(id.clone()) {
return Some(id.clone());
}
collect(body, seen)
.or_else(|| next.as_ref().and_then(|n| collect(n, seen)))
}
$name::ChildWorkflow { id, child, next } => {
if !seen.insert(id.clone()) {
return Some(id.clone());
}
collect(child, seen)
.or_else(|| next.as_ref().and_then(|n| collect(n, seen)))
}
}
}
collect(self, &mut HashSet::new())
}
}
};
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, strum::AsRefStr, strum::Display, strum::EnumString)]
#[strum(serialize_all = "snake_case")]
pub enum NodeKind {
Task,
Fork,
Delay,
AwaitSignal,
Branch,
Loop,
ChildWorkflow,
}
#[derive(Debug, Clone)]
pub struct NodeInfo<'a> {
pub id: &'a str,
pub kind: NodeKind,
pub predecessor_id: Option<&'a str>,
pub timeout: Option<std::time::Duration>,
pub retry_policy: Option<&'a RetryPolicy>,
pub priority: Option<u8>,
pub tags: &'a [String],
pub version: Option<&'a str>,
}
pub struct NodeIter<'a> {
stack: Vec<(&'a WorkflowContinuation, Option<&'a str>)>,
}
const EMPTY_TAGS: &[String] = &[];
impl<'a> Iterator for NodeIter<'a> {
type Item = NodeInfo<'a>;
#[allow(clippy::too_many_lines)]
fn next(&mut self) -> Option<Self::Item> {
let (cont, predecessor) = self.stack.pop()?;
let (id, kind, timeout, retry_policy, priority, tags, version) = match cont {
WorkflowContinuation::Task {
id,
timeout,
retry_policy,
priority,
tags,
version,
..
} => (
id.as_str(),
NodeKind::Task,
*timeout,
retry_policy.as_ref(),
*priority,
tags.as_slice(),
version.as_deref(),
),
WorkflowContinuation::Fork { id, .. } => (
id.as_str(),
NodeKind::Fork,
None,
None,
None,
EMPTY_TAGS,
None,
),
WorkflowContinuation::Delay { id, duration, .. } => (
id.as_str(),
NodeKind::Delay,
Some(*duration),
None,
None,
EMPTY_TAGS,
None,
),
WorkflowContinuation::AwaitSignal { id, timeout, .. } => (
id.as_str(),
NodeKind::AwaitSignal,
*timeout,
None,
None,
EMPTY_TAGS,
None,
),
WorkflowContinuation::Branch { id, .. } => (
id.as_str(),
NodeKind::Branch,
None,
None,
None,
EMPTY_TAGS,
None,
),
WorkflowContinuation::Loop { id, .. } => (
id.as_str(),
NodeKind::Loop,
None,
None,
None,
EMPTY_TAGS,
None,
),
WorkflowContinuation::ChildWorkflow { id, .. } => (
id.as_str(),
NodeKind::ChildWorkflow,
None,
None,
None,
EMPTY_TAGS,
None,
),
};
match cont {
WorkflowContinuation::Task { id, next, .. }
| WorkflowContinuation::Delay { id, next, .. }
| WorkflowContinuation::AwaitSignal { id, next, .. } => {
if let Some(n) = next {
self.stack.push((n, Some(id)));
}
}
WorkflowContinuation::Fork { id, branches, join } => {
if let Some(j) = join {
self.stack.push((j, Some(id)));
}
for b in branches.iter().rev() {
self.stack.push((b, Some(id)));
}
}
WorkflowContinuation::Branch {
id,
branches,
default,
next,
..
} => {
if let Some(n) = next {
self.stack.push((n, Some(id)));
}
if let Some(d) = default {
self.stack.push((d, Some(id)));
}
let mut keys: Vec<&String> = branches.keys().collect();
keys.sort();
for k in keys.into_iter().rev() {
self.stack.push((&branches[k], Some(id)));
}
}
WorkflowContinuation::Loop { id, body, next, .. } => {
if let Some(n) = next {
self.stack.push((n, Some(id)));
}
self.stack.push((body, Some(id)));
}
WorkflowContinuation::ChildWorkflow {
id, child, next, ..
} => {
if let Some(n) = next {
self.stack.push((n, Some(id)));
}
self.stack.push((child, Some(id)));
}
}
Some(NodeInfo {
id,
kind,
predecessor_id: predecessor,
timeout,
retry_policy,
priority,
tags,
version,
})
}
}
pub enum WorkflowContinuation {
Task {
id: String,
func: Option<UntypedCoreTask>,
timeout: Option<std::time::Duration>,
retry_policy: Option<RetryPolicy>,
version: Option<String>,
priority: Option<u8>,
tags: Vec<String>,
next: Option<Box<WorkflowContinuation>>,
},
Fork {
id: String,
branches: Box<[Arc<WorkflowContinuation>]>,
join: Option<Box<WorkflowContinuation>>,
},
Delay {
id: String,
duration: std::time::Duration,
next: Option<Box<WorkflowContinuation>>,
},
AwaitSignal {
id: String,
signal_name: String,
timeout: Option<std::time::Duration>,
next: Option<Box<WorkflowContinuation>>,
},
Branch {
id: String,
key_fn: Option<UntypedCoreTask>,
branches: HashMap<String, Box<WorkflowContinuation>>,
default: Option<Box<WorkflowContinuation>>,
next: Option<Box<WorkflowContinuation>>,
},
Loop {
id: String,
body: Box<WorkflowContinuation>,
max_iterations: u32,
on_max: MaxIterationsPolicy,
next: Option<Box<WorkflowContinuation>>,
},
ChildWorkflow {
id: String,
child: Arc<WorkflowContinuation>,
next: Option<Box<WorkflowContinuation>>,
},
}
impl_find_duplicate_id!(
WorkflowContinuation,
task_fields: { .. },
delay_extra: { .. },
deref_branch: |b: &Arc<WorkflowContinuation>| -> &WorkflowContinuation { b },
deref_branch_map: |b: &WorkflowContinuation| -> &WorkflowContinuation { b }
);
#[must_use]
pub fn key_fn_id(branch_id: &str) -> String {
format!("{branch_id}::key_fn")
}
#[must_use]
pub fn loop_node_id(counter: usize) -> String {
format!("loop_{counter}")
}
impl WorkflowContinuation {
#[must_use]
pub fn derive_fork_id(branch_ids: &[&str]) -> String {
branch_ids.join("||")
}
#[must_use]
pub fn id(&self) -> &str {
match self {
WorkflowContinuation::Task { id, .. }
| WorkflowContinuation::Fork { id, .. }
| WorkflowContinuation::Delay { id, .. }
| WorkflowContinuation::AwaitSignal { id, .. }
| WorkflowContinuation::Branch { id, .. }
| WorkflowContinuation::Loop { id, .. }
| WorkflowContinuation::ChildWorkflow { id, .. } => id,
}
}
#[must_use]
pub fn get_next(&self) -> Option<&WorkflowContinuation> {
match self {
Self::Task { next, .. }
| Self::Delay { next, .. }
| Self::AwaitSignal { next, .. }
| Self::Branch { next, .. }
| Self::Loop { next, .. }
| Self::ChildWorkflow { next, .. } => next.as_deref(),
Self::Fork { join, .. } => join.as_deref(),
}
}
#[must_use]
pub fn first_task_id(&self) -> &str {
match self {
WorkflowContinuation::Task { id, .. }
| WorkflowContinuation::Delay { id, .. }
| WorkflowContinuation::AwaitSignal { id, .. }
| WorkflowContinuation::Branch { id, .. } => id,
WorkflowContinuation::Fork { branches, .. } => {
if let Some(first_branch) = branches.first() {
first_branch.first_task_id()
} else {
"unknown"
}
}
WorkflowContinuation::Loop { body, .. } => body.first_task_id(),
WorkflowContinuation::ChildWorkflow { child, .. } => child.first_task_id(),
}
}
#[must_use]
pub fn first_task_priority(&self) -> Option<u8> {
match self {
WorkflowContinuation::Task { priority, .. } => *priority,
WorkflowContinuation::Delay { .. }
| WorkflowContinuation::AwaitSignal { .. }
| WorkflowContinuation::Branch { .. } => None,
WorkflowContinuation::Fork { branches, .. } => {
branches.first().and_then(|b| b.first_task_priority())
}
WorkflowContinuation::Loop { body, .. } => body.first_task_priority(),
WorkflowContinuation::ChildWorkflow { child, .. } => child.first_task_priority(),
}
}
#[must_use]
pub fn first_task_tags(&self) -> Vec<String> {
match self {
WorkflowContinuation::Task { tags, .. } => tags.clone(),
WorkflowContinuation::Delay { .. }
| WorkflowContinuation::AwaitSignal { .. }
| WorkflowContinuation::Branch { .. } => vec![],
WorkflowContinuation::Fork { branches, .. } => branches
.first()
.map(|b| b.first_task_tags())
.unwrap_or_default(),
WorkflowContinuation::Loop { body, .. } => body.first_task_tags(),
WorkflowContinuation::ChildWorkflow { child, .. } => child.first_task_tags(),
}
}
#[must_use]
pub fn first_task_hint(&self) -> crate::snapshot::TaskHint {
crate::snapshot::TaskHint::new(
self.first_task_id(),
self.first_task_priority(),
&self.first_task_tags(),
)
}
#[must_use]
pub fn terminal_task_id(&self) -> &str {
let mut current = self;
while let Some(next) = current.get_next() {
current = next;
}
current.first_task_id()
}
#[must_use]
pub fn find_task_name(&self, target_id: &crate::TaskId) -> Option<&str> {
self.find_task(target_id).and_then(|n| match n {
WorkflowContinuation::Task { id, .. } => Some(id.as_str()),
_ => None,
})
}
fn find_task(&self, target_id: &crate::TaskId) -> Option<&Self> {
match self {
WorkflowContinuation::Task { id, next, .. } => {
if crate::TaskId::from(id.as_str()) == *target_id {
return Some(self);
}
next.as_ref().and_then(|n| n.find_task(target_id))
}
WorkflowContinuation::Delay { next, .. }
| WorkflowContinuation::AwaitSignal { next, .. } => {
next.as_ref().and_then(|n| n.find_task(target_id))
}
WorkflowContinuation::Fork { branches, join, .. } => {
for branch in branches {
if let Some(found) = branch.find_task(target_id) {
return Some(found);
}
}
join.as_ref().and_then(|j| j.find_task(target_id))
}
WorkflowContinuation::Branch {
branches,
default,
next,
..
} => {
for branch in branches.values() {
if let Some(found) = branch.find_task(target_id) {
return Some(found);
}
}
if let Some(d) = default
&& let Some(found) = d.find_task(target_id)
{
return Some(found);
}
next.as_ref().and_then(|n| n.find_task(target_id))
}
WorkflowContinuation::Loop { body, next, .. } => body
.find_task(target_id)
.or_else(|| next.as_ref().and_then(|n| n.find_task(target_id))),
WorkflowContinuation::ChildWorkflow { child, next, .. } => child
.find_task(target_id)
.or_else(|| next.as_ref().and_then(|n| n.find_task(target_id))),
}
}
fn find_task_mut(&mut self, target_id: &crate::TaskId) -> Option<&mut Self> {
match self {
WorkflowContinuation::Task { id, .. }
if crate::TaskId::from(id.as_str()) == *target_id =>
{
Some(self)
}
WorkflowContinuation::Task { next, .. } => {
next.as_mut().and_then(|n| n.find_task_mut(target_id))
}
WorkflowContinuation::Delay { next, .. }
| WorkflowContinuation::AwaitSignal { next, .. } => {
next.as_mut().and_then(|n| n.find_task_mut(target_id))
}
WorkflowContinuation::Fork { join, .. } => {
join.as_mut().and_then(|j| j.find_task_mut(target_id))
}
WorkflowContinuation::Branch {
branches,
default,
next,
..
} => {
for branch in branches.values_mut() {
if let Some(found) = branch.find_task_mut(target_id) {
return Some(found);
}
}
if let Some(d) = default
&& let Some(found) = d.find_task_mut(target_id)
{
return Some(found);
}
next.as_mut().and_then(|n| n.find_task_mut(target_id))
}
WorkflowContinuation::Loop { body, next, .. } => {
if let Some(found) = body.find_task_mut(target_id) {
return Some(found);
}
next.as_mut().and_then(|n| n.find_task_mut(target_id))
}
WorkflowContinuation::ChildWorkflow { next, .. } => {
next.as_mut().and_then(|n| n.find_task_mut(target_id))
}
}
}
pub fn set_task_timeout(
&mut self,
target_id: &crate::TaskId,
timeout: Option<std::time::Duration>,
) {
if let Some(WorkflowContinuation::Task { timeout: t, .. }) = self.find_task_mut(target_id) {
*t = timeout;
}
}
pub fn set_task_retry_policy(
&mut self,
target_id: &crate::TaskId,
policy: Option<RetryPolicy>,
) {
if let Some(WorkflowContinuation::Task { retry_policy, .. }) = self.find_task_mut(target_id)
{
*retry_policy = policy;
}
}
pub fn set_task_version(&mut self, target_id: &crate::TaskId, ver: Option<String>) {
if let Some(WorkflowContinuation::Task { version, .. }) = self.find_task_mut(target_id) {
*version = ver;
}
}
#[must_use]
pub fn get_task_retry_policy(&self, task_id: &crate::TaskId) -> Option<&RetryPolicy> {
match self.find_task(task_id)? {
WorkflowContinuation::Task { retry_policy, .. } => retry_policy.as_ref(),
_ => None,
}
}
#[must_use]
pub fn get_task_timeout(&self, task_id: &crate::TaskId) -> Option<std::time::Duration> {
match self.find_task(task_id)? {
WorkflowContinuation::Task { timeout, .. } => *timeout,
_ => None,
}
}
#[must_use]
pub fn get_task_priority(&self, task_id: &crate::TaskId) -> Option<u8> {
match self.find_task(task_id)? {
WorkflowContinuation::Task { priority, .. } => *priority,
_ => None,
}
}
#[must_use]
pub fn get_task_tags(&self, task_id: &crate::TaskId) -> Vec<String> {
match self.find_task(task_id) {
Some(WorkflowContinuation::Task { tags, .. }) => tags.clone(),
_ => vec![],
}
}
pub fn set_task_tags(&mut self, target_id: &crate::TaskId, new_tags: Vec<String>) {
if let Some(WorkflowContinuation::Task { tags, .. }) = self.find_task_mut(target_id) {
*tags = new_tags;
}
}
#[must_use]
pub fn build_task_metadata(&self, task_id: &crate::TaskId) -> crate::task::TaskMetadata {
match self.find_task(task_id) {
Some(WorkflowContinuation::Task {
timeout,
retry_policy,
version,
priority,
tags,
..
}) => crate::task::TaskMetadata::from_node_fields(
*timeout,
retry_policy.clone(),
version.clone(),
*priority,
tags.clone(),
),
_ => crate::task::TaskMetadata::default(),
}
}
#[must_use]
pub fn iter_nodes(&self) -> NodeIter<'_> {
NodeIter {
stack: vec![(self, None)],
}
}
#[must_use]
pub fn to_serializable(&self) -> SerializableContinuation {
match self {
#[allow(clippy::cast_possible_truncation)] WorkflowContinuation::Task {
id,
timeout,
retry_policy,
version,
priority,
tags,
next,
..
} => SerializableContinuation::Task {
id: id.clone(),
timeout_ms: timeout.map(|d| d.as_millis() as u64),
retry_policy: retry_policy.clone(),
version: version.clone(),
priority: *priority,
tags: tags.clone(),
next: next.as_ref().map(|n| Box::new(n.to_serializable())),
},
WorkflowContinuation::Fork { id, branches, join } => SerializableContinuation::Fork {
id: id.clone(),
branches: branches.iter().map(|b| b.to_serializable()).collect(),
join: join.as_ref().map(|j| Box::new(j.to_serializable())),
},
#[allow(clippy::cast_possible_truncation)] WorkflowContinuation::Delay { id, duration, next } => SerializableContinuation::Delay {
id: id.clone(),
duration_ms: duration.as_millis() as u64,
next: next.as_ref().map(|n| Box::new(n.to_serializable())),
},
#[allow(clippy::cast_possible_truncation)]
WorkflowContinuation::AwaitSignal {
id,
signal_name,
timeout,
next,
} => SerializableContinuation::AwaitSignal {
id: id.clone(),
signal_name: signal_name.clone(),
timeout_ms: timeout.map(|d| d.as_millis() as u64),
next: next.as_ref().map(|n| Box::new(n.to_serializable())),
},
WorkflowContinuation::Branch {
id,
branches,
default,
next,
..
} => SerializableContinuation::Branch {
id: id.clone(),
branches: branches
.iter()
.map(|(k, v)| (k.clone(), Box::new(v.to_serializable())))
.collect(),
default: default.as_ref().map(|d| Box::new(d.to_serializable())),
next: next.as_ref().map(|n| Box::new(n.to_serializable())),
},
WorkflowContinuation::ChildWorkflow { id, child, next } => {
SerializableContinuation::ChildWorkflow {
id: id.clone(),
child: Box::new(child.to_serializable()),
next: next.as_ref().map(|n| Box::new(n.to_serializable())),
}
}
WorkflowContinuation::Loop {
id,
body,
max_iterations,
on_max,
next,
} => SerializableContinuation::Loop {
id: id.clone(),
body: Box::new(body.to_serializable()),
max_iterations: *max_iterations,
on_max: *on_max,
next: next.as_ref().map(|n| Box::new(n.to_serializable())),
},
}
}
pub fn append_to_chain(&mut self, new_node: WorkflowContinuation) {
match self {
WorkflowContinuation::Task { next, .. }
| WorkflowContinuation::Delay { next, .. }
| WorkflowContinuation::AwaitSignal { next, .. }
| WorkflowContinuation::Branch { next, .. }
| WorkflowContinuation::Loop { next, .. }
| WorkflowContinuation::ChildWorkflow { next, .. } => match next {
Some(next_box) => next_box.append_to_chain(new_node),
None => *next = Some(Box::new(new_node)),
},
WorkflowContinuation::Fork { join, .. } => match join {
Some(join_box) => join_box.append_to_chain(new_node),
None => *join = Some(Box::new(new_node)),
},
}
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub enum SerializableContinuation {
Task {
id: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
timeout_ms: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
retry_policy: Option<RetryPolicy>,
#[serde(default, skip_serializing_if = "Option::is_none")]
version: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
priority: Option<u8>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
tags: Vec<String>,
next: Option<Box<SerializableContinuation>>,
},
Fork {
id: String,
branches: Vec<SerializableContinuation>,
join: Option<Box<SerializableContinuation>>,
},
Delay {
id: String,
duration_ms: u64,
next: Option<Box<SerializableContinuation>>,
},
AwaitSignal {
id: String,
signal_name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
timeout_ms: Option<u64>,
next: Option<Box<SerializableContinuation>>,
},
Branch {
id: String,
branches: HashMap<String, Box<SerializableContinuation>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
default: Option<Box<SerializableContinuation>>,
next: Option<Box<SerializableContinuation>>,
},
Loop {
id: String,
body: Box<SerializableContinuation>,
max_iterations: u32,
on_max: MaxIterationsPolicy,
next: Option<Box<SerializableContinuation>>,
},
ChildWorkflow {
id: String,
child: Box<SerializableContinuation>,
next: Option<Box<SerializableContinuation>>,
},
}
impl_find_duplicate_id!(
SerializableContinuation,
task_fields: { .. },
delay_extra: { .. },
deref_branch: |b: &SerializableContinuation| -> &SerializableContinuation { b },
deref_branch_map: |b: &SerializableContinuation| -> &SerializableContinuation { b }
);
impl SerializableContinuation {
pub fn to_runnable(
&self,
registry: &crate::registry::TaskRegistry,
) -> Result<WorkflowContinuation, crate::error::BuildError> {
if let Some(dup) = self.find_duplicate_id() {
return Err(crate::error::BuildError::DuplicateTaskId(dup));
}
self.to_runnable_unchecked(registry)
}
#[allow(clippy::too_many_lines)]
fn to_runnable_unchecked(
&self,
registry: &crate::registry::TaskRegistry,
) -> Result<WorkflowContinuation, crate::error::BuildError> {
match self {
SerializableContinuation::Task {
id,
timeout_ms,
retry_policy,
version,
priority,
tags,
next,
} => {
let func = registry
.get(id)
.ok_or_else(|| crate::error::BuildError::TaskNotFound(id.clone()))?;
let next = next
.as_ref()
.map(|n| n.to_runnable_unchecked(registry).map(Box::new))
.transpose()?;
Ok(WorkflowContinuation::Task {
id: id.clone(),
func: Some(func),
timeout: timeout_ms.map(std::time::Duration::from_millis),
retry_policy: retry_policy.clone(),
version: version.clone(),
priority: *priority,
tags: tags.clone(),
next,
})
}
SerializableContinuation::Fork { id, branches, join } => {
let branches: Result<Vec<_>, _> = branches
.iter()
.map(|b| b.to_runnable_unchecked(registry).map(Arc::new))
.collect();
let join = join
.as_ref()
.map(|j| j.to_runnable_unchecked(registry).map(Box::new))
.transpose()?;
Ok(WorkflowContinuation::Fork {
id: id.clone(),
branches: branches?.into_boxed_slice(),
join,
})
}
SerializableContinuation::Delay {
id,
duration_ms,
next,
} => {
let next = next
.as_ref()
.map(|n| n.to_runnable_unchecked(registry).map(Box::new))
.transpose()?;
Ok(WorkflowContinuation::Delay {
id: id.clone(),
duration: std::time::Duration::from_millis(*duration_ms),
next,
})
}
SerializableContinuation::AwaitSignal {
id,
signal_name,
timeout_ms,
next,
} => {
let next = next
.as_ref()
.map(|n| n.to_runnable_unchecked(registry).map(Box::new))
.transpose()?;
Ok(WorkflowContinuation::AwaitSignal {
id: id.clone(),
signal_name: signal_name.clone(),
timeout: timeout_ms.map(std::time::Duration::from_millis),
next,
})
}
SerializableContinuation::Branch {
id,
branches,
default,
next,
} => {
let kf_id = key_fn_id(id);
let key_fn = registry
.get(&kf_id)
.ok_or(crate::error::BuildError::TaskNotFound(kf_id))?;
let branches: Result<HashMap<_, _>, _> = branches
.iter()
.map(|(k, v)| {
v.to_runnable_unchecked(registry)
.map(|c| (k.clone(), Box::new(c)))
})
.collect();
let default = default
.as_ref()
.map(|d| d.to_runnable_unchecked(registry).map(Box::new))
.transpose()?;
let next = next
.as_ref()
.map(|n| n.to_runnable_unchecked(registry).map(Box::new))
.transpose()?;
Ok(WorkflowContinuation::Branch {
id: id.clone(),
key_fn: Some(key_fn),
branches: branches?,
default,
next,
})
}
SerializableContinuation::Loop {
id,
body,
max_iterations,
on_max,
next,
} => {
let body = body.to_runnable_unchecked(registry)?;
let next = next
.as_ref()
.map(|n| n.to_runnable_unchecked(registry).map(Box::new))
.transpose()?;
Ok(WorkflowContinuation::Loop {
id: id.clone(),
body: Box::new(body),
max_iterations: *max_iterations,
on_max: *on_max,
next,
})
}
SerializableContinuation::ChildWorkflow { id, child, next } => {
let child = child.to_runnable_unchecked(registry)?;
let next = next
.as_ref()
.map(|n| n.to_runnable_unchecked(registry).map(Box::new))
.transpose()?;
Ok(WorkflowContinuation::ChildWorkflow {
id: id.clone(),
child: Arc::new(child),
next,
})
}
}
}
#[must_use]
pub fn task_ids(&self) -> Vec<&str> {
fn collect<'a>(cont: &'a SerializableContinuation, ids: &mut Vec<&'a str>) {
match cont {
SerializableContinuation::Task { id, next, .. }
| SerializableContinuation::Delay { id, next, .. }
| SerializableContinuation::AwaitSignal { id, next, .. } => {
ids.push(id.as_str());
if let Some(n) = next {
collect(n, ids);
}
}
SerializableContinuation::Fork { id, branches, join } => {
ids.push(id.as_str());
for b in branches {
collect(b, ids);
}
if let Some(j) = join {
collect(j, ids);
}
}
SerializableContinuation::Branch {
id,
branches,
default,
next,
} => {
ids.push(id.as_str());
for b in branches.values() {
collect(b, ids);
}
if let Some(d) = default {
collect(d, ids);
}
if let Some(n) = next {
collect(n, ids);
}
}
SerializableContinuation::Loop { id, body, next, .. } => {
ids.push(id.as_str());
collect(body, ids);
if let Some(n) = next {
collect(n, ids);
}
}
SerializableContinuation::ChildWorkflow { id, child, next } => {
ids.push(id.as_str());
collect(child, ids);
if let Some(n) = next {
collect(n, ids);
}
}
}
}
let mut ids = vec![];
collect(self, &mut ids);
ids
}
#[must_use]
#[allow(clippy::too_many_lines)]
pub fn compute_definition_hash(&self) -> crate::DefinitionHash {
#[allow(clippy::too_many_lines)]
fn hash_continuation(cont: &SerializableContinuation, hasher: &mut Sha256) {
match cont {
SerializableContinuation::Task {
id,
timeout_ms,
retry_policy,
version,
next,
..
} => {
hasher.update(b"T:"); hasher.update(id.as_bytes());
if let Some(ms) = timeout_ms {
hasher.update(b":t:");
hasher.update(ms.to_string().as_bytes());
}
if let Some(rp) = retry_policy {
hasher.update(b":r:");
hasher.update(rp.max_retries.to_string().as_bytes());
hasher.update(b":");
hasher.update(rp.initial_delay.as_millis().to_string().as_bytes());
hasher.update(b":");
hasher.update(rp.backoff_multiplier.to_string().as_bytes());
}
if let Some(v) = version {
hasher.update(b":v:");
hasher.update(v.as_bytes());
}
hasher.update(b";");
if let Some(n) = next {
hash_continuation(n, hasher);
}
}
SerializableContinuation::Fork { id, branches, join } => {
hasher.update(b"F:");
hasher.update(id.as_bytes());
hasher.update(b"[");
for branch in branches {
hash_continuation(branch, hasher);
hasher.update(b",");
}
hasher.update(b"]");
if let Some(j) = join {
hasher.update(b"J:");
hash_continuation(j, hasher);
}
}
SerializableContinuation::Delay {
id,
duration_ms,
next,
} => {
hasher.update(b"D:");
hasher.update(id.as_bytes());
hasher.update(b":");
hasher.update(duration_ms.to_string().as_bytes());
hasher.update(b";");
if let Some(n) = next {
hash_continuation(n, hasher);
}
}
SerializableContinuation::AwaitSignal {
id,
signal_name,
timeout_ms,
next,
} => {
hasher.update(b"S:");
hasher.update(id.as_bytes());
hasher.update(b":");
hasher.update(signal_name.as_bytes());
if let Some(ms) = timeout_ms {
hasher.update(b":t:");
hasher.update(ms.to_string().as_bytes());
}
hasher.update(b";");
if let Some(n) = next {
hash_continuation(n, hasher);
}
}
SerializableContinuation::Branch {
id,
branches,
default,
next,
} => {
hasher.update(b"B:");
hasher.update(id.as_bytes());
hasher.update(b"{");
let mut keys: Vec<&String> = branches.keys().collect();
keys.sort();
for key in keys {
hasher.update(key.as_bytes());
hasher.update(b"=>");
if let Some(branch) = branches.get(key) {
hash_continuation(branch, hasher);
}
hasher.update(b",");
}
hasher.update(b"}");
if let Some(d) = default {
hasher.update(b"_=>");
hash_continuation(d, hasher);
}
hasher.update(b";");
if let Some(n) = next {
hash_continuation(n, hasher);
}
}
SerializableContinuation::Loop {
id,
body,
max_iterations,
on_max,
next,
} => {
hasher.update(b"L:");
hasher.update(id.as_bytes());
hasher.update(b":");
hasher.update(max_iterations.to_string().as_bytes());
hasher.update(b":");
hasher.update(on_max.to_string().as_bytes());
hasher.update(b"{");
hash_continuation(body, hasher);
hasher.update(b"}");
hasher.update(b";");
if let Some(n) = next {
hash_continuation(n, hasher);
}
}
SerializableContinuation::ChildWorkflow { id, child, next } => {
hasher.update(b"CW:");
hasher.update(id.as_bytes());
hasher.update(b"{");
hash_continuation(child, hasher);
hasher.update(b"}");
hasher.update(b";");
if let Some(n) = next {
hash_continuation(n, hasher);
}
}
}
}
let mut hasher = Sha256::new();
hash_continuation(self, &mut hasher);
crate::DefinitionHash::from_hash(crate::Hash32::from_digest(hasher))
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct SerializedWorkflowState {
pub workflow_id: String,
pub definition_hash: crate::DefinitionHash,
pub continuation: SerializableContinuation,
}
#[derive(
Debug,
Clone,
Copy,
Default,
PartialEq,
Eq,
strum::EnumString,
strum::Display,
strum::VariantNames,
)]
#[strum(serialize_all = "snake_case")]
pub enum ConflictPolicy {
#[default]
Fail,
#[strum(serialize = "use_existing", serialize = "useExisting")]
UseExisting,
#[strum(serialize = "terminate_existing", serialize = "terminateExisting")]
TerminateExisting,
}
impl ConflictPolicy {
pub fn parse_optional(s: Option<&str>) -> Result<Self, &str> {
match s {
None => Ok(Self::default()),
Some(val) => val.parse::<Self>().map_err(|_| val),
}
}
#[must_use]
pub fn valid_names() -> &'static [&'static str] {
<Self as strum::VariantNames>::VARIANTS
}
}
#[derive(Debug, strum::AsRefStr, strum::EnumDiscriminants)]
#[strum_discriminants(name(WorkflowStatusKind))]
#[strum_discriminants(derive(strum::AsRefStr))]
#[strum_discriminants(strum(serialize_all = "snake_case"))]
#[strum_discriminants(doc = "Fieldless discriminant of [`WorkflowStatus`] for string comparisons.")]
pub enum WorkflowStatus {
#[strum(serialize = "in_progress")]
InProgress,
#[strum(serialize = "completed")]
Completed,
#[strum(serialize = "failed")]
Failed(String),
#[strum(serialize = "cancelled")]
Cancelled {
reason: Option<String>,
cancelled_by: Option<String>,
},
#[strum(serialize = "paused")]
Paused {
reason: Option<String>,
paused_by: Option<String>,
},
#[strum(serialize = "waiting")]
Waiting {
wake_at: chrono::DateTime<chrono::Utc>,
delay_id: crate::TaskId,
},
#[strum(serialize = "awaiting_signal")]
AwaitingSignal {
signal_id: crate::TaskId,
signal_name: String,
wake_at: Option<chrono::DateTime<chrono::Utc>>,
},
}
#[derive(Debug, Default)]
pub struct FlatWorkflowStatus {
pub status: String,
pub error: Option<String>,
pub reason: Option<String>,
pub cancelled_by: Option<String>,
pub paused_by: Option<String>,
pub wake_at: Option<String>,
pub delay_id: Option<String>,
pub signal_id: Option<String>,
pub signal_name: Option<String>,
}
impl From<WorkflowStatus> for FlatWorkflowStatus {
fn from(status: WorkflowStatus) -> Self {
let mut flat = Self {
status: status.as_ref().to_string(),
..Self::default()
};
match status {
WorkflowStatus::Completed | WorkflowStatus::InProgress => {}
WorkflowStatus::Failed(e) => flat.error = Some(e),
WorkflowStatus::Cancelled {
reason,
cancelled_by,
} => {
flat.reason = reason;
flat.cancelled_by = cancelled_by;
}
WorkflowStatus::Paused { reason, paused_by } => {
flat.reason = reason;
flat.paused_by = paused_by;
}
WorkflowStatus::Waiting { wake_at, delay_id } => {
flat.wake_at = Some(wake_at.to_rfc3339());
flat.delay_id = Some(delay_id.to_hex());
}
WorkflowStatus::AwaitingSignal {
signal_id,
signal_name,
wake_at,
} => {
flat.signal_id = Some(signal_id.to_hex());
flat.signal_name = Some(signal_name);
flat.wake_at = wake_at.map(|t| t.to_rfc3339());
}
}
flat
}
}
pub use crate::builder::{
BranchCollector, ContinuationState, ForkBuilder, NoContinuation, NoRegistry, RegistryBehavior,
RouteBuilder, SubBuilder, WorkflowBuilder,
};
use crate::registry::TaskRegistry;
pub struct Workflow<C, Input, M = ()> {
pub(crate) definition_hash: crate::DefinitionHash,
pub(crate) context: WorkflowContext<C, M>,
pub(crate) continuation: WorkflowContinuation,
pub(crate) task_index: Arc<crate::task_index::TaskIndex>,
pub(crate) _phantom: PhantomData<Input>,
}
impl<C, Input, M> Workflow<C, Input, M> {
#[must_use]
pub fn workflow_id(&self) -> &str {
&self.context.workflow_name
}
#[must_use]
pub fn definition_hash(&self) -> &crate::DefinitionHash {
&self.definition_hash
}
#[must_use]
pub fn context(&self) -> &WorkflowContext<C, M> {
&self.context
}
#[must_use]
pub fn codec(&self) -> &Arc<C> {
&self.context.codec
}
#[must_use]
pub fn continuation(&self) -> &WorkflowContinuation {
&self.continuation
}
#[must_use]
pub fn task_index(&self) -> &crate::task_index::TaskIndex {
&self.task_index
}
#[must_use]
pub fn task_index_arc(&self) -> Arc<crate::task_index::TaskIndex> {
Arc::clone(&self.task_index)
}
#[must_use]
pub fn metadata(&self) -> &Arc<M> {
&self.context.metadata
}
#[must_use]
pub fn iter_nodes(&self) -> NodeIter<'_> {
self.continuation.iter_nodes()
}
#[must_use]
pub fn into_continuation(self) -> WorkflowContinuation {
self.continuation
}
}
pub struct SerializableWorkflow<C, Input, M = ()> {
pub(crate) inner: Workflow<C, Input, M>,
pub(crate) registry: TaskRegistry,
}
impl<C, Input, M> SerializableWorkflow<C, Input, M> {
#[must_use]
pub fn workflow_id(&self) -> &str {
self.inner.workflow_id()
}
#[must_use]
pub fn definition_hash(&self) -> &crate::DefinitionHash {
self.inner.definition_hash()
}
#[must_use]
pub fn workflow(&self) -> &Workflow<C, Input, M> {
&self.inner
}
#[must_use]
pub fn context(&self) -> &WorkflowContext<C, M> {
self.inner.context()
}
#[must_use]
pub fn codec(&self) -> &Arc<C> {
self.inner.codec()
}
#[must_use]
pub fn continuation(&self) -> &WorkflowContinuation {
self.inner.continuation()
}
#[must_use]
pub fn metadata(&self) -> &Arc<M> {
self.inner.metadata()
}
#[must_use]
pub fn registry(&self) -> &TaskRegistry {
&self.registry
}
#[must_use]
pub fn into_parts(self) -> (WorkflowContinuation, TaskRegistry) {
(self.inner.continuation, self.registry)
}
#[must_use]
pub fn to_serializable(&self) -> SerializedWorkflowState {
SerializedWorkflowState {
workflow_id: self.inner.workflow_id().to_string(),
definition_hash: self.inner.definition_hash,
continuation: self.inner.continuation().to_serializable(),
}
}
pub fn to_runnable(
&self,
state: &SerializedWorkflowState,
) -> Result<WorkflowContinuation, crate::error::BuildError> {
if state.definition_hash != self.inner.definition_hash {
return Err(crate::error::BuildError::DefinitionMismatch {
expected: self.inner.definition_hash,
found: state.definition_hash,
});
}
state.continuation.to_runnable(&self.registry)
}
}
impl<C, Input, M> Deref for SerializableWorkflow<C, Input, M> {
type Target = Workflow<C, Input, M>;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
#[cfg(test)]
#[allow(
clippy::unwrap_used,
clippy::panic,
clippy::cast_lossless,
clippy::cast_possible_truncation,
clippy::uninlined_format_args,
clippy::manual_let_else,
clippy::too_many_lines,
clippy::items_after_statements,
clippy::indexing_slicing
)]
mod tests {
use crate::codec::{Decoder, Encoder, sealed};
use crate::error::BoxError;
use crate::workflow::WorkflowBuilder;
use bytes::Bytes;
struct DummyCodec;
impl Encoder for DummyCodec {}
impl Decoder for DummyCodec {}
impl<Input> sealed::EncodeValue<Input> for DummyCodec {
fn encode_value(&self, _value: &Input) -> Result<Bytes, BoxError> {
Ok(Bytes::new())
}
}
impl<Output> sealed::DecodeValue<Output> for DummyCodec {
fn decode_value(&self, _bytes: Bytes) -> Result<Output, BoxError> {
Err("Not implemented".into())
}
}
#[test]
fn test_workflow_build() {
use crate::context::WorkflowContext;
use crate::workflow::Workflow;
use std::sync::Arc;
let ctx = WorkflowContext::new("test-workflow", Arc::new(DummyCodec), Arc::new(()));
let workflow: Workflow<DummyCodec, u32> = WorkflowBuilder::new(ctx)
.then("test", |i: u32| async move { Ok(i + 1) })
.build()
.unwrap();
let _workflow_ref = &workflow;
}
#[test]
fn test_workflow_with_metadata() {
use crate::context::WorkflowContext;
use crate::workflow::Workflow;
use std::sync::Arc;
let ctx = WorkflowContext::new(
"test-workflow",
Arc::new(DummyCodec),
Arc::new("test_metadata"),
);
let workflow: Workflow<DummyCodec, u32, &str> = WorkflowBuilder::new(ctx)
.then("test", |i: u32| async move { Ok(i + 1) })
.build()
.unwrap();
assert_eq!(**workflow.metadata(), "test_metadata");
}
#[test]
fn test_task_order() {
use crate::context::WorkflowContext;
use crate::workflow::Workflow;
use std::sync::Arc;
let ctx = WorkflowContext::new("test-workflow", Arc::new(DummyCodec), Arc::new(()));
let workflow: Workflow<DummyCodec, u32> = WorkflowBuilder::new(ctx)
.then("first", |i: u32| async move { Ok(i + 1) })
.then("second", |i: u32| async move { Ok(i + 2) })
.then("third", |i: u32| async move { Ok(i + 3) })
.build()
.unwrap();
let mut current = workflow.continuation();
let mut task_ids = vec![];
while let crate::workflow::WorkflowContinuation::Task { id, next, .. } = current {
task_ids.push(id.clone());
match next {
Some(next_box) => current = next_box.as_ref(),
None => break,
}
}
assert_eq!(
task_ids,
vec!["first", "second", "third"],
"Tasks should execute in the order they were added"
);
}
#[test]
fn test_heterogeneous_fork_join_compiles() {
use crate::context::WorkflowContext;
use crate::task::BranchOutputs;
use crate::workflow::Workflow;
use std::sync::Arc;
let ctx = WorkflowContext::new("test-workflow", Arc::new(DummyCodec), Arc::new(()));
let workflow: Workflow<DummyCodec, u32> = WorkflowBuilder::new(ctx)
.then("prepare", |i: u32| async move { Ok(i) })
.branches(|b| {
b.add("count", |i: u32| async move { Ok(i * 2) });
b.add("name", |i: u32| async move { Ok(format!("item_{}", i)) });
b.add("ratio", |i: u32| async move { Ok(i as f64 / 100.0) });
})
.join("combine", |outputs: BranchOutputs<DummyCodec>| async move {
let _ = outputs.len();
Ok(format!("combined {} branches", outputs.len()))
})
.then("final", |s: String| async move { Ok(s.len() as u32) })
.build()
.unwrap();
let _workflow_ref = &workflow;
}
#[test]
fn test_duplicate_branch_id_returns_error() {
use crate::context::WorkflowContext;
use crate::error::BuildError;
use std::sync::Arc;
let ctx = WorkflowContext::new("test-workflow", Arc::new(DummyCodec), Arc::new(()));
let result = WorkflowBuilder::<_, u32, _>::new(ctx)
.then("prepare", |i: u32| async move { Ok(i) })
.branches(|b| {
b.add("count", |i: u32| async move { Ok(i * 2) });
b.add("count", |i: u32| async move { Ok(i * 3) }); })
.join("combine", |_outputs| async move { Ok(0u32) })
.build();
let err = match result {
Err(e) => e,
Ok(_) => panic!("expected build error"),
};
assert!(
err.iter()
.any(|e| matches!(e, BuildError::DuplicateTaskId(id) if id == "count"))
);
}
#[test]
fn test_serializable_continuation() {
use crate::context::WorkflowContext;
use crate::error::BuildError;
use crate::registry::TaskRegistry;
use std::sync::Arc;
let codec = Arc::new(DummyCodec);
let ctx = WorkflowContext::new("test-workflow", codec.clone(), Arc::new(()));
let workflow = WorkflowBuilder::new(ctx)
.then("step1", |i: u32| async move { Ok(i + 1) })
.then("step2", |i: u32| async move { Ok(i * 2) })
.build()
.unwrap();
let serializable = workflow.continuation().to_serializable();
let task_ids = serializable.task_ids();
assert_eq!(task_ids, vec!["step1", "step2"]);
let empty_registry = TaskRegistry::new();
let result = serializable.to_runnable(&empty_registry);
assert!(matches!(result, Err(BuildError::TaskNotFound(id)) if id == "step1"));
let mut registry = TaskRegistry::new();
registry.register_fn("step1", codec.clone(), |i: u32| async move { Ok(i + 1) });
registry.register_fn("step2", codec.clone(), |i: u32| async move { Ok(i * 2) });
let hydrated = serializable.to_runnable(®istry);
assert!(hydrated.is_ok());
}
#[test]
fn test_serializable_fork_join() {
use crate::context::WorkflowContext;
use crate::task::BranchOutputs;
use std::sync::Arc;
let ctx = WorkflowContext::new("test-workflow", Arc::new(DummyCodec), Arc::new(()));
let workflow = WorkflowBuilder::new(ctx)
.then("prepare", |i: u32| async move { Ok(i) })
.branches(|b| {
b.add("branch_a", |i: u32| async move { Ok(i * 2) });
b.add("branch_b", |i: u32| async move { Ok(i + 10) });
})
.join(
"merge",
|_: BranchOutputs<DummyCodec>| async move { Ok(0u32) },
)
.build()
.unwrap();
let serializable = workflow.continuation().to_serializable();
let task_ids = serializable.task_ids();
assert!(task_ids.contains(&"prepare"));
assert!(task_ids.contains(&"branch_a||branch_b"));
assert!(task_ids.contains(&"branch_a"));
assert!(task_ids.contains(&"branch_b"));
assert!(task_ids.contains(&"merge"));
assert_eq!(task_ids.len(), 5);
}
#[test]
fn test_serializable_workflow_builder() {
use crate::context::WorkflowContext;
use std::sync::Arc;
let codec = Arc::new(DummyCodec);
let ctx = WorkflowContext::new("test-workflow", codec, Arc::new(()));
let workflow = WorkflowBuilder::new(ctx)
.with_registry()
.then("step1", |i: u32| async move { Ok(i + 1) })
.then("step2", |i: u32| async move { Ok(i * 2) })
.build()
.unwrap();
assert!(workflow.registry().contains("step1"));
assert!(workflow.registry().contains("step2"));
assert_eq!(workflow.registry().len(), 2);
let serializable = workflow.to_serializable();
assert_eq!(serializable.continuation.task_ids(), vec!["step1", "step2"]);
let hydrated = workflow.to_runnable(&serializable);
assert!(hydrated.is_ok());
}
#[test]
fn test_with_existing_registry_and_then_registered() {
use crate::context::WorkflowContext;
use crate::registry::TaskRegistry;
use crate::workflow::SerializableWorkflow;
use std::sync::Arc;
let codec = Arc::new(DummyCodec);
let mut registry = TaskRegistry::new();
registry.register_fn("double", codec.clone(), |i: u32| async move { Ok(i * 2) });
registry.register_fn("add_ten", codec.clone(), |i: u32| async move { Ok(i + 10) });
let ctx = WorkflowContext::new("test-workflow", codec.clone(), Arc::new(()));
let workflow: SerializableWorkflow<_, u32> = WorkflowBuilder::new(ctx)
.with_existing_registry(registry)
.then_registered::<u32>("double")
.then_registered::<u32>("add_ten")
.build()
.unwrap();
assert!(workflow.registry().contains("double"));
assert!(workflow.registry().contains("add_ten"));
let serializable = workflow.to_serializable();
assert_eq!(
serializable.continuation.task_ids(),
vec!["double", "add_ten"]
);
let hydrated = workflow.to_runnable(&serializable);
assert!(hydrated.is_ok());
}
#[test]
fn test_mixed_inline_and_registered_tasks() {
use crate::context::WorkflowContext;
use crate::registry::TaskRegistry;
use crate::workflow::SerializableWorkflow;
use std::sync::Arc;
let codec = Arc::new(DummyCodec);
let mut registry = TaskRegistry::new();
registry.register_fn(
"preregistered",
codec.clone(),
|i: u32| async move { Ok(i * 2) },
);
let ctx = WorkflowContext::new("test-workflow", codec.clone(), Arc::new(()));
let workflow: SerializableWorkflow<_, u32> = WorkflowBuilder::new(ctx)
.with_existing_registry(registry)
.then_registered::<u32>("preregistered") .then("inline", |i: u32| async move { Ok(i + 5) }) .build()
.unwrap();
assert!(workflow.registry().contains("preregistered"));
assert!(workflow.registry().contains("inline"));
assert_eq!(workflow.registry().len(), 2);
}
#[test]
fn test_workflow_id_and_definition_hash() {
use crate::context::WorkflowContext;
use std::sync::Arc;
let ctx = WorkflowContext::new("my-workflow-id", Arc::new(DummyCodec), Arc::new(()));
let workflow = WorkflowBuilder::new(ctx)
.with_registry()
.then("step1", |i: u32| async move { Ok(i + 1) })
.then("step2", |i: u32| async move { Ok(i * 2) })
.build()
.unwrap();
assert_eq!(workflow.workflow_id(), "my-workflow-id");
assert_ne!(
*workflow.definition_hash(),
crate::DefinitionHash::from_bytes([0u8; 32])
);
let state = workflow.to_serializable();
assert_eq!(state.workflow_id, "my-workflow-id");
assert_eq!(&state.definition_hash, workflow.definition_hash());
}
#[test]
fn test_definition_hash_changes_with_structure() {
use crate::context::WorkflowContext;
use std::sync::Arc;
let ctx1 = WorkflowContext::new("workflow", Arc::new(DummyCodec), Arc::new(()));
let workflow1 = WorkflowBuilder::new(ctx1)
.with_registry()
.then("step1", |i: u32| async move { Ok(i + 1) })
.build()
.unwrap();
let ctx2 = WorkflowContext::new("workflow", Arc::new(DummyCodec), Arc::new(()));
let workflow2 = WorkflowBuilder::new(ctx2)
.with_registry()
.then("step1", |i: u32| async move { Ok(i + 1) })
.then("step2", |i: u32| async move { Ok(i * 2) })
.build()
.unwrap();
assert_ne!(workflow1.definition_hash(), workflow2.definition_hash());
}
#[test]
fn test_definition_mismatch_error() {
use crate::context::WorkflowContext;
use crate::error::BuildError;
use std::sync::Arc;
let ctx = WorkflowContext::new("test-workflow", Arc::new(DummyCodec), Arc::new(()));
let workflow = WorkflowBuilder::new(ctx)
.with_registry()
.then("step1", |i: u32| async move { Ok(i + 1) })
.build()
.unwrap();
let mut state = workflow.to_serializable();
state.definition_hash = crate::DefinitionHash::sha256(b"wrong-hash");
let result = workflow.to_runnable(&state);
assert!(matches!(result, Err(BuildError::DefinitionMismatch { .. })));
}
#[test]
fn test_duplicate_id_tampering_detection() {
use crate::error::BuildError;
use crate::registry::TaskRegistry;
use crate::workflow::SerializableContinuation;
use std::sync::Arc;
let codec = Arc::new(DummyCodec);
let mut registry = TaskRegistry::new();
registry.register_fn("step1", codec.clone(), |i: u32| async move { Ok(i + 1) });
registry.register_fn("step2", codec.clone(), |i: u32| async move { Ok(i * 2) });
let tampered = SerializableContinuation::Task {
id: "step1".to_string(),
timeout_ms: None,
retry_policy: None,
version: None,
priority: None,
tags: vec![],
next: Some(Box::new(SerializableContinuation::Task {
id: "step1".to_string(), timeout_ms: None,
retry_policy: None,
version: None,
priority: None,
tags: vec![],
next: None,
})),
};
let result = tampered.to_runnable(®istry);
assert!(matches!(
result,
Err(BuildError::DuplicateTaskId(id)) if id == "step1"
));
}
#[test]
fn test_delay_builder() {
use crate::context::WorkflowContext;
use crate::workflow::{Workflow, WorkflowContinuation};
use std::sync::Arc;
use std::time::Duration;
let ctx = WorkflowContext::new("test-workflow", Arc::new(DummyCodec), Arc::new(()));
let workflow: Workflow<DummyCodec, u32> = WorkflowBuilder::new(ctx)
.then("step1", |i: u32| async move { Ok(i + 1) })
.delay("wait_1s", Duration::from_secs(1))
.then("step2", |i: u32| async move { Ok(i * 2) })
.build()
.unwrap();
let mut ids = vec![];
let mut current = workflow.continuation();
loop {
match current {
WorkflowContinuation::Task { id, next, .. } => {
ids.push(format!("task:{id}"));
match next {
Some(n) => current = n,
None => break,
}
}
WorkflowContinuation::Delay {
id, duration, next, ..
} => {
ids.push(format!("delay:{id}:{}ms", duration.as_millis()));
match next {
Some(n) => current = n,
None => break,
}
}
_ => break,
}
}
assert_eq!(
ids,
vec!["task:step1", "delay:wait_1s:1000ms", "task:step2"]
);
}
#[test]
fn test_delay_serialization_roundtrip() {
use crate::context::WorkflowContext;
use crate::workflow::SerializableContinuation;
use std::sync::Arc;
use std::time::Duration;
let ctx = WorkflowContext::new("test-workflow", Arc::new(DummyCodec), Arc::new(()));
let workflow = WorkflowBuilder::new(ctx)
.with_registry()
.then("step1", |i: u32| async move { Ok(i + 1) })
.delay("wait_5s", Duration::from_secs(5))
.then("step2", |i: u32| async move { Ok(i * 2) })
.build()
.unwrap();
let serializable = workflow.to_serializable();
let task_ids = serializable.continuation.task_ids();
assert_eq!(task_ids, vec!["step1", "wait_5s", "step2"]);
match &serializable.continuation {
SerializableContinuation::Task { next, .. } => {
let next = next.as_ref().unwrap();
match next.as_ref() {
SerializableContinuation::Delay {
id, duration_ms, ..
} => {
assert_eq!(id, "wait_5s");
assert_eq!(*duration_ms, 5000);
}
other => panic!("Expected Delay, got {other:?}"),
}
}
other => panic!("Expected Task, got {other:?}"),
}
let hydrated = workflow.to_runnable(&serializable);
assert!(hydrated.is_ok());
}
#[test]
fn test_delay_first_task_id() {
use crate::context::WorkflowContext;
use std::sync::Arc;
use std::time::Duration;
let ctx = WorkflowContext::new("test-workflow", Arc::new(DummyCodec), Arc::new(()));
let workflow = WorkflowBuilder::new(ctx)
.delay("initial_delay", Duration::from_secs(10))
.then("step1", |i: u32| async move { Ok(i + 1) })
.build()
.unwrap();
assert_eq!(workflow.continuation().first_task_id(), "initial_delay");
}
#[test]
fn test_delay_duplicate_id_detection() {
use crate::context::WorkflowContext;
use crate::error::BuildError;
use std::sync::Arc;
use std::time::Duration;
let ctx = WorkflowContext::new("test-workflow", Arc::new(DummyCodec), Arc::new(()));
let result = WorkflowBuilder::<_, u32, _>::new(ctx)
.then("dup", |i: u32| async move { Ok(i + 1) })
.delay("dup", Duration::from_secs(1))
.build();
let err = match result {
Err(e) => e,
Ok(_) => panic!("expected build error"),
};
assert!(
err.iter()
.any(|e| matches!(e, BuildError::DuplicateTaskId(id) if id == "dup"))
);
}
#[test]
fn test_delay_definition_hash_includes_duration() {
use crate::context::WorkflowContext;
use crate::workflow::SerializableWorkflow;
use std::sync::Arc;
use std::time::Duration;
let ctx1 = WorkflowContext::new("workflow", Arc::new(DummyCodec), Arc::new(()));
let wf1: SerializableWorkflow<_, u32> = WorkflowBuilder::new(ctx1)
.with_registry()
.then("step1", |i: u32| async move { Ok(i + 1) })
.delay("wait", Duration::from_secs(1))
.build()
.unwrap();
let ctx2 = WorkflowContext::new("workflow", Arc::new(DummyCodec), Arc::new(()));
let wf2: SerializableWorkflow<_, u32> = WorkflowBuilder::new(ctx2)
.with_registry()
.then("step1", |i: u32| async move { Ok(i + 1) })
.delay("wait", Duration::from_mins(1))
.build()
.unwrap();
assert_ne!(wf1.definition_hash(), wf2.definition_hash());
}
#[test]
fn test_delay_definition_hash_differs_from_task() {
use crate::context::WorkflowContext;
use crate::workflow::SerializableWorkflow;
use std::sync::Arc;
use std::time::Duration;
let ctx1 = WorkflowContext::new("workflow", Arc::new(DummyCodec), Arc::new(()));
let wf1: SerializableWorkflow<_, u32> = WorkflowBuilder::new(ctx1)
.with_registry()
.then("step1", |i: u32| async move { Ok(i + 1) })
.build()
.unwrap();
let ctx2 = WorkflowContext::new("workflow", Arc::new(DummyCodec), Arc::new(()));
let wf2: SerializableWorkflow<_, u32> = WorkflowBuilder::new(ctx2)
.with_registry()
.delay("step1", Duration::from_secs(1))
.build()
.unwrap();
assert_ne!(wf1.definition_hash(), wf2.definition_hash());
}
#[test]
fn test_delay_task_ids() {
use crate::context::WorkflowContext;
use std::sync::Arc;
use std::time::Duration;
let ctx = WorkflowContext::new("test-workflow", Arc::new(DummyCodec), Arc::new(()));
let workflow = WorkflowBuilder::new(ctx)
.then("fetch", |i: u32| async move { Ok(i) })
.delay("wait_24h", Duration::from_hours(24))
.then("process", |i: u32| async move { Ok(i + 1) })
.build()
.unwrap();
let serializable = workflow.continuation().to_serializable();
let ids = serializable.task_ids();
assert_eq!(ids, vec!["fetch", "wait_24h", "process"]);
}
#[test]
fn test_delay_only_workflow() {
use crate::context::WorkflowContext;
use std::sync::Arc;
use std::time::Duration;
use crate::workflow::Workflow;
let ctx = WorkflowContext::new("test-workflow", Arc::new(DummyCodec), Arc::new(()));
let workflow: Workflow<DummyCodec, u32> = WorkflowBuilder::new(ctx)
.delay("just_wait", Duration::from_millis(10))
.build()
.unwrap();
assert_eq!(workflow.continuation().first_task_id(), "just_wait");
let serializable = workflow.continuation().to_serializable();
assert_eq!(serializable.task_ids(), vec!["just_wait"]);
}
#[test]
fn test_delay_to_runnable_no_registry_needed() {
use crate::registry::TaskRegistry;
use crate::workflow::SerializableContinuation;
let delay = SerializableContinuation::Delay {
id: "wait".to_string(),
duration_ms: 5000,
next: None,
};
let empty_registry = TaskRegistry::new();
let result = delay.to_runnable(&empty_registry);
assert!(result.is_ok());
let runnable = result.unwrap();
match runnable {
crate::workflow::WorkflowContinuation::Delay {
id, duration, next, ..
} => {
assert_eq!(id, "wait");
assert_eq!(duration, std::time::Duration::from_secs(5));
assert!(next.is_none());
}
_ => panic!("Expected Delay variant"),
}
}
#[test]
fn test_timeout_serialization_roundtrip() {
use crate::context::WorkflowContext;
use crate::task::TaskMetadata;
use crate::workflow::SerializableContinuation;
use std::sync::Arc;
use std::time::Duration;
let ctx = WorkflowContext::new("test-workflow", Arc::new(DummyCodec), Arc::new(()));
let workflow = WorkflowBuilder::new(ctx)
.with_registry()
.then("step1", |i: u32| async move { Ok(i + 1) })
.with_metadata(TaskMetadata {
timeout: Some(Duration::from_secs(30)),
..Default::default()
})
.then("step2", |i: u32| async move { Ok(i * 2) })
.build()
.unwrap();
let serializable = workflow.to_serializable();
match &serializable.continuation {
SerializableContinuation::Task { id, timeout_ms, .. } => {
assert_eq!(id, "step1");
assert_eq!(*timeout_ms, Some(30_000));
}
other => panic!("Expected Task, got {other:?}"),
}
let hydrated = workflow.to_runnable(&serializable).unwrap();
match &hydrated {
crate::workflow::WorkflowContinuation::Task { id, timeout, .. } => {
assert_eq!(id, "step1");
assert_eq!(*timeout, Some(Duration::from_secs(30)));
}
_ => panic!("Expected Task variant"),
}
}
#[test]
fn test_timeout_changes_definition_hash() {
use crate::context::WorkflowContext;
use crate::task::TaskMetadata;
use crate::workflow::SerializableWorkflow;
use std::sync::Arc;
use std::time::Duration;
let ctx1 = WorkflowContext::new("workflow", Arc::new(DummyCodec), Arc::new(()));
let wf1: SerializableWorkflow<_, u32> = WorkflowBuilder::new(ctx1)
.with_registry()
.then("step1", |i: u32| async move { Ok(i + 1) })
.build()
.unwrap();
let ctx2 = WorkflowContext::new("workflow", Arc::new(DummyCodec), Arc::new(()));
let wf2: SerializableWorkflow<_, u32> = WorkflowBuilder::new(ctx2)
.with_registry()
.then("step1", |i: u32| async move { Ok(i + 1) })
.with_metadata(TaskMetadata {
timeout: Some(Duration::from_secs(30)),
..Default::default()
})
.build()
.unwrap();
assert_ne!(wf1.definition_hash(), wf2.definition_hash());
}
#[test]
fn test_no_timeout_field_absent_in_serialization() {
use crate::context::WorkflowContext;
use std::sync::Arc;
let ctx = WorkflowContext::new("test-workflow", Arc::new(DummyCodec), Arc::new(()));
let workflow = WorkflowBuilder::new(ctx)
.with_registry()
.then("step1", |i: u32| async move { Ok(i + 1) })
.build()
.unwrap();
let serializable = workflow.to_serializable();
let json = serde_json::to_string(&serializable.continuation).unwrap();
assert!(
!json.contains("timeout_ms"),
"timeout_ms should be absent when None: {json}"
);
}
#[test]
fn test_task_version_changes_definition_hash() {
use crate::context::WorkflowContext;
use crate::task::TaskMetadata;
use crate::workflow::SerializableWorkflow;
use std::sync::Arc;
let ctx1 = WorkflowContext::new("workflow", Arc::new(DummyCodec), Arc::new(()));
let wf_no_version: SerializableWorkflow<_, u32> = WorkflowBuilder::new(ctx1)
.with_registry()
.then("step1", |i: u32| async move { Ok(i + 1) })
.build()
.unwrap();
let ctx2 = WorkflowContext::new("workflow", Arc::new(DummyCodec), Arc::new(()));
let wf_v1: SerializableWorkflow<_, u32> = WorkflowBuilder::new(ctx2)
.with_registry()
.then("step1", |i: u32| async move { Ok(i + 1) })
.with_metadata(TaskMetadata {
version: Some("1.0".into()),
..Default::default()
})
.build()
.unwrap();
let ctx3 = WorkflowContext::new("workflow", Arc::new(DummyCodec), Arc::new(()));
let wf_v2: SerializableWorkflow<_, u32> = WorkflowBuilder::new(ctx3)
.with_registry()
.then("step1", |i: u32| async move { Ok(i + 1) })
.with_metadata(TaskMetadata {
version: Some("2.0".into()),
..Default::default()
})
.build()
.unwrap();
let ctx4 = WorkflowContext::new("workflow", Arc::new(DummyCodec), Arc::new(()));
let wf_v1_again: SerializableWorkflow<_, u32> = WorkflowBuilder::new(ctx4)
.with_registry()
.then("step1", |i: u32| async move { Ok(i + 1) })
.with_metadata(TaskMetadata {
version: Some("1.0".into()),
..Default::default()
})
.build()
.unwrap();
assert_ne!(
wf_no_version.definition_hash(),
wf_v1.definition_hash(),
"Adding version should change hash"
);
assert_ne!(
wf_v1.definition_hash(),
wf_v2.definition_hash(),
"Different versions should produce different hashes"
);
assert_eq!(
wf_v1.definition_hash(),
wf_v1_again.definition_hash(),
"Same version should produce same hash"
);
}
#[test]
fn test_version_absent_in_serialization_when_none() {
use crate::context::WorkflowContext;
use std::sync::Arc;
let ctx = WorkflowContext::new("test-workflow", Arc::new(DummyCodec), Arc::new(()));
let workflow = WorkflowBuilder::new(ctx)
.with_registry()
.then("step1", |i: u32| async move { Ok(i + 1) })
.build()
.unwrap();
let serializable = workflow.to_serializable();
let json = serde_json::to_string(&serializable.continuation).unwrap();
assert!(
!json.contains("version"),
"version should be absent when None: {json}"
);
}
#[test]
fn test_version_present_in_serialization_when_set() {
use crate::context::WorkflowContext;
use crate::task::TaskMetadata;
use std::sync::Arc;
let ctx = WorkflowContext::new("test-workflow", Arc::new(DummyCodec), Arc::new(()));
let workflow = WorkflowBuilder::new(ctx)
.with_registry()
.then("step1", |i: u32| async move { Ok(i + 1) })
.with_metadata(TaskMetadata {
version: Some("3.0".into()),
..Default::default()
})
.build()
.unwrap();
let serializable = workflow.to_serializable();
let json = serde_json::to_string(&serializable.continuation).unwrap();
assert!(
json.contains(r#""version":"3.0""#),
"version should be present in JSON: {json}"
);
}
#[test]
fn test_nodes_single_task() {
use crate::context::WorkflowContext;
use crate::workflow::{NodeKind, Workflow};
use std::sync::Arc;
let ctx = WorkflowContext::new("test-workflow", Arc::new(DummyCodec), Arc::new(()));
let workflow: Workflow<DummyCodec, u32> = WorkflowBuilder::new(ctx)
.then("only", |i: u32| async move { Ok(i + 1) })
.build()
.unwrap();
let nodes: Vec<_> = workflow.iter_nodes().collect();
assert_eq!(nodes.len(), 1);
assert_eq!(nodes[0].id, "only");
assert_eq!(nodes[0].kind, NodeKind::Task);
assert!(nodes[0].predecessor_id.is_none());
}
#[test]
fn test_nodes_chain_order() {
use crate::context::WorkflowContext;
use crate::workflow::{NodeKind, Workflow};
use std::sync::Arc;
let ctx = WorkflowContext::new("test-workflow", Arc::new(DummyCodec), Arc::new(()));
let workflow: Workflow<DummyCodec, u32> = WorkflowBuilder::new(ctx)
.then("a", |i: u32| async move { Ok(i + 1) })
.then("b", |i: u32| async move { Ok(i + 2) })
.then("c", |i: u32| async move { Ok(i + 3) })
.build()
.unwrap();
let nodes: Vec<_> = workflow.iter_nodes().collect();
let ids: Vec<&str> = nodes.iter().map(|n| n.id).collect();
assert_eq!(ids, vec!["a", "b", "c"]);
assert!(nodes.iter().all(|n| n.kind == NodeKind::Task));
assert_eq!(nodes[0].predecessor_id, None);
assert_eq!(nodes[1].predecessor_id, Some("a"));
assert_eq!(nodes[2].predecessor_id, Some("b"));
}
#[test]
fn test_nodes_fork_with_join() {
use crate::context::WorkflowContext;
use crate::task::BranchOutputs;
use crate::workflow::{NodeKind, Workflow};
use std::sync::Arc;
let ctx = WorkflowContext::new("test-workflow", Arc::new(DummyCodec), Arc::new(()));
let workflow: Workflow<DummyCodec, u32> = WorkflowBuilder::new(ctx)
.then("prepare", |i: u32| async move { Ok(i) })
.branches(|b| {
b.add("left", |i: u32| async move { Ok(i * 2) });
b.add("right", |i: u32| async move { Ok(i + 10) });
})
.join(
"merge",
|_: BranchOutputs<DummyCodec>| async move { Ok(0u32) },
)
.build()
.unwrap();
let nodes: Vec<_> = workflow.iter_nodes().collect();
let ids: Vec<&str> = nodes.iter().map(|n| n.id).collect();
assert_eq!(ids[0], "prepare");
assert_eq!(nodes[1].kind, NodeKind::Fork);
assert!(ids.contains(&"left"));
assert!(ids.contains(&"right"));
assert_eq!(*ids.last().unwrap(), "merge");
assert_eq!(nodes[1].predecessor_id, Some("prepare"));
let fork_id = nodes[1].id;
let left_node = nodes.iter().find(|n| n.id == "left").unwrap();
let right_node = nodes.iter().find(|n| n.id == "right").unwrap();
assert_eq!(left_node.predecessor_id, Some(fork_id));
assert_eq!(right_node.predecessor_id, Some(fork_id));
let merge_node = nodes.iter().find(|n| n.id == "merge").unwrap();
assert_eq!(merge_node.predecessor_id, Some(fork_id));
}
#[test]
fn test_nodes_loop() {
use crate::context::WorkflowContext;
use crate::loop_result::LoopResult;
use crate::workflow::{NodeKind, Workflow};
use std::sync::Arc;
let ctx = WorkflowContext::new("test-workflow", Arc::new(DummyCodec), Arc::new(()));
let workflow: Workflow<DummyCodec, u32> = WorkflowBuilder::new(ctx)
.loop_task(
"iterate",
|i: u32| async move { Ok(LoopResult::Done(i)) },
5,
)
.then("after", |i: u32| async move { Ok(i) })
.build()
.unwrap();
let nodes: Vec<_> = workflow.iter_nodes().collect();
assert_eq!(nodes[0].kind, NodeKind::Loop);
assert_eq!(nodes[1].id, "iterate");
assert_eq!(nodes[1].kind, NodeKind::Task);
assert_eq!(nodes[2].id, "after");
assert_eq!(nodes[2].kind, NodeKind::Task);
assert_eq!(nodes[0].predecessor_id, None);
assert_eq!(nodes[1].predecessor_id, Some(nodes[0].id)); assert_eq!(nodes[2].predecessor_id, Some(nodes[0].id)); }
#[test]
fn test_nodes_delay_reports_duration_as_timeout() {
use crate::context::WorkflowContext;
use crate::workflow::{NodeKind, Workflow};
use std::sync::Arc;
use std::time::Duration;
let ctx = WorkflowContext::new("test-workflow", Arc::new(DummyCodec), Arc::new(()));
let workflow: Workflow<DummyCodec, u32> = WorkflowBuilder::new(ctx)
.delay("wait_5s", Duration::from_secs(5))
.then("after", |i: u32| async move { Ok(i) })
.build()
.unwrap();
let nodes: Vec<_> = workflow.iter_nodes().collect();
assert_eq!(nodes[0].id, "wait_5s");
assert_eq!(nodes[0].kind, NodeKind::Delay);
assert_eq!(nodes[0].timeout, Some(Duration::from_secs(5)));
assert_eq!(nodes[0].predecessor_id, None);
assert_eq!(nodes[1].id, "after");
assert_eq!(nodes[1].predecessor_id, Some("wait_5s"));
}
#[test]
fn test_nodes_metadata_extraction() {
use crate::context::WorkflowContext;
use crate::task::{RetryPolicy, TaskMetadata};
use crate::workflow::NodeKind;
use std::sync::Arc;
use std::time::Duration;
let retry = RetryPolicy {
max_retries: 3,
initial_delay: Duration::from_millis(100),
backoff_multiplier: 2.0,
max_delay: Some(Duration::from_secs(10)),
};
let ctx = WorkflowContext::new("test-workflow", Arc::new(DummyCodec), Arc::new(()));
let workflow = WorkflowBuilder::new(ctx)
.with_registry()
.then("step", |i: u32| async move { Ok(i) })
.with_metadata(TaskMetadata {
timeout: Some(Duration::from_secs(30)),
retries: Some(retry.clone()),
version: Some("2.0".into()),
..Default::default()
})
.build()
.unwrap();
let nodes: Vec<_> = workflow.iter_nodes().collect();
assert_eq!(nodes.len(), 1);
let node = &nodes[0];
assert_eq!(node.id, "step");
assert_eq!(node.kind, NodeKind::Task);
assert_eq!(node.timeout, Some(Duration::from_secs(30)));
assert_eq!(node.retry_policy.unwrap().max_retries, 3);
}
}
#[cfg(test)]
#[allow(
clippy::unwrap_used,
clippy::expect_used,
clippy::panic,
clippy::indexing_slicing,
clippy::too_many_lines,
clippy::items_after_statements
)]
mod proptests {
use super::{MaxIterationsPolicy, SerializableContinuation};
use proptest::prelude::*;
fn arb_id() -> impl Strategy<Value = String> {
"[a-z0-9]{1,8}"
}
fn arb_continuation(depth: usize) -> BoxedStrategy<SerializableContinuation> {
let leaf = arb_id().prop_map(|id| SerializableContinuation::Task {
id,
timeout_ms: None,
retry_policy: None,
version: None,
priority: None,
tags: vec![],
next: None,
});
if depth == 0 {
return leaf.boxed();
}
prop_oneof![
(
arb_id(),
prop::option::of(any::<u64>()),
prop::option::of(arb_continuation(depth - 1).prop_map(Box::new)),
)
.prop_map(|(id, timeout_ms, next)| SerializableContinuation::Task {
id,
timeout_ms,
retry_policy: None,
version: None,
priority: None,
tags: vec![],
next,
}),
(
arb_id(),
prop::collection::vec(arb_continuation(depth - 1), 0..3),
prop::option::of(arb_continuation(depth - 1).prop_map(Box::new)),
)
.prop_map(|(id, branches, join)| SerializableContinuation::Fork {
id,
branches,
join,
}),
(
arb_id(),
any::<u64>(),
prop::option::of(arb_continuation(depth - 1).prop_map(Box::new)),
)
.prop_map(|(id, duration_ms, next)| SerializableContinuation::Delay {
id,
duration_ms,
next,
}),
(
arb_id(),
arb_id(),
prop::option::of(any::<u64>()),
prop::option::of(arb_continuation(depth - 1).prop_map(Box::new)),
)
.prop_map(|(id, signal_name, timeout_ms, next)| {
SerializableContinuation::AwaitSignal {
id,
signal_name,
timeout_ms,
next,
}
}),
(
arb_id(),
prop::collection::hash_map(
arb_id(),
arb_continuation(depth - 1).prop_map(Box::new),
0..3
),
prop::option::of(arb_continuation(depth - 1).prop_map(Box::new)),
prop::option::of(arb_continuation(depth - 1).prop_map(Box::new)),
)
.prop_map(|(id, branches, default, next)| {
SerializableContinuation::Branch {
id,
branches,
default,
next,
}
}),
(
arb_id(),
arb_continuation(depth - 1).prop_map(Box::new),
1..100u32,
prop::bool::ANY.prop_map(|b| if b {
MaxIterationsPolicy::Fail
} else {
MaxIterationsPolicy::ExitWithLast
}),
prop::option::of(arb_continuation(depth - 1).prop_map(Box::new)),
)
.prop_map(|(id, body, max_iterations, on_max, next)| {
SerializableContinuation::Loop {
id,
body,
max_iterations,
on_max,
next,
}
}),
(
arb_id(),
arb_continuation(depth - 1).prop_map(Box::new),
prop::option::of(arb_continuation(depth - 1).prop_map(Box::new)),
)
.prop_map(|(id, child, next)| {
SerializableContinuation::ChildWorkflow { id, child, next }
}),
]
.boxed()
}
fn arb_unique_continuation(
depth: usize,
prefix: &str,
) -> BoxedStrategy<SerializableContinuation> {
let id = format!("{prefix}n");
if depth == 0 {
return Just(SerializableContinuation::Task {
id,
timeout_ms: None,
retry_policy: None,
version: None,
priority: None,
tags: vec![],
next: None,
})
.boxed();
}
let id_clone = id.clone();
prop_oneof![
prop::option::of(
arb_unique_continuation(depth - 1, &format!("{prefix}0_")).prop_map(Box::new),
)
.prop_map(move |next| SerializableContinuation::Task {
id: id_clone.clone(),
timeout_ms: None,
retry_policy: None,
version: None,
priority: None,
tags: vec![],
next,
}),
{
let id_f = id.clone();
let prefix_f = prefix.to_string();
(0..3u8)
.prop_flat_map(move |branch_count| {
let id_inner = id_f.clone();
let prefix_inner = prefix_f.clone();
let branches: Vec<BoxedStrategy<SerializableContinuation>> = (0
..branch_count)
.map(|i| {
arb_unique_continuation(depth - 1, &format!("{prefix_inner}b{i}_"))
})
.collect();
let join = prop::option::of(
arb_unique_continuation(depth - 1, &format!("{prefix_inner}j_"))
.prop_map(Box::new),
);
(branches, join).prop_map(move |(branches, join)| {
SerializableContinuation::Fork {
id: id_inner.clone(),
branches,
join,
}
})
})
.boxed()
},
{
let id_d = id.clone();
let prefix_d = prefix.to_string();
(
any::<u64>(),
prop::option::of(
arb_unique_continuation(depth - 1, &format!("{prefix_d}d_"))
.prop_map(Box::new),
),
)
.prop_map(move |(duration_ms, next)| {
SerializableContinuation::Delay {
id: id_d.clone(),
duration_ms,
next,
}
})
},
{
let id_s = id.clone();
let prefix_s = prefix.to_string();
(
arb_id(),
prop::option::of(any::<u64>()),
prop::option::of(
arb_unique_continuation(depth - 1, &format!("{prefix_s}s_"))
.prop_map(Box::new),
),
)
.prop_map(move |(signal_name, timeout_ms, next)| {
SerializableContinuation::AwaitSignal {
id: id_s.clone(),
signal_name,
timeout_ms,
next,
}
})
},
{
let id_b = id.clone();
let prefix_b = prefix.to_string();
let b0 = arb_unique_continuation(depth - 1, &format!("{prefix_b}br0_"))
.prop_map(Box::new);
let b1 = arb_unique_continuation(depth - 1, &format!("{prefix_b}br1_"))
.prop_map(Box::new);
let default = prop::option::of(
arb_unique_continuation(depth - 1, &format!("{prefix_b}bd_"))
.prop_map(Box::new),
);
let next = prop::option::of(
arb_unique_continuation(depth - 1, &format!("{prefix_b}bn_"))
.prop_map(Box::new),
);
(b0, b1, default, next).prop_map(move |(branch0, branch1, default, next)| {
let mut branches = std::collections::HashMap::new();
branches.insert("k0".to_string(), branch0);
branches.insert("k1".to_string(), branch1);
SerializableContinuation::Branch {
id: id_b.clone(),
branches,
default,
next,
}
})
},
{
let id_l = id.clone();
let prefix_l = prefix.to_string();
let body = arb_unique_continuation(depth - 1, &format!("{prefix_l}lb_"))
.prop_map(Box::new);
let next = prop::option::of(
arb_unique_continuation(depth - 1, &format!("{prefix_l}ln_"))
.prop_map(Box::new),
);
(
body,
1..100u32,
prop::bool::ANY.prop_map(|b| {
if b {
MaxIterationsPolicy::Fail
} else {
MaxIterationsPolicy::ExitWithLast
}
}),
next,
)
.prop_map(move |(body, max_iterations, on_max, next)| {
SerializableContinuation::Loop {
id: id_l.clone(),
body,
max_iterations,
on_max,
next,
}
})
},
{
let id_cw = id;
let prefix_cw = prefix.to_string();
let child = arb_unique_continuation(depth - 1, &format!("{prefix_cw}cc_"))
.prop_map(Box::new);
let next = prop::option::of(
arb_unique_continuation(depth - 1, &format!("{prefix_cw}cn_"))
.prop_map(Box::new),
);
(child, next).prop_map(move |(child, next)| {
SerializableContinuation::ChildWorkflow {
id: id_cw.clone(),
child,
next,
}
})
},
]
.boxed()
}
fn collect_ids(cont: &SerializableContinuation) -> Vec<String> {
let mut ids = vec![];
fn walk(c: &SerializableContinuation, out: &mut Vec<String>) {
match c {
SerializableContinuation::Task { id, next, .. }
| SerializableContinuation::Delay { id, next, .. }
| SerializableContinuation::AwaitSignal { id, next, .. } => {
out.push(id.clone());
if let Some(n) = next {
walk(n, out);
}
}
SerializableContinuation::Fork { id, branches, join } => {
out.push(id.clone());
for b in branches {
walk(b, out);
}
if let Some(j) = join {
walk(j, out);
}
}
SerializableContinuation::Branch {
id,
branches,
default,
next,
} => {
out.push(id.clone());
for b in branches.values() {
walk(b, out);
}
if let Some(d) = default {
walk(d, out);
}
if let Some(n) = next {
walk(n, out);
}
}
SerializableContinuation::Loop { id, body, next, .. } => {
out.push(id.clone());
walk(body, out);
if let Some(n) = next {
walk(n, out);
}
}
SerializableContinuation::ChildWorkflow { id, child, next } => {
out.push(id.clone());
walk(child, out);
if let Some(n) = next {
walk(n, out);
}
}
}
}
walk(cont, &mut ids);
ids
}
fn inject_duplicate(cont: &SerializableContinuation, dup_id: &str) -> SerializableContinuation {
match cont {
SerializableContinuation::Task {
timeout_ms,
retry_policy,
version,
next,
..
} => SerializableContinuation::Task {
id: dup_id.to_string(),
timeout_ms: *timeout_ms,
retry_policy: retry_policy.clone(),
version: version.clone(),
priority: None,
tags: vec![],
next: next.clone(),
},
SerializableContinuation::Fork { branches, join, .. } => {
SerializableContinuation::Fork {
id: dup_id.to_string(),
branches: branches.clone(),
join: join.clone(),
}
}
SerializableContinuation::Delay {
duration_ms, next, ..
} => SerializableContinuation::Delay {
id: dup_id.to_string(),
duration_ms: *duration_ms,
next: next.clone(),
},
SerializableContinuation::AwaitSignal {
signal_name,
timeout_ms,
next,
..
} => SerializableContinuation::AwaitSignal {
id: dup_id.to_string(),
signal_name: signal_name.clone(),
timeout_ms: *timeout_ms,
next: next.clone(),
},
SerializableContinuation::Branch {
branches,
default,
next,
..
} => SerializableContinuation::Branch {
id: dup_id.to_string(),
branches: branches.clone(),
default: default.clone(),
next: next.clone(),
},
SerializableContinuation::Loop {
body,
max_iterations,
on_max,
next,
..
} => SerializableContinuation::Loop {
id: dup_id.to_string(),
body: body.clone(),
max_iterations: *max_iterations,
on_max: *on_max,
next: next.clone(),
},
SerializableContinuation::ChildWorkflow { child, next, .. } => {
SerializableContinuation::ChildWorkflow {
id: dup_id.to_string(),
child: child.clone(),
next: next.clone(),
}
}
}
}
proptest! {
#[test]
fn hash_is_deterministic(cont in arb_continuation(3)) {
let h1 = cont.compute_definition_hash();
let h2 = cont.compute_definition_hash();
prop_assert_eq!(h1, h2);
}
#[test]
fn serde_roundtrip_preserves_hash(cont in arb_continuation(3)) {
let original_hash = cont.compute_definition_hash();
let json = serde_json::to_string(&cont).unwrap();
let recovered: SerializableContinuation = serde_json::from_str(&json).unwrap();
prop_assert_eq!(original_hash, recovered.compute_definition_hash());
}
#[test]
fn unique_ids_means_none(cont in arb_unique_continuation(3, "r_")) {
prop_assert!(cont.find_duplicate_id().is_none());
}
#[test]
fn injected_duplicate_is_detected(cont in arb_unique_continuation(3, "r_")) {
let ids = collect_ids(&cont);
if ids.len() >= 2 {
let dup_id = &ids[1];
let tampered = inject_duplicate(&cont, dup_id);
prop_assert!(tampered.find_duplicate_id().is_some());
}
}
}
}