1use std::collections::HashMap;
23use std::sync::Arc;
24use std::time::Instant;
25
26use chrono::{DateTime, Utc};
27use dashmap::DashMap;
28use serde::{Deserialize, Serialize};
29use tracing::{debug, error, info, instrument, warn};
30use uuid::Uuid;
31
32use punch_memory::MemorySubstrate;
33use punch_runtime::{FighterLoopParams, LlmDriver, run_fighter_loop, tools_for_capabilities};
34use punch_types::{FighterId, FighterManifest, ModelConfig, PunchError, PunchResult, WeightClass};
35
36use crate::workflow_conditions::{Condition, evaluate_condition};
37use crate::workflow_loops::{LoopConfig, LoopState, calculate_backoff, parse_foreach_items};
38use crate::workflow_validation::{ValidationError, topological_sort, validate_workflow};
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
46#[serde(transparent)]
47pub struct WorkflowId(pub Uuid);
48
49impl WorkflowId {
50 pub fn new() -> Self {
51 Self(Uuid::new_v4())
52 }
53}
54
55impl Default for WorkflowId {
56 fn default() -> Self {
57 Self::new()
58 }
59}
60
61impl std::fmt::Display for WorkflowId {
62 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63 write!(f, "{}", self.0)
64 }
65}
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
69#[serde(transparent)]
70pub struct WorkflowRunId(pub Uuid);
71
72impl WorkflowRunId {
73 pub fn new() -> Self {
74 Self(Uuid::new_v4())
75 }
76}
77
78impl Default for WorkflowRunId {
79 fn default() -> Self {
80 Self::new()
81 }
82}
83
84impl std::fmt::Display for WorkflowRunId {
85 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86 write!(f, "{}", self.0)
87 }
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
96#[serde(rename_all = "snake_case")]
97#[derive(Default)]
98pub enum OnError {
99 #[default]
101 FailWorkflow,
102 SkipStep,
104 RetryOnce,
106 Fallback { step: String },
108 CatchAndContinue { error_handler: String },
110 CircuitBreaker {
112 max_failures: usize,
113 cooldown_secs: u64,
114 },
115}
116
117#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
119#[serde(rename_all = "snake_case")]
120pub enum StepStatus {
121 Pending,
122 Running,
123 Completed,
124 Failed,
125 Skipped,
126 Cancelled,
127}
128
129impl std::fmt::Display for StepStatus {
130 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131 match self {
132 Self::Pending => write!(f, "pending"),
133 Self::Running => write!(f, "running"),
134 Self::Completed => write!(f, "completed"),
135 Self::Failed => write!(f, "failed"),
136 Self::Skipped => write!(f, "skipped"),
137 Self::Cancelled => write!(f, "cancelled"),
138 }
139 }
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct WorkflowStep {
145 pub name: String,
147 pub fighter_name: String,
149 pub prompt_template: String,
151 pub timeout_secs: Option<u64>,
153 #[serde(default)]
155 pub on_error: OnError,
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct DagWorkflowStep {
161 pub name: String,
163 pub fighter_name: String,
165 pub prompt_template: String,
167 pub timeout_secs: Option<u64>,
169 #[serde(default)]
171 pub on_error: OnError,
172 #[serde(default)]
174 pub depends_on: Vec<String>,
175 #[serde(default)]
177 pub condition: Option<Condition>,
178 #[serde(default)]
180 pub else_step: Option<String>,
181 #[serde(default)]
183 pub loop_config: Option<LoopConfig>,
184}
185
186impl DagWorkflowStep {
187 pub fn fallback_step(&self) -> Option<String> {
189 match &self.on_error {
190 OnError::Fallback { step } => Some(step.clone()),
191 OnError::CatchAndContinue { error_handler } => Some(error_handler.clone()),
192 _ => None,
193 }
194 }
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize)]
199pub struct Workflow {
200 pub id: WorkflowId,
202 pub name: String,
204 pub steps: Vec<WorkflowStep>,
206}
207
208#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct DagWorkflow {
211 pub id: WorkflowId,
213 pub name: String,
215 pub steps: Vec<DagWorkflowStep>,
217}
218
219#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
221#[serde(rename_all = "snake_case")]
222pub enum WorkflowRunStatus {
223 Pending,
224 Running,
225 Completed,
226 Failed,
227 PartiallyCompleted,
229}
230
231impl std::fmt::Display for WorkflowRunStatus {
232 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
233 match self {
234 Self::Pending => write!(f, "pending"),
235 Self::Running => write!(f, "running"),
236 Self::Completed => write!(f, "completed"),
237 Self::Failed => write!(f, "failed"),
238 Self::PartiallyCompleted => write!(f, "partially_completed"),
239 }
240 }
241}
242
243#[derive(Debug, Clone, Serialize, Deserialize)]
245pub struct StepResult {
246 pub step_name: String,
248 pub response: String,
250 pub tokens_used: u64,
252 pub duration_ms: u64,
254 pub error: Option<String>,
256 #[serde(default = "default_step_status")]
258 pub status: StepStatus,
259 #[serde(default)]
261 pub started_at: Option<DateTime<Utc>>,
262 #[serde(default)]
264 pub completed_at: Option<DateTime<Utc>>,
265}
266
267fn default_step_status() -> StepStatus {
268 StepStatus::Pending
269}
270
271#[derive(Debug, Clone, Serialize, Deserialize)]
273pub struct DeadLetterEntry {
274 pub step_name: String,
276 pub error: String,
278 pub input: String,
280 pub failed_at: DateTime<Utc>,
282}
283
284#[derive(Debug, Clone, Serialize, Deserialize)]
286pub struct WorkflowRun {
287 pub id: WorkflowRunId,
289 pub workflow_id: WorkflowId,
291 pub status: WorkflowRunStatus,
293 pub step_results: Vec<StepResult>,
295 pub started_at: DateTime<Utc>,
297 pub completed_at: Option<DateTime<Utc>>,
299 #[serde(default)]
301 pub dead_letters: Vec<DeadLetterEntry>,
302 #[serde(default)]
304 pub execution_trace: Vec<ExecutionTraceEntry>,
305}
306
307#[derive(Debug, Clone, Serialize, Deserialize)]
309pub struct ExecutionTraceEntry {
310 pub steps: Vec<String>,
312 pub started_at: DateTime<Utc>,
314 pub completed_at: Option<DateTime<Utc>>,
316}
317
318fn expand_variables(
331 template: &str,
332 current_input: &str,
333 step_name: &str,
334 step_results: &[StepResult],
335) -> String {
336 let mut result = template.to_string();
337
338 result = result.replace("{{input}}", current_input);
340 result = result.replace("{{previous_output}}", current_input);
341
342 result = result.replace("{{step_name}}", step_name);
344
345 for (i, sr) in step_results.iter().enumerate() {
347 let var = format!("{{{{step_{}}}}}", i + 1);
348 result = result.replace(&var, &sr.response);
349 }
350
351 for sr in step_results {
353 let var = format!("{{{{{}}}}}", sr.step_name);
354 result = result.replace(&var, &sr.response);
355 }
356
357 result
358}
359
360pub fn expand_dag_variables(
371 template: &str,
372 current_input: &str,
373 step_name: &str,
374 step_results: &HashMap<String, StepResult>,
375 loop_state: Option<&LoopState>,
376) -> String {
377 let mut result = template.to_string();
378
379 result = result.replace("{{input}}", current_input);
381 result = result.replace("{{previous_output}}", current_input);
382 result = result.replace("{{step_name}}", step_name);
383
384 if let Some(ls) = loop_state {
386 result = result.replace("{{loop.index}}", &ls.index.to_string());
387 if let Some(ref item) = ls.item {
388 result = result.replace("{{loop.item}}", item);
389 }
390 }
391
392 let mut output = String::with_capacity(result.len());
395 let mut remaining = result.as_str();
396
397 while let Some(start) = remaining.find("{{") {
398 output.push_str(&remaining[..start]);
399 let after_start = &remaining[start + 2..];
400 if let Some(end) = after_start.find("}}") {
401 let var_content = &after_start[..end];
402 let resolved = resolve_dag_variable(var_content, step_results);
403 output.push_str(&resolved);
404 remaining = &after_start[end + 2..];
405 } else {
406 output.push_str("{{");
407 remaining = after_start;
408 }
409 }
410 output.push_str(remaining);
411
412 output
413}
414
415fn resolve_dag_variable(var: &str, step_results: &HashMap<String, StepResult>) -> String {
417 let (expr, transform) = if let Some(pipe_pos) = var.find(" | ") {
419 let expr = var[..pipe_pos].trim();
420 let transform = var[pipe_pos + 3..].trim();
421 (expr, Some(transform))
422 } else {
423 (var.trim(), None)
424 };
425
426 let value = resolve_dag_expression(expr, step_results);
428
429 match transform {
431 Some("uppercase") => value.to_uppercase(),
432 Some("lowercase") => value.to_lowercase(),
433 Some("trim") => value.trim().to_string(),
434 Some("len") | Some("length") => value.len().to_string(),
435 Some(t) if t.starts_with("json_extract ") => {
436 let path = t
437 .strip_prefix("json_extract ")
438 .unwrap_or("")
439 .trim_matches('"');
440 json_path_extract(&value, path)
441 }
442 _ => value,
443 }
444}
445
446fn resolve_dag_expression(expr: &str, step_results: &HashMap<String, StepResult>) -> String {
448 let parts: Vec<&str> = expr.splitn(2, '.').collect();
449 if parts.len() < 2 {
450 return step_results
452 .get(parts[0])
453 .map(|r| r.response.clone())
454 .unwrap_or_else(|| format!("{{{{{expr}}}}}"));
455 }
456
457 let step_name = parts[0];
458 let property = parts[1];
459
460 let step_result = match step_results.get(step_name) {
461 Some(r) => r,
462 None => return format!("{{{{{expr}}}}}"),
463 };
464
465 match property {
466 "output" => step_result.response.clone(),
467 "status" => step_result.status.to_string(),
468 "duration_ms" => step_result.duration_ms.to_string(),
469 "error" => step_result
470 .error
471 .clone()
472 .unwrap_or_else(|| "none".to_string()),
473 _ if property.starts_with("output.") => {
474 let json_path = property.strip_prefix("output.").unwrap_or("");
475 json_path_extract(&step_result.response, json_path)
476 }
477 _ => format!("{{{{{expr}}}}}"),
478 }
479}
480
481fn json_path_extract(json_str: &str, path: &str) -> String {
485 let path = path.strip_prefix("$.").unwrap_or(path);
486 let parsed: serde_json::Value = match serde_json::from_str(json_str) {
487 Ok(v) => v,
488 Err(_) => return json_str.to_string(),
489 };
490
491 let mut current = &parsed;
492 for segment in path.split('.') {
493 if segment.is_empty() {
494 continue;
495 }
496 match current.get(segment) {
497 Some(v) => current = v,
498 None => return String::new(),
499 }
500 }
501
502 match current {
503 serde_json::Value::String(s) => s.clone(),
504 other => other.to_string(),
505 }
506}
507
508#[derive(Debug, Clone, Default)]
514pub struct CircuitBreakerState {
515 pub consecutive_failures: usize,
517 pub last_trip_time: Option<Instant>,
519}
520
521impl CircuitBreakerState {
522 pub fn is_open(&self, max_failures: usize, cooldown_secs: u64) -> bool {
524 if self.consecutive_failures < max_failures {
525 return false;
526 }
527 match self.last_trip_time {
529 Some(trip_time) => trip_time.elapsed().as_secs() < cooldown_secs,
530 None => true,
531 }
532 }
533
534 pub fn record_failure(&mut self) {
536 self.consecutive_failures += 1;
537 self.last_trip_time = Some(Instant::now());
538 }
539
540 pub fn record_success(&mut self) {
542 self.consecutive_failures = 0;
543 self.last_trip_time = None;
544 }
545}
546
547#[async_trait::async_trait]
553pub trait StepExecutor: Send + Sync {
554 async fn execute(
556 &self,
557 step: &DagWorkflowStep,
558 input: &str,
559 step_results: &HashMap<String, StepResult>,
560 loop_state: Option<&LoopState>,
561 ) -> Result<StepResult, String>;
562}
563
564pub async fn execute_dag(
571 workflow_name: &str,
572 steps: &[DagWorkflowStep],
573 input: &str,
574 executor: Arc<dyn StepExecutor>,
575) -> DagExecutionResult {
576 let validation_errors = validate_workflow(steps);
578 if !validation_errors.is_empty() {
579 return DagExecutionResult {
580 status: WorkflowRunStatus::Failed,
581 step_results: HashMap::new(),
582 dead_letters: Vec::new(),
583 execution_trace: Vec::new(),
584 validation_errors,
585 };
586 }
587
588 let topo_order = match topological_sort(steps) {
590 Ok(order) => order,
591 Err(_) => {
592 return DagExecutionResult {
593 status: WorkflowRunStatus::Failed,
594 step_results: HashMap::new(),
595 dead_letters: Vec::new(),
596 execution_trace: Vec::new(),
597 validation_errors: vec![ValidationError::CycleDetected {
598 steps: steps.iter().map(|s| s.name.clone()).collect(),
599 }],
600 };
601 }
602 };
603
604 let step_map: HashMap<&str, &DagWorkflowStep> =
605 steps.iter().map(|s| (s.name.as_str(), s)).collect();
606
607 let mut completed: HashMap<String, StepResult> = HashMap::new();
608 let mut dead_letters: Vec<DeadLetterEntry> = Vec::new();
609 let mut execution_trace: Vec<ExecutionTraceEntry> = Vec::new();
610 let mut circuit_breakers: HashMap<String, CircuitBreakerState> = HashMap::new();
611 let mut skipped_steps: std::collections::HashSet<String> = std::collections::HashSet::new();
612 let mut failed_steps: std::collections::HashSet<String> = std::collections::HashSet::new();
613
614 let mut remaining: Vec<String> = topo_order;
616
617 while !remaining.is_empty() {
618 let (ready, not_ready): (Vec<String>, Vec<String>) =
620 remaining.into_iter().partition(|name| {
621 let step = match step_map.get(name.as_str()) {
622 Some(s) => s,
623 None => return false,
624 };
625 step.depends_on.iter().all(|dep| {
626 let is_done = completed.contains_key(dep) || skipped_steps.contains(dep);
629 let is_blocking_failure = failed_steps.contains(dep);
630 is_done && !is_blocking_failure
631 })
632 });
633
634 if ready.is_empty() {
635 for name in ¬_ready {
637 skipped_steps.insert(name.clone());
638 completed.insert(
639 name.clone(),
640 StepResult {
641 step_name: name.clone(),
642 response: String::new(),
643 tokens_used: 0,
644 duration_ms: 0,
645 error: Some("cancelled: unmet dependencies".to_string()),
646 status: StepStatus::Cancelled,
647 started_at: None,
648 completed_at: None,
649 },
650 );
651 }
652 break;
653 }
654
655 remaining = not_ready;
656
657 let wave_start = Utc::now();
658 let wave_step_names: Vec<String> = ready.to_vec();
659
660 let mut wave_results: Vec<(String, Result<StepResult, String>, Option<String>)> =
663 Vec::new();
664 let mut join_set: tokio::task::JoinSet<(
665 String,
666 Result<StepResult, String>,
667 Option<String>,
668 )> = tokio::task::JoinSet::new();
669
670 for step_name in &wave_step_names {
671 let step = match step_map.get(step_name.as_str()) {
672 Some(s) => (*s).clone(),
673 None => continue,
674 };
675
676 let should_run = match &step.condition {
678 Some(cond) => evaluate_condition(cond, &completed),
679 None => true,
680 };
681
682 if !should_run {
683 let else_step_name = step.else_step.clone();
684 wave_results.push((
685 step_name.clone(),
686 Ok(StepResult {
687 step_name: step_name.clone(),
688 response: String::new(),
689 tokens_used: 0,
690 duration_ms: 0,
691 error: None,
692 status: StepStatus::Skipped,
693 started_at: Some(Utc::now()),
694 completed_at: Some(Utc::now()),
695 }),
696 else_step_name,
697 ));
698 continue;
699 }
700
701 let cb_state = circuit_breakers
703 .entry(step_name.clone())
704 .or_default()
705 .clone();
706 if let OnError::CircuitBreaker {
707 max_failures,
708 cooldown_secs,
709 } = &step.on_error
710 && cb_state.is_open(*max_failures, *cooldown_secs)
711 {
712 wave_results.push((
713 step_name.clone(),
714 Ok(StepResult {
715 step_name: step_name.clone(),
716 response: String::new(),
717 tokens_used: 0,
718 duration_ms: 0,
719 error: Some("circuit breaker open".to_string()),
720 status: StepStatus::Failed,
721 started_at: Some(Utc::now()),
722 completed_at: Some(Utc::now()),
723 }),
724 None,
725 ));
726 continue;
727 }
728
729 let sn = step_name.clone();
730 let completed_snapshot = completed.clone();
731 let input_clone = input.to_string();
732 let executor_clone = Arc::clone(&executor);
733
734 join_set.spawn(async move {
735 let result = execute_step_with_loops(
736 &step,
737 &input_clone,
738 &completed_snapshot,
739 executor_clone.as_ref(),
740 )
741 .await;
742 (sn, result, None::<String>)
743 });
744 }
745
746 while let Some(join_result) = join_set.join_next().await {
748 match join_result {
749 Ok(task_result) => wave_results.push(task_result),
750 Err(join_err) => {
751 error!(error = %join_err, "spawned step task failed unexpectedly");
753 }
754 }
755 }
756
757 for (step_name, result, _else_step) in wave_results {
759 match result {
760 Ok(mut step_result) => {
761 if step_result.status == StepStatus::Skipped {
762 skipped_steps.insert(step_name.clone());
763 debug!(step = %step_name, workflow = %workflow_name, "step skipped (condition false)");
764 } else if step_result.error.is_some() {
765 failed_steps.insert(step_name.clone());
766 circuit_breakers
768 .entry(step_name.clone())
769 .or_default()
770 .record_failure();
771
772 let step = step_map.get(step_name.as_str());
773 if let Some(step) = step {
774 match &step.on_error {
775 OnError::Fallback { step: fb_step } => {
776 if let Some(fb) = step_map.get(fb_step.as_str()) {
778 let fb_result =
779 executor.execute(fb, input, &completed, None).await;
780 match fb_result {
781 Ok(fb_res) => {
782 step_result = fb_res;
783 step_result.step_name = step_name.clone();
784 failed_steps.remove(&step_name);
785 }
786 Err(fb_err) => {
787 dead_letters.push(DeadLetterEntry {
788 step_name: step_name.clone(),
789 error: fb_err,
790 input: input.to_string(),
791 failed_at: Utc::now(),
792 });
793 }
794 }
795 }
796 }
797 OnError::CatchAndContinue { error_handler } => {
798 if let Some(handler) = step_map.get(error_handler.as_str()) {
800 let _ = executor
801 .execute(handler, input, &completed, None)
802 .await;
803 }
804 failed_steps.remove(&step_name);
806 }
807 OnError::SkipStep => {
808 skipped_steps.insert(step_name.clone());
809 failed_steps.remove(&step_name);
810 }
811 OnError::FailWorkflow => {
812 dead_letters.push(DeadLetterEntry {
813 step_name: step_name.clone(),
814 error: step_result.error.clone().unwrap_or_default(),
815 input: input.to_string(),
816 failed_at: Utc::now(),
817 });
818 }
819 _ => {}
820 }
821 }
822 } else {
823 circuit_breakers
825 .entry(step_name.clone())
826 .or_default()
827 .record_success();
828 info!(step = %step_name, workflow = %workflow_name, "DAG step completed");
829 }
830 completed.insert(step_name, step_result);
831 }
832 Err(e) => {
833 failed_steps.insert(step_name.clone());
834 circuit_breakers
835 .entry(step_name.clone())
836 .or_default()
837 .record_failure();
838
839 let mut step_result = StepResult {
840 step_name: step_name.clone(),
841 response: String::new(),
842 tokens_used: 0,
843 duration_ms: 0,
844 error: Some(e.clone()),
845 status: StepStatus::Failed,
846 started_at: Some(Utc::now()),
847 completed_at: Some(Utc::now()),
848 };
849
850 let step = step_map.get(step_name.as_str());
852 if let Some(step) = step {
853 match &step.on_error {
854 OnError::Fallback { step: fb_step } => {
855 if let Some(fb) = step_map.get(fb_step.as_str())
856 && let Ok(fb_res) =
857 executor.execute(fb, input, &completed, None).await
858 {
859 step_result = fb_res;
860 step_result.step_name = step_name.clone();
861 step_result.error = None;
862 step_result.status = StepStatus::Completed;
863 failed_steps.remove(&step_name);
864 }
865 }
866 OnError::CatchAndContinue { error_handler } => {
867 if let Some(handler) = step_map.get(error_handler.as_str()) {
868 let _ =
869 executor.execute(handler, input, &completed, None).await;
870 }
871 failed_steps.remove(&step_name);
872 }
873 OnError::SkipStep => {
874 step_result.status = StepStatus::Skipped;
875 skipped_steps.insert(step_name.clone());
876 failed_steps.remove(&step_name);
877 }
878 OnError::FailWorkflow => {
879 dead_letters.push(DeadLetterEntry {
880 step_name: step_name.clone(),
881 error: e,
882 input: input.to_string(),
883 failed_at: Utc::now(),
884 });
885 }
886 _ => {
887 dead_letters.push(DeadLetterEntry {
888 step_name: step_name.clone(),
889 error: e,
890 input: input.to_string(),
891 failed_at: Utc::now(),
892 });
893 }
894 }
895 } else {
896 dead_letters.push(DeadLetterEntry {
897 step_name: step_name.clone(),
898 error: e,
899 input: input.to_string(),
900 failed_at: Utc::now(),
901 });
902 }
903
904 completed.insert(step_name, step_result);
905 }
906 }
907 }
908
909 execution_trace.push(ExecutionTraceEntry {
910 steps: wave_step_names,
911 started_at: wave_start,
912 completed_at: Some(Utc::now()),
913 });
914 }
915
916 let has_failures = completed.values().any(|r| r.status == StepStatus::Failed);
918 let has_successes = completed
919 .values()
920 .any(|r| r.status == StepStatus::Completed);
921
922 let status = if has_failures && has_successes {
923 WorkflowRunStatus::PartiallyCompleted
924 } else if has_failures {
925 WorkflowRunStatus::Failed
926 } else {
927 WorkflowRunStatus::Completed
928 };
929
930 DagExecutionResult {
931 status,
932 step_results: completed,
933 dead_letters,
934 execution_trace,
935 validation_errors: Vec::new(),
936 }
937}
938
939async fn execute_step_with_loops(
941 step: &DagWorkflowStep,
942 input: &str,
943 completed: &HashMap<String, StepResult>,
944 executor: &dyn StepExecutor,
945) -> Result<StepResult, String> {
946 match &step.loop_config {
947 None => executor.execute(step, input, completed, None).await,
948 Some(LoopConfig::ForEach {
949 source_step,
950 max_iterations,
951 }) => {
952 let source_output = completed
953 .get(source_step)
954 .map(|r| r.response.as_str())
955 .unwrap_or("[]");
956 let items = parse_foreach_items(source_output)?;
957 let max = (*max_iterations).min(items.len());
958
959 let mut loop_state = LoopState::new();
960 let start = Utc::now();
961 let instant = Instant::now();
962
963 for (i, item) in items.into_iter().take(max).enumerate() {
964 loop_state.index = i;
965 loop_state.item = Some(item);
966
967 let result = executor
968 .execute(step, input, completed, Some(&loop_state))
969 .await;
970
971 match result {
972 Ok(r) => {
973 if r.response.contains("__BREAK__") {
975 loop_state.push_result(r.response.replace("__BREAK__", ""));
976 break;
977 }
978 if r.response.contains("__CONTINUE__") {
979 continue;
980 }
981 loop_state.push_result(r.response);
982 }
983 Err(e) => return Err(e),
984 }
985 }
986
987 let combined = loop_state.accumulated_results.join("\n");
988 Ok(StepResult {
989 step_name: step.name.clone(),
990 response: combined,
991 tokens_used: 0,
992 duration_ms: instant.elapsed().as_millis() as u64,
993 error: None,
994 status: StepStatus::Completed,
995 started_at: Some(start),
996 completed_at: Some(Utc::now()),
997 })
998 }
999 Some(LoopConfig::While {
1000 condition,
1001 max_iterations,
1002 }) => {
1003 let mut loop_state = LoopState::new();
1004 let start = Utc::now();
1005 let instant = Instant::now();
1006
1007 for i in 0..*max_iterations {
1008 let mut extended = completed.clone();
1011 if !loop_state.accumulated_results.is_empty() {
1012 extended.insert(
1013 step.name.clone(),
1014 StepResult {
1015 step_name: step.name.clone(),
1016 response: loop_state
1017 .accumulated_results
1018 .last()
1019 .cloned()
1020 .unwrap_or_default(),
1021 tokens_used: 0,
1022 duration_ms: 0,
1023 error: None,
1024 status: StepStatus::Completed,
1025 started_at: None,
1026 completed_at: None,
1027 },
1028 );
1029 }
1030
1031 if !evaluate_condition(condition, &extended) {
1032 break;
1033 }
1034
1035 loop_state.index = i;
1036 let result = executor
1037 .execute(step, input, &extended, Some(&loop_state))
1038 .await;
1039
1040 match result {
1041 Ok(r) => {
1042 if r.response.contains("__BREAK__") {
1043 loop_state.push_result(r.response.replace("__BREAK__", ""));
1044 break;
1045 }
1046 loop_state.push_result(r.response);
1047 }
1048 Err(e) => return Err(e),
1049 }
1050 }
1051
1052 let combined = loop_state.accumulated_results.join("\n");
1053 Ok(StepResult {
1054 step_name: step.name.clone(),
1055 response: combined,
1056 tokens_used: 0,
1057 duration_ms: instant.elapsed().as_millis() as u64,
1058 error: None,
1059 status: StepStatus::Completed,
1060 started_at: Some(start),
1061 completed_at: Some(Utc::now()),
1062 })
1063 }
1064 Some(LoopConfig::Retry {
1065 max_retries,
1066 backoff_ms,
1067 backoff_multiplier,
1068 }) => {
1069 let start = Utc::now();
1070 let instant = Instant::now();
1071 let mut last_error = String::new();
1072
1073 for attempt in 0..=*max_retries {
1074 if attempt > 0 {
1075 let wait = calculate_backoff(attempt - 1, *backoff_ms, *backoff_multiplier);
1076 tokio::time::sleep(std::time::Duration::from_millis(wait)).await;
1077 }
1078
1079 match executor.execute(step, input, completed, None).await {
1080 Ok(r) => return Ok(r),
1081 Err(e) => {
1082 last_error = e;
1083 warn!(step = %step.name, attempt = attempt + 1, "retry attempt failed");
1084 }
1085 }
1086 }
1087
1088 Ok(StepResult {
1089 step_name: step.name.clone(),
1090 response: String::new(),
1091 tokens_used: 0,
1092 duration_ms: instant.elapsed().as_millis() as u64,
1093 error: Some(last_error),
1094 status: StepStatus::Failed,
1095 started_at: Some(start),
1096 completed_at: Some(Utc::now()),
1097 })
1098 }
1099 }
1100}
1101
1102#[derive(Debug, Clone)]
1104pub struct DagExecutionResult {
1105 pub status: WorkflowRunStatus,
1107 pub step_results: HashMap<String, StepResult>,
1109 pub dead_letters: Vec<DeadLetterEntry>,
1111 pub execution_trace: Vec<ExecutionTraceEntry>,
1113 pub validation_errors: Vec<ValidationError>,
1115}
1116
1117pub struct WorkflowEngine {
1123 workflows: DashMap<WorkflowId, Workflow>,
1125 dag_workflows: DashMap<WorkflowId, DagWorkflow>,
1127 runs: DashMap<WorkflowRunId, WorkflowRun>,
1129}
1130
1131impl WorkflowEngine {
1132 pub fn new() -> Self {
1134 Self {
1135 workflows: DashMap::new(),
1136 dag_workflows: DashMap::new(),
1137 runs: DashMap::new(),
1138 }
1139 }
1140
1141 pub fn register_workflow(&self, workflow: Workflow) -> WorkflowId {
1143 let id = workflow.id;
1144 info!(workflow_id = %id, name = %workflow.name, "workflow registered");
1145 self.workflows.insert(id, workflow);
1146 id
1147 }
1148
1149 pub fn register_dag_workflow(
1154 &self,
1155 workflow: DagWorkflow,
1156 ) -> Result<WorkflowId, Vec<ValidationError>> {
1157 let errors = validate_workflow(&workflow.steps);
1158 if !errors.is_empty() {
1159 return Err(errors);
1160 }
1161 let id = workflow.id;
1162 info!(workflow_id = %id, name = %workflow.name, "DAG workflow registered");
1163 self.dag_workflows.insert(id, workflow);
1164 Ok(id)
1165 }
1166
1167 #[instrument(skip(self, input, memory, driver, model_config), fields(%workflow_id))]
1169 pub async fn execute_workflow(
1170 &self,
1171 workflow_id: &WorkflowId,
1172 input: String,
1173 memory: Arc<MemorySubstrate>,
1174 driver: Arc<dyn LlmDriver>,
1175 model_config: &ModelConfig,
1176 ) -> PunchResult<WorkflowRunId> {
1177 let workflow = self
1178 .workflows
1179 .get(workflow_id)
1180 .ok_or_else(|| PunchError::Internal(format!("workflow {} not found", workflow_id)))?
1181 .clone();
1182
1183 let run_id = WorkflowRunId::new();
1184 let run = WorkflowRun {
1185 id: run_id,
1186 workflow_id: *workflow_id,
1187 status: WorkflowRunStatus::Running,
1188 step_results: Vec::new(),
1189 started_at: Utc::now(),
1190 completed_at: None,
1191 dead_letters: Vec::new(),
1192 execution_trace: Vec::new(),
1193 };
1194 self.runs.insert(run_id, run);
1195
1196 let mut current_input = input.clone();
1197 let mut step_results: Vec<StepResult> = Vec::new();
1198 let mut failed = false;
1199
1200 for step in &workflow.steps {
1201 let result = self
1202 .execute_single_step(
1203 step,
1204 &workflow.name,
1205 ¤t_input,
1206 &step_results,
1207 &memory,
1208 &driver,
1209 model_config,
1210 )
1211 .await;
1212
1213 match result {
1214 Ok(step_result) => {
1215 current_input = step_result.response.clone();
1216 step_results.push(step_result);
1217 }
1218 Err(e) => {
1219 let error_msg = format!("{e}");
1220 match step.on_error {
1221 OnError::SkipStep => {
1222 warn!(step = %step.name, error = %error_msg, "step failed, skipping");
1223 let skip_result = StepResult {
1224 step_name: step.name.clone(),
1225 response: String::new(),
1226 tokens_used: 0,
1227 duration_ms: 0,
1228 error: Some(error_msg),
1229 status: StepStatus::Skipped,
1230 started_at: None,
1231 completed_at: None,
1232 };
1233 step_results.push(skip_result);
1234 continue;
1235 }
1236 OnError::RetryOnce => {
1237 warn!(step = %step.name, error = %error_msg, "step failed, retrying once");
1238 let retry_result = self
1239 .execute_single_step(
1240 step,
1241 &workflow.name,
1242 ¤t_input,
1243 &step_results,
1244 &memory,
1245 &driver,
1246 model_config,
1247 )
1248 .await;
1249
1250 match retry_result {
1251 Ok(step_result) => {
1252 current_input = step_result.response.clone();
1253 step_results.push(step_result);
1254 }
1255 Err(retry_err) => {
1256 error!(step = %step.name, error = %retry_err, "step failed on retry");
1257 let fail_result = StepResult {
1258 step_name: step.name.clone(),
1259 response: String::new(),
1260 tokens_used: 0,
1261 duration_ms: 0,
1262 error: Some(format!("{retry_err}")),
1263 status: StepStatus::Failed,
1264 started_at: None,
1265 completed_at: None,
1266 };
1267 step_results.push(fail_result);
1268 failed = true;
1269 break;
1270 }
1271 }
1272 }
1273 OnError::FailWorkflow => {
1274 error!(step = %step.name, error = %error_msg, "step failed, aborting workflow");
1275 let fail_result = StepResult {
1276 step_name: step.name.clone(),
1277 response: String::new(),
1278 tokens_used: 0,
1279 duration_ms: 0,
1280 error: Some(error_msg),
1281 status: StepStatus::Failed,
1282 started_at: None,
1283 completed_at: None,
1284 };
1285 step_results.push(fail_result);
1286 failed = true;
1287 break;
1288 }
1289 _ => {
1290 let fail_result = StepResult {
1293 step_name: step.name.clone(),
1294 response: String::new(),
1295 tokens_used: 0,
1296 duration_ms: 0,
1297 error: Some(error_msg),
1298 status: StepStatus::Failed,
1299 started_at: None,
1300 completed_at: None,
1301 };
1302 step_results.push(fail_result);
1303 failed = true;
1304 break;
1305 }
1306 }
1307 }
1308 }
1309 }
1310
1311 if let Some(mut run) = self.runs.get_mut(&run_id) {
1313 run.step_results = step_results;
1314 run.status = if failed {
1315 WorkflowRunStatus::Failed
1316 } else {
1317 WorkflowRunStatus::Completed
1318 };
1319 run.completed_at = Some(Utc::now());
1320 }
1321
1322 Ok(run_id)
1323 }
1324
1325 #[allow(clippy::too_many_arguments)]
1328 async fn execute_single_step(
1329 &self,
1330 step: &WorkflowStep,
1331 workflow_name: &str,
1332 current_input: &str,
1333 step_results: &[StepResult],
1334 memory: &Arc<MemorySubstrate>,
1335 driver: &Arc<dyn LlmDriver>,
1336 model_config: &ModelConfig,
1337 ) -> PunchResult<StepResult> {
1338 let step_start = Instant::now();
1339 let started_at = Utc::now();
1340
1341 let prompt = expand_variables(
1343 &step.prompt_template,
1344 current_input,
1345 &step.name,
1346 step_results,
1347 );
1348
1349 let fighter_id = FighterId::new();
1351 let fighter_manifest = FighterManifest {
1352 name: step.fighter_name.clone(),
1353 description: format!("Workflow step: {}", step.name),
1354 model: model_config.clone(),
1355 system_prompt: format!(
1356 "You are executing step '{}' of workflow '{}'.",
1357 step.name, workflow_name
1358 ),
1359 capabilities: Vec::new(),
1360 weight_class: WeightClass::Middleweight,
1361 tenant_id: None,
1362 };
1363
1364 if let Err(e) = memory
1366 .save_fighter(
1367 &fighter_id,
1368 &fighter_manifest,
1369 punch_types::FighterStatus::Idle,
1370 )
1371 .await
1372 {
1373 error!(error = %e, "failed to persist workflow fighter");
1374 }
1375
1376 let bout_id = memory.create_bout(&fighter_id).await.map_err(|e| {
1377 PunchError::Internal(format!(
1378 "failed to create bout for step '{}': {e}",
1379 step.name
1380 ))
1381 })?;
1382
1383 let available_tools = tools_for_capabilities(&fighter_manifest.capabilities);
1384 let timeout_secs = step.timeout_secs.unwrap_or(120);
1385
1386 let params = FighterLoopParams {
1387 manifest: fighter_manifest,
1388 user_message: prompt,
1389 bout_id,
1390 fighter_id,
1391 memory: Arc::clone(memory),
1392 driver: Arc::clone(driver),
1393 available_tools,
1394 mcp_tools: Vec::new(),
1395 max_iterations: Some(20),
1396 context_window: None,
1397 tool_timeout_secs: Some(timeout_secs),
1398 coordinator: None,
1399 approval_engine: None,
1400 sandbox: None,
1401 mcp_clients: None,
1402 model_routing: None,
1403 channel_notifier: None,
1404 user_content_parts: vec![],
1405 eco_mode: false,
1406 };
1407
1408 let loop_result = tokio::time::timeout(
1409 std::time::Duration::from_secs(timeout_secs),
1410 run_fighter_loop(params),
1411 )
1412 .await;
1413
1414 match loop_result {
1415 Ok(Ok(result)) => {
1416 let step_result = StepResult {
1417 step_name: step.name.clone(),
1418 response: result.response,
1419 tokens_used: result.usage.total(),
1420 duration_ms: step_start.elapsed().as_millis() as u64,
1421 error: None,
1422 status: StepStatus::Completed,
1423 started_at: Some(started_at),
1424 completed_at: Some(Utc::now()),
1425 };
1426 info!(step = %step.name, tokens = step_result.tokens_used, "workflow step completed");
1427 Ok(step_result)
1428 }
1429 Ok(Err(e)) => Err(e),
1430 Err(_) => Err(PunchError::Internal(format!(
1431 "step '{}' timed out after {}s",
1432 step.name, timeout_secs
1433 ))),
1434 }
1435 }
1436
1437 pub fn get_run(&self, run_id: &WorkflowRunId) -> Option<WorkflowRun> {
1439 self.runs.get(run_id).map(|r| r.clone())
1440 }
1441
1442 pub fn list_workflows(&self) -> Vec<Workflow> {
1444 self.workflows.iter().map(|w| w.value().clone()).collect()
1445 }
1446
1447 pub fn list_dag_workflows(&self) -> Vec<DagWorkflow> {
1449 self.dag_workflows
1450 .iter()
1451 .map(|w| w.value().clone())
1452 .collect()
1453 }
1454
1455 pub fn list_runs(&self) -> Vec<WorkflowRun> {
1457 self.runs.iter().map(|r| r.value().clone()).collect()
1458 }
1459
1460 pub fn list_runs_for_workflow(&self, workflow_id: &WorkflowId) -> Vec<WorkflowRun> {
1462 self.runs
1463 .iter()
1464 .filter(|r| r.value().workflow_id == *workflow_id)
1465 .map(|r| r.value().clone())
1466 .collect()
1467 }
1468
1469 pub fn get_workflow(&self, id: &WorkflowId) -> Option<Workflow> {
1471 self.workflows.get(id).map(|w| w.clone())
1472 }
1473
1474 pub fn get_dag_workflow(&self, id: &WorkflowId) -> Option<DagWorkflow> {
1476 self.dag_workflows.get(id).map(|w| w.clone())
1477 }
1478}
1479
1480impl Default for WorkflowEngine {
1481 fn default() -> Self {
1482 Self::new()
1483 }
1484}
1485
1486#[cfg(test)]
1491mod tests {
1492 use super::*;
1493 use std::sync::atomic::{AtomicUsize, Ordering};
1494 use std::time::Duration;
1495
1496 struct MockExecutor {
1498 responses: HashMap<String, String>,
1500 failing_steps: HashMap<String, String>,
1502 execution_counts: DashMap<String, AtomicUsize>,
1504 }
1505
1506 impl MockExecutor {
1507 fn new() -> Self {
1508 Self {
1509 responses: HashMap::new(),
1510 failing_steps: HashMap::new(),
1511 execution_counts: DashMap::new(),
1512 }
1513 }
1514
1515 fn with_response(mut self, step: &str, response: &str) -> Self {
1516 self.responses
1517 .insert(step.to_string(), response.to_string());
1518 self
1519 }
1520
1521 fn with_failure(mut self, step: &str, error: &str) -> Self {
1522 self.failing_steps
1523 .insert(step.to_string(), error.to_string());
1524 self
1525 }
1526
1527 #[allow(dead_code)]
1528 fn execution_count(&self, step: &str) -> usize {
1529 self.execution_counts
1530 .get(step)
1531 .map(|c| c.load(Ordering::Relaxed))
1532 .unwrap_or(0)
1533 }
1534 }
1535
1536 #[async_trait::async_trait]
1537 impl StepExecutor for MockExecutor {
1538 async fn execute(
1539 &self,
1540 step: &DagWorkflowStep,
1541 input: &str,
1542 step_results: &HashMap<String, StepResult>,
1543 loop_state: Option<&LoopState>,
1544 ) -> Result<StepResult, String> {
1545 self.execution_counts
1547 .entry(step.name.clone())
1548 .or_insert_with(|| AtomicUsize::new(0))
1549 .fetch_add(1, Ordering::Relaxed);
1550
1551 if let Some(err) = self.failing_steps.get(&step.name) {
1553 return Err(err.clone());
1554 }
1555
1556 let prompt = expand_dag_variables(
1557 &step.prompt_template,
1558 input,
1559 &step.name,
1560 step_results,
1561 loop_state,
1562 );
1563
1564 let response = self.responses.get(&step.name).cloned().unwrap_or(prompt);
1565
1566 Ok(StepResult {
1567 step_name: step.name.clone(),
1568 response,
1569 tokens_used: 10,
1570 duration_ms: 5,
1571 error: None,
1572 status: StepStatus::Completed,
1573 started_at: Some(Utc::now()),
1574 completed_at: Some(Utc::now()),
1575 })
1576 }
1577 }
1578
1579 struct TimedMockExecutor {
1581 delay_ms: u64,
1582 }
1583
1584 #[async_trait::async_trait]
1585 impl StepExecutor for TimedMockExecutor {
1586 async fn execute(
1587 &self,
1588 step: &DagWorkflowStep,
1589 _input: &str,
1590 _step_results: &HashMap<String, StepResult>,
1591 _loop_state: Option<&LoopState>,
1592 ) -> Result<StepResult, String> {
1593 tokio::time::sleep(Duration::from_millis(self.delay_ms)).await;
1594 Ok(StepResult {
1595 step_name: step.name.clone(),
1596 response: format!("done-{}", step.name),
1597 tokens_used: 10,
1598 duration_ms: self.delay_ms,
1599 error: None,
1600 status: StepStatus::Completed,
1601 started_at: Some(Utc::now()),
1602 completed_at: Some(Utc::now()),
1603 })
1604 }
1605 }
1606
1607 struct FailNTimesMockExecutor {
1609 fail_count: usize,
1610 attempts: DashMap<String, AtomicUsize>,
1611 }
1612
1613 impl FailNTimesMockExecutor {
1614 fn new(fail_count: usize) -> Self {
1615 Self {
1616 fail_count,
1617 attempts: DashMap::new(),
1618 }
1619 }
1620 }
1621
1622 #[async_trait::async_trait]
1623 impl StepExecutor for FailNTimesMockExecutor {
1624 async fn execute(
1625 &self,
1626 step: &DagWorkflowStep,
1627 _input: &str,
1628 _step_results: &HashMap<String, StepResult>,
1629 _loop_state: Option<&LoopState>,
1630 ) -> Result<StepResult, String> {
1631 let attempt = self
1632 .attempts
1633 .entry(step.name.clone())
1634 .or_insert_with(|| AtomicUsize::new(0))
1635 .fetch_add(1, Ordering::Relaxed);
1636
1637 if attempt < self.fail_count {
1638 return Err(format!("failure attempt {}", attempt + 1));
1639 }
1640
1641 Ok(StepResult {
1642 step_name: step.name.clone(),
1643 response: format!("success on attempt {}", attempt + 1),
1644 tokens_used: 10,
1645 duration_ms: 5,
1646 error: None,
1647 status: StepStatus::Completed,
1648 started_at: Some(Utc::now()),
1649 completed_at: Some(Utc::now()),
1650 })
1651 }
1652 }
1653
1654 fn dag_step(name: &str, deps: &[&str]) -> DagWorkflowStep {
1655 DagWorkflowStep {
1656 name: name.to_string(),
1657 fighter_name: "test".to_string(),
1658 prompt_template: "{{input}}".to_string(),
1659 timeout_secs: None,
1660 on_error: OnError::FailWorkflow,
1661 depends_on: deps.iter().map(|d| d.to_string()).collect(),
1662 condition: None,
1663 else_step: None,
1664 loop_config: None,
1665 }
1666 }
1667
1668 #[test]
1671 fn register_and_list_workflows() {
1672 let engine = WorkflowEngine::new();
1673
1674 let workflow = Workflow {
1675 id: WorkflowId::new(),
1676 name: "test-workflow".to_string(),
1677 steps: vec![
1678 WorkflowStep {
1679 name: "step1".to_string(),
1680 fighter_name: "analyzer".to_string(),
1681 prompt_template: "Analyze: {{input}}".to_string(),
1682 timeout_secs: None,
1683 on_error: OnError::FailWorkflow,
1684 },
1685 WorkflowStep {
1686 name: "step2".to_string(),
1687 fighter_name: "summarizer".to_string(),
1688 prompt_template: "Summarize the analysis: {{step1}}".to_string(),
1689 timeout_secs: Some(60),
1690 on_error: OnError::SkipStep,
1691 },
1692 ],
1693 };
1694
1695 let id = engine.register_workflow(workflow);
1696 let workflows = engine.list_workflows();
1697 assert_eq!(workflows.len(), 1);
1698 assert_eq!(workflows[0].name, "test-workflow");
1699 assert_eq!(workflows[0].steps.len(), 2);
1700
1701 let fetched = engine.get_workflow(&id).expect("workflow should exist");
1702 assert_eq!(fetched.name, "test-workflow");
1703 }
1704
1705 #[test]
1706 fn variable_substitution_basic() {
1707 let result = expand_variables(
1708 "Analyze {{input}} for step {{step_name}}",
1709 "hello world",
1710 "analysis",
1711 &[],
1712 );
1713 assert_eq!(result, "Analyze hello world for step analysis");
1714 }
1715
1716 #[test]
1717 fn variable_substitution_previous_output() {
1718 let result = expand_variables(
1719 "Continue from: {{previous_output}}",
1720 "step 1 output",
1721 "step2",
1722 &[],
1723 );
1724 assert_eq!(result, "Continue from: step 1 output");
1725 }
1726
1727 #[test]
1728 fn variable_substitution_step_refs() {
1729 let step_results = vec![
1730 StepResult {
1731 step_name: "analyze".to_string(),
1732 response: "analysis result".to_string(),
1733 tokens_used: 100,
1734 duration_ms: 500,
1735 error: None,
1736 status: StepStatus::Completed,
1737 started_at: None,
1738 completed_at: None,
1739 },
1740 StepResult {
1741 step_name: "review".to_string(),
1742 response: "review result".to_string(),
1743 tokens_used: 80,
1744 duration_ms: 400,
1745 error: None,
1746 status: StepStatus::Completed,
1747 started_at: None,
1748 completed_at: None,
1749 },
1750 ];
1751
1752 let result = expand_variables(
1753 "Step 1 said: {{step_1}}, Step 2 said: {{step_2}}",
1754 "current",
1755 "step3",
1756 &step_results,
1757 );
1758 assert_eq!(
1759 result,
1760 "Step 1 said: analysis result, Step 2 said: review result"
1761 );
1762
1763 let result = expand_variables(
1764 "Analysis: {{analyze}}, Review: {{review}}",
1765 "current",
1766 "step3",
1767 &step_results,
1768 );
1769 assert_eq!(result, "Analysis: analysis result, Review: review result");
1770 }
1771
1772 #[test]
1773 fn workflow_run_status_display() {
1774 assert_eq!(WorkflowRunStatus::Pending.to_string(), "pending");
1775 assert_eq!(WorkflowRunStatus::Running.to_string(), "running");
1776 assert_eq!(WorkflowRunStatus::Completed.to_string(), "completed");
1777 assert_eq!(WorkflowRunStatus::Failed.to_string(), "failed");
1778 assert_eq!(
1779 WorkflowRunStatus::PartiallyCompleted.to_string(),
1780 "partially_completed"
1781 );
1782 }
1783
1784 #[test]
1785 fn get_nonexistent_run_returns_none() {
1786 let engine = WorkflowEngine::new();
1787 let run_id = WorkflowRunId::new();
1788 assert!(engine.get_run(&run_id).is_none());
1789 }
1790
1791 #[test]
1792 fn get_nonexistent_workflow_returns_none() {
1793 let engine = WorkflowEngine::new();
1794 let id = WorkflowId::new();
1795 assert!(engine.get_workflow(&id).is_none());
1796 }
1797
1798 #[test]
1799 fn workflow_engine_default() {
1800 let engine = WorkflowEngine::default();
1801 assert!(engine.list_workflows().is_empty());
1802 assert!(engine.list_runs().is_empty());
1803 }
1804
1805 #[test]
1806 fn register_multiple_workflows() {
1807 let engine = WorkflowEngine::new();
1808
1809 for i in 0..5 {
1810 let workflow = Workflow {
1811 id: WorkflowId::new(),
1812 name: format!("workflow-{}", i),
1813 steps: vec![],
1814 };
1815 engine.register_workflow(workflow);
1816 }
1817
1818 assert_eq!(engine.list_workflows().len(), 5);
1819 }
1820
1821 #[test]
1822 fn register_workflow_returns_correct_id() {
1823 let engine = WorkflowEngine::new();
1824 let wf_id = WorkflowId::new();
1825 let workflow = Workflow {
1826 id: wf_id,
1827 name: "id-test".to_string(),
1828 steps: vec![],
1829 };
1830 let returned_id = engine.register_workflow(workflow);
1831 assert_eq!(returned_id, wf_id);
1832 }
1833
1834 #[test]
1835 fn workflow_id_display() {
1836 let id = WorkflowId::new();
1837 let s = format!("{}", id);
1838 assert!(!s.is_empty());
1839 }
1840
1841 #[test]
1842 fn workflow_run_id_display() {
1843 let id = WorkflowRunId::new();
1844 let s = format!("{}", id);
1845 assert!(!s.is_empty());
1846 }
1847
1848 #[test]
1849 fn workflow_id_default() {
1850 let id = WorkflowId::default();
1851 assert!(!id.0.is_nil());
1852 }
1853
1854 #[test]
1855 fn workflow_run_id_default() {
1856 let id = WorkflowRunId::default();
1857 assert!(!id.0.is_nil());
1858 }
1859
1860 #[test]
1861 fn variable_substitution_no_variables() {
1862 let result = expand_variables("plain text with no vars", "input", "step", &[]);
1863 assert_eq!(result, "plain text with no vars");
1864 }
1865
1866 #[test]
1867 fn variable_substitution_all_variables_at_once() {
1868 let step_results = vec![StepResult {
1869 step_name: "analysis".to_string(),
1870 response: "analyzed data".to_string(),
1871 tokens_used: 50,
1872 duration_ms: 100,
1873 error: None,
1874 status: StepStatus::Completed,
1875 started_at: None,
1876 completed_at: None,
1877 }];
1878
1879 let result = expand_variables(
1880 "Input: {{input}}, Prev: {{previous_output}}, Step: {{step_name}}, S1: {{step_1}}, Named: {{analysis}}",
1881 "my input",
1882 "current_step",
1883 &step_results,
1884 );
1885 assert_eq!(
1886 result,
1887 "Input: my input, Prev: my input, Step: current_step, S1: analyzed data, Named: analyzed data"
1888 );
1889 }
1890
1891 #[test]
1892 fn variable_substitution_empty_input() {
1893 let result = expand_variables("{{input}} is here", "", "step", &[]);
1894 assert_eq!(result, " is here");
1895 }
1896
1897 #[test]
1898 fn variable_substitution_multiple_same_var() {
1899 let result = expand_variables("{{input}} and {{input}} again", "hello", "step", &[]);
1900 assert_eq!(result, "hello and hello again");
1901 }
1902
1903 #[test]
1904 fn on_error_default_is_fail_workflow() {
1905 let on_error = OnError::default();
1906 assert!(matches!(on_error, OnError::FailWorkflow));
1907 }
1908
1909 #[test]
1910 fn list_runs_for_workflow_filters_correctly() {
1911 let engine = WorkflowEngine::new();
1912 let wf_id_1 = WorkflowId::new();
1913 let wf_id_2 = WorkflowId::new();
1914
1915 assert!(engine.list_runs_for_workflow(&wf_id_1).is_empty());
1916 assert!(engine.list_runs_for_workflow(&wf_id_2).is_empty());
1917 }
1918
1919 #[test]
1920 fn workflow_step_serialization() {
1921 let step = WorkflowStep {
1922 name: "test".to_string(),
1923 fighter_name: "fighter".to_string(),
1924 prompt_template: "Do {{input}}".to_string(),
1925 timeout_secs: Some(30),
1926 on_error: OnError::SkipStep,
1927 };
1928 let json = serde_json::to_string(&step).expect("serialize");
1929 let deserialized: WorkflowStep = serde_json::from_str(&json).expect("deserialize");
1930 assert_eq!(deserialized.name, "test");
1931 assert_eq!(deserialized.timeout_secs, Some(30));
1932 }
1933
1934 #[test]
1935 fn workflow_serialization_roundtrip() {
1936 let workflow = Workflow {
1937 id: WorkflowId::new(),
1938 name: "roundtrip".to_string(),
1939 steps: vec![WorkflowStep {
1940 name: "s1".to_string(),
1941 fighter_name: "f1".to_string(),
1942 prompt_template: "{{input}}".to_string(),
1943 timeout_secs: None,
1944 on_error: OnError::RetryOnce,
1945 }],
1946 };
1947 let json = serde_json::to_string(&workflow).expect("serialize");
1948 let deserialized: Workflow = serde_json::from_str(&json).expect("deserialize");
1949 assert_eq!(deserialized.name, "roundtrip");
1950 assert_eq!(deserialized.steps.len(), 1);
1951 }
1952
1953 #[test]
1954 fn step_result_with_error() {
1955 let sr = StepResult {
1956 step_name: "failing".to_string(),
1957 response: String::new(),
1958 tokens_used: 0,
1959 duration_ms: 0,
1960 error: Some("timeout".to_string()),
1961 status: StepStatus::Failed,
1962 started_at: None,
1963 completed_at: None,
1964 };
1965 assert!(sr.error.is_some());
1966 assert_eq!(sr.error.expect("error"), "timeout");
1967 }
1968
1969 #[test]
1970 fn variable_substitution_step_ref_by_number_out_of_range() {
1971 let step_results = vec![
1972 StepResult {
1973 step_name: "a".to_string(),
1974 response: "r1".to_string(),
1975 tokens_used: 0,
1976 duration_ms: 0,
1977 error: None,
1978 status: StepStatus::Completed,
1979 started_at: None,
1980 completed_at: None,
1981 },
1982 StepResult {
1983 step_name: "b".to_string(),
1984 response: "r2".to_string(),
1985 tokens_used: 0,
1986 duration_ms: 0,
1987 error: None,
1988 status: StepStatus::Completed,
1989 started_at: None,
1990 completed_at: None,
1991 },
1992 ];
1993 let result = expand_variables("{{step_5}}", "input", "step", &step_results);
1994 assert_eq!(result, "{{step_5}}");
1995 }
1996
1997 #[tokio::test]
2000 async fn dag_linear_execution() {
2001 let steps = vec![
2002 dag_step("a", &[]),
2003 dag_step("b", &["a"]),
2004 dag_step("c", &["b"]),
2005 ];
2006 let executor = MockExecutor::new()
2007 .with_response("a", "result_a")
2008 .with_response("b", "result_b")
2009 .with_response("c", "result_c");
2010
2011 let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2012 assert_eq!(result.status, WorkflowRunStatus::Completed);
2013 assert_eq!(result.step_results.len(), 3);
2014 assert_eq!(result.step_results["a"].response, "result_a");
2015 assert_eq!(result.step_results["b"].response, "result_b");
2016 assert_eq!(result.step_results["c"].response, "result_c");
2017 }
2018
2019 #[tokio::test]
2020 async fn dag_fan_out_execution() {
2021 let steps = vec![
2022 dag_step("root", &[]),
2023 dag_step("branch1", &["root"]),
2024 dag_step("branch2", &["root"]),
2025 dag_step("branch3", &["root"]),
2026 ];
2027 let executor = MockExecutor::new()
2028 .with_response("root", "root_out")
2029 .with_response("branch1", "b1_out")
2030 .with_response("branch2", "b2_out")
2031 .with_response("branch3", "b3_out");
2032
2033 let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2034 assert_eq!(result.status, WorkflowRunStatus::Completed);
2035 assert_eq!(result.step_results.len(), 4);
2036 assert_eq!(result.step_results["branch1"].response, "b1_out");
2038 assert_eq!(result.step_results["branch2"].response, "b2_out");
2039 assert_eq!(result.step_results["branch3"].response, "b3_out");
2040 }
2041
2042 #[tokio::test]
2043 async fn dag_fan_in_execution() {
2044 let steps = vec![
2045 dag_step("a", &[]),
2046 dag_step("b", &[]),
2047 dag_step("c", &[]),
2048 dag_step("join", &["a", "b", "c"]),
2049 ];
2050 let executor = MockExecutor::new()
2051 .with_response("a", "ra")
2052 .with_response("b", "rb")
2053 .with_response("c", "rc")
2054 .with_response("join", "joined");
2055
2056 let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2057 assert_eq!(result.status, WorkflowRunStatus::Completed);
2058 assert_eq!(result.step_results["join"].response, "joined");
2059 assert_eq!(result.execution_trace.len(), 2);
2061 let first_wave = &result.execution_trace[0].steps;
2062 assert!(first_wave.contains(&"a".to_string()));
2063 assert!(first_wave.contains(&"b".to_string()));
2064 assert!(first_wave.contains(&"c".to_string()));
2065 }
2066
2067 #[tokio::test]
2068 async fn dag_diamond_execution() {
2069 let steps = vec![
2070 dag_step("root", &[]),
2071 dag_step("left", &["root"]),
2072 dag_step("right", &["root"]),
2073 dag_step("join", &["left", "right"]),
2074 ];
2075 let executor = MockExecutor::new()
2076 .with_response("root", "root_out")
2077 .with_response("left", "left_out")
2078 .with_response("right", "right_out")
2079 .with_response("join", "joined");
2080
2081 let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2082 assert_eq!(result.status, WorkflowRunStatus::Completed);
2083 assert_eq!(result.step_results.len(), 4);
2084 let wave2 = &result.execution_trace[1].steps;
2086 assert!(wave2.contains(&"left".to_string()));
2087 assert!(wave2.contains(&"right".to_string()));
2088 }
2089
2090 #[tokio::test]
2091 async fn dag_parallel_actually_concurrent() {
2092 let steps = vec![dag_step("a", &[]), dag_step("b", &[]), dag_step("c", &[])];
2095 let executor = TimedMockExecutor { delay_ms: 50 };
2096
2097 let start = Instant::now();
2098 let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2099 let elapsed = start.elapsed();
2100
2101 assert_eq!(result.status, WorkflowRunStatus::Completed);
2102 assert_eq!(result.step_results.len(), 3);
2103 assert!(
2106 elapsed.as_millis() < 120,
2107 "parallel execution took {}ms, expected ~50ms",
2108 elapsed.as_millis()
2109 );
2110 }
2111
2112 #[tokio::test]
2113 async fn dag_condition_if_success() {
2114 let mut steps = vec![dag_step("a", &[]), dag_step("b", &["a"])];
2115 steps[1].condition = Some(Condition::IfSuccess {
2116 step: "a".to_string(),
2117 });
2118 let executor = MockExecutor::new()
2119 .with_response("a", "ok")
2120 .with_response("b", "b_ran");
2121
2122 let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2123 assert_eq!(result.step_results["b"].status, StepStatus::Completed);
2124 assert_eq!(result.step_results["b"].response, "b_ran");
2125 }
2126
2127 #[tokio::test]
2128 async fn dag_condition_skips_step() {
2129 let mut steps = vec![dag_step("a", &[]), dag_step("b", &["a"])];
2130 steps[1].condition = Some(Condition::IfFailure {
2131 step: "a".to_string(),
2132 });
2133 let executor = MockExecutor::new()
2134 .with_response("a", "ok")
2135 .with_response("b", "should_not_run");
2136
2137 let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2138 assert_eq!(result.step_results["b"].status, StepStatus::Skipped);
2139 }
2140
2141 #[tokio::test]
2142 async fn dag_condition_if_output() {
2143 let mut steps = vec![dag_step("a", &[]), dag_step("b", &["a"])];
2144 steps[1].condition = Some(Condition::IfOutput {
2145 step: "a".to_string(),
2146 contains: "magic".to_string(),
2147 });
2148 let executor = MockExecutor::new()
2149 .with_response("a", "this has magic inside")
2150 .with_response("b", "b_ran");
2151
2152 let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2153 assert_eq!(result.step_results["b"].status, StepStatus::Completed);
2154 }
2155
2156 #[tokio::test]
2157 async fn dag_condition_if_output_no_match() {
2158 let mut steps = vec![dag_step("a", &[]), dag_step("b", &["a"])];
2159 steps[1].condition = Some(Condition::IfOutput {
2160 step: "a".to_string(),
2161 contains: "magic".to_string(),
2162 });
2163 let executor = MockExecutor::new()
2164 .with_response("a", "no special word here")
2165 .with_response("b", "should_not_run");
2166
2167 let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2168 assert_eq!(result.step_results["b"].status, StepStatus::Skipped);
2169 }
2170
2171 #[tokio::test]
2172 async fn dag_foreach_loop() {
2173 let mut steps = vec![dag_step("source", &[]), dag_step("process", &["source"])];
2174 steps[0].prompt_template = "{{input}}".to_string();
2175 steps[1].loop_config = Some(LoopConfig::ForEach {
2176 source_step: "source".to_string(),
2177 max_iterations: 100,
2178 });
2179 steps[1].prompt_template = "process item: {{loop.item}}".to_string();
2180
2181 let executor =
2182 MockExecutor::new().with_response("source", r#"["apple", "banana", "cherry"]"#);
2183
2184 let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2185 assert_eq!(result.status, WorkflowRunStatus::Completed);
2186 let process_result = &result.step_results["process"];
2187 assert!(
2189 process_result.response.contains("process item: apple"),
2190 "response: {}",
2191 process_result.response
2192 );
2193 }
2194
2195 #[tokio::test]
2196 async fn dag_while_loop() {
2197 let mut steps = vec![dag_step("counter", &[])];
2198 steps[0].loop_config = Some(LoopConfig::While {
2199 condition: Condition::Expression("true".to_string()),
2200 max_iterations: 5,
2201 });
2202
2203 let executor = MockExecutor::new().with_response("counter", "tick");
2204
2205 let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2206 assert_eq!(result.status, WorkflowRunStatus::Completed);
2207 let counter_result = &result.step_results["counter"];
2208 let ticks: Vec<&str> = counter_result.response.split('\n').collect();
2210 assert_eq!(ticks.len(), 5);
2211 }
2212
2213 #[tokio::test]
2214 async fn dag_retry_loop_succeeds_eventually() {
2215 let mut steps = vec![dag_step("flaky", &[])];
2216 steps[0].loop_config = Some(LoopConfig::Retry {
2217 max_retries: 3,
2218 backoff_ms: 1, backoff_multiplier: 1.0,
2220 });
2221
2222 let executor = FailNTimesMockExecutor::new(2);
2224
2225 let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2226 assert_eq!(result.status, WorkflowRunStatus::Completed);
2227 assert!(result.step_results["flaky"].error.is_none());
2228 assert!(
2229 result.step_results["flaky"]
2230 .response
2231 .contains("success on attempt 3")
2232 );
2233 }
2234
2235 #[tokio::test]
2236 async fn dag_retry_loop_exhausts_retries() {
2237 let mut steps = vec![dag_step("flaky", &[])];
2238 steps[0].loop_config = Some(LoopConfig::Retry {
2239 max_retries: 2,
2240 backoff_ms: 1,
2241 backoff_multiplier: 1.0,
2242 });
2243
2244 let executor = FailNTimesMockExecutor::new(10);
2246
2247 let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2248 assert!(result.step_results["flaky"].error.is_some());
2249 }
2250
2251 #[tokio::test]
2252 async fn dag_step_failure_with_skip() {
2253 let mut steps = vec![
2254 dag_step("a", &[]),
2255 dag_step("b", &["a"]),
2256 dag_step("c", &["b"]),
2257 ];
2258 steps[1].on_error = OnError::SkipStep;
2259
2260 let executor = MockExecutor::new()
2261 .with_response("a", "ok")
2262 .with_failure("b", "b failed")
2263 .with_response("c", "c_ran");
2264
2265 let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2266 assert!(result.step_results.contains_key("c"));
2269 }
2270
2271 #[tokio::test]
2272 async fn dag_step_failure_cascades() {
2273 let steps = vec![
2274 dag_step("a", &[]),
2275 dag_step("b", &["a"]),
2276 dag_step("c", &["b"]),
2277 ];
2278
2279 let executor = MockExecutor::new()
2280 .with_response("a", "ok")
2281 .with_failure("b", "b failed")
2282 .with_response("c", "should_not_run");
2283
2284 let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2285 assert!(result.step_results["b"].error.is_some());
2286 assert_eq!(result.step_results["c"].status, StepStatus::Cancelled);
2288 }
2289
2290 #[tokio::test]
2291 async fn dag_empty_workflow() {
2292 let executor = MockExecutor::new();
2293 let result = execute_dag("test", &[], "input", Arc::new(executor)).await;
2294 assert_eq!(result.status, WorkflowRunStatus::Failed);
2295 assert!(!result.validation_errors.is_empty());
2296 }
2297
2298 #[tokio::test]
2299 async fn dag_single_step() {
2300 let steps = vec![dag_step("only", &[])];
2301 let executor = MockExecutor::new().with_response("only", "done");
2302
2303 let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2304 assert_eq!(result.status, WorkflowRunStatus::Completed);
2305 assert_eq!(result.step_results.len(), 1);
2306 assert_eq!(result.step_results["only"].response, "done");
2307 }
2308
2309 #[tokio::test]
2310 async fn dag_all_steps_fail() {
2311 let steps = vec![dag_step("a", &[]), dag_step("b", &[])];
2312
2313 let executor = MockExecutor::new()
2314 .with_failure("a", "a failed")
2315 .with_failure("b", "b failed");
2316
2317 let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2318 assert_eq!(result.status, WorkflowRunStatus::Failed);
2319 assert!(!result.dead_letters.is_empty());
2320 }
2321
2322 #[tokio::test]
2323 async fn dag_partial_completion() {
2324 let steps = vec![dag_step("good", &[]), dag_step("bad", &[])];
2325
2326 let executor = MockExecutor::new()
2327 .with_response("good", "ok")
2328 .with_failure("bad", "nope");
2329
2330 let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2331 assert_eq!(result.status, WorkflowRunStatus::PartiallyCompleted);
2332 }
2333
2334 #[tokio::test]
2335 async fn dag_validation_rejects_cycle() {
2336 let steps = vec![dag_step("a", &["b"]), dag_step("b", &["a"])];
2337 let executor = MockExecutor::new();
2338 let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2339 assert_eq!(result.status, WorkflowRunStatus::Failed);
2340 assert!(!result.validation_errors.is_empty());
2341 }
2342
2343 #[tokio::test]
2344 async fn dag_all_steps_skipped() {
2345 let mut steps = vec![dag_step("a", &[]), dag_step("b", &[])];
2346 steps[0].condition = Some(Condition::Expression("false".to_string()));
2347 steps[1].condition = Some(Condition::Expression("false".to_string()));
2348
2349 let executor = MockExecutor::new();
2350 let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2351 assert_eq!(result.status, WorkflowRunStatus::Completed);
2353 assert_eq!(result.step_results["a"].status, StepStatus::Skipped);
2354 assert_eq!(result.step_results["b"].status, StepStatus::Skipped);
2355 }
2356
2357 #[test]
2360 fn dag_variables_step_output() {
2361 let mut results = HashMap::new();
2362 results.insert(
2363 "analyze".to_string(),
2364 StepResult {
2365 step_name: "analyze".to_string(),
2366 response: "found 3 bugs".to_string(),
2367 tokens_used: 100,
2368 duration_ms: 500,
2369 error: None,
2370 status: StepStatus::Completed,
2371 started_at: None,
2372 completed_at: None,
2373 },
2374 );
2375
2376 let expanded = expand_dag_variables(
2377 "Result: {{analyze.output}}",
2378 "input",
2379 "next",
2380 &results,
2381 None,
2382 );
2383 assert_eq!(expanded, "Result: found 3 bugs");
2384 }
2385
2386 #[test]
2387 fn dag_variables_step_status() {
2388 let mut results = HashMap::new();
2389 results.insert(
2390 "build".to_string(),
2391 StepResult {
2392 step_name: "build".to_string(),
2393 response: "ok".to_string(),
2394 tokens_used: 50,
2395 duration_ms: 300,
2396 error: None,
2397 status: StepStatus::Completed,
2398 started_at: None,
2399 completed_at: None,
2400 },
2401 );
2402
2403 let expanded = expand_dag_variables(
2404 "Build status: {{build.status}}",
2405 "input",
2406 "deploy",
2407 &results,
2408 None,
2409 );
2410 assert_eq!(expanded, "Build status: completed");
2411 }
2412
2413 #[test]
2414 fn dag_variables_step_duration() {
2415 let mut results = HashMap::new();
2416 results.insert(
2417 "fetch".to_string(),
2418 StepResult {
2419 step_name: "fetch".to_string(),
2420 response: "data".to_string(),
2421 tokens_used: 10,
2422 duration_ms: 1234,
2423 error: None,
2424 status: StepStatus::Completed,
2425 started_at: None,
2426 completed_at: None,
2427 },
2428 );
2429
2430 let expanded = expand_dag_variables(
2431 "Fetch took {{fetch.duration_ms}}ms",
2432 "input",
2433 "next",
2434 &results,
2435 None,
2436 );
2437 assert_eq!(expanded, "Fetch took 1234ms");
2438 }
2439
2440 #[test]
2441 fn dag_variables_loop_state() {
2442 let results = HashMap::new();
2443 let mut loop_state = LoopState::new();
2444 loop_state.index = 2;
2445 loop_state.item = Some("banana".to_string());
2446
2447 let expanded = expand_dag_variables(
2448 "Item {{loop.index}}: {{loop.item}}",
2449 "input",
2450 "process",
2451 &results,
2452 Some(&loop_state),
2453 );
2454 assert_eq!(expanded, "Item 2: banana");
2455 }
2456
2457 #[test]
2458 fn dag_variables_json_path() {
2459 let mut results = HashMap::new();
2460 results.insert(
2461 "api".to_string(),
2462 StepResult {
2463 step_name: "api".to_string(),
2464 response: r#"{"user": {"name": "Alice", "age": 30}}"#.to_string(),
2465 tokens_used: 10,
2466 duration_ms: 100,
2467 error: None,
2468 status: StepStatus::Completed,
2469 started_at: None,
2470 completed_at: None,
2471 },
2472 );
2473
2474 let expanded = expand_dag_variables(
2475 "Name: {{api.output.user.name}}",
2476 "input",
2477 "next",
2478 &results,
2479 None,
2480 );
2481 assert_eq!(expanded, "Name: Alice");
2482 }
2483
2484 #[test]
2485 fn dag_variables_transform_uppercase() {
2486 let mut results = HashMap::new();
2487 results.insert(
2488 "greet".to_string(),
2489 StepResult {
2490 step_name: "greet".to_string(),
2491 response: "hello world".to_string(),
2492 tokens_used: 10,
2493 duration_ms: 50,
2494 error: None,
2495 status: StepStatus::Completed,
2496 started_at: None,
2497 completed_at: None,
2498 },
2499 );
2500
2501 let expanded = expand_dag_variables(
2502 "{{greet.output | uppercase}}",
2503 "input",
2504 "next",
2505 &results,
2506 None,
2507 );
2508 assert_eq!(expanded, "HELLO WORLD");
2509 }
2510
2511 #[test]
2512 fn dag_variables_transform_lowercase() {
2513 let mut results = HashMap::new();
2514 results.insert(
2515 "shout".to_string(),
2516 StepResult {
2517 step_name: "shout".to_string(),
2518 response: "LOUD NOISE".to_string(),
2519 tokens_used: 10,
2520 duration_ms: 50,
2521 error: None,
2522 status: StepStatus::Completed,
2523 started_at: None,
2524 completed_at: None,
2525 },
2526 );
2527
2528 let expanded = expand_dag_variables(
2529 "{{shout.output | lowercase}}",
2530 "input",
2531 "next",
2532 &results,
2533 None,
2534 );
2535 assert_eq!(expanded, "loud noise");
2536 }
2537
2538 #[test]
2539 fn dag_variables_transform_json_extract() {
2540 let mut results = HashMap::new();
2541 results.insert(
2542 "data".to_string(),
2543 StepResult {
2544 step_name: "data".to_string(),
2545 response: r#"{"key": "value123"}"#.to_string(),
2546 tokens_used: 10,
2547 duration_ms: 50,
2548 error: None,
2549 status: StepStatus::Completed,
2550 started_at: None,
2551 completed_at: None,
2552 },
2553 );
2554
2555 let expanded = expand_dag_variables(
2556 "{{data.output | json_extract \"$.key\"}}",
2557 "input",
2558 "next",
2559 &results,
2560 None,
2561 );
2562 assert_eq!(expanded, "value123");
2563 }
2564
2565 #[test]
2566 fn json_path_extract_simple() {
2567 let result = json_path_extract(r#"{"name": "Bob"}"#, "name");
2568 assert_eq!(result, "Bob");
2569 }
2570
2571 #[test]
2572 fn json_path_extract_nested() {
2573 let result = json_path_extract(r#"{"a": {"b": {"c": 42}}}"#, "a.b.c");
2574 assert_eq!(result, "42");
2575 }
2576
2577 #[test]
2578 fn json_path_extract_dollar_prefix() {
2579 let result = json_path_extract(r#"{"key": "val"}"#, "$.key");
2580 assert_eq!(result, "val");
2581 }
2582
2583 #[test]
2584 fn json_path_extract_missing_key() {
2585 let result = json_path_extract(r#"{"key": "val"}"#, "missing");
2586 assert_eq!(result, "");
2587 }
2588
2589 #[test]
2590 fn json_path_extract_invalid_json() {
2591 let result = json_path_extract("not json", "key");
2592 assert_eq!(result, "not json");
2593 }
2594
2595 #[test]
2598 fn step_status_display() {
2599 assert_eq!(StepStatus::Pending.to_string(), "pending");
2600 assert_eq!(StepStatus::Running.to_string(), "running");
2601 assert_eq!(StepStatus::Completed.to_string(), "completed");
2602 assert_eq!(StepStatus::Failed.to_string(), "failed");
2603 assert_eq!(StepStatus::Skipped.to_string(), "skipped");
2604 assert_eq!(StepStatus::Cancelled.to_string(), "cancelled");
2605 }
2606
2607 #[test]
2610 fn on_error_fallback_serialization() {
2611 let on_error = OnError::Fallback {
2612 step: "backup".to_string(),
2613 };
2614 let json = serde_json::to_string(&on_error).expect("serialize");
2615 let deser: OnError = serde_json::from_str(&json).expect("deserialize");
2616 assert!(matches!(deser, OnError::Fallback { step } if step == "backup"));
2617 }
2618
2619 #[test]
2620 fn on_error_catch_and_continue_serialization() {
2621 let on_error = OnError::CatchAndContinue {
2622 error_handler: "handler".to_string(),
2623 };
2624 let json = serde_json::to_string(&on_error).expect("serialize");
2625 let deser: OnError = serde_json::from_str(&json).expect("deserialize");
2626 assert!(
2627 matches!(deser, OnError::CatchAndContinue { error_handler } if error_handler == "handler")
2628 );
2629 }
2630
2631 #[test]
2632 fn on_error_circuit_breaker_serialization() {
2633 let on_error = OnError::CircuitBreaker {
2634 max_failures: 5,
2635 cooldown_secs: 60,
2636 };
2637 let json = serde_json::to_string(&on_error).expect("serialize");
2638 let deser: OnError = serde_json::from_str(&json).expect("deserialize");
2639 assert!(matches!(
2640 deser,
2641 OnError::CircuitBreaker {
2642 max_failures: 5,
2643 cooldown_secs: 60
2644 }
2645 ));
2646 }
2647
2648 #[test]
2651 fn circuit_breaker_default_closed() {
2652 let cb = CircuitBreakerState::default();
2653 assert!(!cb.is_open(3, 60));
2654 }
2655
2656 #[test]
2657 fn circuit_breaker_opens_after_max_failures() {
2658 let mut cb = CircuitBreakerState::default();
2659 cb.record_failure();
2660 cb.record_failure();
2661 cb.record_failure();
2662 assert!(cb.is_open(3, 60));
2663 }
2664
2665 #[test]
2666 fn circuit_breaker_resets_on_success() {
2667 let mut cb = CircuitBreakerState::default();
2668 cb.record_failure();
2669 cb.record_failure();
2670 cb.record_success();
2671 assert!(!cb.is_open(3, 60));
2672 assert_eq!(cb.consecutive_failures, 0);
2673 }
2674
2675 #[test]
2678 fn register_dag_workflow_valid() {
2679 let engine = WorkflowEngine::new();
2680 let wf = DagWorkflow {
2681 id: WorkflowId::new(),
2682 name: "test-dag".to_string(),
2683 steps: vec![dag_step("a", &[]), dag_step("b", &["a"])],
2684 };
2685 let result = engine.register_dag_workflow(wf);
2686 assert!(result.is_ok());
2687 }
2688
2689 #[test]
2690 fn register_dag_workflow_with_cycle_fails() {
2691 let engine = WorkflowEngine::new();
2692 let wf = DagWorkflow {
2693 id: WorkflowId::new(),
2694 name: "bad-dag".to_string(),
2695 steps: vec![dag_step("a", &["b"]), dag_step("b", &["a"])],
2696 };
2697 let result = engine.register_dag_workflow(wf);
2698 assert!(result.is_err());
2699 }
2700
2701 #[test]
2702 fn list_dag_workflows() {
2703 let engine = WorkflowEngine::new();
2704 let wf = DagWorkflow {
2705 id: WorkflowId::new(),
2706 name: "dag1".to_string(),
2707 steps: vec![dag_step("a", &[])],
2708 };
2709 engine.register_dag_workflow(wf).expect("should register");
2710 assert_eq!(engine.list_dag_workflows().len(), 1);
2711 }
2712
2713 #[test]
2714 fn get_dag_workflow() {
2715 let engine = WorkflowEngine::new();
2716 let id = WorkflowId::new();
2717 let wf = DagWorkflow {
2718 id,
2719 name: "dag1".to_string(),
2720 steps: vec![dag_step("a", &[])],
2721 };
2722 engine.register_dag_workflow(wf).expect("should register");
2723 let fetched = engine.get_dag_workflow(&id).expect("should exist");
2724 assert_eq!(fetched.name, "dag1");
2725 }
2726
2727 #[test]
2728 fn get_nonexistent_dag_workflow() {
2729 let engine = WorkflowEngine::new();
2730 assert!(engine.get_dag_workflow(&WorkflowId::new()).is_none());
2731 }
2732
2733 #[tokio::test]
2736 async fn dag_dead_letters_populated_on_failure() {
2737 let steps = vec![dag_step("a", &[])];
2738 let executor = MockExecutor::new().with_failure("a", "catastrophic failure");
2739
2740 let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2741 assert!(!result.dead_letters.is_empty());
2742 assert_eq!(result.dead_letters[0].step_name, "a");
2743 assert_eq!(result.dead_letters[0].error, "catastrophic failure");
2744 }
2745
2746 #[tokio::test]
2749 async fn dag_execution_trace_records_waves() {
2750 let steps = vec![
2751 dag_step("a", &[]),
2752 dag_step("b", &["a"]),
2753 dag_step("c", &["b"]),
2754 ];
2755 let executor = MockExecutor::new()
2756 .with_response("a", "ok")
2757 .with_response("b", "ok")
2758 .with_response("c", "ok");
2759
2760 let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2761 assert_eq!(result.execution_trace.len(), 3);
2763 assert_eq!(result.execution_trace[0].steps, vec!["a"]);
2764 assert_eq!(result.execution_trace[1].steps, vec!["b"]);
2765 assert_eq!(result.execution_trace[2].steps, vec!["c"]);
2766 }
2767
2768 #[test]
2771 fn dag_step_fallback_step_extraction() {
2772 let mut step = dag_step("test", &[]);
2773 assert!(step.fallback_step().is_none());
2774
2775 step.on_error = OnError::Fallback {
2776 step: "backup".to_string(),
2777 };
2778 assert_eq!(step.fallback_step(), Some("backup".to_string()));
2779
2780 step.on_error = OnError::CatchAndContinue {
2781 error_handler: "handler".to_string(),
2782 };
2783 assert_eq!(step.fallback_step(), Some("handler".to_string()));
2784 }
2785
2786 #[test]
2789 fn dag_workflow_serialization_roundtrip() {
2790 let wf = DagWorkflow {
2791 id: WorkflowId::new(),
2792 name: "test-dag".to_string(),
2793 steps: vec![dag_step("a", &[]), dag_step("b", &["a"])],
2794 };
2795 let json = serde_json::to_string(&wf).expect("serialize");
2796 let deser: DagWorkflow = serde_json::from_str(&json).expect("deserialize");
2797 assert_eq!(deser.name, "test-dag");
2798 assert_eq!(deser.steps.len(), 2);
2799 }
2800
2801 #[test]
2802 fn dag_workflow_step_with_condition_serialization() {
2803 let mut step = dag_step("test", &["dep1"]);
2804 step.condition = Some(Condition::IfSuccess {
2805 step: "dep1".to_string(),
2806 });
2807 step.else_step = Some("fallback".to_string());
2808 let json = serde_json::to_string(&step).expect("serialize");
2809 let deser: DagWorkflowStep = serde_json::from_str(&json).expect("deserialize");
2810 assert!(deser.condition.is_some());
2811 assert_eq!(deser.else_step, Some("fallback".to_string()));
2812 }
2813
2814 #[test]
2815 fn dead_letter_entry_serialization() {
2816 let entry = DeadLetterEntry {
2817 step_name: "failed_step".to_string(),
2818 error: "boom".to_string(),
2819 input: "test input".to_string(),
2820 failed_at: Utc::now(),
2821 };
2822 let json = serde_json::to_string(&entry).expect("serialize");
2823 let deser: DeadLetterEntry = serde_json::from_str(&json).expect("deserialize");
2824 assert_eq!(deser.step_name, "failed_step");
2825 assert_eq!(deser.error, "boom");
2826 }
2827
2828 #[test]
2829 fn execution_trace_entry_serialization() {
2830 let entry = ExecutionTraceEntry {
2831 steps: vec!["a".to_string(), "b".to_string()],
2832 started_at: Utc::now(),
2833 completed_at: Some(Utc::now()),
2834 };
2835 let json = serde_json::to_string(&entry).expect("serialize");
2836 let deser: ExecutionTraceEntry = serde_json::from_str(&json).expect("deserialize");
2837 assert_eq!(deser.steps.len(), 2);
2838 }
2839
2840 #[test]
2841 fn workflow_run_with_new_fields_serialization() {
2842 let run = WorkflowRun {
2843 id: WorkflowRunId::new(),
2844 workflow_id: WorkflowId::new(),
2845 status: WorkflowRunStatus::PartiallyCompleted,
2846 step_results: Vec::new(),
2847 started_at: Utc::now(),
2848 completed_at: None,
2849 dead_letters: vec![DeadLetterEntry {
2850 step_name: "x".to_string(),
2851 error: "err".to_string(),
2852 input: "in".to_string(),
2853 failed_at: Utc::now(),
2854 }],
2855 execution_trace: Vec::new(),
2856 };
2857 let json = serde_json::to_string(&run).expect("serialize");
2858 let deser: WorkflowRun = serde_json::from_str(&json).expect("deserialize");
2859 assert_eq!(deser.status, WorkflowRunStatus::PartiallyCompleted);
2860 assert_eq!(deser.dead_letters.len(), 1);
2861 }
2862
2863 #[test]
2864 fn step_result_with_new_fields() {
2865 let sr = StepResult {
2866 step_name: "test".to_string(),
2867 response: "ok".to_string(),
2868 tokens_used: 10,
2869 duration_ms: 100,
2870 error: None,
2871 status: StepStatus::Completed,
2872 started_at: Some(Utc::now()),
2873 completed_at: Some(Utc::now()),
2874 };
2875 let json = serde_json::to_string(&sr).expect("serialize");
2876 let deser: StepResult = serde_json::from_str(&json).expect("deserialize");
2877 assert_eq!(deser.status, StepStatus::Completed);
2878 assert!(deser.started_at.is_some());
2879 }
2880
2881 #[tokio::test]
2884 async fn dag_fallback_on_error() {
2885 let mut steps = vec![dag_step("main", &[]), dag_step("backup", &[])];
2886 steps[0].on_error = OnError::Fallback {
2887 step: "backup".to_string(),
2888 };
2889
2890 let executor = MockExecutor::new()
2891 .with_failure("main", "main failed")
2892 .with_response("backup", "backup result");
2893
2894 let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2895 let main_result = &result.step_results["main"];
2898 assert_eq!(main_result.response, "backup result");
2899 }
2900
2901 #[tokio::test]
2902 async fn dag_catch_and_continue() {
2903 let mut steps = vec![
2904 dag_step("risky", &[]),
2905 dag_step("handler", &[]),
2906 dag_step("next", &["risky"]),
2907 ];
2908 steps[0].on_error = OnError::CatchAndContinue {
2909 error_handler: "handler".to_string(),
2910 };
2911
2912 let executor = MockExecutor::new()
2913 .with_failure("risky", "oops")
2914 .with_response("handler", "handled")
2915 .with_response("next", "continued");
2916
2917 let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2918 assert!(result.step_results.contains_key("next"));
2920 }
2921
2922 struct ConcurrencyProofExecutor {
2926 delay_ms: u64,
2927 timings: Arc<tokio::sync::Mutex<Vec<(String, Instant, Instant)>>>,
2929 }
2930
2931 impl ConcurrencyProofExecutor {
2932 fn new(delay_ms: u64) -> Self {
2933 Self {
2934 delay_ms,
2935 timings: Arc::new(tokio::sync::Mutex::new(Vec::new())),
2936 }
2937 }
2938 }
2939
2940 #[async_trait::async_trait]
2941 impl StepExecutor for ConcurrencyProofExecutor {
2942 async fn execute(
2943 &self,
2944 step: &DagWorkflowStep,
2945 _input: &str,
2946 _step_results: &HashMap<String, StepResult>,
2947 _loop_state: Option<&LoopState>,
2948 ) -> Result<StepResult, String> {
2949 let start = Instant::now();
2950 tokio::time::sleep(Duration::from_millis(self.delay_ms)).await;
2951 let end = Instant::now();
2952
2953 self.timings
2954 .lock()
2955 .await
2956 .push((step.name.clone(), start, end));
2957
2958 Ok(StepResult {
2959 step_name: step.name.clone(),
2960 response: format!("done-{}", step.name),
2961 tokens_used: 10,
2962 duration_ms: self.delay_ms,
2963 error: None,
2964 status: StepStatus::Completed,
2965 started_at: Some(Utc::now()),
2966 completed_at: Some(Utc::now()),
2967 })
2968 }
2969 }
2970
2971 #[tokio::test]
2973 async fn dag_three_independent_steps_parallel_timing() {
2974 let steps = vec![dag_step("x", &[]), dag_step("y", &[]), dag_step("z", &[])];
2975 let executor = ConcurrencyProofExecutor::new(50);
2976 let timings = Arc::clone(&executor.timings);
2977
2978 let start = Instant::now();
2979 let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
2980 let elapsed = start.elapsed();
2981
2982 assert_eq!(result.status, WorkflowRunStatus::Completed);
2983 assert_eq!(result.step_results.len(), 3);
2984 assert!(
2986 elapsed.as_millis() < 100,
2987 "3 independent 50ms steps took {}ms, should be ~50ms for parallel execution",
2988 elapsed.as_millis()
2989 );
2990
2991 let recorded = timings.lock().await;
2993 assert_eq!(recorded.len(), 3);
2994 let starts: Vec<_> = recorded.iter().map(|(_, s, _)| *s).collect();
2996 let earliest = starts.iter().min().copied().expect("should have starts");
2997 for s in &starts {
2998 let diff = s.duration_since(earliest).as_millis();
2999 assert!(
3000 diff < 20,
3001 "start time spread {}ms too large for parallel execution",
3002 diff
3003 );
3004 }
3005 }
3006
3007 #[tokio::test]
3009 async fn dag_fan_out_fan_in_timing() {
3010 let steps = vec![
3011 dag_step("a", &[]),
3012 dag_step("b", &["a"]),
3013 dag_step("c", &["a"]),
3014 dag_step("d", &["a"]),
3015 dag_step("e", &["b", "c", "d"]),
3016 ];
3017 let executor = TimedMockExecutor { delay_ms: 30 };
3018
3019 let start = Instant::now();
3020 let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
3021 let elapsed = start.elapsed();
3022
3023 assert_eq!(result.status, WorkflowRunStatus::Completed);
3024 assert_eq!(result.step_results.len(), 5);
3025
3026 assert!(
3029 elapsed.as_millis() < 130,
3030 "fan-out/fan-in took {}ms, expected ~90ms",
3031 elapsed.as_millis()
3032 );
3033
3034 assert_eq!(result.execution_trace.len(), 3);
3036 let wave2 = &result.execution_trace[1].steps;
3038 assert_eq!(wave2.len(), 3);
3039 }
3040
3041 #[tokio::test]
3043 async fn dag_fan_in_parallel_roots() {
3044 let steps = vec![
3045 dag_step("r1", &[]),
3046 dag_step("r2", &[]),
3047 dag_step("r3", &[]),
3048 dag_step("join", &["r1", "r2", "r3"]),
3049 ];
3050 let executor = MockExecutor::new()
3051 .with_response("r1", "out1")
3052 .with_response("r2", "out2")
3053 .with_response("r3", "out3")
3054 .with_response("join", "merged");
3055
3056 let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
3057 assert_eq!(result.status, WorkflowRunStatus::Completed);
3058 assert_eq!(result.step_results["join"].response, "merged");
3059 assert_eq!(result.execution_trace.len(), 2);
3061 assert_eq!(result.execution_trace[0].steps.len(), 3);
3062 }
3063
3064 #[tokio::test]
3066 async fn dag_diamond_dependency_parallel() {
3067 let steps = vec![
3068 dag_step("a", &[]),
3069 dag_step("b", &["a"]),
3070 dag_step("c", &["a"]),
3071 dag_step("d", &["b", "c"]),
3072 ];
3073 let executor = TimedMockExecutor { delay_ms: 30 };
3074
3075 let start = Instant::now();
3076 let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
3077 let elapsed = start.elapsed();
3078
3079 assert_eq!(result.status, WorkflowRunStatus::Completed);
3080 assert_eq!(result.execution_trace.len(), 3);
3082 let wave2 = &result.execution_trace[1].steps;
3084 assert!(wave2.contains(&"b".to_string()));
3085 assert!(wave2.contains(&"c".to_string()));
3086 assert!(
3088 elapsed.as_millis() < 120,
3089 "diamond took {}ms, expected ~90ms",
3090 elapsed.as_millis()
3091 );
3092 }
3093
3094 #[tokio::test]
3096 async fn dag_conditional_skip_in_dag() {
3097 let mut steps = vec![
3098 dag_step("check", &[]),
3099 dag_step("true_branch", &["check"]),
3100 dag_step("false_branch", &["check"]),
3101 ];
3102 steps[1].condition = Some(Condition::IfSuccess {
3104 step: "check".to_string(),
3105 });
3106 steps[2].condition = Some(Condition::IfFailure {
3108 step: "check".to_string(),
3109 });
3110
3111 let executor = MockExecutor::new()
3112 .with_response("check", "all good")
3113 .with_response("true_branch", "ran")
3114 .with_response("false_branch", "should_not_run");
3115
3116 let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
3117 assert_eq!(
3118 result.step_results["true_branch"].status,
3119 StepStatus::Completed
3120 );
3121 assert_eq!(
3122 result.step_results["false_branch"].status,
3123 StepStatus::Skipped
3124 );
3125 }
3126
3127 #[tokio::test]
3129 async fn dag_loop_foreach_within_dag() {
3130 let mut steps = vec![
3131 dag_step("data", &[]),
3132 dag_step("process", &["data"]),
3133 dag_step("summary", &["process"]),
3134 ];
3135 steps[1].loop_config = Some(LoopConfig::ForEach {
3136 source_step: "data".to_string(),
3137 max_iterations: 10,
3138 });
3139 steps[1].prompt_template = "process: {{loop.item}}".to_string();
3140
3141 let executor = MockExecutor::new()
3142 .with_response("data", r#"["red", "green", "blue"]"#)
3143 .with_response("summary", "done");
3144
3145 let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
3146 assert_eq!(result.status, WorkflowRunStatus::Completed);
3147 let process_out = &result.step_results["process"].response;
3148 assert!(process_out.contains("process: red"));
3150 assert!(process_out.contains("process: green"));
3151 assert!(process_out.contains("process: blue"));
3152 }
3153
3154 #[tokio::test]
3156 async fn dag_partial_failure_parallel_branches() {
3157 let steps = vec![
3158 dag_step("root", &[]),
3159 dag_step("ok_branch", &["root"]),
3160 dag_step("fail_branch", &["root"]),
3161 dag_step("ok_branch2", &["root"]),
3162 ];
3163
3164 let executor = MockExecutor::new()
3165 .with_response("root", "start")
3166 .with_response("ok_branch", "success1")
3167 .with_failure("fail_branch", "branch failed")
3168 .with_response("ok_branch2", "success2");
3169
3170 let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
3171 assert_eq!(result.status, WorkflowRunStatus::PartiallyCompleted);
3172 assert_eq!(
3173 result.step_results["ok_branch"].status,
3174 StepStatus::Completed
3175 );
3176 assert_eq!(
3177 result.step_results["ok_branch2"].status,
3178 StepStatus::Completed
3179 );
3180 assert!(result.step_results["fail_branch"].error.is_some());
3181 }
3182
3183 #[tokio::test]
3185 async fn dag_fallback_step_runs_on_failure() {
3186 let mut steps = vec![
3187 dag_step("primary", &[]),
3188 dag_step("fallback_handler", &[]),
3189 dag_step("downstream", &["primary"]),
3190 ];
3191 steps[0].on_error = OnError::Fallback {
3192 step: "fallback_handler".to_string(),
3193 };
3194
3195 let executor = MockExecutor::new()
3196 .with_failure("primary", "primary broke")
3197 .with_response("fallback_handler", "recovered via fallback")
3198 .with_response("downstream", "downstream ran");
3199
3200 let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
3201 let primary_result = &result.step_results["primary"];
3203 assert_eq!(primary_result.response, "recovered via fallback");
3204 assert!(result.step_results.contains_key("downstream"));
3206 }
3207
3208 #[tokio::test]
3210 async fn dag_circuit_breaker_triggers() {
3211 let mut steps = vec![dag_step("cb_step", &[])];
3212 steps[0].on_error = OnError::CircuitBreaker {
3213 max_failures: 2,
3214 cooldown_secs: 300,
3215 };
3216
3217 let executor1 = MockExecutor::new().with_failure("cb_step", "fail1");
3219 let result1 = execute_dag("test", &steps, "input", Arc::new(executor1)).await;
3220 assert!(result1.step_results["cb_step"].error.is_some());
3221
3222 let mut cb = CircuitBreakerState::default();
3226 cb.record_failure();
3227 assert!(!cb.is_open(2, 300), "should not be open after 1 failure");
3228 cb.record_failure();
3229 assert!(cb.is_open(2, 300), "should be open after 2 failures");
3230 assert!(cb.is_open(2, 300));
3232 }
3233
3234 #[tokio::test]
3236 async fn dag_variable_substitution_across_parallel_branches() {
3237 let mut steps = vec![
3238 dag_step("source_a", &[]),
3239 dag_step("source_b", &[]),
3240 dag_step("consumer", &["source_a", "source_b"]),
3241 ];
3242 steps[2].prompt_template = "A={{source_a.output}}, B={{source_b.output}}".to_string();
3243
3244 let executor = MockExecutor::new()
3245 .with_response("source_a", "value_from_a")
3246 .with_response("source_b", "value_from_b");
3247 let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
3250 assert_eq!(result.status, WorkflowRunStatus::Completed);
3251 let consumer_out = &result.step_results["consumer"].response;
3252 assert!(
3253 consumer_out.contains("value_from_a"),
3254 "consumer should see source_a output, got: {consumer_out}"
3255 );
3256 assert!(
3257 consumer_out.contains("value_from_b"),
3258 "consumer should see source_b output, got: {consumer_out}"
3259 );
3260 }
3261
3262 #[tokio::test]
3264 async fn dag_wide_parallel_fan_out_timing() {
3265 let steps: Vec<DagWorkflowStep> =
3267 (0..10).map(|i| dag_step(&format!("s{i}"), &[])).collect();
3268 let executor = TimedMockExecutor { delay_ms: 30 };
3269
3270 let start = Instant::now();
3271 let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
3272 let elapsed = start.elapsed();
3273
3274 assert_eq!(result.status, WorkflowRunStatus::Completed);
3275 assert_eq!(result.step_results.len(), 10);
3276 assert!(
3278 elapsed.as_millis() < 80,
3279 "10 parallel 30ms steps took {}ms, expected ~30ms",
3280 elapsed.as_millis()
3281 );
3282 assert_eq!(result.execution_trace.len(), 1);
3283 assert_eq!(result.execution_trace[0].steps.len(), 10);
3284 }
3285
3286 #[tokio::test]
3288 async fn dag_while_loop_with_condition() {
3289 let mut steps = vec![dag_step("looper", &[])];
3290 steps[0].loop_config = Some(LoopConfig::While {
3291 condition: Condition::Expression("true".to_string()),
3292 max_iterations: 3,
3293 });
3294
3295 let executor = MockExecutor::new().with_response("looper", "iteration");
3296
3297 let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
3298 assert_eq!(result.status, WorkflowRunStatus::Completed);
3299 let output = &result.step_results["looper"].response;
3300 let lines: Vec<&str> = output.split('\n').collect();
3302 assert_eq!(lines.len(), 3);
3303 }
3304
3305 #[tokio::test]
3307 async fn dag_retry_succeeds_on_retry() {
3308 let mut steps = vec![dag_step("retry_step", &[])];
3309 steps[0].loop_config = Some(LoopConfig::Retry {
3310 max_retries: 2,
3311 backoff_ms: 1,
3312 backoff_multiplier: 1.0,
3313 });
3314
3315 let executor = FailNTimesMockExecutor::new(1);
3316
3317 let result = execute_dag("test", &steps, "input", Arc::new(executor)).await;
3318 assert_eq!(result.status, WorkflowRunStatus::Completed);
3319 assert!(result.step_results["retry_step"].error.is_none());
3320 assert!(
3321 result.step_results["retry_step"]
3322 .response
3323 .contains("success on attempt 2")
3324 );
3325 }
3326}