Skip to main content

punch_kernel/
workflow_validation.rs

1//! Pre-execution validation for workflow DAGs.
2//!
3//! Performs cycle detection, unreachable step detection, missing dependency
4//! detection, variable reference validation, and depth/breadth limits.
5
6use std::collections::{HashMap, HashSet, VecDeque};
7
8use serde::{Deserialize, Serialize};
9
10use crate::workflow::DagWorkflowStep;
11
12/// Maximum allowed depth of a DAG (longest path from any root to any leaf).
13pub const MAX_DAG_DEPTH: usize = 100;
14
15/// Maximum allowed number of steps in a workflow.
16pub const MAX_DAG_BREADTH: usize = 1000;
17
18/// A validation error found in a workflow definition.
19#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
20#[serde(rename_all = "snake_case")]
21pub enum ValidationError {
22    /// The DAG contains a cycle involving these steps.
23    CycleDetected { steps: Vec<String> },
24    /// A step declares a dependency that doesn't exist.
25    MissingDependency { step: String, missing_dep: String },
26    /// A step is unreachable (no path from any root).
27    UnreachableStep { step: String },
28    /// A variable reference points to a non-existent step.
29    InvalidVariableRef { step: String, variable: String },
30    /// Duplicate step names found.
31    DuplicateStepName { name: String },
32    /// The workflow has no steps.
33    EmptyWorkflow,
34    /// The DAG exceeds the maximum depth limit.
35    ExceedsMaxDepth { depth: usize, limit: usize },
36    /// The DAG exceeds the maximum breadth limit.
37    ExceedsMaxBreadth { breadth: usize, limit: usize },
38    /// An else_step references a non-existent step.
39    InvalidElseStep { step: String, else_step: String },
40    /// A fallback step references a non-existent step.
41    InvalidFallbackStep { step: String, fallback: String },
42}
43
44impl std::fmt::Display for ValidationError {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        match self {
47            Self::CycleDetected { steps } => {
48                write!(f, "cycle detected involving steps: {}", steps.join(" -> "))
49            }
50            Self::MissingDependency { step, missing_dep } => {
51                write!(
52                    f,
53                    "step '{step}' depends on non-existent step '{missing_dep}'"
54                )
55            }
56            Self::UnreachableStep { step } => {
57                write!(f, "step '{step}' is unreachable from any root step")
58            }
59            Self::InvalidVariableRef { step, variable } => {
60                write!(f, "step '{step}' references unknown variable '{variable}'")
61            }
62            Self::DuplicateStepName { name } => {
63                write!(f, "duplicate step name: '{name}'")
64            }
65            Self::EmptyWorkflow => write!(f, "workflow has no steps"),
66            Self::ExceedsMaxDepth { depth, limit } => {
67                write!(f, "DAG depth {depth} exceeds limit {limit}")
68            }
69            Self::ExceedsMaxBreadth { breadth, limit } => {
70                write!(f, "workflow has {breadth} steps, exceeding limit {limit}")
71            }
72            Self::InvalidElseStep { step, else_step } => {
73                write!(
74                    f,
75                    "step '{step}' has else_step '{else_step}' which doesn't exist"
76                )
77            }
78            Self::InvalidFallbackStep { step, fallback } => {
79                write!(
80                    f,
81                    "step '{step}' has fallback '{fallback}' which doesn't exist"
82                )
83            }
84        }
85    }
86}
87
88/// Validate a workflow DAG, returning a list of all errors found.
89///
90/// Returns an empty vec if the workflow is valid.
91pub fn validate_workflow(steps: &[DagWorkflowStep]) -> Vec<ValidationError> {
92    let mut errors = Vec::new();
93
94    // Empty workflow check
95    if steps.is_empty() {
96        errors.push(ValidationError::EmptyWorkflow);
97        return errors;
98    }
99
100    // Breadth check
101    if steps.len() > MAX_DAG_BREADTH {
102        errors.push(ValidationError::ExceedsMaxBreadth {
103            breadth: steps.len(),
104            limit: MAX_DAG_BREADTH,
105        });
106    }
107
108    // Build name set and check for duplicates
109    let mut name_set: HashSet<&str> = HashSet::new();
110    for step in steps {
111        if !name_set.insert(&step.name) {
112            errors.push(ValidationError::DuplicateStepName {
113                name: step.name.clone(),
114            });
115        }
116    }
117
118    // Missing dependency check
119    for step in steps {
120        for dep in &step.depends_on {
121            if !name_set.contains(dep.as_str()) {
122                errors.push(ValidationError::MissingDependency {
123                    step: step.name.clone(),
124                    missing_dep: dep.clone(),
125                });
126            }
127        }
128    }
129
130    // Else step check
131    for step in steps {
132        if let Some(ref else_step) = step.else_step
133            && !name_set.contains(else_step.as_str())
134        {
135            errors.push(ValidationError::InvalidElseStep {
136                step: step.name.clone(),
137                else_step: else_step.clone(),
138            });
139        }
140    }
141
142    // Fallback step check
143    for step in steps {
144        if let Some(ref fallback) = step.fallback_step()
145            && !name_set.contains(fallback.as_str())
146        {
147            errors.push(ValidationError::InvalidFallbackStep {
148                step: step.name.clone(),
149                fallback: fallback.clone(),
150            });
151        }
152    }
153
154    // Cycle detection via topological sort (Kahn's algorithm)
155    let cycle_result = topological_sort(steps);
156    match cycle_result {
157        Ok(sorted) => {
158            // Check depth
159            let depth = compute_dag_depth(steps, &sorted);
160            if depth > MAX_DAG_DEPTH {
161                errors.push(ValidationError::ExceedsMaxDepth {
162                    depth,
163                    limit: MAX_DAG_DEPTH,
164                });
165            }
166
167            // Unreachable step detection
168            let reachable = find_reachable_steps(steps);
169            for step in steps {
170                if !reachable.contains(step.name.as_str()) {
171                    errors.push(ValidationError::UnreachableStep {
172                        step: step.name.clone(),
173                    });
174                }
175            }
176        }
177        Err(cycle_steps) => {
178            errors.push(ValidationError::CycleDetected { steps: cycle_steps });
179        }
180    }
181
182    // Variable reference validation
183    errors.extend(validate_variable_refs(steps, &name_set));
184
185    errors
186}
187
188/// Perform topological sort using Kahn's algorithm.
189///
190/// Returns `Ok(sorted_names)` or `Err(cycle_participants)`.
191pub fn topological_sort(steps: &[DagWorkflowStep]) -> Result<Vec<String>, Vec<String>> {
192    let mut in_degree: HashMap<&str, usize> = HashMap::new();
193    let mut adjacency: HashMap<&str, Vec<&str>> = HashMap::new();
194
195    for step in steps {
196        in_degree.entry(&step.name).or_insert(0);
197        adjacency.entry(&step.name).or_default();
198        for dep in &step.depends_on {
199            if let Some(dep_step) = steps.iter().find(|s| s.name == *dep) {
200                adjacency
201                    .entry(&dep_step.name)
202                    .or_default()
203                    .push(&step.name);
204                *in_degree.entry(&step.name).or_insert(0) += 1;
205            }
206        }
207    }
208
209    let mut queue: VecDeque<&str> = in_degree
210        .iter()
211        .filter(|&(_, &deg)| deg == 0)
212        .map(|(&name, _)| name)
213        .collect();
214
215    let mut sorted = Vec::new();
216
217    while let Some(node) = queue.pop_front() {
218        sorted.push(node.to_string());
219        if let Some(neighbors) = adjacency.get(node) {
220            for &neighbor in neighbors {
221                if let Some(deg) = in_degree.get_mut(neighbor) {
222                    *deg -= 1;
223                    if *deg == 0 {
224                        queue.push_back(neighbor);
225                    }
226                }
227            }
228        }
229    }
230
231    if sorted.len() == steps.len() {
232        Ok(sorted)
233    } else {
234        // Find cycle participants: all nodes not in the sorted list
235        let sorted_set: HashSet<&str> = sorted.iter().map(|s| s.as_str()).collect();
236        let cycle_nodes: Vec<String> = steps
237            .iter()
238            .filter(|s| !sorted_set.contains(s.name.as_str()))
239            .map(|s| s.name.clone())
240            .collect();
241        Err(cycle_nodes)
242    }
243}
244
245/// Compute the longest path in the DAG (depth).
246fn compute_dag_depth(steps: &[DagWorkflowStep], topo_order: &[String]) -> usize {
247    let mut depth: HashMap<&str, usize> = HashMap::new();
248
249    for name in topo_order {
250        let step = steps.iter().find(|s| s.name == *name);
251        let max_dep_depth = match step {
252            Some(s) => s
253                .depends_on
254                .iter()
255                .filter_map(|d| depth.get(d.as_str()))
256                .copied()
257                .max()
258                .unwrap_or(0),
259            None => 0,
260        };
261        depth.insert(name, max_dep_depth + 1);
262    }
263
264    depth.values().copied().max().unwrap_or(0)
265}
266
267/// Find all steps reachable from root steps (those with no dependencies).
268fn find_reachable_steps(steps: &[DagWorkflowStep]) -> HashSet<&str> {
269    let step_map: HashMap<&str, &DagWorkflowStep> =
270        steps.iter().map(|s| (s.name.as_str(), s)).collect();
271
272    // Build forward adjacency (dep -> dependents)
273    let mut forward: HashMap<&str, Vec<&str>> = HashMap::new();
274    for step in steps {
275        forward.entry(&step.name).or_default();
276        for dep in &step.depends_on {
277            forward.entry(dep.as_str()).or_default().push(&step.name);
278        }
279    }
280
281    // Root steps have no dependencies
282    let roots: Vec<&str> = steps
283        .iter()
284        .filter(|s| s.depends_on.is_empty())
285        .map(|s| s.name.as_str())
286        .collect();
287
288    let mut reachable: HashSet<&str> = HashSet::new();
289    let mut queue: VecDeque<&str> = roots.into_iter().collect();
290
291    while let Some(node) = queue.pop_front() {
292        if reachable.insert(node) {
293            if let Some(neighbors) = forward.get(node) {
294                for &n in neighbors {
295                    if !reachable.contains(n) {
296                        queue.push_back(n);
297                    }
298                }
299            }
300            // Also follow else_step links
301            if let Some(step) = step_map.get(node)
302                && let Some(ref else_step) = step.else_step
303                && !reachable.contains(else_step.as_str())
304            {
305                queue.push_back(else_step);
306            }
307        }
308    }
309
310    reachable
311}
312
313/// Validate that variable references like `{{step_name.output}}` point to real steps.
314fn validate_variable_refs(
315    steps: &[DagWorkflowStep],
316    name_set: &HashSet<&str>,
317) -> Vec<ValidationError> {
318    let mut errors = Vec::new();
319
320    for step in steps {
321        let template = &step.prompt_template;
322        // Find all {{...}} patterns
323        let mut pos = 0;
324        while let Some(start) = template[pos..].find("{{") {
325            let abs_start = pos + start + 2;
326            if let Some(end) = template[abs_start..].find("}}") {
327                let var_content = &template[abs_start..abs_start + end];
328                // Check if it references a step output (step_name.output, step_name.status, etc.)
329                if let Some(dot_pos) = var_content.find('.') {
330                    let ref_step = &var_content[..dot_pos];
331                    // Skip built-in variables
332                    if ref_step != "loop" && ref_step != "step" && !name_set.contains(ref_step) {
333                        errors.push(ValidationError::InvalidVariableRef {
334                            step: step.name.clone(),
335                            variable: var_content.to_string(),
336                        });
337                    }
338                }
339                pos = abs_start + end + 2;
340            } else {
341                break;
342            }
343        }
344    }
345
346    errors
347}
348
349// ---------------------------------------------------------------------------
350// Tests
351// ---------------------------------------------------------------------------
352
353#[cfg(test)]
354mod tests {
355    use super::*;
356    use crate::workflow::{DagWorkflowStep, OnError};
357
358    fn step(name: &str, deps: &[&str]) -> DagWorkflowStep {
359        DagWorkflowStep {
360            name: name.to_string(),
361            fighter_name: "test".to_string(),
362            prompt_template: "{{input}}".to_string(),
363            timeout_secs: None,
364            on_error: OnError::FailWorkflow,
365            depends_on: deps.iter().map(|d| d.to_string()).collect(),
366            condition: None,
367            else_step: None,
368            loop_config: None,
369        }
370    }
371
372    #[test]
373    fn validate_empty_workflow() {
374        let errors = validate_workflow(&[]);
375        assert_eq!(errors.len(), 1);
376        assert!(matches!(errors[0], ValidationError::EmptyWorkflow));
377    }
378
379    #[test]
380    fn validate_single_step() {
381        let steps = vec![step("root", &[])];
382        let errors = validate_workflow(&steps);
383        assert!(errors.is_empty(), "errors: {errors:?}");
384    }
385
386    #[test]
387    fn validate_linear_chain() {
388        let steps = vec![step("a", &[]), step("b", &["a"]), step("c", &["b"])];
389        let errors = validate_workflow(&steps);
390        assert!(errors.is_empty(), "errors: {errors:?}");
391    }
392
393    #[test]
394    fn validate_fan_out() {
395        let steps = vec![
396            step("root", &[]),
397            step("b1", &["root"]),
398            step("b2", &["root"]),
399            step("b3", &["root"]),
400        ];
401        let errors = validate_workflow(&steps);
402        assert!(errors.is_empty(), "errors: {errors:?}");
403    }
404
405    #[test]
406    fn validate_fan_in() {
407        let steps = vec![
408            step("a", &[]),
409            step("b", &[]),
410            step("c", &[]),
411            step("join", &["a", "b", "c"]),
412        ];
413        let errors = validate_workflow(&steps);
414        assert!(errors.is_empty(), "errors: {errors:?}");
415    }
416
417    #[test]
418    fn validate_diamond() {
419        let steps = vec![
420            step("root", &[]),
421            step("left", &["root"]),
422            step("right", &["root"]),
423            step("join", &["left", "right"]),
424        ];
425        let errors = validate_workflow(&steps);
426        assert!(errors.is_empty(), "errors: {errors:?}");
427    }
428
429    #[test]
430    fn detect_cycle_simple() {
431        let steps = vec![step("a", &["b"]), step("b", &["a"])];
432        let errors = validate_workflow(&steps);
433        assert!(errors
434            .iter()
435            .any(|e| matches!(e, ValidationError::CycleDetected { .. })));
436    }
437
438    #[test]
439    fn detect_cycle_three_way() {
440        let steps = vec![step("a", &["c"]), step("b", &["a"]), step("c", &["b"])];
441        let errors = validate_workflow(&steps);
442        assert!(errors
443            .iter()
444            .any(|e| matches!(e, ValidationError::CycleDetected { .. })));
445    }
446
447    #[test]
448    fn detect_missing_dependency() {
449        let steps = vec![step("a", &[]), step("b", &["nonexistent"])];
450        let errors = validate_workflow(&steps);
451        assert!(errors.iter().any(|e| matches!(
452            e,
453            ValidationError::MissingDependency {
454                step,
455                missing_dep
456            } if step == "b" && missing_dep == "nonexistent"
457        )));
458    }
459
460    #[test]
461    fn detect_duplicate_step_name() {
462        let steps = vec![step("dup", &[]), step("dup", &[])];
463        let errors = validate_workflow(&steps);
464        assert!(errors
465            .iter()
466            .any(|e| matches!(e, ValidationError::DuplicateStepName { name } if name == "dup")));
467    }
468
469    #[test]
470    fn detect_invalid_variable_ref() {
471        let mut steps = vec![step("a", &[])];
472        steps[0].prompt_template = "Use {{nonexistent.output}}".to_string();
473        let errors = validate_workflow(&steps);
474        assert!(errors
475            .iter()
476            .any(|e| matches!(e, ValidationError::InvalidVariableRef { .. })));
477    }
478
479    #[test]
480    fn valid_variable_ref_not_flagged() {
481        let mut steps = vec![step("a", &[]), step("b", &["a"])];
482        steps[1].prompt_template = "Use {{a.output}}".to_string();
483        let errors = validate_workflow(&steps);
484        assert!(
485            errors.is_empty(),
486            "should not flag valid refs, got: {errors:?}"
487        );
488    }
489
490    #[test]
491    fn loop_variable_not_flagged() {
492        let mut steps = vec![step("a", &[])];
493        steps[0].prompt_template = "Item {{loop.item}} at {{loop.index}}".to_string();
494        let errors = validate_workflow(&steps);
495        assert!(errors.is_empty(), "loop vars should be ignored: {errors:?}");
496    }
497
498    #[test]
499    fn topological_sort_linear() {
500        let steps = vec![step("a", &[]), step("b", &["a"]), step("c", &["b"])];
501        let sorted = topological_sort(&steps).expect("should sort");
502        let a_pos = sorted.iter().position(|s| s == "a").expect("a");
503        let b_pos = sorted.iter().position(|s| s == "b").expect("b");
504        let c_pos = sorted.iter().position(|s| s == "c").expect("c");
505        assert!(a_pos < b_pos);
506        assert!(b_pos < c_pos);
507    }
508
509    #[test]
510    fn topological_sort_diamond() {
511        let steps = vec![
512            step("root", &[]),
513            step("left", &["root"]),
514            step("right", &["root"]),
515            step("join", &["left", "right"]),
516        ];
517        let sorted = topological_sort(&steps).expect("should sort");
518        let root_pos = sorted.iter().position(|s| s == "root").expect("root");
519        let left_pos = sorted.iter().position(|s| s == "left").expect("left");
520        let right_pos = sorted.iter().position(|s| s == "right").expect("right");
521        let join_pos = sorted.iter().position(|s| s == "join").expect("join");
522        assert!(root_pos < left_pos);
523        assert!(root_pos < right_pos);
524        assert!(left_pos < join_pos);
525        assert!(right_pos < join_pos);
526    }
527
528    #[test]
529    fn topological_sort_cycle_returns_err() {
530        let steps = vec![step("a", &["b"]), step("b", &["a"])];
531        let result = topological_sort(&steps);
532        assert!(result.is_err());
533        let cycle = result.expect_err("cycle");
534        assert!(cycle.contains(&"a".to_string()));
535        assert!(cycle.contains(&"b".to_string()));
536    }
537
538    #[test]
539    fn validation_error_display() {
540        let err = ValidationError::CycleDetected {
541            steps: vec!["a".to_string(), "b".to_string()],
542        };
543        let display = format!("{err}");
544        assert!(display.contains("cycle detected"));
545        assert!(display.contains("a -> b"));
546    }
547
548    #[test]
549    fn validation_error_serialization() {
550        let err = ValidationError::MissingDependency {
551            step: "s1".to_string(),
552            missing_dep: "s2".to_string(),
553        };
554        let json = serde_json::to_string(&err).expect("serialize");
555        let deser: ValidationError = serde_json::from_str(&json).expect("deserialize");
556        assert_eq!(err, deser);
557    }
558
559    #[test]
560    fn detect_invalid_else_step() {
561        let mut steps = vec![step("a", &[])];
562        steps[0].else_step = Some("nonexistent".to_string());
563        let errors = validate_workflow(&steps);
564        assert!(errors
565            .iter()
566            .any(|e| matches!(e, ValidationError::InvalidElseStep { .. })));
567    }
568
569    #[test]
570    fn valid_else_step_not_flagged() {
571        let mut steps = vec![step("a", &[]), step("b", &[])];
572        steps[0].else_step = Some("b".to_string());
573        let errors = validate_workflow(&steps);
574        assert!(
575            !errors
576                .iter()
577                .any(|e| matches!(e, ValidationError::InvalidElseStep { .. })),
578            "valid else_step should not be flagged: {errors:?}"
579        );
580    }
581}