1use std::collections::{HashMap, HashSet, VecDeque};
7
8use serde::{Deserialize, Serialize};
9
10use crate::workflow::DagWorkflowStep;
11
12pub const MAX_DAG_DEPTH: usize = 100;
14
15pub const MAX_DAG_BREADTH: usize = 1000;
17
18#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
20#[serde(rename_all = "snake_case")]
21pub enum ValidationError {
22 CycleDetected { steps: Vec<String> },
24 MissingDependency { step: String, missing_dep: String },
26 UnreachableStep { step: String },
28 InvalidVariableRef { step: String, variable: String },
30 DuplicateStepName { name: String },
32 EmptyWorkflow,
34 ExceedsMaxDepth { depth: usize, limit: usize },
36 ExceedsMaxBreadth { breadth: usize, limit: usize },
38 InvalidElseStep { step: String, else_step: String },
40 InvalidFallbackStep { step: String, fallback: String },
42}
43
44impl std::fmt::Display for ValidationError {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 match self {
47 Self::CycleDetected { steps } => {
48 write!(f, "cycle detected involving steps: {}", steps.join(" -> "))
49 }
50 Self::MissingDependency { step, missing_dep } => {
51 write!(
52 f,
53 "step '{step}' depends on non-existent step '{missing_dep}'"
54 )
55 }
56 Self::UnreachableStep { step } => {
57 write!(f, "step '{step}' is unreachable from any root step")
58 }
59 Self::InvalidVariableRef { step, variable } => {
60 write!(f, "step '{step}' references unknown variable '{variable}'")
61 }
62 Self::DuplicateStepName { name } => {
63 write!(f, "duplicate step name: '{name}'")
64 }
65 Self::EmptyWorkflow => write!(f, "workflow has no steps"),
66 Self::ExceedsMaxDepth { depth, limit } => {
67 write!(f, "DAG depth {depth} exceeds limit {limit}")
68 }
69 Self::ExceedsMaxBreadth { breadth, limit } => {
70 write!(f, "workflow has {breadth} steps, exceeding limit {limit}")
71 }
72 Self::InvalidElseStep { step, else_step } => {
73 write!(
74 f,
75 "step '{step}' has else_step '{else_step}' which doesn't exist"
76 )
77 }
78 Self::InvalidFallbackStep { step, fallback } => {
79 write!(
80 f,
81 "step '{step}' has fallback '{fallback}' which doesn't exist"
82 )
83 }
84 }
85 }
86}
87
88pub fn validate_workflow(steps: &[DagWorkflowStep]) -> Vec<ValidationError> {
92 let mut errors = Vec::new();
93
94 if steps.is_empty() {
96 errors.push(ValidationError::EmptyWorkflow);
97 return errors;
98 }
99
100 if steps.len() > MAX_DAG_BREADTH {
102 errors.push(ValidationError::ExceedsMaxBreadth {
103 breadth: steps.len(),
104 limit: MAX_DAG_BREADTH,
105 });
106 }
107
108 let mut name_set: HashSet<&str> = HashSet::new();
110 for step in steps {
111 if !name_set.insert(&step.name) {
112 errors.push(ValidationError::DuplicateStepName {
113 name: step.name.clone(),
114 });
115 }
116 }
117
118 for step in steps {
120 for dep in &step.depends_on {
121 if !name_set.contains(dep.as_str()) {
122 errors.push(ValidationError::MissingDependency {
123 step: step.name.clone(),
124 missing_dep: dep.clone(),
125 });
126 }
127 }
128 }
129
130 for step in steps {
132 if let Some(ref else_step) = step.else_step
133 && !name_set.contains(else_step.as_str())
134 {
135 errors.push(ValidationError::InvalidElseStep {
136 step: step.name.clone(),
137 else_step: else_step.clone(),
138 });
139 }
140 }
141
142 for step in steps {
144 if let Some(ref fallback) = step.fallback_step()
145 && !name_set.contains(fallback.as_str())
146 {
147 errors.push(ValidationError::InvalidFallbackStep {
148 step: step.name.clone(),
149 fallback: fallback.clone(),
150 });
151 }
152 }
153
154 let cycle_result = topological_sort(steps);
156 match cycle_result {
157 Ok(sorted) => {
158 let depth = compute_dag_depth(steps, &sorted);
160 if depth > MAX_DAG_DEPTH {
161 errors.push(ValidationError::ExceedsMaxDepth {
162 depth,
163 limit: MAX_DAG_DEPTH,
164 });
165 }
166
167 let reachable = find_reachable_steps(steps);
169 for step in steps {
170 if !reachable.contains(step.name.as_str()) {
171 errors.push(ValidationError::UnreachableStep {
172 step: step.name.clone(),
173 });
174 }
175 }
176 }
177 Err(cycle_steps) => {
178 errors.push(ValidationError::CycleDetected { steps: cycle_steps });
179 }
180 }
181
182 errors.extend(validate_variable_refs(steps, &name_set));
184
185 errors
186}
187
188pub fn topological_sort(steps: &[DagWorkflowStep]) -> Result<Vec<String>, Vec<String>> {
192 let mut in_degree: HashMap<&str, usize> = HashMap::new();
193 let mut adjacency: HashMap<&str, Vec<&str>> = HashMap::new();
194
195 for step in steps {
196 in_degree.entry(&step.name).or_insert(0);
197 adjacency.entry(&step.name).or_default();
198 for dep in &step.depends_on {
199 if let Some(dep_step) = steps.iter().find(|s| s.name == *dep) {
200 adjacency
201 .entry(&dep_step.name)
202 .or_default()
203 .push(&step.name);
204 *in_degree.entry(&step.name).or_insert(0) += 1;
205 }
206 }
207 }
208
209 let mut queue: VecDeque<&str> = in_degree
210 .iter()
211 .filter(|&(_, °)| deg == 0)
212 .map(|(&name, _)| name)
213 .collect();
214
215 let mut sorted = Vec::new();
216
217 while let Some(node) = queue.pop_front() {
218 sorted.push(node.to_string());
219 if let Some(neighbors) = adjacency.get(node) {
220 for &neighbor in neighbors {
221 if let Some(deg) = in_degree.get_mut(neighbor) {
222 *deg -= 1;
223 if *deg == 0 {
224 queue.push_back(neighbor);
225 }
226 }
227 }
228 }
229 }
230
231 if sorted.len() == steps.len() {
232 Ok(sorted)
233 } else {
234 let sorted_set: HashSet<&str> = sorted.iter().map(|s| s.as_str()).collect();
236 let cycle_nodes: Vec<String> = steps
237 .iter()
238 .filter(|s| !sorted_set.contains(s.name.as_str()))
239 .map(|s| s.name.clone())
240 .collect();
241 Err(cycle_nodes)
242 }
243}
244
245fn compute_dag_depth(steps: &[DagWorkflowStep], topo_order: &[String]) -> usize {
247 let mut depth: HashMap<&str, usize> = HashMap::new();
248
249 for name in topo_order {
250 let step = steps.iter().find(|s| s.name == *name);
251 let max_dep_depth = match step {
252 Some(s) => s
253 .depends_on
254 .iter()
255 .filter_map(|d| depth.get(d.as_str()))
256 .copied()
257 .max()
258 .unwrap_or(0),
259 None => 0,
260 };
261 depth.insert(name, max_dep_depth + 1);
262 }
263
264 depth.values().copied().max().unwrap_or(0)
265}
266
267fn find_reachable_steps(steps: &[DagWorkflowStep]) -> HashSet<&str> {
269 let step_map: HashMap<&str, &DagWorkflowStep> =
270 steps.iter().map(|s| (s.name.as_str(), s)).collect();
271
272 let mut forward: HashMap<&str, Vec<&str>> = HashMap::new();
274 for step in steps {
275 forward.entry(&step.name).or_default();
276 for dep in &step.depends_on {
277 forward.entry(dep.as_str()).or_default().push(&step.name);
278 }
279 }
280
281 let roots: Vec<&str> = steps
283 .iter()
284 .filter(|s| s.depends_on.is_empty())
285 .map(|s| s.name.as_str())
286 .collect();
287
288 let mut reachable: HashSet<&str> = HashSet::new();
289 let mut queue: VecDeque<&str> = roots.into_iter().collect();
290
291 while let Some(node) = queue.pop_front() {
292 if reachable.insert(node) {
293 if let Some(neighbors) = forward.get(node) {
294 for &n in neighbors {
295 if !reachable.contains(n) {
296 queue.push_back(n);
297 }
298 }
299 }
300 if let Some(step) = step_map.get(node)
302 && let Some(ref else_step) = step.else_step
303 && !reachable.contains(else_step.as_str())
304 {
305 queue.push_back(else_step);
306 }
307 }
308 }
309
310 reachable
311}
312
313fn validate_variable_refs(
315 steps: &[DagWorkflowStep],
316 name_set: &HashSet<&str>,
317) -> Vec<ValidationError> {
318 let mut errors = Vec::new();
319
320 for step in steps {
321 let template = &step.prompt_template;
322 let mut pos = 0;
324 while let Some(start) = template[pos..].find("{{") {
325 let abs_start = pos + start + 2;
326 if let Some(end) = template[abs_start..].find("}}") {
327 let var_content = &template[abs_start..abs_start + end];
328 if let Some(dot_pos) = var_content.find('.') {
330 let ref_step = &var_content[..dot_pos];
331 if ref_step != "loop" && ref_step != "step" && !name_set.contains(ref_step) {
333 errors.push(ValidationError::InvalidVariableRef {
334 step: step.name.clone(),
335 variable: var_content.to_string(),
336 });
337 }
338 }
339 pos = abs_start + end + 2;
340 } else {
341 break;
342 }
343 }
344 }
345
346 errors
347}
348
349#[cfg(test)]
354mod tests {
355 use super::*;
356 use crate::workflow::{DagWorkflowStep, OnError};
357
358 fn step(name: &str, deps: &[&str]) -> DagWorkflowStep {
359 DagWorkflowStep {
360 name: name.to_string(),
361 fighter_name: "test".to_string(),
362 prompt_template: "{{input}}".to_string(),
363 timeout_secs: None,
364 on_error: OnError::FailWorkflow,
365 depends_on: deps.iter().map(|d| d.to_string()).collect(),
366 condition: None,
367 else_step: None,
368 loop_config: None,
369 }
370 }
371
372 #[test]
373 fn validate_empty_workflow() {
374 let errors = validate_workflow(&[]);
375 assert_eq!(errors.len(), 1);
376 assert!(matches!(errors[0], ValidationError::EmptyWorkflow));
377 }
378
379 #[test]
380 fn validate_single_step() {
381 let steps = vec![step("root", &[])];
382 let errors = validate_workflow(&steps);
383 assert!(errors.is_empty(), "errors: {errors:?}");
384 }
385
386 #[test]
387 fn validate_linear_chain() {
388 let steps = vec![step("a", &[]), step("b", &["a"]), step("c", &["b"])];
389 let errors = validate_workflow(&steps);
390 assert!(errors.is_empty(), "errors: {errors:?}");
391 }
392
393 #[test]
394 fn validate_fan_out() {
395 let steps = vec![
396 step("root", &[]),
397 step("b1", &["root"]),
398 step("b2", &["root"]),
399 step("b3", &["root"]),
400 ];
401 let errors = validate_workflow(&steps);
402 assert!(errors.is_empty(), "errors: {errors:?}");
403 }
404
405 #[test]
406 fn validate_fan_in() {
407 let steps = vec![
408 step("a", &[]),
409 step("b", &[]),
410 step("c", &[]),
411 step("join", &["a", "b", "c"]),
412 ];
413 let errors = validate_workflow(&steps);
414 assert!(errors.is_empty(), "errors: {errors:?}");
415 }
416
417 #[test]
418 fn validate_diamond() {
419 let steps = vec![
420 step("root", &[]),
421 step("left", &["root"]),
422 step("right", &["root"]),
423 step("join", &["left", "right"]),
424 ];
425 let errors = validate_workflow(&steps);
426 assert!(errors.is_empty(), "errors: {errors:?}");
427 }
428
429 #[test]
430 fn detect_cycle_simple() {
431 let steps = vec![step("a", &["b"]), step("b", &["a"])];
432 let errors = validate_workflow(&steps);
433 assert!(errors
434 .iter()
435 .any(|e| matches!(e, ValidationError::CycleDetected { .. })));
436 }
437
438 #[test]
439 fn detect_cycle_three_way() {
440 let steps = vec![step("a", &["c"]), step("b", &["a"]), step("c", &["b"])];
441 let errors = validate_workflow(&steps);
442 assert!(errors
443 .iter()
444 .any(|e| matches!(e, ValidationError::CycleDetected { .. })));
445 }
446
447 #[test]
448 fn detect_missing_dependency() {
449 let steps = vec![step("a", &[]), step("b", &["nonexistent"])];
450 let errors = validate_workflow(&steps);
451 assert!(errors.iter().any(|e| matches!(
452 e,
453 ValidationError::MissingDependency {
454 step,
455 missing_dep
456 } if step == "b" && missing_dep == "nonexistent"
457 )));
458 }
459
460 #[test]
461 fn detect_duplicate_step_name() {
462 let steps = vec![step("dup", &[]), step("dup", &[])];
463 let errors = validate_workflow(&steps);
464 assert!(errors
465 .iter()
466 .any(|e| matches!(e, ValidationError::DuplicateStepName { name } if name == "dup")));
467 }
468
469 #[test]
470 fn detect_invalid_variable_ref() {
471 let mut steps = vec![step("a", &[])];
472 steps[0].prompt_template = "Use {{nonexistent.output}}".to_string();
473 let errors = validate_workflow(&steps);
474 assert!(errors
475 .iter()
476 .any(|e| matches!(e, ValidationError::InvalidVariableRef { .. })));
477 }
478
479 #[test]
480 fn valid_variable_ref_not_flagged() {
481 let mut steps = vec![step("a", &[]), step("b", &["a"])];
482 steps[1].prompt_template = "Use {{a.output}}".to_string();
483 let errors = validate_workflow(&steps);
484 assert!(
485 errors.is_empty(),
486 "should not flag valid refs, got: {errors:?}"
487 );
488 }
489
490 #[test]
491 fn loop_variable_not_flagged() {
492 let mut steps = vec![step("a", &[])];
493 steps[0].prompt_template = "Item {{loop.item}} at {{loop.index}}".to_string();
494 let errors = validate_workflow(&steps);
495 assert!(errors.is_empty(), "loop vars should be ignored: {errors:?}");
496 }
497
498 #[test]
499 fn topological_sort_linear() {
500 let steps = vec![step("a", &[]), step("b", &["a"]), step("c", &["b"])];
501 let sorted = topological_sort(&steps).expect("should sort");
502 let a_pos = sorted.iter().position(|s| s == "a").expect("a");
503 let b_pos = sorted.iter().position(|s| s == "b").expect("b");
504 let c_pos = sorted.iter().position(|s| s == "c").expect("c");
505 assert!(a_pos < b_pos);
506 assert!(b_pos < c_pos);
507 }
508
509 #[test]
510 fn topological_sort_diamond() {
511 let steps = vec![
512 step("root", &[]),
513 step("left", &["root"]),
514 step("right", &["root"]),
515 step("join", &["left", "right"]),
516 ];
517 let sorted = topological_sort(&steps).expect("should sort");
518 let root_pos = sorted.iter().position(|s| s == "root").expect("root");
519 let left_pos = sorted.iter().position(|s| s == "left").expect("left");
520 let right_pos = sorted.iter().position(|s| s == "right").expect("right");
521 let join_pos = sorted.iter().position(|s| s == "join").expect("join");
522 assert!(root_pos < left_pos);
523 assert!(root_pos < right_pos);
524 assert!(left_pos < join_pos);
525 assert!(right_pos < join_pos);
526 }
527
528 #[test]
529 fn topological_sort_cycle_returns_err() {
530 let steps = vec![step("a", &["b"]), step("b", &["a"])];
531 let result = topological_sort(&steps);
532 assert!(result.is_err());
533 let cycle = result.expect_err("cycle");
534 assert!(cycle.contains(&"a".to_string()));
535 assert!(cycle.contains(&"b".to_string()));
536 }
537
538 #[test]
539 fn validation_error_display() {
540 let err = ValidationError::CycleDetected {
541 steps: vec!["a".to_string(), "b".to_string()],
542 };
543 let display = format!("{err}");
544 assert!(display.contains("cycle detected"));
545 assert!(display.contains("a -> b"));
546 }
547
548 #[test]
549 fn validation_error_serialization() {
550 let err = ValidationError::MissingDependency {
551 step: "s1".to_string(),
552 missing_dep: "s2".to_string(),
553 };
554 let json = serde_json::to_string(&err).expect("serialize");
555 let deser: ValidationError = serde_json::from_str(&json).expect("deserialize");
556 assert_eq!(err, deser);
557 }
558
559 #[test]
560 fn detect_invalid_else_step() {
561 let mut steps = vec![step("a", &[])];
562 steps[0].else_step = Some("nonexistent".to_string());
563 let errors = validate_workflow(&steps);
564 assert!(errors
565 .iter()
566 .any(|e| matches!(e, ValidationError::InvalidElseStep { .. })));
567 }
568
569 #[test]
570 fn valid_else_step_not_flagged() {
571 let mut steps = vec![step("a", &[]), step("b", &[])];
572 steps[0].else_step = Some("b".to_string());
573 let errors = validate_workflow(&steps);
574 assert!(
575 !errors
576 .iter()
577 .any(|e| matches!(e, ValidationError::InvalidElseStep { .. })),
578 "valid else_step should not be flagged: {errors:?}"
579 );
580 }
581}