Skip to main content

punch_kernel/
workflow.rs

1//! Multi-step agent workflow engine with DAG execution.
2//!
3//! The [`WorkflowEngine`] allows registering named workflows composed of
4//! sequential steps or DAG-structured steps with parallel fan-out, conditional
5//! branching, loops, and advanced error handling.
6//!
7//! ## Variable substitution
8//!
9//! Prompt templates support:
10//! - `{{input}}` / `{{previous_output}}` — current pipeline input
11//! - `{{step_name}}` — name of the current step
12//! - `{{step_N}}` — output of step N (1-indexed, sequential mode)
13//! - `{{some_step_name}}` — output of a step by name
14//! - `{{step_name.output}}` — explicit step output reference
15//! - `{{step_name.status}}` — step completion status
16//! - `{{step_name.duration_ms}}` — step duration
17//! - `{{loop.index}}` — current loop iteration
18//! - `{{loop.item}}` — current loop item (ForEach)
19//! - `{{step_name.output.field.nested}}` — JSON path into step output
20//! - `{{step_name.output | uppercase}}` — data transformation
21
22use std::collections::HashMap;
23use std::sync::Arc;
24use std::time::Instant;
25
26use chrono::{DateTime, Utc};
27use dashmap::DashMap;
28use serde::{Deserialize, Serialize};
29use tracing::{debug, error, info, instrument, warn};
30use uuid::Uuid;
31
32use punch_memory::MemorySubstrate;
33use punch_runtime::{FighterLoopParams, LlmDriver, run_fighter_loop, tools_for_capabilities};
34use punch_types::{FighterId, FighterManifest, ModelConfig, PunchError, PunchResult, WeightClass};
35
36use crate::workflow_conditions::{Condition, evaluate_condition};
37use crate::workflow_loops::{LoopConfig, LoopState, calculate_backoff, parse_foreach_items};
38use crate::workflow_validation::{ValidationError, topological_sort, validate_workflow};
39
40// ---------------------------------------------------------------------------
41// ID types
42// ---------------------------------------------------------------------------
43
44/// Unique identifier for a workflow definition.
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
46#[serde(transparent)]
47pub struct WorkflowId(pub Uuid);
48
49impl WorkflowId {
50    pub fn new() -> Self {
51        Self(Uuid::new_v4())
52    }
53}
54
55impl Default for WorkflowId {
56    fn default() -> Self {
57        Self::new()
58    }
59}
60
61impl std::fmt::Display for WorkflowId {
62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63        write!(f, "{}", self.0)
64    }
65}
66
67/// Unique identifier for a workflow run (execution instance).
68#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
69#[serde(transparent)]
70pub struct WorkflowRunId(pub Uuid);
71
72impl WorkflowRunId {
73    pub fn new() -> Self {
74        Self(Uuid::new_v4())
75    }
76}
77
78impl Default for WorkflowRunId {
79    fn default() -> Self {
80        Self::new()
81    }
82}
83
84impl std::fmt::Display for WorkflowRunId {
85    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86        write!(f, "{}", self.0)
87    }
88}
89
90// ---------------------------------------------------------------------------
91// Workflow types
92// ---------------------------------------------------------------------------
93
94/// What to do when a workflow step fails.
95#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
96#[serde(rename_all = "snake_case")]
97#[derive(Default)]
98pub enum OnError {
99    /// Abort the entire workflow.
100    #[default]
101    FailWorkflow,
102    /// Skip the failed step and continue.
103    SkipStep,
104    /// Retry the step once, then fail if it fails again.
105    RetryOnce,
106    /// On error, run a fallback step instead.
107    Fallback { step: String },
108    /// Run an error handler step, then continue the workflow.
109    CatchAndContinue { error_handler: String },
110    /// Stop trying after N consecutive failures, with a cooldown.
111    CircuitBreaker {
112        max_failures: usize,
113        cooldown_secs: u64,
114    },
115}
116
117/// Per-step execution status.
118#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
119#[serde(rename_all = "snake_case")]
120pub enum StepStatus {
121    Pending,
122    Running,
123    Completed,
124    Failed,
125    Skipped,
126    Cancelled,
127}
128
129impl std::fmt::Display for StepStatus {
130    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131        match self {
132            Self::Pending => write!(f, "pending"),
133            Self::Running => write!(f, "running"),
134            Self::Completed => write!(f, "completed"),
135            Self::Failed => write!(f, "failed"),
136            Self::Skipped => write!(f, "skipped"),
137            Self::Cancelled => write!(f, "cancelled"),
138        }
139    }
140}
141
142/// A single step within a sequential workflow (legacy format, still supported).
143#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct WorkflowStep {
145    /// Human-readable name for this step.
146    pub name: String,
147    /// The fighter name to use for this step.
148    pub fighter_name: String,
149    /// Prompt template with variable substitution.
150    pub prompt_template: String,
151    /// Maximum time in seconds for this step (default 120).
152    pub timeout_secs: Option<u64>,
153    /// Error handling strategy.
154    #[serde(default)]
155    pub on_error: OnError,
156}
157
158/// A single step within a DAG workflow.
159#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct DagWorkflowStep {
161    /// Human-readable name for this step (must be unique within the workflow).
162    pub name: String,
163    /// The fighter name to use for this step.
164    pub fighter_name: String,
165    /// Prompt template with variable substitution.
166    pub prompt_template: String,
167    /// Maximum time in seconds for this step (default 120).
168    pub timeout_secs: Option<u64>,
169    /// Error handling strategy.
170    #[serde(default)]
171    pub on_error: OnError,
172    /// Steps that must complete before this one runs.
173    #[serde(default)]
174    pub depends_on: Vec<String>,
175    /// Optional condition — step is skipped if condition evaluates to false.
176    #[serde(default)]
177    pub condition: Option<Condition>,
178    /// If condition is false, run this step instead (if/else branching).
179    #[serde(default)]
180    pub else_step: Option<String>,
181    /// Optional loop configuration.
182    #[serde(default)]
183    pub loop_config: Option<LoopConfig>,
184}
185
186impl DagWorkflowStep {
187    /// Extract the fallback step name from the on_error strategy, if any.
188    pub fn fallback_step(&self) -> Option<String> {
189        match &self.on_error {
190            OnError::Fallback { step } => Some(step.clone()),
191            OnError::CatchAndContinue { error_handler } => Some(error_handler.clone()),
192            _ => None,
193        }
194    }
195}
196
197/// A workflow definition composed of sequential steps (legacy).
198#[derive(Debug, Clone, Serialize, Deserialize)]
199pub struct Workflow {
200    /// Unique identifier.
201    pub id: WorkflowId,
202    /// Human-readable name.
203    pub name: String,
204    /// Ordered steps to execute.
205    pub steps: Vec<WorkflowStep>,
206}
207
208/// A DAG workflow definition with parallel execution support.
209#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct DagWorkflow {
211    /// Unique identifier.
212    pub id: WorkflowId,
213    /// Human-readable name.
214    pub name: String,
215    /// DAG steps (order in vec doesn't matter — execution order is determined by dependencies).
216    pub steps: Vec<DagWorkflowStep>,
217}
218
219/// Status of a workflow run.
220#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
221#[serde(rename_all = "snake_case")]
222pub enum WorkflowRunStatus {
223    Pending,
224    Running,
225    Completed,
226    Failed,
227    /// Some branches succeeded, some failed.
228    PartiallyCompleted,
229}
230
231impl std::fmt::Display for WorkflowRunStatus {
232    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
233        match self {
234            Self::Pending => write!(f, "pending"),
235            Self::Running => write!(f, "running"),
236            Self::Completed => write!(f, "completed"),
237            Self::Failed => write!(f, "failed"),
238            Self::PartiallyCompleted => write!(f, "partially_completed"),
239        }
240    }
241}
242
243/// Result of executing a single workflow step.
244#[derive(Debug, Clone, Serialize, Deserialize)]
245pub struct StepResult {
246    /// Name of the step.
247    pub step_name: String,
248    /// The response text from the fighter.
249    pub response: String,
250    /// Tokens consumed.
251    pub tokens_used: u64,
252    /// Duration in milliseconds.
253    pub duration_ms: u64,
254    /// Error message, if any.
255    pub error: Option<String>,
256    /// Per-step status.
257    #[serde(default = "default_step_status")]
258    pub status: StepStatus,
259    /// When the step started executing.
260    #[serde(default)]
261    pub started_at: Option<DateTime<Utc>>,
262    /// When the step finished executing.
263    #[serde(default)]
264    pub completed_at: Option<DateTime<Utc>>,
265}
266
267fn default_step_status() -> StepStatus {
268    StepStatus::Pending
269}
270
271/// A failed step result stored in the dead letter queue.
272#[derive(Debug, Clone, Serialize, Deserialize)]
273pub struct DeadLetterEntry {
274    /// The step name that failed.
275    pub step_name: String,
276    /// The error message.
277    pub error: String,
278    /// The input that was provided to the step.
279    pub input: String,
280    /// When the failure occurred.
281    pub failed_at: DateTime<Utc>,
282}
283
284/// A single execution of a workflow.
285#[derive(Debug, Clone, Serialize, Deserialize)]
286pub struct WorkflowRun {
287    /// Unique run identifier.
288    pub id: WorkflowRunId,
289    /// The workflow that was executed.
290    pub workflow_id: WorkflowId,
291    /// Current status.
292    pub status: WorkflowRunStatus,
293    /// Results of each completed step.
294    pub step_results: Vec<StepResult>,
295    /// When the run started.
296    pub started_at: DateTime<Utc>,
297    /// When the run completed (or failed).
298    pub completed_at: Option<DateTime<Utc>>,
299    /// Dead letter queue for failed steps.
300    #[serde(default)]
301    pub dead_letters: Vec<DeadLetterEntry>,
302    /// Execution trace showing which steps ran in parallel.
303    #[serde(default)]
304    pub execution_trace: Vec<ExecutionTraceEntry>,
305}
306
307/// An entry in the execution trace showing what happened at each "wave" of execution.
308#[derive(Debug, Clone, Serialize, Deserialize)]
309pub struct ExecutionTraceEntry {
310    /// Steps that executed in this wave (parallel batch).
311    pub steps: Vec<String>,
312    /// When this wave started.
313    pub started_at: DateTime<Utc>,
314    /// When this wave completed.
315    pub completed_at: Option<DateTime<Utc>>,
316}
317
318// ---------------------------------------------------------------------------
319// Variable substitution
320// ---------------------------------------------------------------------------
321
322/// Replace template variables in a prompt string (sequential mode).
323///
324/// Supported variables:
325/// - `{{input}}` — the current input (original input or previous step's output)
326/// - `{{previous_output}}` — alias for `{{input}}`
327/// - `{{step_name}}` — the name of the current step
328/// - `{{step_1}}` / `{{step_N}}` — output of step N (1-indexed)
329/// - `{{some_step_name}}` — output of a step referenced by its name
330fn expand_variables(
331    template: &str,
332    current_input: &str,
333    step_name: &str,
334    step_results: &[StepResult],
335) -> String {
336    let mut result = template.to_string();
337
338    // {{input}} and {{previous_output}} both resolve to the current pipeline input
339    result = result.replace("{{input}}", current_input);
340    result = result.replace("{{previous_output}}", current_input);
341
342    // {{step_name}} resolves to the current step's name
343    result = result.replace("{{step_name}}", step_name);
344
345    // {{step_N}} resolves to the output of the Nth step (1-indexed)
346    for (i, sr) in step_results.iter().enumerate() {
347        let var = format!("{{{{step_{}}}}}", i + 1);
348        result = result.replace(&var, &sr.response);
349    }
350
351    // {{step_result_name}} resolves to the output of a step by name
352    for sr in step_results {
353        let var = format!("{{{{{}}}}}", sr.step_name);
354        result = result.replace(&var, &sr.response);
355    }
356
357    result
358}
359
360/// Replace template variables in a prompt string (DAG mode).
361///
362/// Supports all the sequential variables plus:
363/// - `{{step_name.output}}` — explicit output reference
364/// - `{{step_name.status}}` — step status
365/// - `{{step_name.duration_ms}}` — step duration
366/// - `{{loop.index}}` — current loop iteration
367/// - `{{loop.item}}` — current loop item
368/// - `{{step_name.output.field.nested}}` — JSON path
369/// - `{{step_name.output | uppercase}}` — transformations
370pub fn expand_dag_variables(
371    template: &str,
372    current_input: &str,
373    step_name: &str,
374    step_results: &HashMap<String, StepResult>,
375    loop_state: Option<&LoopState>,
376) -> String {
377    let mut result = template.to_string();
378
379    // Basic variables
380    result = result.replace("{{input}}", current_input);
381    result = result.replace("{{previous_output}}", current_input);
382    result = result.replace("{{step_name}}", step_name);
383
384    // Loop variables
385    if let Some(ls) = loop_state {
386        result = result.replace("{{loop.index}}", &ls.index.to_string());
387        if let Some(ref item) = ls.item {
388            result = result.replace("{{loop.item}}", item);
389        }
390    }
391
392    // Process {{name.property}} and {{name.output.path}} patterns
393    // We need to find all {{...}} patterns and resolve them
394    let mut output = String::with_capacity(result.len());
395    let mut remaining = result.as_str();
396
397    while let Some(start) = remaining.find("{{") {
398        output.push_str(&remaining[..start]);
399        let after_start = &remaining[start + 2..];
400        if let Some(end) = after_start.find("}}") {
401            let var_content = &after_start[..end];
402            let resolved = resolve_dag_variable(var_content, step_results);
403            output.push_str(&resolved);
404            remaining = &after_start[end + 2..];
405        } else {
406            output.push_str("{{");
407            remaining = after_start;
408        }
409    }
410    output.push_str(remaining);
411
412    output
413}
414
415/// Resolve a single variable expression like `step_name.output` or `step_name.output | uppercase`.
416fn resolve_dag_variable(var: &str, step_results: &HashMap<String, StepResult>) -> String {
417    // Check for pipe transformation: `expr | transform`
418    let (expr, transform) = if let Some(pipe_pos) = var.find(" | ") {
419        let expr = var[..pipe_pos].trim();
420        let transform = var[pipe_pos + 3..].trim();
421        (expr, Some(transform))
422    } else {
423        (var.trim(), None)
424    };
425
426    // Resolve the expression
427    let value = resolve_dag_expression(expr, step_results);
428
429    // Apply transformation if present
430    match transform {
431        Some("uppercase") => value.to_uppercase(),
432        Some("lowercase") => value.to_lowercase(),
433        Some("trim") => value.trim().to_string(),
434        Some("len") | Some("length") => value.len().to_string(),
435        Some(t) if t.starts_with("json_extract ") => {
436            let path = t
437                .strip_prefix("json_extract ")
438                .unwrap_or("")
439                .trim_matches('"');
440            json_path_extract(&value, path)
441        }
442        _ => value,
443    }
444}
445
446/// Resolve a dotted expression like `step_name.output.field.nested`.
447fn resolve_dag_expression(expr: &str, step_results: &HashMap<String, StepResult>) -> String {
448    let parts: Vec<&str> = expr.splitn(2, '.').collect();
449    if parts.len() < 2 {
450        // Plain step name reference
451        return step_results
452            .get(parts[0])
453            .map(|r| r.response.clone())
454            .unwrap_or_else(|| format!("{{{{{expr}}}}}"));
455    }
456
457    let step_name = parts[0];
458    let property = parts[1];
459
460    let step_result = match step_results.get(step_name) {
461        Some(r) => r,
462        None => return format!("{{{{{expr}}}}}"),
463    };
464
465    match property {
466        "output" => step_result.response.clone(),
467        "status" => step_result.status.to_string(),
468        "duration_ms" => step_result.duration_ms.to_string(),
469        "error" => step_result
470            .error
471            .clone()
472            .unwrap_or_else(|| "none".to_string()),
473        _ if property.starts_with("output.") => {
474            let json_path = property.strip_prefix("output.").unwrap_or("");
475            json_path_extract(&step_result.response, json_path)
476        }
477        _ => format!("{{{{{expr}}}}}"),
478    }
479}
480
481/// Extract a value from a JSON string using a dot-separated path.
482///
483/// Supports paths like `field`, `field.nested`, `$.key` (strips leading `$.`).
484fn json_path_extract(json_str: &str, path: &str) -> String {
485    let path = path.strip_prefix("$.").unwrap_or(path);
486    let parsed: serde_json::Value = match serde_json::from_str(json_str) {
487        Ok(v) => v,
488        Err(_) => return json_str.to_string(),
489    };
490
491    let mut current = &parsed;
492    for segment in path.split('.') {
493        if segment.is_empty() {
494            continue;
495        }
496        match current.get(segment) {
497            Some(v) => current = v,
498            None => return String::new(),
499        }
500    }
501
502    match current {
503        serde_json::Value::String(s) => s.clone(),
504        other => other.to_string(),
505    }
506}
507
508// ---------------------------------------------------------------------------
509// Circuit breaker state
510// ---------------------------------------------------------------------------
511
512/// Tracks circuit breaker state per-step across workflow runs.
513#[derive(Debug, Clone, Default)]
514pub struct CircuitBreakerState {
515    /// Number of consecutive failures.
516    pub consecutive_failures: usize,
517    /// When the circuit was last tripped (entered open state).
518    pub last_trip_time: Option<Instant>,
519}
520
521impl CircuitBreakerState {
522    /// Check if the circuit is currently open (blocking execution).
523    pub fn is_open(&self, max_failures: usize, cooldown_secs: u64) -> bool {
524        if self.consecutive_failures < max_failures {
525            return false;
526        }
527        // Check if cooldown has elapsed
528        match self.last_trip_time {
529            Some(trip_time) => trip_time.elapsed().as_secs() < cooldown_secs,
530            None => true,
531        }
532    }
533
534    /// Record a failure.
535    pub fn record_failure(&mut self) {
536        self.consecutive_failures += 1;
537        self.last_trip_time = Some(Instant::now());
538    }
539
540    /// Record a success, resetting the counter.
541    pub fn record_success(&mut self) {
542        self.consecutive_failures = 0;
543        self.last_trip_time = None;
544    }
545}
546
547// ---------------------------------------------------------------------------
548// DAG Executor (testable without LLM)
549// ---------------------------------------------------------------------------
550
551/// A step executor trait that allows testing the DAG engine without real LLM calls.
552#[async_trait::async_trait]
553pub trait StepExecutor: Send + Sync {
554    /// Execute a single step and return its result.
555    async fn execute(
556        &self,
557        step: &DagWorkflowStep,
558        input: &str,
559        step_results: &HashMap<String, StepResult>,
560        loop_state: Option<&LoopState>,
561    ) -> Result<StepResult, String>;
562}
563
564/// Execute a DAG workflow using the provided step executor.
565///
566/// This is the core DAG execution engine. Steps with no dependencies (roots) run
567/// first. When a step completes, any step whose dependencies are now all satisfied
568/// is scheduled. Steps with no mutual dependencies run concurrently using
569/// `tokio::task::JoinSet` for true multi-threaded parallelism.
570pub async fn execute_dag(
571    workflow_name: &str,
572    steps: &[DagWorkflowStep],
573    input: &str,
574    executor: Arc<dyn StepExecutor>,
575) -> DagExecutionResult {
576    // Validate first
577    let validation_errors = validate_workflow(steps);
578    if !validation_errors.is_empty() {
579        return DagExecutionResult {
580            status: WorkflowRunStatus::Failed,
581            step_results: HashMap::new(),
582            dead_letters: Vec::new(),
583            execution_trace: Vec::new(),
584            validation_errors,
585        };
586    }
587
588    // Get topological order
589    let topo_order = match topological_sort(steps) {
590        Ok(order) => order,
591        Err(_) => {
592            return DagExecutionResult {
593                status: WorkflowRunStatus::Failed,
594                step_results: HashMap::new(),
595                dead_letters: Vec::new(),
596                execution_trace: Vec::new(),
597                validation_errors: vec![ValidationError::CycleDetected {
598                    steps: steps.iter().map(|s| s.name.clone()).collect(),
599                }],
600            };
601        }
602    };
603
604    let step_map: HashMap<&str, &DagWorkflowStep> =
605        steps.iter().map(|s| (s.name.as_str(), s)).collect();
606
607    let mut completed: HashMap<String, StepResult> = HashMap::new();
608    let mut dead_letters: Vec<DeadLetterEntry> = Vec::new();
609    let mut execution_trace: Vec<ExecutionTraceEntry> = Vec::new();
610    let mut circuit_breakers: HashMap<String, CircuitBreakerState> = HashMap::new();
611    let mut skipped_steps: std::collections::HashSet<String> = std::collections::HashSet::new();
612    let mut failed_steps: std::collections::HashSet<String> = std::collections::HashSet::new();
613
614    // Process in waves: each wave contains steps whose dependencies are all satisfied
615    let mut remaining: Vec<String> = topo_order;
616
617    while !remaining.is_empty() {
618        // Find all steps that can run now (all deps satisfied)
619        let (ready, not_ready): (Vec<String>, Vec<String>) =
620            remaining.into_iter().partition(|name| {
621                let step = match step_map.get(name.as_str()) {
622                    Some(s) => s,
623                    None => return false,
624                };
625                step.depends_on.iter().all(|dep| {
626                    // A dependency is satisfied if it completed (not in failed_steps)
627                    // or was explicitly skipped/handled
628                    let is_done = completed.contains_key(dep) || skipped_steps.contains(dep);
629                    let is_blocking_failure = failed_steps.contains(dep);
630                    is_done && !is_blocking_failure
631                })
632            });
633
634        if ready.is_empty() {
635            // No progress possible — remaining steps have unmet deps (likely due to failures)
636            for name in &not_ready {
637                skipped_steps.insert(name.clone());
638                completed.insert(
639                    name.clone(),
640                    StepResult {
641                        step_name: name.clone(),
642                        response: String::new(),
643                        tokens_used: 0,
644                        duration_ms: 0,
645                        error: Some("cancelled: unmet dependencies".to_string()),
646                        status: StepStatus::Cancelled,
647                        started_at: None,
648                        completed_at: None,
649                    },
650                );
651            }
652            break;
653        }
654
655        remaining = not_ready;
656
657        let wave_start = Utc::now();
658        let wave_step_names: Vec<String> = ready.to_vec();
659
660        // Execute all ready steps concurrently using tokio::task::JoinSet
661        // for true multi-threaded parallelism.
662        let mut wave_results: Vec<(String, Result<StepResult, String>, Option<String>)> =
663            Vec::new();
664        let mut join_set: tokio::task::JoinSet<(
665            String,
666            Result<StepResult, String>,
667            Option<String>,
668        )> = tokio::task::JoinSet::new();
669
670        for step_name in &wave_step_names {
671            let step = match step_map.get(step_name.as_str()) {
672                Some(s) => (*s).clone(),
673                None => continue,
674            };
675
676            // Check condition
677            let should_run = match &step.condition {
678                Some(cond) => evaluate_condition(cond, &completed),
679                None => true,
680            };
681
682            if !should_run {
683                let else_step_name = step.else_step.clone();
684                wave_results.push((
685                    step_name.clone(),
686                    Ok(StepResult {
687                        step_name: step_name.clone(),
688                        response: String::new(),
689                        tokens_used: 0,
690                        duration_ms: 0,
691                        error: None,
692                        status: StepStatus::Skipped,
693                        started_at: Some(Utc::now()),
694                        completed_at: Some(Utc::now()),
695                    }),
696                    else_step_name,
697                ));
698                continue;
699            }
700
701            // Check circuit breaker
702            let cb_state = circuit_breakers
703                .entry(step_name.clone())
704                .or_default()
705                .clone();
706            if let OnError::CircuitBreaker {
707                max_failures,
708                cooldown_secs,
709            } = &step.on_error
710                && cb_state.is_open(*max_failures, *cooldown_secs)
711            {
712                wave_results.push((
713                    step_name.clone(),
714                    Ok(StepResult {
715                        step_name: step_name.clone(),
716                        response: String::new(),
717                        tokens_used: 0,
718                        duration_ms: 0,
719                        error: Some("circuit breaker open".to_string()),
720                        status: StepStatus::Failed,
721                        started_at: Some(Utc::now()),
722                        completed_at: Some(Utc::now()),
723                    }),
724                    None,
725                ));
726                continue;
727            }
728
729            let sn = step_name.clone();
730            let completed_snapshot = completed.clone();
731            let input_clone = input.to_string();
732            let executor_clone = Arc::clone(&executor);
733
734            join_set.spawn(async move {
735                let result = execute_step_with_loops(
736                    &step,
737                    &input_clone,
738                    &completed_snapshot,
739                    executor_clone.as_ref(),
740                )
741                .await;
742                (sn, result, None::<String>)
743            });
744        }
745
746        // Wait for all spawned tasks to complete
747        while let Some(join_result) = join_set.join_next().await {
748            match join_result {
749                Ok(task_result) => wave_results.push(task_result),
750                Err(join_err) => {
751                    // A JoinError means the task panicked or was cancelled
752                    error!(error = %join_err, "spawned step task failed unexpectedly");
753                }
754            }
755        }
756
757        // Process results
758        for (step_name, result, _else_step) in wave_results {
759            match result {
760                Ok(mut step_result) => {
761                    if step_result.status == StepStatus::Skipped {
762                        skipped_steps.insert(step_name.clone());
763                        debug!(step = %step_name, workflow = %workflow_name, "step skipped (condition false)");
764                    } else if step_result.error.is_some() {
765                        failed_steps.insert(step_name.clone());
766                        // Update circuit breaker
767                        circuit_breakers
768                            .entry(step_name.clone())
769                            .or_default()
770                            .record_failure();
771
772                        let step = step_map.get(step_name.as_str());
773                        if let Some(step) = step {
774                            match &step.on_error {
775                                OnError::Fallback { step: fb_step } => {
776                                    // Try to execute fallback
777                                    if let Some(fb) = step_map.get(fb_step.as_str()) {
778                                        let fb_result =
779                                            executor.execute(fb, input, &completed, None).await;
780                                        match fb_result {
781                                            Ok(fb_res) => {
782                                                step_result = fb_res;
783                                                step_result.step_name = step_name.clone();
784                                                failed_steps.remove(&step_name);
785                                            }
786                                            Err(fb_err) => {
787                                                dead_letters.push(DeadLetterEntry {
788                                                    step_name: step_name.clone(),
789                                                    error: fb_err,
790                                                    input: input.to_string(),
791                                                    failed_at: Utc::now(),
792                                                });
793                                            }
794                                        }
795                                    }
796                                }
797                                OnError::CatchAndContinue { error_handler } => {
798                                    // Run the error handler
799                                    if let Some(handler) = step_map.get(error_handler.as_str()) {
800                                        let _ = executor
801                                            .execute(handler, input, &completed, None)
802                                            .await;
803                                    }
804                                    // Continue anyway — mark as handled
805                                    failed_steps.remove(&step_name);
806                                }
807                                OnError::SkipStep => {
808                                    skipped_steps.insert(step_name.clone());
809                                    failed_steps.remove(&step_name);
810                                }
811                                OnError::FailWorkflow => {
812                                    dead_letters.push(DeadLetterEntry {
813                                        step_name: step_name.clone(),
814                                        error: step_result.error.clone().unwrap_or_default(),
815                                        input: input.to_string(),
816                                        failed_at: Utc::now(),
817                                    });
818                                }
819                                _ => {}
820                            }
821                        }
822                    } else {
823                        // Success
824                        circuit_breakers
825                            .entry(step_name.clone())
826                            .or_default()
827                            .record_success();
828                        info!(step = %step_name, workflow = %workflow_name, "DAG step completed");
829                    }
830                    completed.insert(step_name, step_result);
831                }
832                Err(e) => {
833                    failed_steps.insert(step_name.clone());
834                    circuit_breakers
835                        .entry(step_name.clone())
836                        .or_default()
837                        .record_failure();
838
839                    let mut step_result = StepResult {
840                        step_name: step_name.clone(),
841                        response: String::new(),
842                        tokens_used: 0,
843                        duration_ms: 0,
844                        error: Some(e.clone()),
845                        status: StepStatus::Failed,
846                        started_at: Some(Utc::now()),
847                        completed_at: Some(Utc::now()),
848                    };
849
850                    // Try error recovery strategies
851                    let step = step_map.get(step_name.as_str());
852                    if let Some(step) = step {
853                        match &step.on_error {
854                            OnError::Fallback { step: fb_step } => {
855                                if let Some(fb) = step_map.get(fb_step.as_str())
856                                    && let Ok(fb_res) =
857                                        executor.execute(fb, input, &completed, None).await
858                                {
859                                    step_result = fb_res;
860                                    step_result.step_name = step_name.clone();
861                                    step_result.error = None;
862                                    step_result.status = StepStatus::Completed;
863                                    failed_steps.remove(&step_name);
864                                }
865                            }
866                            OnError::CatchAndContinue { error_handler } => {
867                                if let Some(handler) = step_map.get(error_handler.as_str()) {
868                                    let _ =
869                                        executor.execute(handler, input, &completed, None).await;
870                                }
871                                failed_steps.remove(&step_name);
872                            }
873                            OnError::SkipStep => {
874                                step_result.status = StepStatus::Skipped;
875                                skipped_steps.insert(step_name.clone());
876                                failed_steps.remove(&step_name);
877                            }
878                            OnError::FailWorkflow => {
879                                dead_letters.push(DeadLetterEntry {
880                                    step_name: step_name.clone(),
881                                    error: e,
882                                    input: input.to_string(),
883                                    failed_at: Utc::now(),
884                                });
885                            }
886                            _ => {
887                                dead_letters.push(DeadLetterEntry {
888                                    step_name: step_name.clone(),
889                                    error: e,
890                                    input: input.to_string(),
891                                    failed_at: Utc::now(),
892                                });
893                            }
894                        }
895                    } else {
896                        dead_letters.push(DeadLetterEntry {
897                            step_name: step_name.clone(),
898                            error: e,
899                            input: input.to_string(),
900                            failed_at: Utc::now(),
901                        });
902                    }
903
904                    completed.insert(step_name, step_result);
905                }
906            }
907        }
908
909        execution_trace.push(ExecutionTraceEntry {
910            steps: wave_step_names,
911            started_at: wave_start,
912            completed_at: Some(Utc::now()),
913        });
914    }
915
916    // Determine final status
917    let has_failures = completed.values().any(|r| r.status == StepStatus::Failed);
918    let has_successes = completed
919        .values()
920        .any(|r| r.status == StepStatus::Completed);
921
922    let status = if has_failures && has_successes {
923        WorkflowRunStatus::PartiallyCompleted
924    } else if has_failures {
925        WorkflowRunStatus::Failed
926    } else {
927        WorkflowRunStatus::Completed
928    };
929
930    DagExecutionResult {
931        status,
932        step_results: completed,
933        dead_letters,
934        execution_trace,
935        validation_errors: Vec::new(),
936    }
937}
938
939/// Execute a step, handling loop configurations.
940async fn execute_step_with_loops(
941    step: &DagWorkflowStep,
942    input: &str,
943    completed: &HashMap<String, StepResult>,
944    executor: &dyn StepExecutor,
945) -> Result<StepResult, String> {
946    match &step.loop_config {
947        None => executor.execute(step, input, completed, None).await,
948        Some(LoopConfig::ForEach {
949            source_step,
950            max_iterations,
951        }) => {
952            let source_output = completed
953                .get(source_step)
954                .map(|r| r.response.as_str())
955                .unwrap_or("[]");
956            let items = parse_foreach_items(source_output)?;
957            let max = (*max_iterations).min(items.len());
958
959            let mut loop_state = LoopState::new();
960            let start = Utc::now();
961            let instant = Instant::now();
962
963            for (i, item) in items.into_iter().take(max).enumerate() {
964                loop_state.index = i;
965                loop_state.item = Some(item);
966
967                let result = executor
968                    .execute(step, input, completed, Some(&loop_state))
969                    .await;
970
971                match result {
972                    Ok(r) => {
973                        // Check for break/continue signals in output
974                        if r.response.contains("__BREAK__") {
975                            loop_state.push_result(r.response.replace("__BREAK__", ""));
976                            break;
977                        }
978                        if r.response.contains("__CONTINUE__") {
979                            continue;
980                        }
981                        loop_state.push_result(r.response);
982                    }
983                    Err(e) => return Err(e),
984                }
985            }
986
987            let combined = loop_state.accumulated_results.join("\n");
988            Ok(StepResult {
989                step_name: step.name.clone(),
990                response: combined,
991                tokens_used: 0,
992                duration_ms: instant.elapsed().as_millis() as u64,
993                error: None,
994                status: StepStatus::Completed,
995                started_at: Some(start),
996                completed_at: Some(Utc::now()),
997            })
998        }
999        Some(LoopConfig::While {
1000            condition,
1001            max_iterations,
1002        }) => {
1003            let mut loop_state = LoopState::new();
1004            let start = Utc::now();
1005            let instant = Instant::now();
1006
1007            for i in 0..*max_iterations {
1008                // Evaluate the condition with current completed results
1009                // For while loops, we add the accumulated results as a synthetic step
1010                let mut extended = completed.clone();
1011                if !loop_state.accumulated_results.is_empty() {
1012                    extended.insert(
1013                        step.name.clone(),
1014                        StepResult {
1015                            step_name: step.name.clone(),
1016                            response: loop_state
1017                                .accumulated_results
1018                                .last()
1019                                .cloned()
1020                                .unwrap_or_default(),
1021                            tokens_used: 0,
1022                            duration_ms: 0,
1023                            error: None,
1024                            status: StepStatus::Completed,
1025                            started_at: None,
1026                            completed_at: None,
1027                        },
1028                    );
1029                }
1030
1031                if !evaluate_condition(condition, &extended) {
1032                    break;
1033                }
1034
1035                loop_state.index = i;
1036                let result = executor
1037                    .execute(step, input, &extended, Some(&loop_state))
1038                    .await;
1039
1040                match result {
1041                    Ok(r) => {
1042                        if r.response.contains("__BREAK__") {
1043                            loop_state.push_result(r.response.replace("__BREAK__", ""));
1044                            break;
1045                        }
1046                        loop_state.push_result(r.response);
1047                    }
1048                    Err(e) => return Err(e),
1049                }
1050            }
1051
1052            let combined = loop_state.accumulated_results.join("\n");
1053            Ok(StepResult {
1054                step_name: step.name.clone(),
1055                response: combined,
1056                tokens_used: 0,
1057                duration_ms: instant.elapsed().as_millis() as u64,
1058                error: None,
1059                status: StepStatus::Completed,
1060                started_at: Some(start),
1061                completed_at: Some(Utc::now()),
1062            })
1063        }
1064        Some(LoopConfig::Retry {
1065            max_retries,
1066            backoff_ms,
1067            backoff_multiplier,
1068        }) => {
1069            let start = Utc::now();
1070            let instant = Instant::now();
1071            let mut last_error = String::new();
1072
1073            for attempt in 0..=*max_retries {
1074                if attempt > 0 {
1075                    let wait = calculate_backoff(attempt - 1, *backoff_ms, *backoff_multiplier);
1076                    tokio::time::sleep(std::time::Duration::from_millis(wait)).await;
1077                }
1078
1079                match executor.execute(step, input, completed, None).await {
1080                    Ok(r) => return Ok(r),
1081                    Err(e) => {
1082                        last_error = e;
1083                        warn!(step = %step.name, attempt = attempt + 1, "retry attempt failed");
1084                    }
1085                }
1086            }
1087
1088            Ok(StepResult {
1089                step_name: step.name.clone(),
1090                response: String::new(),
1091                tokens_used: 0,
1092                duration_ms: instant.elapsed().as_millis() as u64,
1093                error: Some(last_error),
1094                status: StepStatus::Failed,
1095                started_at: Some(start),
1096                completed_at: Some(Utc::now()),
1097            })
1098        }
1099    }
1100}
1101
1102/// Result of executing a DAG workflow.
1103#[derive(Debug, Clone)]
1104pub struct DagExecutionResult {
1105    /// Overall workflow status.
1106    pub status: WorkflowRunStatus,
1107    /// Per-step results keyed by step name.
1108    pub step_results: HashMap<String, StepResult>,
1109    /// Dead letter entries for failed steps.
1110    pub dead_letters: Vec<DeadLetterEntry>,
1111    /// Execution trace.
1112    pub execution_trace: Vec<ExecutionTraceEntry>,
1113    /// Validation errors (if any — non-empty means workflow didn't execute).
1114    pub validation_errors: Vec<ValidationError>,
1115}
1116
1117// ---------------------------------------------------------------------------
1118// WorkflowEngine
1119// ---------------------------------------------------------------------------
1120
1121/// Engine for registering and executing multi-step agent workflows.
1122pub struct WorkflowEngine {
1123    /// Registered workflow definitions (sequential).
1124    workflows: DashMap<WorkflowId, Workflow>,
1125    /// Registered DAG workflow definitions.
1126    dag_workflows: DashMap<WorkflowId, DagWorkflow>,
1127    /// Workflow execution runs.
1128    runs: DashMap<WorkflowRunId, WorkflowRun>,
1129}
1130
1131impl WorkflowEngine {
1132    /// Create a new workflow engine.
1133    pub fn new() -> Self {
1134        Self {
1135            workflows: DashMap::new(),
1136            dag_workflows: DashMap::new(),
1137            runs: DashMap::new(),
1138        }
1139    }
1140
1141    /// Register a sequential workflow definition and return its ID.
1142    pub fn register_workflow(&self, workflow: Workflow) -> WorkflowId {
1143        let id = workflow.id;
1144        info!(workflow_id = %id, name = %workflow.name, "workflow registered");
1145        self.workflows.insert(id, workflow);
1146        id
1147    }
1148
1149    /// Register a DAG workflow definition and return its ID.
1150    ///
1151    /// Validates the workflow before registering. Returns an error with
1152    /// validation details if the workflow is invalid.
1153    pub fn register_dag_workflow(
1154        &self,
1155        workflow: DagWorkflow,
1156    ) -> Result<WorkflowId, Vec<ValidationError>> {
1157        let errors = validate_workflow(&workflow.steps);
1158        if !errors.is_empty() {
1159            return Err(errors);
1160        }
1161        let id = workflow.id;
1162        info!(workflow_id = %id, name = %workflow.name, "DAG workflow registered");
1163        self.dag_workflows.insert(id, workflow);
1164        Ok(id)
1165    }
1166
1167    /// Execute a sequential workflow with the given input string.
1168    #[instrument(skip(self, input, memory, driver, model_config), fields(%workflow_id))]
1169    pub async fn execute_workflow(
1170        &self,
1171        workflow_id: &WorkflowId,
1172        input: String,
1173        memory: Arc<MemorySubstrate>,
1174        driver: Arc<dyn LlmDriver>,
1175        model_config: &ModelConfig,
1176    ) -> PunchResult<WorkflowRunId> {
1177        let workflow = self
1178            .workflows
1179            .get(workflow_id)
1180            .ok_or_else(|| PunchError::Internal(format!("workflow {} not found", workflow_id)))?
1181            .clone();
1182
1183        let run_id = WorkflowRunId::new();
1184        let run = WorkflowRun {
1185            id: run_id,
1186            workflow_id: *workflow_id,
1187            status: WorkflowRunStatus::Running,
1188            step_results: Vec::new(),
1189            started_at: Utc::now(),
1190            completed_at: None,
1191            dead_letters: Vec::new(),
1192            execution_trace: Vec::new(),
1193        };
1194        self.runs.insert(run_id, run);
1195
1196        let mut current_input = input.clone();
1197        let mut step_results: Vec<StepResult> = Vec::new();
1198        let mut failed = false;
1199
1200        for step in &workflow.steps {
1201            let result = self
1202                .execute_single_step(
1203                    step,
1204                    &workflow.name,
1205                    &current_input,
1206                    &step_results,
1207                    &memory,
1208                    &driver,
1209                    model_config,
1210                )
1211                .await;
1212
1213            match result {
1214                Ok(step_result) => {
1215                    current_input = step_result.response.clone();
1216                    step_results.push(step_result);
1217                }
1218                Err(e) => {
1219                    let error_msg = format!("{e}");
1220                    match step.on_error {
1221                        OnError::SkipStep => {
1222                            warn!(step = %step.name, error = %error_msg, "step failed, skipping");
1223                            let skip_result = StepResult {
1224                                step_name: step.name.clone(),
1225                                response: String::new(),
1226                                tokens_used: 0,
1227                                duration_ms: 0,
1228                                error: Some(error_msg),
1229                                status: StepStatus::Skipped,
1230                                started_at: None,
1231                                completed_at: None,
1232                            };
1233                            step_results.push(skip_result);
1234                            continue;
1235                        }
1236                        OnError::RetryOnce => {
1237                            warn!(step = %step.name, error = %error_msg, "step failed, retrying once");
1238                            let retry_result = self
1239                                .execute_single_step(
1240                                    step,
1241                                    &workflow.name,
1242                                    &current_input,
1243                                    &step_results,
1244                                    &memory,
1245                                    &driver,
1246                                    model_config,
1247                                )
1248                                .await;
1249
1250                            match retry_result {
1251                                Ok(step_result) => {
1252                                    current_input = step_result.response.clone();
1253                                    step_results.push(step_result);
1254                                }
1255                                Err(retry_err) => {
1256                                    error!(step = %step.name, error = %retry_err, "step failed on retry");
1257                                    let fail_result = StepResult {
1258                                        step_name: step.name.clone(),
1259                                        response: String::new(),
1260                                        tokens_used: 0,
1261                                        duration_ms: 0,
1262                                        error: Some(format!("{retry_err}")),
1263                                        status: StepStatus::Failed,
1264                                        started_at: None,
1265                                        completed_at: None,
1266                                    };
1267                                    step_results.push(fail_result);
1268                                    failed = true;
1269                                    break;
1270                                }
1271                            }
1272                        }
1273                        OnError::FailWorkflow => {
1274                            error!(step = %step.name, error = %error_msg, "step failed, aborting workflow");
1275                            let fail_result = StepResult {
1276                                step_name: step.name.clone(),
1277                                response: String::new(),
1278                                tokens_used: 0,
1279                                duration_ms: 0,
1280                                error: Some(error_msg),
1281                                status: StepStatus::Failed,
1282                                started_at: None,
1283                                completed_at: None,
1284                            };
1285                            step_results.push(fail_result);
1286                            failed = true;
1287                            break;
1288                        }
1289                        _ => {
1290                            // Fallback/CatchAndContinue/CircuitBreaker in sequential mode
1291                            // just fail the workflow for now
1292                            let fail_result = StepResult {
1293                                step_name: step.name.clone(),
1294                                response: String::new(),
1295                                tokens_used: 0,
1296                                duration_ms: 0,
1297                                error: Some(error_msg),
1298                                status: StepStatus::Failed,
1299                                started_at: None,
1300                                completed_at: None,
1301                            };
1302                            step_results.push(fail_result);
1303                            failed = true;
1304                            break;
1305                        }
1306                    }
1307                }
1308            }
1309        }
1310
1311        // Update the run with results.
1312        if let Some(mut run) = self.runs.get_mut(&run_id) {
1313            run.step_results = step_results;
1314            run.status = if failed {
1315                WorkflowRunStatus::Failed
1316            } else {
1317                WorkflowRunStatus::Completed
1318            };
1319            run.completed_at = Some(Utc::now());
1320        }
1321
1322        Ok(run_id)
1323    }
1324
1325    /// Execute a single workflow step, creating a temporary fighter and running
1326    /// it through the fighter loop.
1327    #[allow(clippy::too_many_arguments)]
1328    async fn execute_single_step(
1329        &self,
1330        step: &WorkflowStep,
1331        workflow_name: &str,
1332        current_input: &str,
1333        step_results: &[StepResult],
1334        memory: &Arc<MemorySubstrate>,
1335        driver: &Arc<dyn LlmDriver>,
1336        model_config: &ModelConfig,
1337    ) -> PunchResult<StepResult> {
1338        let step_start = Instant::now();
1339        let started_at = Utc::now();
1340
1341        // Substitute variables in the prompt template.
1342        let prompt = expand_variables(
1343            &step.prompt_template,
1344            current_input,
1345            &step.name,
1346            step_results,
1347        );
1348
1349        // Create a temporary fighter for this step.
1350        let fighter_id = FighterId::new();
1351        let fighter_manifest = FighterManifest {
1352            name: step.fighter_name.clone(),
1353            description: format!("Workflow step: {}", step.name),
1354            model: model_config.clone(),
1355            system_prompt: format!(
1356                "You are executing step '{}' of workflow '{}'.",
1357                step.name, workflow_name
1358            ),
1359            capabilities: Vec::new(),
1360            weight_class: WeightClass::Middleweight,
1361            tenant_id: None,
1362        };
1363
1364        // Save the fighter and create a bout.
1365        if let Err(e) = memory
1366            .save_fighter(
1367                &fighter_id,
1368                &fighter_manifest,
1369                punch_types::FighterStatus::Idle,
1370            )
1371            .await
1372        {
1373            error!(error = %e, "failed to persist workflow fighter");
1374        }
1375
1376        let bout_id = memory.create_bout(&fighter_id).await.map_err(|e| {
1377            PunchError::Internal(format!(
1378                "failed to create bout for step '{}': {e}",
1379                step.name
1380            ))
1381        })?;
1382
1383        let available_tools = tools_for_capabilities(&fighter_manifest.capabilities);
1384        let timeout_secs = step.timeout_secs.unwrap_or(120);
1385
1386        let params = FighterLoopParams {
1387            manifest: fighter_manifest,
1388            user_message: prompt,
1389            bout_id,
1390            fighter_id,
1391            memory: Arc::clone(memory),
1392            driver: Arc::clone(driver),
1393            available_tools,
1394            mcp_tools: Vec::new(),
1395            max_iterations: Some(20),
1396            context_window: None,
1397            tool_timeout_secs: Some(timeout_secs),
1398            coordinator: None,
1399            approval_engine: None,
1400            sandbox: None,
1401            mcp_clients: None,
1402            model_routing: None,
1403            channel_notifier: None,
1404            user_content_parts: vec![],
1405            eco_mode: false,
1406        };
1407
1408        let loop_result = tokio::time::timeout(
1409            std::time::Duration::from_secs(timeout_secs),
1410            run_fighter_loop(params),
1411        )
1412        .await;
1413
1414        match loop_result {
1415            Ok(Ok(result)) => {
1416                let step_result = StepResult {
1417                    step_name: step.name.clone(),
1418                    response: result.response,
1419                    tokens_used: result.usage.total(),
1420                    duration_ms: step_start.elapsed().as_millis() as u64,
1421                    error: None,
1422                    status: StepStatus::Completed,
1423                    started_at: Some(started_at),
1424                    completed_at: Some(Utc::now()),
1425                };
1426                info!(step = %step.name, tokens = step_result.tokens_used, "workflow step completed");
1427                Ok(step_result)
1428            }
1429            Ok(Err(e)) => Err(e),
1430            Err(_) => Err(PunchError::Internal(format!(
1431                "step '{}' timed out after {}s",
1432                step.name, timeout_secs
1433            ))),
1434        }
1435    }
1436
1437    /// Get a workflow run by its ID.
1438    pub fn get_run(&self, run_id: &WorkflowRunId) -> Option<WorkflowRun> {
1439        self.runs.get(run_id).map(|r| r.clone())
1440    }
1441
1442    /// List all registered sequential workflows.
1443    pub fn list_workflows(&self) -> Vec<Workflow> {
1444        self.workflows.iter().map(|w| w.value().clone()).collect()
1445    }
1446
1447    /// List all registered DAG workflows.
1448    pub fn list_dag_workflows(&self) -> Vec<DagWorkflow> {
1449        self.dag_workflows
1450            .iter()
1451            .map(|w| w.value().clone())
1452            .collect()
1453    }
1454
1455    /// List all workflow runs.
1456    pub fn list_runs(&self) -> Vec<WorkflowRun> {
1457        self.runs.iter().map(|r| r.value().clone()).collect()
1458    }
1459
1460    /// List workflow runs filtered by workflow ID.
1461    pub fn list_runs_for_workflow(&self, workflow_id: &WorkflowId) -> Vec<WorkflowRun> {
1462        self.runs
1463            .iter()
1464            .filter(|r| r.value().workflow_id == *workflow_id)
1465            .map(|r| r.value().clone())
1466            .collect()
1467    }
1468
1469    /// Get a sequential workflow by its ID.
1470    pub fn get_workflow(&self, id: &WorkflowId) -> Option<Workflow> {
1471        self.workflows.get(id).map(|w| w.clone())
1472    }
1473
1474    /// Get a DAG workflow by its ID.
1475    pub fn get_dag_workflow(&self, id: &WorkflowId) -> Option<DagWorkflow> {
1476        self.dag_workflows.get(id).map(|w| w.clone())
1477    }
1478}
1479
1480impl Default for WorkflowEngine {
1481    fn default() -> Self {
1482        Self::new()
1483    }
1484}
1485
1486// ---------------------------------------------------------------------------
1487// Tests
1488// ---------------------------------------------------------------------------
1489
1490#[cfg(test)]
1491mod tests {
1492    use super::*;
1493    use std::sync::atomic::{AtomicUsize, Ordering};
1494    use std::time::Duration;
1495
1496    // A mock step executor for testing
1497    struct MockExecutor {
1498        /// Map of step name -> response
1499        responses: HashMap<String, String>,
1500        /// Steps that should fail
1501        failing_steps: HashMap<String, String>,
1502        /// Track execution count per step
1503        execution_counts: DashMap<String, AtomicUsize>,
1504    }
1505
1506    impl MockExecutor {
1507        fn new() -> Self {
1508            Self {
1509                responses: HashMap::new(),
1510                failing_steps: HashMap::new(),
1511                execution_counts: DashMap::new(),
1512            }
1513        }
1514
1515        fn with_response(mut self, step: &str, response: &str) -> Self {
1516            self.responses
1517                .insert(step.to_string(), response.to_string());
1518            self
1519        }
1520
1521        fn with_failure(mut self, step: &str, error: &str) -> Self {
1522            self.failing_steps
1523                .insert(step.to_string(), error.to_string());
1524            self
1525        }
1526
1527        #[allow(dead_code)]
1528        fn execution_count(&self, step: &str) -> usize {
1529            self.execution_counts
1530                .get(step)
1531                .map(|c| c.load(Ordering::Relaxed))
1532                .unwrap_or(0)
1533        }
1534    }
1535
1536    #[async_trait::async_trait]
1537    impl StepExecutor for MockExecutor {
1538        async fn execute(
1539            &self,
1540            step: &DagWorkflowStep,
1541            input: &str,
1542            step_results: &HashMap<String, StepResult>,
1543            loop_state: Option<&LoopState>,
1544        ) -> Result<StepResult, String> {
1545            // Track execution
1546            self.execution_counts
1547                .entry(step.name.clone())
1548                .or_insert_with(|| AtomicUsize::new(0))
1549                .fetch_add(1, Ordering::Relaxed);
1550
1551            // Check if step should fail
1552            if let Some(err) = self.failing_steps.get(&step.name) {
1553                return Err(err.clone());
1554            }
1555
1556            let prompt = expand_dag_variables(
1557                &step.prompt_template,
1558                input,
1559                &step.name,
1560                step_results,
1561                loop_state,
1562            );
1563
1564            let response = self.responses.get(&step.name).cloned().unwrap_or(prompt);
1565
1566            Ok(StepResult {
1567                step_name: step.name.clone(),
1568                response,
1569                tokens_used: 10,
1570                duration_ms: 5,
1571                error: None,
1572                status: StepStatus::Completed,
1573                started_at: Some(Utc::now()),
1574                completed_at: Some(Utc::now()),
1575            })
1576        }
1577    }
1578
1579    /// A mock executor that adds a delay to simulate real execution time.
1580    struct TimedMockExecutor {
1581        delay_ms: u64,
1582    }
1583
1584    #[async_trait::async_trait]
1585    impl StepExecutor for TimedMockExecutor {
1586        async fn execute(
1587            &self,
1588            step: &DagWorkflowStep,
1589            _input: &str,
1590            _step_results: &HashMap<String, StepResult>,
1591            _loop_state: Option<&LoopState>,
1592        ) -> Result<StepResult, String> {
1593            tokio::time::sleep(Duration::from_millis(self.delay_ms)).await;
1594            Ok(StepResult {
1595                step_name: step.name.clone(),
1596                response: format!("done-{}", step.name),
1597                tokens_used: 10,
1598                duration_ms: self.delay_ms,
1599                error: None,
1600                status: StepStatus::Completed,
1601                started_at: Some(Utc::now()),
1602                completed_at: Some(Utc::now()),
1603            })
1604        }
1605    }
1606
1607    /// A mock executor that fails the first N attempts for a step.
1608    struct FailNTimesMockExecutor {
1609        fail_count: usize,
1610        attempts: DashMap<String, AtomicUsize>,
1611    }
1612
1613    impl FailNTimesMockExecutor {
1614        fn new(fail_count: usize) -> Self {
1615            Self {
1616                fail_count,
1617                attempts: DashMap::new(),
1618            }
1619        }
1620    }
1621
1622    #[async_trait::async_trait]
1623    impl StepExecutor for FailNTimesMockExecutor {
1624        async fn execute(
1625            &self,
1626            step: &DagWorkflowStep,
1627            _input: &str,
1628            _step_results: &HashMap<String, StepResult>,
1629            _loop_state: Option<&LoopState>,
1630        ) -> Result<StepResult, String> {
1631            let attempt = self
1632                .attempts
1633                .entry(step.name.clone())
1634                .or_insert_with(|| AtomicUsize::new(0))
1635                .fetch_add(1, Ordering::Relaxed);
1636
1637            if attempt < self.fail_count {
1638                return Err(format!("failure attempt {}", attempt + 1));
1639            }
1640
1641            Ok(StepResult {
1642                step_name: step.name.clone(),
1643                response: format!("success on attempt {}", attempt + 1),
1644                tokens_used: 10,
1645                duration_ms: 5,
1646                error: None,
1647                status: StepStatus::Completed,
1648                started_at: Some(Utc::now()),
1649                completed_at: Some(Utc::now()),
1650            })
1651        }
1652    }
1653
1654    fn dag_step(name: &str, deps: &[&str]) -> DagWorkflowStep {
1655        DagWorkflowStep {
1656            name: name.to_string(),
1657            fighter_name: "test".to_string(),
1658            prompt_template: "{{input}}".to_string(),
1659            timeout_secs: None,
1660            on_error: OnError::FailWorkflow,
1661            depends_on: deps.iter().map(|d| d.to_string()).collect(),
1662            condition: None,
1663            else_step: None,
1664            loop_config: None,
1665        }
1666    }
1667
1668    // ---- Existing sequential tests (preserved) ----
1669
1670    #[test]
1671    fn register_and_list_workflows() {
1672        let engine = WorkflowEngine::new();
1673
1674        let workflow = Workflow {
1675            id: WorkflowId::new(),
1676            name: "test-workflow".to_string(),
1677            steps: vec![
1678                WorkflowStep {
1679                    name: "step1".to_string(),
1680                    fighter_name: "analyzer".to_string(),
1681                    prompt_template: "Analyze: {{input}}".to_string(),
1682                    timeout_secs: None,
1683                    on_error: OnError::FailWorkflow,
1684                },
1685                WorkflowStep {
1686                    name: "step2".to_string(),
1687                    fighter_name: "summarizer".to_string(),
1688                    prompt_template: "Summarize the analysis: {{step1}}".to_string(),
1689                    timeout_secs: Some(60),
1690                    on_error: OnError::SkipStep,
1691                },
1692            ],
1693        };
1694
1695        let id = engine.register_workflow(workflow);
1696        let workflows = engine.list_workflows();
1697        assert_eq!(workflows.len(), 1);
1698        assert_eq!(workflows[0].name, "test-workflow");
1699        assert_eq!(workflows[0].steps.len(), 2);
1700
1701        let fetched = engine.get_workflow(&id).expect("workflow should exist");
1702        assert_eq!(fetched.name, "test-workflow");
1703    }
1704
1705    #[test]
1706    fn variable_substitution_basic() {
1707        let result = expand_variables(
1708            "Analyze {{input}} for step {{step_name}}",
1709            "hello world",
1710            "analysis",
1711            &[],
1712        );
1713        assert_eq!(result, "Analyze hello world for step analysis");
1714    }
1715
1716    #[test]
1717    fn variable_substitution_previous_output() {
1718        let result = expand_variables(
1719            "Continue from: {{previous_output}}",
1720            "step 1 output",
1721            "step2",
1722            &[],
1723        );
1724        assert_eq!(result, "Continue from: step 1 output");
1725    }
1726
1727    #[test]
1728    fn variable_substitution_step_refs() {
1729        let step_results = vec![
1730            StepResult {
1731                step_name: "analyze".to_string(),
1732                response: "analysis result".to_string(),
1733                tokens_used: 100,
1734                duration_ms: 500,
1735                error: None,
1736                status: StepStatus::Completed,
1737                started_at: None,
1738                completed_at: None,
1739            },
1740            StepResult {
1741                step_name: "review".to_string(),
1742                response: "review result".to_string(),
1743                tokens_used: 80,
1744                duration_ms: 400,
1745                error: None,
1746                status: StepStatus::Completed,
1747                started_at: None,
1748                completed_at: None,
1749            },
1750        ];
1751
1752        let result = expand_variables(
1753            "Step 1 said: {{step_1}}, Step 2 said: {{step_2}}",
1754            "current",
1755            "step3",
1756            &step_results,
1757        );
1758        assert_eq!(
1759            result,
1760            "Step 1 said: analysis result, Step 2 said: review result"
1761        );
1762
1763        let result = expand_variables(
1764            "Analysis: {{analyze}}, Review: {{review}}",
1765            "current",
1766            "step3",
1767            &step_results,
1768        );
1769        assert_eq!(result, "Analysis: analysis result, Review: review result");
1770    }
1771
1772    #[test]
1773    fn workflow_run_status_display() {
1774        assert_eq!(WorkflowRunStatus::Pending.to_string(), "pending");
1775        assert_eq!(WorkflowRunStatus::Running.to_string(), "running");
1776        assert_eq!(WorkflowRunStatus::Completed.to_string(), "completed");
1777        assert_eq!(WorkflowRunStatus::Failed.to_string(), "failed");
1778        assert_eq!(
1779            WorkflowRunStatus::PartiallyCompleted.to_string(),
1780            "partially_completed"
1781        );
1782    }
1783
1784    #[test]
1785    fn get_nonexistent_run_returns_none() {
1786        let engine = WorkflowEngine::new();
1787        let run_id = WorkflowRunId::new();
1788        assert!(engine.get_run(&run_id).is_none());
1789    }
1790
1791    #[test]
1792    fn get_nonexistent_workflow_returns_none() {
1793        let engine = WorkflowEngine::new();
1794        let id = WorkflowId::new();
1795        assert!(engine.get_workflow(&id).is_none());
1796    }
1797
1798    #[test]
1799    fn workflow_engine_default() {
1800        let engine = WorkflowEngine::default();
1801        assert!(engine.list_workflows().is_empty());
1802        assert!(engine.list_runs().is_empty());
1803    }
1804
1805    #[test]
1806    fn register_multiple_workflows() {
1807        let engine = WorkflowEngine::new();
1808
1809        for i in 0..5 {
1810            let workflow = Workflow {
1811                id: WorkflowId::new(),
1812                name: format!("workflow-{}", i),
1813                steps: vec![],
1814            };
1815            engine.register_workflow(workflow);
1816        }
1817
1818        assert_eq!(engine.list_workflows().len(), 5);
1819    }
1820
1821    #[test]
1822    fn register_workflow_returns_correct_id() {
1823        let engine = WorkflowEngine::new();
1824        let wf_id = WorkflowId::new();
1825        let workflow = Workflow {
1826            id: wf_id,
1827            name: "id-test".to_string(),
1828            steps: vec![],
1829        };
1830        let returned_id = engine.register_workflow(workflow);
1831        assert_eq!(returned_id, wf_id);
1832    }
1833
1834    #[test]
1835    fn workflow_id_display() {
1836        let id = WorkflowId::new();
1837        let s = format!("{}", id);
1838        assert!(!s.is_empty());
1839    }
1840
1841    #[test]
1842    fn workflow_run_id_display() {
1843        let id = WorkflowRunId::new();
1844        let s = format!("{}", id);
1845        assert!(!s.is_empty());
1846    }
1847
1848    #[test]
1849    fn workflow_id_default() {
1850        let id = WorkflowId::default();
1851        assert!(!id.0.is_nil());
1852    }
1853
1854    #[test]
1855    fn workflow_run_id_default() {
1856        let id = WorkflowRunId::default();
1857        assert!(!id.0.is_nil());
1858    }
1859
1860    #[test]
1861    fn variable_substitution_no_variables() {
1862        let result = expand_variables("plain text with no vars", "input", "step", &[]);
1863        assert_eq!(result, "plain text with no vars");
1864    }
1865
1866    #[test]
1867    fn variable_substitution_all_variables_at_once() {
1868        let step_results = vec![StepResult {
1869            step_name: "analysis".to_string(),
1870            response: "analyzed data".to_string(),
1871            tokens_used: 50,
1872            duration_ms: 100,
1873            error: None,
1874            status: StepStatus::Completed,
1875            started_at: None,
1876            completed_at: None,
1877        }];
1878
1879        let result = expand_variables(
1880            "Input: {{input}}, Prev: {{previous_output}}, Step: {{step_name}}, S1: {{step_1}}, Named: {{analysis}}",
1881            "my input",
1882            "current_step",
1883            &step_results,
1884        );
1885        assert_eq!(
1886            result,
1887            "Input: my input, Prev: my input, Step: current_step, S1: analyzed data, Named: analyzed data"
1888        );
1889    }
1890
1891    #[test]
1892    fn variable_substitution_empty_input() {
1893        let result = expand_variables("{{input}} is here", "", "step", &[]);
1894        assert_eq!(result, " is here");
1895    }
1896
1897    #[test]
1898    fn variable_substitution_multiple_same_var() {
1899        let result = expand_variables("{{input}} and {{input}} again", "hello", "step", &[]);
1900        assert_eq!(result, "hello and hello again");
1901    }
1902
1903    #[test]
1904    fn on_error_default_is_fail_workflow() {
1905        let on_error = OnError::default();
1906        assert!(matches!(on_error, OnError::FailWorkflow));
1907    }
1908
1909    #[test]
1910    fn list_runs_for_workflow_filters_correctly() {
1911        let engine = WorkflowEngine::new();
1912        let wf_id_1 = WorkflowId::new();
1913        let wf_id_2 = WorkflowId::new();
1914
1915        assert!(engine.list_runs_for_workflow(&wf_id_1).is_empty());
1916        assert!(engine.list_runs_for_workflow(&wf_id_2).is_empty());
1917    }
1918
1919    #[test]
1920    fn workflow_step_serialization() {
1921        let step = WorkflowStep {
1922            name: "test".to_string(),
1923            fighter_name: "fighter".to_string(),
1924            prompt_template: "Do {{input}}".to_string(),
1925            timeout_secs: Some(30),
1926            on_error: OnError::SkipStep,
1927        };
1928        let json = serde_json::to_string(&step).expect("serialize");
1929        let deserialized: WorkflowStep = serde_json::from_str(&json).expect("deserialize");
1930        assert_eq!(deserialized.name, "test");
1931        assert_eq!(deserialized.timeout_secs, Some(30));
1932    }
1933
1934    #[test]
1935    fn workflow_serialization_roundtrip() {
1936        let workflow = Workflow {
1937            id: WorkflowId::new(),
1938            name: "roundtrip".to_string(),
1939            steps: vec![WorkflowStep {
1940                name: "s1".to_string(),
1941                fighter_name: "f1".to_string(),
1942                prompt_template: "{{input}}".to_string(),
1943                timeout_secs: None,
1944                on_error: OnError::RetryOnce,
1945            }],
1946        };
1947        let json = serde_json::to_string(&workflow).expect("serialize");
1948        let deserialized: Workflow = serde_json::from_str(&json).expect("deserialize");
1949        assert_eq!(deserialized.name, "roundtrip");
1950        assert_eq!(deserialized.steps.len(), 1);
1951    }
1952
1953    #[test]
1954    fn step_result_with_error() {
1955        let sr = StepResult {
1956            step_name: "failing".to_string(),
1957            response: String::new(),
1958            tokens_used: 0,
1959            duration_ms: 0,
1960            error: Some("timeout".to_string()),
1961            status: StepStatus::Failed,
1962            started_at: None,
1963            completed_at: None,
1964        };
1965        assert!(sr.error.is_some());
1966        assert_eq!(sr.error.expect("error"), "timeout");
1967    }
1968
1969    #[test]
1970    fn variable_substitution_step_ref_by_number_out_of_range() {
1971        let step_results = vec![
1972            StepResult {
1973                step_name: "a".to_string(),
1974                response: "r1".to_string(),
1975                tokens_used: 0,
1976                duration_ms: 0,
1977                error: None,
1978                status: StepStatus::Completed,
1979                started_at: None,
1980                completed_at: None,
1981            },
1982            StepResult {
1983                step_name: "b".to_string(),
1984                response: "r2".to_string(),
1985                tokens_used: 0,
1986                duration_ms: 0,
1987                error: None,
1988                status: StepStatus::Completed,
1989                started_at: None,
1990                completed_at: None,
1991            },
1992        ];
1993        let result = expand_variables("{{step_5}}", "input", "step", &step_results);
1994        assert_eq!(result, "{{step_5}}");
1995    }
1996
1997    // ---- New DAG tests ----
1998
1999    #[tokio::test]
2000    async fn dag_linear_execution() {
2001        let steps = vec![
2002            dag_step("a", &[]),
2003            dag_step("b", &["a"]),
2004            dag_step("c", &["b"]),
2005        ];
2006        let executor = MockExecutor::new()
2007            .with_response("a", "result_a")
2008            .with_response("b", "result_b")
2009            .with_response("c", "result_c");
2010
2011        let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2012        assert_eq!(result.status, WorkflowRunStatus::Completed);
2013        assert_eq!(result.step_results.len(), 3);
2014        assert_eq!(result.step_results["a"].response, "result_a");
2015        assert_eq!(result.step_results["b"].response, "result_b");
2016        assert_eq!(result.step_results["c"].response, "result_c");
2017    }
2018
2019    #[tokio::test]
2020    async fn dag_fan_out_execution() {
2021        let steps = vec![
2022            dag_step("root", &[]),
2023            dag_step("branch1", &["root"]),
2024            dag_step("branch2", &["root"]),
2025            dag_step("branch3", &["root"]),
2026        ];
2027        let executor = MockExecutor::new()
2028            .with_response("root", "root_out")
2029            .with_response("branch1", "b1_out")
2030            .with_response("branch2", "b2_out")
2031            .with_response("branch3", "b3_out");
2032
2033        let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2034        assert_eq!(result.status, WorkflowRunStatus::Completed);
2035        assert_eq!(result.step_results.len(), 4);
2036        // All branches should have completed
2037        assert_eq!(result.step_results["branch1"].response, "b1_out");
2038        assert_eq!(result.step_results["branch2"].response, "b2_out");
2039        assert_eq!(result.step_results["branch3"].response, "b3_out");
2040    }
2041
2042    #[tokio::test]
2043    async fn dag_fan_in_execution() {
2044        let steps = vec![
2045            dag_step("a", &[]),
2046            dag_step("b", &[]),
2047            dag_step("c", &[]),
2048            dag_step("join", &["a", "b", "c"]),
2049        ];
2050        let executor = MockExecutor::new()
2051            .with_response("a", "ra")
2052            .with_response("b", "rb")
2053            .with_response("c", "rc")
2054            .with_response("join", "joined");
2055
2056        let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2057        assert_eq!(result.status, WorkflowRunStatus::Completed);
2058        assert_eq!(result.step_results["join"].response, "joined");
2059        // a, b, c should have run in the same wave (first trace entry)
2060        assert_eq!(result.execution_trace.len(), 2);
2061        let first_wave = &result.execution_trace[0].steps;
2062        assert!(first_wave.contains(&"a".to_string()));
2063        assert!(first_wave.contains(&"b".to_string()));
2064        assert!(first_wave.contains(&"c".to_string()));
2065    }
2066
2067    #[tokio::test]
2068    async fn dag_diamond_execution() {
2069        let steps = vec![
2070            dag_step("root", &[]),
2071            dag_step("left", &["root"]),
2072            dag_step("right", &["root"]),
2073            dag_step("join", &["left", "right"]),
2074        ];
2075        let executor = MockExecutor::new()
2076            .with_response("root", "root_out")
2077            .with_response("left", "left_out")
2078            .with_response("right", "right_out")
2079            .with_response("join", "joined");
2080
2081        let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2082        assert_eq!(result.status, WorkflowRunStatus::Completed);
2083        assert_eq!(result.step_results.len(), 4);
2084        // left and right should be in same wave
2085        let wave2 = &result.execution_trace[1].steps;
2086        assert!(wave2.contains(&"left".to_string()));
2087        assert!(wave2.contains(&"right".to_string()));
2088    }
2089
2090    #[tokio::test]
2091    async fn dag_parallel_actually_concurrent() {
2092        // Steps a, b, c have no deps, each takes 50ms.
2093        // If run sequentially: ~150ms. If parallel: ~50ms.
2094        let steps = vec![dag_step("a", &[]), dag_step("b", &[]), dag_step("c", &[])];
2095        let executor = TimedMockExecutor { delay_ms: 50 };
2096
2097        let start = Instant::now();
2098        let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2099        let elapsed = start.elapsed();
2100
2101        assert_eq!(result.status, WorkflowRunStatus::Completed);
2102        assert_eq!(result.step_results.len(), 3);
2103        // Should complete in roughly 50ms (parallel), not 150ms (sequential)
2104        // Use generous bound to avoid flakiness
2105        assert!(
2106            elapsed.as_millis() < 120,
2107            "parallel execution took {}ms, expected ~50ms",
2108            elapsed.as_millis()
2109        );
2110    }
2111
2112    #[tokio::test]
2113    async fn dag_condition_if_success() {
2114        let mut steps = vec![dag_step("a", &[]), dag_step("b", &["a"])];
2115        steps[1].condition = Some(Condition::IfSuccess {
2116            step: "a".to_string(),
2117        });
2118        let executor = MockExecutor::new()
2119            .with_response("a", "ok")
2120            .with_response("b", "b_ran");
2121
2122        let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2123        assert_eq!(result.step_results["b"].status, StepStatus::Completed);
2124        assert_eq!(result.step_results["b"].response, "b_ran");
2125    }
2126
2127    #[tokio::test]
2128    async fn dag_condition_skips_step() {
2129        let mut steps = vec![dag_step("a", &[]), dag_step("b", &["a"])];
2130        steps[1].condition = Some(Condition::IfFailure {
2131            step: "a".to_string(),
2132        });
2133        let executor = MockExecutor::new()
2134            .with_response("a", "ok")
2135            .with_response("b", "should_not_run");
2136
2137        let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2138        assert_eq!(result.step_results["b"].status, StepStatus::Skipped);
2139    }
2140
2141    #[tokio::test]
2142    async fn dag_condition_if_output() {
2143        let mut steps = vec![dag_step("a", &[]), dag_step("b", &["a"])];
2144        steps[1].condition = Some(Condition::IfOutput {
2145            step: "a".to_string(),
2146            contains: "magic".to_string(),
2147        });
2148        let executor = MockExecutor::new()
2149            .with_response("a", "this has magic inside")
2150            .with_response("b", "b_ran");
2151
2152        let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2153        assert_eq!(result.step_results["b"].status, StepStatus::Completed);
2154    }
2155
2156    #[tokio::test]
2157    async fn dag_condition_if_output_no_match() {
2158        let mut steps = vec![dag_step("a", &[]), dag_step("b", &["a"])];
2159        steps[1].condition = Some(Condition::IfOutput {
2160            step: "a".to_string(),
2161            contains: "magic".to_string(),
2162        });
2163        let executor = MockExecutor::new()
2164            .with_response("a", "no special word here")
2165            .with_response("b", "should_not_run");
2166
2167        let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2168        assert_eq!(result.step_results["b"].status, StepStatus::Skipped);
2169    }
2170
2171    #[tokio::test]
2172    async fn dag_foreach_loop() {
2173        let mut steps = vec![dag_step("source", &[]), dag_step("process", &["source"])];
2174        steps[0].prompt_template = "{{input}}".to_string();
2175        steps[1].loop_config = Some(LoopConfig::ForEach {
2176            source_step: "source".to_string(),
2177            max_iterations: 100,
2178        });
2179        steps[1].prompt_template = "process item: {{loop.item}}".to_string();
2180
2181        let executor =
2182            MockExecutor::new().with_response("source", r#"["apple", "banana", "cherry"]"#);
2183
2184        let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2185        assert_eq!(result.status, WorkflowRunStatus::Completed);
2186        let process_result = &result.step_results["process"];
2187        // Should have processed all 3 items
2188        assert!(
2189            process_result.response.contains("process item: apple"),
2190            "response: {}",
2191            process_result.response
2192        );
2193    }
2194
2195    #[tokio::test]
2196    async fn dag_while_loop() {
2197        let mut steps = vec![dag_step("counter", &[])];
2198        steps[0].loop_config = Some(LoopConfig::While {
2199            condition: Condition::Expression("true".to_string()),
2200            max_iterations: 5,
2201        });
2202
2203        let executor = MockExecutor::new().with_response("counter", "tick");
2204
2205        let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2206        assert_eq!(result.status, WorkflowRunStatus::Completed);
2207        let counter_result = &result.step_results["counter"];
2208        // Should have 5 "tick" entries
2209        let ticks: Vec<&str> = counter_result.response.split('\n').collect();
2210        assert_eq!(ticks.len(), 5);
2211    }
2212
2213    #[tokio::test]
2214    async fn dag_retry_loop_succeeds_eventually() {
2215        let mut steps = vec![dag_step("flaky", &[])];
2216        steps[0].loop_config = Some(LoopConfig::Retry {
2217            max_retries: 3,
2218            backoff_ms: 1, // minimal backoff for tests
2219            backoff_multiplier: 1.0,
2220        });
2221
2222        // Fails first 2 times, succeeds on 3rd
2223        let executor = FailNTimesMockExecutor::new(2);
2224
2225        let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2226        assert_eq!(result.status, WorkflowRunStatus::Completed);
2227        assert!(result.step_results["flaky"].error.is_none());
2228        assert!(
2229            result.step_results["flaky"]
2230                .response
2231                .contains("success on attempt 3")
2232        );
2233    }
2234
2235    #[tokio::test]
2236    async fn dag_retry_loop_exhausts_retries() {
2237        let mut steps = vec![dag_step("flaky", &[])];
2238        steps[0].loop_config = Some(LoopConfig::Retry {
2239            max_retries: 2,
2240            backoff_ms: 1,
2241            backoff_multiplier: 1.0,
2242        });
2243
2244        // Fails all attempts (need 4 failures to exhaust 1 attempt + 2 retries + 1 more)
2245        let executor = FailNTimesMockExecutor::new(10);
2246
2247        let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2248        assert!(result.step_results["flaky"].error.is_some());
2249    }
2250
2251    #[tokio::test]
2252    async fn dag_step_failure_with_skip() {
2253        let mut steps = vec![
2254            dag_step("a", &[]),
2255            dag_step("b", &["a"]),
2256            dag_step("c", &["b"]),
2257        ];
2258        steps[1].on_error = OnError::SkipStep;
2259
2260        let executor = MockExecutor::new()
2261            .with_response("a", "ok")
2262            .with_failure("b", "b failed")
2263            .with_response("c", "c_ran");
2264
2265        let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2266        // b failed but was skipped, c should still run
2267        // since b is in step_results (as skipped/failed), c's deps are met
2268        assert!(result.step_results.contains_key("c"));
2269    }
2270
2271    #[tokio::test]
2272    async fn dag_step_failure_cascades() {
2273        let steps = vec![
2274            dag_step("a", &[]),
2275            dag_step("b", &["a"]),
2276            dag_step("c", &["b"]),
2277        ];
2278
2279        let executor = MockExecutor::new()
2280            .with_response("a", "ok")
2281            .with_failure("b", "b failed")
2282            .with_response("c", "should_not_run");
2283
2284        let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2285        assert!(result.step_results["b"].error.is_some());
2286        // c should be cancelled since b failed (FailWorkflow is default)
2287        assert_eq!(result.step_results["c"].status, StepStatus::Cancelled);
2288    }
2289
2290    #[tokio::test]
2291    async fn dag_empty_workflow() {
2292        let executor = MockExecutor::new();
2293        let result = execute_dag("test", &[], "input", Arc::new(executor)).await;
2294        assert_eq!(result.status, WorkflowRunStatus::Failed);
2295        assert!(!result.validation_errors.is_empty());
2296    }
2297
2298    #[tokio::test]
2299    async fn dag_single_step() {
2300        let steps = vec![dag_step("only", &[])];
2301        let executor = MockExecutor::new().with_response("only", "done");
2302
2303        let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2304        assert_eq!(result.status, WorkflowRunStatus::Completed);
2305        assert_eq!(result.step_results.len(), 1);
2306        assert_eq!(result.step_results["only"].response, "done");
2307    }
2308
2309    #[tokio::test]
2310    async fn dag_all_steps_fail() {
2311        let steps = vec![dag_step("a", &[]), dag_step("b", &[])];
2312
2313        let executor = MockExecutor::new()
2314            .with_failure("a", "a failed")
2315            .with_failure("b", "b failed");
2316
2317        let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2318        assert_eq!(result.status, WorkflowRunStatus::Failed);
2319        assert!(!result.dead_letters.is_empty());
2320    }
2321
2322    #[tokio::test]
2323    async fn dag_partial_completion() {
2324        let steps = vec![dag_step("good", &[]), dag_step("bad", &[])];
2325
2326        let executor = MockExecutor::new()
2327            .with_response("good", "ok")
2328            .with_failure("bad", "nope");
2329
2330        let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2331        assert_eq!(result.status, WorkflowRunStatus::PartiallyCompleted);
2332    }
2333
2334    #[tokio::test]
2335    async fn dag_validation_rejects_cycle() {
2336        let steps = vec![dag_step("a", &["b"]), dag_step("b", &["a"])];
2337        let executor = MockExecutor::new();
2338        let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2339        assert_eq!(result.status, WorkflowRunStatus::Failed);
2340        assert!(!result.validation_errors.is_empty());
2341    }
2342
2343    #[tokio::test]
2344    async fn dag_all_steps_skipped() {
2345        let mut steps = vec![dag_step("a", &[]), dag_step("b", &[])];
2346        steps[0].condition = Some(Condition::Expression("false".to_string()));
2347        steps[1].condition = Some(Condition::Expression("false".to_string()));
2348
2349        let executor = MockExecutor::new();
2350        let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2351        // All skipped = no failures, no successes -> Completed
2352        assert_eq!(result.status, WorkflowRunStatus::Completed);
2353        assert_eq!(result.step_results["a"].status, StepStatus::Skipped);
2354        assert_eq!(result.step_results["b"].status, StepStatus::Skipped);
2355    }
2356
2357    // ---- DAG variable substitution tests ----
2358
2359    #[test]
2360    fn dag_variables_step_output() {
2361        let mut results = HashMap::new();
2362        results.insert(
2363            "analyze".to_string(),
2364            StepResult {
2365                step_name: "analyze".to_string(),
2366                response: "found 3 bugs".to_string(),
2367                tokens_used: 100,
2368                duration_ms: 500,
2369                error: None,
2370                status: StepStatus::Completed,
2371                started_at: None,
2372                completed_at: None,
2373            },
2374        );
2375
2376        let expanded = expand_dag_variables(
2377            "Result: {{analyze.output}}",
2378            "input",
2379            "next",
2380            &results,
2381            None,
2382        );
2383        assert_eq!(expanded, "Result: found 3 bugs");
2384    }
2385
2386    #[test]
2387    fn dag_variables_step_status() {
2388        let mut results = HashMap::new();
2389        results.insert(
2390            "build".to_string(),
2391            StepResult {
2392                step_name: "build".to_string(),
2393                response: "ok".to_string(),
2394                tokens_used: 50,
2395                duration_ms: 300,
2396                error: None,
2397                status: StepStatus::Completed,
2398                started_at: None,
2399                completed_at: None,
2400            },
2401        );
2402
2403        let expanded = expand_dag_variables(
2404            "Build status: {{build.status}}",
2405            "input",
2406            "deploy",
2407            &results,
2408            None,
2409        );
2410        assert_eq!(expanded, "Build status: completed");
2411    }
2412
2413    #[test]
2414    fn dag_variables_step_duration() {
2415        let mut results = HashMap::new();
2416        results.insert(
2417            "fetch".to_string(),
2418            StepResult {
2419                step_name: "fetch".to_string(),
2420                response: "data".to_string(),
2421                tokens_used: 10,
2422                duration_ms: 1234,
2423                error: None,
2424                status: StepStatus::Completed,
2425                started_at: None,
2426                completed_at: None,
2427            },
2428        );
2429
2430        let expanded = expand_dag_variables(
2431            "Fetch took {{fetch.duration_ms}}ms",
2432            "input",
2433            "next",
2434            &results,
2435            None,
2436        );
2437        assert_eq!(expanded, "Fetch took 1234ms");
2438    }
2439
2440    #[test]
2441    fn dag_variables_loop_state() {
2442        let results = HashMap::new();
2443        let mut loop_state = LoopState::new();
2444        loop_state.index = 2;
2445        loop_state.item = Some("banana".to_string());
2446
2447        let expanded = expand_dag_variables(
2448            "Item {{loop.index}}: {{loop.item}}",
2449            "input",
2450            "process",
2451            &results,
2452            Some(&loop_state),
2453        );
2454        assert_eq!(expanded, "Item 2: banana");
2455    }
2456
2457    #[test]
2458    fn dag_variables_json_path() {
2459        let mut results = HashMap::new();
2460        results.insert(
2461            "api".to_string(),
2462            StepResult {
2463                step_name: "api".to_string(),
2464                response: r#"{"user": {"name": "Alice", "age": 30}}"#.to_string(),
2465                tokens_used: 10,
2466                duration_ms: 100,
2467                error: None,
2468                status: StepStatus::Completed,
2469                started_at: None,
2470                completed_at: None,
2471            },
2472        );
2473
2474        let expanded = expand_dag_variables(
2475            "Name: {{api.output.user.name}}",
2476            "input",
2477            "next",
2478            &results,
2479            None,
2480        );
2481        assert_eq!(expanded, "Name: Alice");
2482    }
2483
2484    #[test]
2485    fn dag_variables_transform_uppercase() {
2486        let mut results = HashMap::new();
2487        results.insert(
2488            "greet".to_string(),
2489            StepResult {
2490                step_name: "greet".to_string(),
2491                response: "hello world".to_string(),
2492                tokens_used: 10,
2493                duration_ms: 50,
2494                error: None,
2495                status: StepStatus::Completed,
2496                started_at: None,
2497                completed_at: None,
2498            },
2499        );
2500
2501        let expanded = expand_dag_variables(
2502            "{{greet.output | uppercase}}",
2503            "input",
2504            "next",
2505            &results,
2506            None,
2507        );
2508        assert_eq!(expanded, "HELLO WORLD");
2509    }
2510
2511    #[test]
2512    fn dag_variables_transform_lowercase() {
2513        let mut results = HashMap::new();
2514        results.insert(
2515            "shout".to_string(),
2516            StepResult {
2517                step_name: "shout".to_string(),
2518                response: "LOUD NOISE".to_string(),
2519                tokens_used: 10,
2520                duration_ms: 50,
2521                error: None,
2522                status: StepStatus::Completed,
2523                started_at: None,
2524                completed_at: None,
2525            },
2526        );
2527
2528        let expanded = expand_dag_variables(
2529            "{{shout.output | lowercase}}",
2530            "input",
2531            "next",
2532            &results,
2533            None,
2534        );
2535        assert_eq!(expanded, "loud noise");
2536    }
2537
2538    #[test]
2539    fn dag_variables_transform_json_extract() {
2540        let mut results = HashMap::new();
2541        results.insert(
2542            "data".to_string(),
2543            StepResult {
2544                step_name: "data".to_string(),
2545                response: r#"{"key": "value123"}"#.to_string(),
2546                tokens_used: 10,
2547                duration_ms: 50,
2548                error: None,
2549                status: StepStatus::Completed,
2550                started_at: None,
2551                completed_at: None,
2552            },
2553        );
2554
2555        let expanded = expand_dag_variables(
2556            "{{data.output | json_extract \"$.key\"}}",
2557            "input",
2558            "next",
2559            &results,
2560            None,
2561        );
2562        assert_eq!(expanded, "value123");
2563    }
2564
2565    #[test]
2566    fn json_path_extract_simple() {
2567        let result = json_path_extract(r#"{"name": "Bob"}"#, "name");
2568        assert_eq!(result, "Bob");
2569    }
2570
2571    #[test]
2572    fn json_path_extract_nested() {
2573        let result = json_path_extract(r#"{"a": {"b": {"c": 42}}}"#, "a.b.c");
2574        assert_eq!(result, "42");
2575    }
2576
2577    #[test]
2578    fn json_path_extract_dollar_prefix() {
2579        let result = json_path_extract(r#"{"key": "val"}"#, "$.key");
2580        assert_eq!(result, "val");
2581    }
2582
2583    #[test]
2584    fn json_path_extract_missing_key() {
2585        let result = json_path_extract(r#"{"key": "val"}"#, "missing");
2586        assert_eq!(result, "");
2587    }
2588
2589    #[test]
2590    fn json_path_extract_invalid_json() {
2591        let result = json_path_extract("not json", "key");
2592        assert_eq!(result, "not json");
2593    }
2594
2595    // ---- Step status tests ----
2596
2597    #[test]
2598    fn step_status_display() {
2599        assert_eq!(StepStatus::Pending.to_string(), "pending");
2600        assert_eq!(StepStatus::Running.to_string(), "running");
2601        assert_eq!(StepStatus::Completed.to_string(), "completed");
2602        assert_eq!(StepStatus::Failed.to_string(), "failed");
2603        assert_eq!(StepStatus::Skipped.to_string(), "skipped");
2604        assert_eq!(StepStatus::Cancelled.to_string(), "cancelled");
2605    }
2606
2607    // ---- On error variant tests ----
2608
2609    #[test]
2610    fn on_error_fallback_serialization() {
2611        let on_error = OnError::Fallback {
2612            step: "backup".to_string(),
2613        };
2614        let json = serde_json::to_string(&on_error).expect("serialize");
2615        let deser: OnError = serde_json::from_str(&json).expect("deserialize");
2616        assert!(matches!(deser, OnError::Fallback { step } if step == "backup"));
2617    }
2618
2619    #[test]
2620    fn on_error_catch_and_continue_serialization() {
2621        let on_error = OnError::CatchAndContinue {
2622            error_handler: "handler".to_string(),
2623        };
2624        let json = serde_json::to_string(&on_error).expect("serialize");
2625        let deser: OnError = serde_json::from_str(&json).expect("deserialize");
2626        assert!(
2627            matches!(deser, OnError::CatchAndContinue { error_handler } if error_handler == "handler")
2628        );
2629    }
2630
2631    #[test]
2632    fn on_error_circuit_breaker_serialization() {
2633        let on_error = OnError::CircuitBreaker {
2634            max_failures: 5,
2635            cooldown_secs: 60,
2636        };
2637        let json = serde_json::to_string(&on_error).expect("serialize");
2638        let deser: OnError = serde_json::from_str(&json).expect("deserialize");
2639        assert!(matches!(
2640            deser,
2641            OnError::CircuitBreaker {
2642                max_failures: 5,
2643                cooldown_secs: 60
2644            }
2645        ));
2646    }
2647
2648    // ---- Circuit breaker tests ----
2649
2650    #[test]
2651    fn circuit_breaker_default_closed() {
2652        let cb = CircuitBreakerState::default();
2653        assert!(!cb.is_open(3, 60));
2654    }
2655
2656    #[test]
2657    fn circuit_breaker_opens_after_max_failures() {
2658        let mut cb = CircuitBreakerState::default();
2659        cb.record_failure();
2660        cb.record_failure();
2661        cb.record_failure();
2662        assert!(cb.is_open(3, 60));
2663    }
2664
2665    #[test]
2666    fn circuit_breaker_resets_on_success() {
2667        let mut cb = CircuitBreakerState::default();
2668        cb.record_failure();
2669        cb.record_failure();
2670        cb.record_success();
2671        assert!(!cb.is_open(3, 60));
2672        assert_eq!(cb.consecutive_failures, 0);
2673    }
2674
2675    // ---- DAG workflow registration tests ----
2676
2677    #[test]
2678    fn register_dag_workflow_valid() {
2679        let engine = WorkflowEngine::new();
2680        let wf = DagWorkflow {
2681            id: WorkflowId::new(),
2682            name: "test-dag".to_string(),
2683            steps: vec![dag_step("a", &[]), dag_step("b", &["a"])],
2684        };
2685        let result = engine.register_dag_workflow(wf);
2686        assert!(result.is_ok());
2687    }
2688
2689    #[test]
2690    fn register_dag_workflow_with_cycle_fails() {
2691        let engine = WorkflowEngine::new();
2692        let wf = DagWorkflow {
2693            id: WorkflowId::new(),
2694            name: "bad-dag".to_string(),
2695            steps: vec![dag_step("a", &["b"]), dag_step("b", &["a"])],
2696        };
2697        let result = engine.register_dag_workflow(wf);
2698        assert!(result.is_err());
2699    }
2700
2701    #[test]
2702    fn list_dag_workflows() {
2703        let engine = WorkflowEngine::new();
2704        let wf = DagWorkflow {
2705            id: WorkflowId::new(),
2706            name: "dag1".to_string(),
2707            steps: vec![dag_step("a", &[])],
2708        };
2709        engine.register_dag_workflow(wf).expect("should register");
2710        assert_eq!(engine.list_dag_workflows().len(), 1);
2711    }
2712
2713    #[test]
2714    fn get_dag_workflow() {
2715        let engine = WorkflowEngine::new();
2716        let id = WorkflowId::new();
2717        let wf = DagWorkflow {
2718            id,
2719            name: "dag1".to_string(),
2720            steps: vec![dag_step("a", &[])],
2721        };
2722        engine.register_dag_workflow(wf).expect("should register");
2723        let fetched = engine.get_dag_workflow(&id).expect("should exist");
2724        assert_eq!(fetched.name, "dag1");
2725    }
2726
2727    #[test]
2728    fn get_nonexistent_dag_workflow() {
2729        let engine = WorkflowEngine::new();
2730        assert!(engine.get_dag_workflow(&WorkflowId::new()).is_none());
2731    }
2732
2733    // ---- Dead letter queue tests ----
2734
2735    #[tokio::test]
2736    async fn dag_dead_letters_populated_on_failure() {
2737        let steps = vec![dag_step("a", &[])];
2738        let executor = MockExecutor::new().with_failure("a", "catastrophic failure");
2739
2740        let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2741        assert!(!result.dead_letters.is_empty());
2742        assert_eq!(result.dead_letters[0].step_name, "a");
2743        assert_eq!(result.dead_letters[0].error, "catastrophic failure");
2744    }
2745
2746    // ---- Execution trace tests ----
2747
2748    #[tokio::test]
2749    async fn dag_execution_trace_records_waves() {
2750        let steps = vec![
2751            dag_step("a", &[]),
2752            dag_step("b", &["a"]),
2753            dag_step("c", &["b"]),
2754        ];
2755        let executor = MockExecutor::new()
2756            .with_response("a", "ok")
2757            .with_response("b", "ok")
2758            .with_response("c", "ok");
2759
2760        let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2761        // 3 waves for a linear chain
2762        assert_eq!(result.execution_trace.len(), 3);
2763        assert_eq!(result.execution_trace[0].steps, vec!["a"]);
2764        assert_eq!(result.execution_trace[1].steps, vec!["b"]);
2765        assert_eq!(result.execution_trace[2].steps, vec!["c"]);
2766    }
2767
2768    // ---- DagWorkflowStep helper tests ----
2769
2770    #[test]
2771    fn dag_step_fallback_step_extraction() {
2772        let mut step = dag_step("test", &[]);
2773        assert!(step.fallback_step().is_none());
2774
2775        step.on_error = OnError::Fallback {
2776            step: "backup".to_string(),
2777        };
2778        assert_eq!(step.fallback_step(), Some("backup".to_string()));
2779
2780        step.on_error = OnError::CatchAndContinue {
2781            error_handler: "handler".to_string(),
2782        };
2783        assert_eq!(step.fallback_step(), Some("handler".to_string()));
2784    }
2785
2786    // ---- Serialization tests for new types ----
2787
2788    #[test]
2789    fn dag_workflow_serialization_roundtrip() {
2790        let wf = DagWorkflow {
2791            id: WorkflowId::new(),
2792            name: "test-dag".to_string(),
2793            steps: vec![dag_step("a", &[]), dag_step("b", &["a"])],
2794        };
2795        let json = serde_json::to_string(&wf).expect("serialize");
2796        let deser: DagWorkflow = serde_json::from_str(&json).expect("deserialize");
2797        assert_eq!(deser.name, "test-dag");
2798        assert_eq!(deser.steps.len(), 2);
2799    }
2800
2801    #[test]
2802    fn dag_workflow_step_with_condition_serialization() {
2803        let mut step = dag_step("test", &["dep1"]);
2804        step.condition = Some(Condition::IfSuccess {
2805            step: "dep1".to_string(),
2806        });
2807        step.else_step = Some("fallback".to_string());
2808        let json = serde_json::to_string(&step).expect("serialize");
2809        let deser: DagWorkflowStep = serde_json::from_str(&json).expect("deserialize");
2810        assert!(deser.condition.is_some());
2811        assert_eq!(deser.else_step, Some("fallback".to_string()));
2812    }
2813
2814    #[test]
2815    fn dead_letter_entry_serialization() {
2816        let entry = DeadLetterEntry {
2817            step_name: "failed_step".to_string(),
2818            error: "boom".to_string(),
2819            input: "test input".to_string(),
2820            failed_at: Utc::now(),
2821        };
2822        let json = serde_json::to_string(&entry).expect("serialize");
2823        let deser: DeadLetterEntry = serde_json::from_str(&json).expect("deserialize");
2824        assert_eq!(deser.step_name, "failed_step");
2825        assert_eq!(deser.error, "boom");
2826    }
2827
2828    #[test]
2829    fn execution_trace_entry_serialization() {
2830        let entry = ExecutionTraceEntry {
2831            steps: vec!["a".to_string(), "b".to_string()],
2832            started_at: Utc::now(),
2833            completed_at: Some(Utc::now()),
2834        };
2835        let json = serde_json::to_string(&entry).expect("serialize");
2836        let deser: ExecutionTraceEntry = serde_json::from_str(&json).expect("deserialize");
2837        assert_eq!(deser.steps.len(), 2);
2838    }
2839
2840    #[test]
2841    fn workflow_run_with_new_fields_serialization() {
2842        let run = WorkflowRun {
2843            id: WorkflowRunId::new(),
2844            workflow_id: WorkflowId::new(),
2845            status: WorkflowRunStatus::PartiallyCompleted,
2846            step_results: Vec::new(),
2847            started_at: Utc::now(),
2848            completed_at: None,
2849            dead_letters: vec![DeadLetterEntry {
2850                step_name: "x".to_string(),
2851                error: "err".to_string(),
2852                input: "in".to_string(),
2853                failed_at: Utc::now(),
2854            }],
2855            execution_trace: Vec::new(),
2856        };
2857        let json = serde_json::to_string(&run).expect("serialize");
2858        let deser: WorkflowRun = serde_json::from_str(&json).expect("deserialize");
2859        assert_eq!(deser.status, WorkflowRunStatus::PartiallyCompleted);
2860        assert_eq!(deser.dead_letters.len(), 1);
2861    }
2862
2863    #[test]
2864    fn step_result_with_new_fields() {
2865        let sr = StepResult {
2866            step_name: "test".to_string(),
2867            response: "ok".to_string(),
2868            tokens_used: 10,
2869            duration_ms: 100,
2870            error: None,
2871            status: StepStatus::Completed,
2872            started_at: Some(Utc::now()),
2873            completed_at: Some(Utc::now()),
2874        };
2875        let json = serde_json::to_string(&sr).expect("serialize");
2876        let deser: StepResult = serde_json::from_str(&json).expect("deserialize");
2877        assert_eq!(deser.status, StepStatus::Completed);
2878        assert!(deser.started_at.is_some());
2879    }
2880
2881    // ---- Fallback error handling test ----
2882
2883    #[tokio::test]
2884    async fn dag_fallback_on_error() {
2885        let mut steps = vec![dag_step("main", &[]), dag_step("backup", &[])];
2886        steps[0].on_error = OnError::Fallback {
2887            step: "backup".to_string(),
2888        };
2889
2890        let executor = MockExecutor::new()
2891            .with_failure("main", "main failed")
2892            .with_response("backup", "backup result");
2893
2894        let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2895        // The main step should have used backup's result
2896        // In our implementation, the step result gets the backup response
2897        let main_result = &result.step_results["main"];
2898        assert_eq!(main_result.response, "backup result");
2899    }
2900
2901    #[tokio::test]
2902    async fn dag_catch_and_continue() {
2903        let mut steps = vec![
2904            dag_step("risky", &[]),
2905            dag_step("handler", &[]),
2906            dag_step("next", &["risky"]),
2907        ];
2908        steps[0].on_error = OnError::CatchAndContinue {
2909            error_handler: "handler".to_string(),
2910        };
2911
2912        let executor = MockExecutor::new()
2913            .with_failure("risky", "oops")
2914            .with_response("handler", "handled")
2915            .with_response("next", "continued");
2916
2917        let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2918        // "next" should have run because CatchAndContinue removes the failure
2919        assert!(result.step_results.contains_key("next"));
2920    }
2921
2922    // ---- Parallel execution proof tests ----
2923
2924    /// A timed executor that records start/end times to prove concurrency.
2925    struct ConcurrencyProofExecutor {
2926        delay_ms: u64,
2927        /// Track (step_name, start_instant, end_instant) for each execution.
2928        timings: Arc<tokio::sync::Mutex<Vec<(String, Instant, Instant)>>>,
2929    }
2930
2931    impl ConcurrencyProofExecutor {
2932        fn new(delay_ms: u64) -> Self {
2933            Self {
2934                delay_ms,
2935                timings: Arc::new(tokio::sync::Mutex::new(Vec::new())),
2936            }
2937        }
2938    }
2939
2940    #[async_trait::async_trait]
2941    impl StepExecutor for ConcurrencyProofExecutor {
2942        async fn execute(
2943            &self,
2944            step: &DagWorkflowStep,
2945            _input: &str,
2946            _step_results: &HashMap<String, StepResult>,
2947            _loop_state: Option<&LoopState>,
2948        ) -> Result<StepResult, String> {
2949            let start = Instant::now();
2950            tokio::time::sleep(Duration::from_millis(self.delay_ms)).await;
2951            let end = Instant::now();
2952
2953            self.timings
2954                .lock()
2955                .await
2956                .push((step.name.clone(), start, end));
2957
2958            Ok(StepResult {
2959                step_name: step.name.clone(),
2960                response: format!("done-{}", step.name),
2961                tokens_used: 10,
2962                duration_ms: self.delay_ms,
2963                error: None,
2964                status: StepStatus::Completed,
2965                started_at: Some(Utc::now()),
2966                completed_at: Some(Utc::now()),
2967            })
2968        }
2969    }
2970
2971    /// Prove 3 independent steps with 50ms sleep each complete in ~50-70ms (not 150ms).
2972    #[tokio::test]
2973    async fn dag_three_independent_steps_parallel_timing() {
2974        let steps = vec![dag_step("x", &[]), dag_step("y", &[]), dag_step("z", &[])];
2975        let executor = ConcurrencyProofExecutor::new(50);
2976        let timings = Arc::clone(&executor.timings);
2977
2978        let start = Instant::now();
2979        let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2980        let elapsed = start.elapsed();
2981
2982        assert_eq!(result.status, WorkflowRunStatus::Completed);
2983        assert_eq!(result.step_results.len(), 3);
2984        // Parallel: should finish in ~50ms, not 150ms
2985        assert!(
2986            elapsed.as_millis() < 100,
2987            "3 independent 50ms steps took {}ms, should be ~50ms for parallel execution",
2988            elapsed.as_millis()
2989        );
2990
2991        // Verify that the steps overlapped in time
2992        let recorded = timings.lock().await;
2993        assert_eq!(recorded.len(), 3);
2994        // All should have started within a few ms of each other
2995        let starts: Vec<_> = recorded.iter().map(|(_, s, _)| *s).collect();
2996        let earliest = starts.iter().min().copied().expect("should have starts");
2997        for s in &starts {
2998            let diff = s.duration_since(earliest).as_millis();
2999            assert!(
3000                diff < 20,
3001                "start time spread {}ms too large for parallel execution",
3002                diff
3003            );
3004        }
3005    }
3006
3007    /// Fan-out: step A -> steps B,C,D in parallel -> step E waits for all.
3008    #[tokio::test]
3009    async fn dag_fan_out_fan_in_timing() {
3010        let steps = vec![
3011            dag_step("a", &[]),
3012            dag_step("b", &["a"]),
3013            dag_step("c", &["a"]),
3014            dag_step("d", &["a"]),
3015            dag_step("e", &["b", "c", "d"]),
3016        ];
3017        let executor = TimedMockExecutor { delay_ms: 30 };
3018
3019        let start = Instant::now();
3020        let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
3021        let elapsed = start.elapsed();
3022
3023        assert_eq!(result.status, WorkflowRunStatus::Completed);
3024        assert_eq!(result.step_results.len(), 5);
3025
3026        // 3 waves: A (30ms) + B,C,D parallel (30ms) + E (30ms) = ~90ms
3027        // Sequential would be 5*30 = 150ms
3028        assert!(
3029            elapsed.as_millis() < 130,
3030            "fan-out/fan-in took {}ms, expected ~90ms",
3031            elapsed.as_millis()
3032        );
3033
3034        // Verify execution trace shows 3 waves
3035        assert_eq!(result.execution_trace.len(), 3);
3036        // Wave 2 should have B, C, D
3037        let wave2 = &result.execution_trace[1].steps;
3038        assert_eq!(wave2.len(), 3);
3039    }
3040
3041    /// Fan-in: multiple parallel roots feed into one join step.
3042    #[tokio::test]
3043    async fn dag_fan_in_parallel_roots() {
3044        let steps = vec![
3045            dag_step("r1", &[]),
3046            dag_step("r2", &[]),
3047            dag_step("r3", &[]),
3048            dag_step("join", &["r1", "r2", "r3"]),
3049        ];
3050        let executor = MockExecutor::new()
3051            .with_response("r1", "out1")
3052            .with_response("r2", "out2")
3053            .with_response("r3", "out3")
3054            .with_response("join", "merged");
3055
3056        let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
3057        assert_eq!(result.status, WorkflowRunStatus::Completed);
3058        assert_eq!(result.step_results["join"].response, "merged");
3059        // r1, r2, r3 in wave 1, join in wave 2
3060        assert_eq!(result.execution_trace.len(), 2);
3061        assert_eq!(result.execution_trace[0].steps.len(), 3);
3062    }
3063
3064    /// Diamond dependency: A -> B,C -> D (D depends on both B and C).
3065    #[tokio::test]
3066    async fn dag_diamond_dependency_parallel() {
3067        let steps = vec![
3068            dag_step("a", &[]),
3069            dag_step("b", &["a"]),
3070            dag_step("c", &["a"]),
3071            dag_step("d", &["b", "c"]),
3072        ];
3073        let executor = TimedMockExecutor { delay_ms: 30 };
3074
3075        let start = Instant::now();
3076        let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
3077        let elapsed = start.elapsed();
3078
3079        assert_eq!(result.status, WorkflowRunStatus::Completed);
3080        // 3 waves: A, B+C parallel, D
3081        assert_eq!(result.execution_trace.len(), 3);
3082        // B and C should be in the same wave
3083        let wave2 = &result.execution_trace[1].steps;
3084        assert!(wave2.contains(&"b".to_string()));
3085        assert!(wave2.contains(&"c".to_string()));
3086        // Total should be ~90ms (3 waves * 30ms), not 120ms (4 sequential)
3087        assert!(
3088            elapsed.as_millis() < 120,
3089            "diamond took {}ms, expected ~90ms",
3090            elapsed.as_millis()
3091        );
3092    }
3093
3094    /// Conditional skipping in a DAG.
3095    #[tokio::test]
3096    async fn dag_conditional_skip_in_dag() {
3097        let mut steps = vec![
3098            dag_step("check", &[]),
3099            dag_step("true_branch", &["check"]),
3100            dag_step("false_branch", &["check"]),
3101        ];
3102        // true_branch runs only if check succeeds (it will)
3103        steps[1].condition = Some(Condition::IfSuccess {
3104            step: "check".to_string(),
3105        });
3106        // false_branch runs only if check fails (it won't)
3107        steps[2].condition = Some(Condition::IfFailure {
3108            step: "check".to_string(),
3109        });
3110
3111        let executor = MockExecutor::new()
3112            .with_response("check", "all good")
3113            .with_response("true_branch", "ran")
3114            .with_response("false_branch", "should_not_run");
3115
3116        let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
3117        assert_eq!(
3118            result.step_results["true_branch"].status,
3119            StepStatus::Completed
3120        );
3121        assert_eq!(
3122            result.step_results["false_branch"].status,
3123            StepStatus::Skipped
3124        );
3125    }
3126
3127    /// Loop execution within a DAG step (ForEach).
3128    #[tokio::test]
3129    async fn dag_loop_foreach_within_dag() {
3130        let mut steps = vec![
3131            dag_step("data", &[]),
3132            dag_step("process", &["data"]),
3133            dag_step("summary", &["process"]),
3134        ];
3135        steps[1].loop_config = Some(LoopConfig::ForEach {
3136            source_step: "data".to_string(),
3137            max_iterations: 10,
3138        });
3139        steps[1].prompt_template = "process: {{loop.item}}".to_string();
3140
3141        let executor = MockExecutor::new()
3142            .with_response("data", r#"["red", "green", "blue"]"#)
3143            .with_response("summary", "done");
3144
3145        let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
3146        assert_eq!(result.status, WorkflowRunStatus::Completed);
3147        let process_out = &result.step_results["process"].response;
3148        // Should contain output from all 3 loop iterations
3149        assert!(process_out.contains("process: red"));
3150        assert!(process_out.contains("process: green"));
3151        assert!(process_out.contains("process: blue"));
3152    }
3153
3154    /// Partial failure: one parallel branch fails, others succeed.
3155    #[tokio::test]
3156    async fn dag_partial_failure_parallel_branches() {
3157        let steps = vec![
3158            dag_step("root", &[]),
3159            dag_step("ok_branch", &["root"]),
3160            dag_step("fail_branch", &["root"]),
3161            dag_step("ok_branch2", &["root"]),
3162        ];
3163
3164        let executor = MockExecutor::new()
3165            .with_response("root", "start")
3166            .with_response("ok_branch", "success1")
3167            .with_failure("fail_branch", "branch failed")
3168            .with_response("ok_branch2", "success2");
3169
3170        let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
3171        assert_eq!(result.status, WorkflowRunStatus::PartiallyCompleted);
3172        assert_eq!(
3173            result.step_results["ok_branch"].status,
3174            StepStatus::Completed
3175        );
3176        assert_eq!(
3177            result.step_results["ok_branch2"].status,
3178            StepStatus::Completed
3179        );
3180        assert!(result.step_results["fail_branch"].error.is_some());
3181    }
3182
3183    /// Fallback step execution on failure.
3184    #[tokio::test]
3185    async fn dag_fallback_step_runs_on_failure() {
3186        let mut steps = vec![
3187            dag_step("primary", &[]),
3188            dag_step("fallback_handler", &[]),
3189            dag_step("downstream", &["primary"]),
3190        ];
3191        steps[0].on_error = OnError::Fallback {
3192            step: "fallback_handler".to_string(),
3193        };
3194
3195        let executor = MockExecutor::new()
3196            .with_failure("primary", "primary broke")
3197            .with_response("fallback_handler", "recovered via fallback")
3198            .with_response("downstream", "downstream ran");
3199
3200        let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
3201        // primary should have the fallback result
3202        let primary_result = &result.step_results["primary"];
3203        assert_eq!(primary_result.response, "recovered via fallback");
3204        // downstream should have run since fallback recovered
3205        assert!(result.step_results.contains_key("downstream"));
3206    }
3207
3208    /// Circuit breaker triggering after N failures.
3209    #[tokio::test]
3210    async fn dag_circuit_breaker_triggers() {
3211        let mut steps = vec![dag_step("cb_step", &[])];
3212        steps[0].on_error = OnError::CircuitBreaker {
3213            max_failures: 2,
3214            cooldown_secs: 300,
3215        };
3216
3217        // First run: fail twice to trip the breaker
3218        let executor1 = MockExecutor::new().with_failure("cb_step", "fail1");
3219        let result1 = execute_dag("test", &steps, "input", Arc::new(executor1)).await;
3220        assert!(result1.step_results["cb_step"].error.is_some());
3221
3222        // The circuit breaker state is per-run, so we test within a single run
3223        // with a step that has CircuitBreaker and fails. The breaker opens internally
3224        // after max_failures. Let's verify the circuit breaker state logic directly.
3225        let mut cb = CircuitBreakerState::default();
3226        cb.record_failure();
3227        assert!(!cb.is_open(2, 300), "should not be open after 1 failure");
3228        cb.record_failure();
3229        assert!(cb.is_open(2, 300), "should be open after 2 failures");
3230        // After cooldown, it should close — but since cooldown is 300s, it's still open
3231        assert!(cb.is_open(2, 300));
3232    }
3233
3234    /// Variable substitution works across parallel branches.
3235    #[tokio::test]
3236    async fn dag_variable_substitution_across_parallel_branches() {
3237        let mut steps = vec![
3238            dag_step("source_a", &[]),
3239            dag_step("source_b", &[]),
3240            dag_step("consumer", &["source_a", "source_b"]),
3241        ];
3242        steps[2].prompt_template = "A={{source_a.output}}, B={{source_b.output}}".to_string();
3243
3244        let executor = MockExecutor::new()
3245            .with_response("source_a", "value_from_a")
3246            .with_response("source_b", "value_from_b");
3247        // consumer doesn't have a fixed response, so it will use the expanded prompt
3248
3249        let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
3250        assert_eq!(result.status, WorkflowRunStatus::Completed);
3251        let consumer_out = &result.step_results["consumer"].response;
3252        assert!(
3253            consumer_out.contains("value_from_a"),
3254            "consumer should see source_a output, got: {consumer_out}"
3255        );
3256        assert!(
3257            consumer_out.contains("value_from_b"),
3258            "consumer should see source_b output, got: {consumer_out}"
3259        );
3260    }
3261
3262    /// Wide parallel fan-out with timing proof.
3263    #[tokio::test]
3264    async fn dag_wide_parallel_fan_out_timing() {
3265        // 10 independent steps each taking 30ms
3266        let steps: Vec<DagWorkflowStep> =
3267            (0..10).map(|i| dag_step(&format!("s{i}"), &[])).collect();
3268        let executor = TimedMockExecutor { delay_ms: 30 };
3269
3270        let start = Instant::now();
3271        let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
3272        let elapsed = start.elapsed();
3273
3274        assert_eq!(result.status, WorkflowRunStatus::Completed);
3275        assert_eq!(result.step_results.len(), 10);
3276        // All 10 should run in one wave (~30ms), not sequentially (~300ms)
3277        assert!(
3278            elapsed.as_millis() < 80,
3279            "10 parallel 30ms steps took {}ms, expected ~30ms",
3280            elapsed.as_millis()
3281        );
3282        assert_eq!(result.execution_trace.len(), 1);
3283        assert_eq!(result.execution_trace[0].steps.len(), 10);
3284    }
3285
3286    /// While loop with condition that eventually terminates.
3287    #[tokio::test]
3288    async fn dag_while_loop_with_condition() {
3289        let mut steps = vec![dag_step("looper", &[])];
3290        steps[0].loop_config = Some(LoopConfig::While {
3291            condition: Condition::Expression("true".to_string()),
3292            max_iterations: 3,
3293        });
3294
3295        let executor = MockExecutor::new().with_response("looper", "iteration");
3296
3297        let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
3298        assert_eq!(result.status, WorkflowRunStatus::Completed);
3299        let output = &result.step_results["looper"].response;
3300        // Should have 3 iterations
3301        let lines: Vec<&str> = output.split('\n').collect();
3302        assert_eq!(lines.len(), 3);
3303    }
3304
3305    /// Retry loop succeeds on second attempt.
3306    #[tokio::test]
3307    async fn dag_retry_succeeds_on_retry() {
3308        let mut steps = vec![dag_step("retry_step", &[])];
3309        steps[0].loop_config = Some(LoopConfig::Retry {
3310            max_retries: 2,
3311            backoff_ms: 1,
3312            backoff_multiplier: 1.0,
3313        });
3314
3315        let executor = FailNTimesMockExecutor::new(1);
3316
3317        let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
3318        assert_eq!(result.status, WorkflowRunStatus::Completed);
3319        assert!(result.step_results["retry_step"].error.is_none());
3320        assert!(
3321            result.step_results["retry_step"]
3322                .response
3323                .contains("success on attempt 2")
3324        );
3325    }
3326}