use crate::{Condition, Signature};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub enum ErrorStrategy {
#[default]
StopOnError,
ContinueOnError,
RetryOnError {
max_retries: u32,
delay: Option<u64>,
},
Fallback {
fallback: Signature,
},
ErrorHandler {
handler: Signature,
},
}
impl ErrorStrategy {
pub fn stop() -> Self {
Self::StopOnError
}
pub fn continue_on_error() -> Self {
Self::ContinueOnError
}
pub fn retry(max_retries: u32) -> Self {
Self::RetryOnError {
max_retries,
delay: None,
}
}
pub fn retry_with_delay(max_retries: u32, delay: u64) -> Self {
Self::RetryOnError {
max_retries,
delay: Some(delay),
}
}
pub fn fallback(task: Signature) -> Self {
Self::Fallback { fallback: task }
}
pub fn error_handler(handler: Signature) -> Self {
Self::ErrorHandler { handler }
}
pub fn allows_continue(&self) -> bool {
!matches!(self, Self::StopOnError)
}
}
impl std::fmt::Display for ErrorStrategy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::StopOnError => write!(f, "StopOnError"),
Self::ContinueOnError => write!(f, "ContinueOnError"),
Self::RetryOnError { max_retries, delay } => {
if let Some(d) = delay {
write!(f, "RetryOnError({} times, {}s delay)", max_retries, d)
} else {
write!(f, "RetryOnError({} times)", max_retries)
}
}
Self::Fallback { fallback } => write!(f, "Fallback({})", fallback.task),
Self::ErrorHandler { handler } => write!(f, "ErrorHandler({})", handler.task),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CancellationToken {
pub workflow_id: Uuid,
pub reason: Option<String>,
pub cancel_tree: bool,
pub branch_id: Option<Uuid>,
}
impl CancellationToken {
pub fn new(workflow_id: Uuid) -> Self {
Self {
workflow_id,
reason: None,
cancel_tree: false,
branch_id: None,
}
}
pub fn with_reason(mut self, reason: String) -> Self {
self.reason = Some(reason);
self
}
pub fn cancel_tree(mut self) -> Self {
self.cancel_tree = true;
self
}
pub fn cancel_branch(mut self, branch_id: Uuid) -> Self {
self.branch_id = Some(branch_id);
self
}
}
impl std::fmt::Display for CancellationToken {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "CancellationToken[workflow={}]", self.workflow_id)?;
if let Some(ref reason) = self.reason {
write!(f, " reason={}", reason)?;
}
if self.cancel_tree {
write!(f, " (tree)")?;
}
if let Some(branch) = self.branch_id {
write!(f, " branch={}", branch)?;
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkflowRetryPolicy {
pub max_retries: u32,
pub retry_failed_only: bool,
pub backoff_factor: Option<f64>,
pub max_backoff: Option<u64>,
pub initial_delay: Option<u64>,
}
impl WorkflowRetryPolicy {
pub fn new(max_retries: u32) -> Self {
Self {
max_retries,
retry_failed_only: false,
backoff_factor: None,
max_backoff: None,
initial_delay: None,
}
}
pub fn failed_only(mut self) -> Self {
self.retry_failed_only = true;
self
}
pub fn with_backoff(mut self, factor: f64, max_delay: u64) -> Self {
self.backoff_factor = Some(factor);
self.max_backoff = Some(max_delay);
self
}
pub fn with_initial_delay(mut self, delay: u64) -> Self {
self.initial_delay = Some(delay);
self
}
pub fn calculate_delay(&self, attempt: u32) -> u64 {
let base_delay = self.initial_delay.unwrap_or(1);
if let Some(factor) = self.backoff_factor {
let delay = (base_delay as f64) * factor.powi(attempt as i32);
let max = self.max_backoff.unwrap_or(300);
delay.min(max as f64) as u64
} else {
base_delay
}
}
}
impl std::fmt::Display for WorkflowRetryPolicy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "WorkflowRetryPolicy[max_retries={}]", self.max_retries)?;
if self.retry_failed_only {
write!(f, " (failed_only)")?;
}
if let Some(factor) = self.backoff_factor {
write!(f, " backoff={}", factor)?;
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkflowTimeout {
pub total_timeout: Option<u64>,
pub stage_timeout: Option<u64>,
pub escalation: TimeoutEscalation,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TimeoutEscalation {
Cancel,
Fail,
ContinuePartial,
}
impl WorkflowTimeout {
pub fn new(total_timeout: u64) -> Self {
Self {
total_timeout: Some(total_timeout),
stage_timeout: None,
escalation: TimeoutEscalation::Cancel,
}
}
pub fn with_stage_timeout(mut self, timeout: u64) -> Self {
self.stage_timeout = Some(timeout);
self
}
pub fn with_escalation(mut self, escalation: TimeoutEscalation) -> Self {
self.escalation = escalation;
self
}
}
impl std::fmt::Display for WorkflowTimeout {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "WorkflowTimeout[")?;
if let Some(total) = self.total_timeout {
write!(f, "total={}s", total)?;
}
if let Some(stage) = self.stage_timeout {
write!(f, " stage={}s", stage)?;
}
write!(f, " escalation={:?}]", self.escalation)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ForEach {
pub task: Signature,
pub items: Vec<serde_json::Value>,
pub concurrency: Option<usize>,
}
impl ForEach {
pub fn new(task: Signature, items: Vec<serde_json::Value>) -> Self {
Self {
task,
items,
concurrency: None,
}
}
pub fn with_concurrency(mut self, concurrency: usize) -> Self {
self.concurrency = Some(concurrency);
self
}
pub fn is_empty(&self) -> bool {
self.items.is_empty()
}
pub fn len(&self) -> usize {
self.items.len()
}
}
impl std::fmt::Display for ForEach {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "ForEach[task={}, {} items]", self.task.task, self.len())?;
if let Some(conc) = self.concurrency {
write!(f, " concurrency={}", conc)?;
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WhileLoop {
pub condition: Condition,
pub body: Signature,
pub max_iterations: Option<u32>,
}
impl WhileLoop {
pub fn new(condition: Condition, body: Signature) -> Self {
Self {
condition,
body,
max_iterations: Some(1000), }
}
pub fn with_max_iterations(mut self, max: u32) -> Self {
self.max_iterations = Some(max);
self
}
pub fn unlimited(mut self) -> Self {
self.max_iterations = None;
self
}
}
impl std::fmt::Display for WhileLoop {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "While[{} -> {}]", self.condition, self.body.task)?;
if let Some(max) = self.max_iterations {
write!(f, " max={}", max)?;
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkflowState {
pub workflow_id: Uuid,
pub status: WorkflowStatus,
pub total_tasks: usize,
pub completed_tasks: usize,
pub failed_tasks: usize,
pub start_time: Option<u64>,
pub end_time: Option<u64>,
pub current_stage: Option<String>,
pub intermediate_results: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum WorkflowStatus {
Pending,
Running,
Success,
Failed,
Cancelled,
Paused,
}
impl WorkflowState {
pub fn new(workflow_id: Uuid, total_tasks: usize) -> Self {
Self {
workflow_id,
status: WorkflowStatus::Pending,
total_tasks,
completed_tasks: 0,
failed_tasks: 0,
start_time: None,
end_time: None,
current_stage: None,
intermediate_results: HashMap::new(),
}
}
pub fn progress(&self) -> f64 {
if self.total_tasks == 0 {
return 100.0;
}
(self.completed_tasks as f64 / self.total_tasks as f64) * 100.0
}
pub fn is_complete(&self) -> bool {
matches!(
self.status,
WorkflowStatus::Success | WorkflowStatus::Failed | WorkflowStatus::Cancelled
)
}
pub fn mark_completed(&mut self) {
self.completed_tasks += 1;
}
pub fn mark_failed(&mut self) {
self.failed_tasks += 1;
}
pub fn set_result(&mut self, key: String, value: serde_json::Value) {
self.intermediate_results.insert(key, value);
}
pub fn get_result(&self, key: &str) -> Option<&serde_json::Value> {
self.intermediate_results.get(key)
}
}
impl std::fmt::Display for WorkflowState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"WorkflowState[id={}, status={:?}, progress={:.1}%]",
self.workflow_id,
self.status,
self.progress()
)?;
if self.failed_tasks > 0 {
write!(f, " failed={}", self.failed_tasks)?;
}
Ok(())
}
}