1use std::collections::HashMap;
2use std::path::Path;
3
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct WorkflowDef {
13 pub name: String,
14 #[serde(default)]
15 pub title: Option<String>,
16 pub description: String,
17 pub trigger: WorkflowTrigger,
18 #[serde(default)]
19 pub targets: Vec<String>,
20 #[serde(default)]
21 pub group: Option<String>,
22 pub inputs: Vec<InputDecl>,
23 pub body: Vec<WorkflowNode>,
24 pub always: Vec<WorkflowNode>,
25 pub source_path: String,
26}
27
28impl WorkflowDef {
29 pub fn display_name(&self) -> &str {
32 self.title.as_deref().unwrap_or(&self.name)
33 }
34
35 pub fn total_nodes(&self) -> usize {
37 count_nodes(&self.body) + count_nodes(&self.always)
38 }
39
40 pub fn top_level_steps(&self) -> usize {
43 self.body.len() + self.always.len()
44 }
45
46 pub fn max_iterations_for_step(&self, step_name: &str) -> Option<u32> {
50 fn search(nodes: &[WorkflowNode], name: &str) -> Option<u32> {
51 for node in nodes {
52 match node {
53 WorkflowNode::DoWhile(n) => {
54 if n.step == name {
55 return Some(n.max_iterations);
56 }
57 if let Some(v) = search(&n.body, name) {
58 return Some(v);
59 }
60 }
61 WorkflowNode::While(n) => {
62 if n.step == name {
63 return Some(n.max_iterations);
64 }
65 if let Some(v) = search(&n.body, name) {
66 return Some(v);
67 }
68 }
69 _ => {
70 if let Some(body) = node.body() {
71 if let Some(v) = search(body, name) {
72 return Some(v);
73 }
74 }
75 }
76 }
77 }
78 None
79 }
80 search(&self.body, step_name).or_else(|| search(&self.always, step_name))
81 }
82
83 pub fn collect_all_snippet_refs(&self) -> Vec<String> {
85 let mut refs = collect_snippet_refs(&self.body);
86 refs.extend(collect_snippet_refs(&self.always));
87 refs.sort();
88 refs.dedup();
89 refs
90 }
91
92 pub fn collect_all_schema_refs(&self) -> Vec<String> {
94 let mut refs = collect_schema_refs(&self.body);
95 refs.extend(collect_schema_refs(&self.always));
96 refs.sort();
97 refs.dedup();
98 refs
99 }
100
101 pub fn collect_all_agent_refs(&self) -> Vec<AgentRef> {
103 let mut refs = collect_agent_names(&self.body);
104 refs.extend(collect_agent_names(&self.always));
105 refs.sort();
106 refs.dedup();
107 refs
108 }
109
110 pub fn collect_all_as_identities(&self) -> Vec<String> {
112 let mut names = collect_as_identities(&self.body);
113 names.extend(collect_as_identities(&self.always));
114 names.sort();
115 names.dedup();
116 names
117 }
118
119 pub fn collect_all_plugin_dirs(&self) -> Vec<String> {
121 let mut dirs = collect_plugin_dirs(&self.body);
122 dirs.extend(collect_plugin_dirs(&self.always));
123 dirs.sort();
124 dirs.dedup();
125 dirs
126 }
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct WorkflowWarning {
132 pub file: String,
134 pub message: String,
136}
137
138#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
140#[serde(rename_all = "snake_case")]
141pub enum WorkflowTrigger {
142 Manual,
143 Pr,
144 Scheduled,
145}
146
147impl std::fmt::Display for WorkflowTrigger {
148 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149 match self {
150 Self::Manual => write!(f, "manual"),
151 Self::Pr => write!(f, "pr"),
152 Self::Scheduled => write!(f, "scheduled"),
153 }
154 }
155}
156
157impl std::str::FromStr for WorkflowTrigger {
158 type Err = String;
159 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
160 match s {
161 "manual" => Ok(Self::Manual),
162 "pr" => Ok(Self::Pr),
163 "scheduled" => Ok(Self::Scheduled),
164 _ => Err(format!("unknown trigger: {s}")),
165 }
166 }
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
171#[serde(rename_all = "snake_case")]
172pub enum InputType {
173 #[default]
174 String,
175 Boolean,
176}
177
178#[derive(Debug, Clone, Serialize, Deserialize)]
180pub struct InputDecl {
181 pub name: String,
182 pub required: bool,
183 pub default: Option<String>,
184 pub description: Option<String>,
185 #[serde(default)]
186 pub input_type: InputType,
187}
188
189#[derive(Debug, Clone, Serialize, Deserialize)]
191#[serde(tag = "type", rename_all = "snake_case")]
192pub enum WorkflowNode {
193 Call(CallNode),
194 CallWorkflow(CallWorkflowNode),
195 If(IfNode),
196 Unless(UnlessNode),
197 While(WhileNode),
198 DoWhile(DoWhileNode),
199 Do(DoNode),
200 Parallel(ParallelNode),
201 Gate(GateNode),
202 Always(AlwaysNode),
203 Script(ScriptNode),
204 ForEach(ForEachNode),
205}
206
207impl WorkflowNode {
208 pub fn body(&self) -> Option<&[WorkflowNode]> {
210 match self {
211 WorkflowNode::If(n) => Some(&n.body),
212 WorkflowNode::Unless(n) => Some(&n.body),
213 WorkflowNode::While(n) => Some(&n.body),
214 WorkflowNode::DoWhile(n) => Some(&n.body),
215 WorkflowNode::Do(n) => Some(&n.body),
216 WorkflowNode::Always(n) => Some(&n.body),
217 _ => None,
218 }
219 }
220}
221
222#[derive(Debug, Clone, Serialize, Deserialize)]
224pub struct ForEachNode {
225 pub name: String,
227 pub over: ForeachOver,
229 pub scope: Option<HashMap<String, String>>,
231 #[serde(default)]
233 pub filter: HashMap<String, String>,
234 pub ordered: bool,
236 pub on_cycle: OnCycle,
238 pub max_parallel: u32,
240 pub workflow: String,
242 #[serde(default)]
245 pub inputs: HashMap<String, String>,
246 pub on_child_fail: OnChildFail,
248}
249
250pub type ForeachOver = String;
252
253#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
255#[serde(rename_all = "snake_case")]
256pub enum OnChildFail {
257 Halt,
259 Continue,
261 SkipDependents,
263}
264
265#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
267#[serde(rename_all = "snake_case")]
268pub enum OnCycle {
269 Fail,
271 Warn,
273}
274
275#[derive(Debug, Clone, Serialize, Deserialize)]
277pub struct ScriptNode {
278 pub name: String,
280 pub run: String,
283 #[serde(default)]
285 pub env: HashMap<String, String>,
286 pub timeout: Option<u64>,
289 #[serde(default)]
291 pub retries: u32,
292 pub on_fail: Option<OnFail>,
294 pub as_identity: Option<String>,
298}
299
300#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
305#[serde(tag = "kind", content = "value", rename_all = "snake_case")]
306pub enum OnFail {
307 Agent(AgentRef),
308 Continue,
309}
310
311#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
317#[serde(tag = "kind", content = "value", rename_all = "snake_case")]
318pub enum AgentRef {
319 Name(String),
320 Path(String),
321}
322
323impl AgentRef {
324 pub fn label(&self) -> &str {
326 match self {
327 Self::Name(s) | Self::Path(s) => s.as_str(),
328 }
329 }
330
331 pub fn step_key(&self) -> String {
338 match self {
339 Self::Name(s) => s.clone(),
340 Self::Path(s) => Path::new(s)
341 .file_stem()
342 .and_then(|stem| stem.to_str())
343 .unwrap_or(s.as_str())
344 .to_string(),
345 }
346 }
347}
348
349impl std::fmt::Display for AgentRef {
350 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
351 write!(f, "{}", self.label())
352 }
353}
354
355#[derive(Debug, Clone, Serialize, Deserialize)]
356pub struct CallNode {
357 pub agent: AgentRef,
358 #[serde(default)]
359 pub retries: u32,
360 pub on_fail: Option<OnFail>,
361 pub output: Option<String>,
363 #[serde(default)]
365 pub with: Vec<String>,
366 pub as_identity: Option<String>,
368 #[serde(default)]
372 pub plugin_dirs: Vec<String>,
373 pub timeout: Option<String>,
376 #[serde(default)]
379 pub max_turns: Option<u32>,
380}
381
382#[derive(Debug, Clone, Serialize, Deserialize)]
384pub struct CallWorkflowNode {
385 pub workflow: String,
386 #[serde(default)]
387 pub inputs: HashMap<String, String>,
388 #[serde(default)]
389 pub retries: u32,
390 pub on_fail: Option<OnFail>,
391 pub as_identity: Option<String>,
393}
394
395#[derive(Debug, Clone, Serialize, Deserialize)]
397#[serde(tag = "kind", rename_all = "snake_case")]
398pub enum Condition {
399 StepMarker { step: String, marker: String },
401 BoolInput { input: String },
403}
404
405#[derive(Debug, Clone, Serialize, Deserialize)]
406pub struct IfNode {
407 pub condition: Condition,
408 pub body: Vec<WorkflowNode>,
409}
410
411#[derive(Debug, Clone, Serialize, Deserialize)]
412pub struct UnlessNode {
413 pub condition: Condition,
414 pub body: Vec<WorkflowNode>,
415}
416
417#[derive(Debug, Clone, Serialize, Deserialize)]
418pub struct WhileNode {
419 pub step: String,
420 pub marker: String,
421 pub max_iterations: u32,
422 pub stuck_after: Option<u32>,
423 pub on_max_iter: OnMaxIter,
424 pub body: Vec<WorkflowNode>,
425}
426
427#[derive(Debug, Clone, Serialize, Deserialize)]
428pub struct DoWhileNode {
429 pub step: String,
430 pub marker: String,
431 pub max_iterations: u32,
432 pub stuck_after: Option<u32>,
433 pub on_max_iter: OnMaxIter,
434 pub body: Vec<WorkflowNode>,
435}
436
437#[derive(Debug, Clone, Serialize, Deserialize)]
439pub struct DoNode {
440 pub output: Option<String>,
442 #[serde(default)]
444 pub with: Vec<String>,
445 pub body: Vec<WorkflowNode>,
446}
447
448#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
449#[serde(rename_all = "snake_case")]
450pub enum OnMaxIter {
451 Fail,
452 Continue,
453}
454
455#[derive(Debug, Clone, Serialize, Deserialize)]
456pub struct ParallelNode {
457 #[serde(default = "default_true")]
458 pub fail_fast: bool,
459 pub min_success: Option<u32>,
460 pub calls: Vec<AgentRef>,
461 pub output: Option<String>,
463 #[serde(default)]
467 pub call_outputs: HashMap<String, String>,
468 #[serde(default)]
470 pub with: Vec<String>,
471 #[serde(default)]
473 pub call_with: HashMap<String, Vec<String>>,
474 #[serde(default)]
477 pub call_if: HashMap<String, (String, String)>,
478 #[serde(default)]
480 pub call_retries: HashMap<String, u32>,
481}
482
483fn default_true() -> bool {
484 true
485}
486
487#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
488#[serde(rename_all = "snake_case")]
489pub enum ApprovalMode {
490 #[default]
491 MinApprovals,
492 ReviewDecision,
493}
494
495#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
496#[serde(rename_all = "snake_case")]
497pub enum OnFailAction {
498 Fail,
499 Continue,
500}
501
502#[derive(Debug, Clone, Serialize, Deserialize)]
507pub struct QualityGateConfig {
508 pub source: String,
510 pub threshold: u32,
512 #[serde(default = "default_on_fail")]
514 pub on_fail_action: OnFailAction,
515}
516
517fn default_on_fail() -> OnFailAction {
518 OnFailAction::Fail
519}
520
521#[derive(Debug, Clone, Serialize, Deserialize)]
527#[serde(untagged)]
528pub enum GateOptions {
529 Static(HashMap<String, String>),
530 StepRef(String),
532}
533
534#[derive(Debug, Clone, Serialize, Deserialize)]
535pub struct GateNode {
536 pub name: String,
537 pub gate_type: String,
538 pub prompt: Option<String>,
539 #[serde(default = "default_one")]
540 pub min_approvals: u32,
541 #[serde(default)]
542 pub approval_mode: ApprovalMode,
543 pub timeout_secs: u64,
544 pub on_timeout: OnTimeout,
545 pub as_identity: Option<String>,
547 #[serde(flatten)]
549 pub quality_gate: Option<QualityGateConfig>,
550 pub options: Option<GateOptions>,
552}
553
554fn default_one() -> u32 {
555 1
556}
557
558pub const QUALITY_GATE_TYPE: &str = "quality_gate";
559
560#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
561#[serde(rename_all = "snake_case")]
562pub enum OnTimeout {
563 Fail,
564 Continue,
565}
566
567#[derive(Debug, Clone, Serialize, Deserialize)]
568pub struct AlwaysNode {
569 pub body: Vec<WorkflowNode>,
570}
571
572pub(crate) fn count_nodes(nodes: &[WorkflowNode]) -> usize {
578 let mut count = 0;
579 for node in nodes {
580 count += 1;
581 match node {
582 WorkflowNode::Parallel(n) => count += n.calls.len(),
583 _ => {
584 if let Some(body) = node.body() {
585 count += count_nodes(body);
586 }
587 }
588 }
589 }
590 count
591}
592
593pub fn collect_agent_names(nodes: &[WorkflowNode]) -> Vec<AgentRef> {
595 let mut refs = Vec::new();
596 for node in nodes {
597 match node {
598 WorkflowNode::Call(n) => {
599 refs.push(n.agent.clone());
600 if let Some(OnFail::Agent(ref a)) = n.on_fail {
601 refs.push(a.clone());
602 }
603 }
604 WorkflowNode::CallWorkflow(n) => {
605 if let Some(OnFail::Agent(ref a)) = n.on_fail {
606 refs.push(a.clone());
607 }
608 }
609 WorkflowNode::Script(n) => {
610 if let Some(OnFail::Agent(ref a)) = n.on_fail {
611 refs.push(a.clone());
612 }
613 }
614 WorkflowNode::Parallel(n) => refs.extend(n.calls.iter().cloned()),
615 _ => {
616 if let Some(body) = node.body() {
617 refs.extend(collect_agent_names(body));
618 }
619 }
620 }
621 }
622 refs
623}
624
625pub(crate) fn collect_snippet_refs(nodes: &[WorkflowNode]) -> Vec<String> {
627 let mut refs = Vec::new();
628 for node in nodes {
629 match node {
630 WorkflowNode::Call(n) => refs.extend(n.with.iter().cloned()),
631 WorkflowNode::Parallel(n) => {
632 refs.extend(n.with.iter().cloned());
633 for extra in n.call_with.values() {
634 refs.extend(extra.iter().cloned());
635 }
636 }
637 WorkflowNode::Do(n) => {
638 refs.extend(n.with.iter().cloned());
639 refs.extend(collect_snippet_refs(&n.body));
640 }
641 _ => {
642 if let Some(body) = node.body() {
643 refs.extend(collect_snippet_refs(body));
644 }
645 }
646 }
647 }
648 refs
649}
650
651pub fn collect_workflow_refs(nodes: &[WorkflowNode]) -> Vec<String> {
653 let mut refs = Vec::new();
654 for node in nodes {
655 match node {
656 WorkflowNode::Call(_) | WorkflowNode::Gate(_) | WorkflowNode::Script(_) => {}
657 WorkflowNode::CallWorkflow(n) => refs.push(n.workflow.clone()),
658 WorkflowNode::If(n) => refs.extend(collect_workflow_refs(&n.body)),
659 WorkflowNode::Unless(n) => refs.extend(collect_workflow_refs(&n.body)),
660 WorkflowNode::While(n) => refs.extend(collect_workflow_refs(&n.body)),
661 WorkflowNode::DoWhile(n) => refs.extend(collect_workflow_refs(&n.body)),
662 WorkflowNode::Do(n) => refs.extend(collect_workflow_refs(&n.body)),
663 WorkflowNode::Parallel(_) => {} WorkflowNode::Always(n) => refs.extend(collect_workflow_refs(&n.body)),
665 WorkflowNode::ForEach(n) => refs.push(n.workflow.clone()),
667 }
668 }
669 refs
670}
671
672pub(crate) fn collect_schema_refs(nodes: &[WorkflowNode]) -> Vec<String> {
674 let mut refs = Vec::new();
675 for node in nodes {
676 match node {
677 WorkflowNode::Call(n) => {
678 if let Some(ref s) = n.output {
679 refs.push(s.clone());
680 }
681 }
682 WorkflowNode::Do(n) => {
683 if let Some(ref s) = n.output {
684 refs.push(s.clone());
685 }
686 refs.extend(collect_schema_refs(&n.body));
687 }
688 WorkflowNode::Parallel(n) => {
689 if let Some(ref s) = n.output {
690 refs.push(s.clone());
691 }
692 refs.extend(n.call_outputs.values().cloned());
693 }
694 _ => {
695 if let Some(body) = node.body() {
696 refs.extend(collect_schema_refs(body));
697 }
698 }
699 }
700 }
701 refs
702}
703
704pub(crate) fn collect_as_identities(nodes: &[WorkflowNode]) -> Vec<String> {
706 let mut names = Vec::new();
707 for node in nodes {
708 match node {
709 WorkflowNode::Call(n) => {
710 if let Some(ref b) = n.as_identity {
711 names.push(b.clone());
712 }
713 }
714 WorkflowNode::CallWorkflow(n) => {
715 if let Some(ref b) = n.as_identity {
716 names.push(b.clone());
717 }
718 }
719 WorkflowNode::Gate(n) => {
720 if let Some(ref b) = n.as_identity {
721 names.push(b.clone());
722 }
723 }
724 WorkflowNode::Script(n) => {
725 if let Some(ref b) = n.as_identity {
726 names.push(b.clone());
727 }
728 }
729 _ => {
730 if let Some(body) = node.body() {
731 names.extend(collect_as_identities(body));
732 }
733 }
734 }
735 }
736 names
737}
738
739pub(crate) fn collect_plugin_dirs(nodes: &[WorkflowNode]) -> Vec<String> {
741 let mut dirs = Vec::new();
742 for node in nodes {
743 match node {
744 WorkflowNode::Call(n) => dirs.extend(n.plugin_dirs.iter().cloned()),
745 _ => {
746 if let Some(body) = node.body() {
747 dirs.extend(collect_plugin_dirs(body));
748 }
749 }
750 }
751 }
752 dirs
753}
754
755#[cfg(test)]
756mod tests {
757 use std::collections::HashMap;
758
759 use super::*;
760
761 fn simple_wf(body: Vec<WorkflowNode>) -> WorkflowDef {
764 WorkflowDef {
765 name: "test_wf".to_string(),
766 title: None,
767 description: String::new(),
768 trigger: WorkflowTrigger::Manual,
769 targets: vec![],
770 group: None,
771 inputs: vec![],
772 body,
773 always: vec![],
774 source_path: "test.wf".to_string(),
775 }
776 }
777
778 fn call(agent: &str) -> WorkflowNode {
779 WorkflowNode::Call(CallNode {
780 agent: AgentRef::Name(agent.to_string()),
781 retries: 0,
782 on_fail: None,
783 output: None,
784 with: vec![],
785 as_identity: None,
786 plugin_dirs: vec![],
787 timeout: None,
788 max_turns: None,
789 })
790 }
791
792 fn call_with_output(agent: &str, output: &str) -> WorkflowNode {
793 WorkflowNode::Call(CallNode {
794 agent: AgentRef::Name(agent.to_string()),
795 output: Some(output.to_string()),
796 retries: 0,
797 on_fail: None,
798 with: vec![],
799 as_identity: None,
800 plugin_dirs: vec![],
801 timeout: None,
802 max_turns: None,
803 })
804 }
805
806 fn call_with_snippets(agent: &str, snippets: &[&str]) -> WorkflowNode {
807 WorkflowNode::Call(CallNode {
808 agent: AgentRef::Name(agent.to_string()),
809 with: snippets.iter().map(|s| s.to_string()).collect(),
810 retries: 0,
811 on_fail: None,
812 output: None,
813 as_identity: None,
814 plugin_dirs: vec![],
815 timeout: None,
816 max_turns: None,
817 })
818 }
819
820 fn call_with_plugin_dirs(agent: &str, dirs: &[&str]) -> WorkflowNode {
821 WorkflowNode::Call(CallNode {
822 agent: AgentRef::Name(agent.to_string()),
823 plugin_dirs: dirs.iter().map(|s| s.to_string()).collect(),
824 retries: 0,
825 on_fail: None,
826 output: None,
827 with: vec![],
828 as_identity: None,
829 timeout: None,
830 max_turns: None,
831 })
832 }
833
834 fn call_with_identity(agent: &str, identity: &str) -> WorkflowNode {
835 WorkflowNode::Call(CallNode {
836 agent: AgentRef::Name(agent.to_string()),
837 as_identity: Some(identity.to_string()),
838 retries: 0,
839 on_fail: None,
840 output: None,
841 with: vec![],
842 plugin_dirs: vec![],
843 timeout: None,
844 max_turns: None,
845 })
846 }
847
848 fn do_while_node(step: &str, max_iter: u32, body: Vec<WorkflowNode>) -> WorkflowNode {
849 WorkflowNode::DoWhile(DoWhileNode {
850 step: step.to_string(),
851 marker: "done".to_string(),
852 max_iterations: max_iter,
853 stuck_after: None,
854 on_max_iter: OnMaxIter::Fail,
855 body,
856 })
857 }
858
859 fn while_node(step: &str, max_iter: u32, body: Vec<WorkflowNode>) -> WorkflowNode {
860 WorkflowNode::While(WhileNode {
861 step: step.to_string(),
862 marker: "needs_revision".to_string(),
863 max_iterations: max_iter,
864 stuck_after: None,
865 on_max_iter: OnMaxIter::Fail,
866 body,
867 })
868 }
869
870 fn if_node(step: &str, marker: &str, body: Vec<WorkflowNode>) -> WorkflowNode {
871 WorkflowNode::If(IfNode {
872 condition: Condition::StepMarker {
873 step: step.to_string(),
874 marker: marker.to_string(),
875 },
876 body,
877 })
878 }
879
880 fn call_workflow(name: &str) -> WorkflowNode {
881 WorkflowNode::CallWorkflow(CallWorkflowNode {
882 workflow: name.to_string(),
883 inputs: HashMap::new(),
884 retries: 0,
885 on_fail: None,
886 as_identity: None,
887 })
888 }
889
890 fn script_node(name: &str, run: &str) -> WorkflowNode {
891 WorkflowNode::Script(ScriptNode {
892 name: name.to_string(),
893 run: run.to_string(),
894 env: HashMap::new(),
895 timeout: None,
896 retries: 0,
897 on_fail: None,
898 as_identity: None,
899 })
900 }
901
902 #[test]
905 fn display_name_returns_title_when_set() {
906 let mut wf = simple_wf(vec![]);
907 wf.title = Some("My Workflow".to_string());
908 assert_eq!(wf.display_name(), "My Workflow");
909 }
910
911 #[test]
912 fn display_name_falls_back_to_name_when_no_title() {
913 let wf = simple_wf(vec![]);
914 assert_eq!(wf.display_name(), "test_wf");
915 }
916
917 #[test]
920 fn total_nodes_flat_list() {
921 let wf = simple_wf(vec![call("a"), call("b"), call("c")]);
922 assert_eq!(wf.total_nodes(), 3);
923 }
924
925 #[test]
926 fn total_nodes_includes_nested_nodes() {
927 let nested = if_node("a", "done", vec![call("b"), call("c")]);
928 let wf = simple_wf(vec![call("a"), nested]);
929 assert_eq!(wf.total_nodes(), 4);
930 }
931
932 #[test]
933 fn total_nodes_includes_always_block() {
934 let mut wf = simple_wf(vec![call("a")]);
935 wf.always = vec![call("cleanup")];
936 assert_eq!(wf.total_nodes(), 2);
937 }
938
939 #[test]
942 fn top_level_steps_returns_only_direct_children() {
943 let nested = if_node("a", "done", vec![call("b"), call("c")]);
944 let wf = simple_wf(vec![call("a"), nested]);
945 assert_eq!(wf.top_level_steps(), 2);
946 }
947
948 #[test]
949 fn top_level_steps_includes_always_block() {
950 let mut wf = simple_wf(vec![call("a"), call("b")]);
951 wf.always = vec![call("cleanup")];
952 assert_eq!(wf.top_level_steps(), 3);
953 }
954
955 #[test]
958 fn max_iterations_for_step_found_in_do_while() {
959 let wf = simple_wf(vec![do_while_node("reviewer", 5, vec![call("reviewer")])]);
960 assert_eq!(wf.max_iterations_for_step("reviewer"), Some(5));
961 }
962
963 #[test]
964 fn max_iterations_for_step_found_in_while() {
965 let wf = simple_wf(vec![
966 call("reviewer"),
967 while_node("reviewer", 3, vec![call("fix")]),
968 ]);
969 assert_eq!(wf.max_iterations_for_step("reviewer"), Some(3));
970 }
971
972 #[test]
973 fn max_iterations_for_step_not_found_returns_none() {
974 let wf = simple_wf(vec![call("a"), call("b")]);
975 assert_eq!(wf.max_iterations_for_step("a"), None);
976 }
977
978 #[test]
979 fn max_iterations_for_step_nested_loop() {
980 let inner = do_while_node("inner", 2, vec![call("inner")]);
981 let outer = while_node("outer", 10, vec![call("outer"), inner]);
982 let wf = simple_wf(vec![outer]);
983 assert_eq!(wf.max_iterations_for_step("inner"), Some(2));
984 assert_eq!(wf.max_iterations_for_step("outer"), Some(10));
985 }
986
987 #[test]
990 fn count_nodes_flat_list() {
991 let nodes = vec![call("a"), call("b")];
992 assert_eq!(count_nodes(&nodes), 2);
993 }
994
995 #[test]
996 fn count_nodes_parallel_counts_calls() {
997 let parallel = WorkflowNode::Parallel(ParallelNode {
998 fail_fast: true,
999 min_success: None,
1000 calls: vec![
1001 AgentRef::Name("a".to_string()),
1002 AgentRef::Name("b".to_string()),
1003 ],
1004 output: None,
1005 call_outputs: HashMap::new(),
1006 with: vec![],
1007 call_with: HashMap::new(),
1008 call_if: HashMap::new(),
1009 call_retries: HashMap::new(),
1010 });
1011 let nodes = vec![parallel];
1012 assert_eq!(count_nodes(&nodes), 3); }
1014
1015 #[test]
1016 fn count_nodes_recursive_into_if_body() {
1017 let nested = if_node("a", "done", vec![call("b"), call("c")]);
1018 assert_eq!(count_nodes(&[nested]), 3); }
1020
1021 #[test]
1024 fn collect_agent_names_flat_call_nodes() {
1025 let nodes = vec![call("agent_a"), call("agent_b")];
1026 let refs = collect_agent_names(&nodes);
1027 let names: Vec<&str> = refs.iter().map(|r| r.label()).collect();
1028 assert!(names.contains(&"agent_a"));
1029 assert!(names.contains(&"agent_b"));
1030 }
1031
1032 #[test]
1033 fn collect_agent_names_deduplication_when_sorted() {
1034 let nodes = vec![call("agent_a"), call("agent_a"), call("agent_b")];
1035 let mut refs = collect_agent_names(&nodes);
1036 refs.sort();
1037 refs.dedup();
1038 assert_eq!(refs.len(), 2);
1039 }
1040
1041 #[test]
1042 fn collect_agent_names_parallel_node() {
1043 let parallel = WorkflowNode::Parallel(ParallelNode {
1044 fail_fast: true,
1045 min_success: None,
1046 calls: vec![
1047 AgentRef::Name("par_a".to_string()),
1048 AgentRef::Name("par_b".to_string()),
1049 ],
1050 output: None,
1051 call_outputs: HashMap::new(),
1052 with: vec![],
1053 call_with: HashMap::new(),
1054 call_if: HashMap::new(),
1055 call_retries: HashMap::new(),
1056 });
1057 let refs = collect_agent_names(&[parallel]);
1058 let names: Vec<&str> = refs.iter().map(|r| r.label()).collect();
1059 assert!(names.contains(&"par_a"));
1060 assert!(names.contains(&"par_b"));
1061 }
1062
1063 #[test]
1064 fn collect_all_agent_refs_deduplicates_and_sorts() {
1065 let wf = simple_wf(vec![call("z_agent"), call("a_agent"), call("z_agent")]);
1066 let refs = wf.collect_all_agent_refs();
1067 assert_eq!(refs.len(), 2);
1068 assert_eq!(refs[0].label(), "a_agent");
1069 assert_eq!(refs[1].label(), "z_agent");
1070 }
1071
1072 #[test]
1075 fn collect_snippet_refs_from_call_with() {
1076 let nodes = vec![call_with_snippets("agent", &["ctx_a", "ctx_b"])];
1077 let refs = collect_snippet_refs(&nodes);
1078 assert!(refs.contains(&"ctx_a".to_string()));
1079 assert!(refs.contains(&"ctx_b".to_string()));
1080 }
1081
1082 #[test]
1083 fn collect_all_snippet_refs_deduplicates() {
1084 let wf = simple_wf(vec![
1085 call_with_snippets("a", &["shared"]),
1086 call_with_snippets("b", &["shared", "unique"]),
1087 ]);
1088 let refs = wf.collect_all_snippet_refs();
1089 assert_eq!(refs.iter().filter(|s| *s == "shared").count(), 1);
1090 assert_eq!(refs.len(), 2);
1091 }
1092
1093 #[test]
1096 fn collect_workflow_refs_from_call_workflow() {
1097 let nodes = vec![call_workflow("child_wf"), call_workflow("other_wf")];
1098 let refs = collect_workflow_refs(&nodes);
1099 assert!(refs.contains(&"child_wf".to_string()));
1100 assert!(refs.contains(&"other_wf".to_string()));
1101 }
1102
1103 #[test]
1104 fn collect_workflow_refs_skips_call_nodes() {
1105 let nodes = vec![call("agent"), call_workflow("child_wf")];
1106 let refs = collect_workflow_refs(&nodes);
1107 assert_eq!(refs.len(), 1);
1108 assert_eq!(refs[0], "child_wf");
1109 }
1110
1111 #[test]
1114 fn collect_schema_refs_from_call_output() {
1115 let nodes = vec![call_with_output("agent", "my_schema")];
1116 let refs = collect_schema_refs(&nodes);
1117 assert!(refs.contains(&"my_schema".to_string()));
1118 }
1119
1120 #[test]
1121 fn collect_all_schema_refs_deduplicates() {
1122 let wf = simple_wf(vec![
1123 call_with_output("a", "schema"),
1124 call_with_output("b", "schema"),
1125 ]);
1126 let refs = wf.collect_all_schema_refs();
1127 assert_eq!(refs.iter().filter(|s| *s == "schema").count(), 1);
1128 }
1129
1130 #[test]
1133 fn collect_as_identities_from_call_nodes() {
1134 let nodes = vec![call_with_identity("agent", "bot-app")];
1135 let names = collect_as_identities(&nodes);
1136 assert!(names.contains(&"bot-app".to_string()));
1137 }
1138
1139 #[test]
1140 fn collect_all_as_identities_deduplicates() {
1141 let wf = simple_wf(vec![
1142 call_with_identity("a", "bot"),
1143 call_with_identity("b", "bot"),
1144 ]);
1145 let names = wf.collect_all_as_identities();
1146 assert_eq!(names.iter().filter(|n| *n == "bot").count(), 1);
1147 }
1148
1149 #[test]
1152 fn collect_plugin_dirs_from_call_nodes() {
1153 let nodes = vec![call_with_plugin_dirs("agent", &["/opt/plugins"])];
1154 let dirs = collect_plugin_dirs(&nodes);
1155 assert!(dirs.contains(&"/opt/plugins".to_string()));
1156 }
1157
1158 #[test]
1159 fn collect_all_plugin_dirs_deduplicates() {
1160 let wf = simple_wf(vec![
1161 call_with_plugin_dirs("a", &["/opt/shared"]),
1162 call_with_plugin_dirs("b", &["/opt/shared", "/opt/unique"]),
1163 ]);
1164 let dirs = wf.collect_all_plugin_dirs();
1165 assert_eq!(dirs.iter().filter(|d| *d == "/opt/shared").count(), 1);
1166 assert_eq!(dirs.len(), 2);
1167 }
1168
1169 #[test]
1172 fn agent_ref_name_step_key_returns_name() {
1173 let r = AgentRef::Name("my_agent".to_string());
1174 assert_eq!(r.step_key(), "my_agent");
1175 }
1176
1177 #[test]
1178 fn agent_ref_path_step_key_returns_file_stem() {
1179 let r = AgentRef::Path(".claude/agents/plan.md".to_string());
1180 assert_eq!(r.step_key(), "plan");
1181 }
1182
1183 #[test]
1184 fn agent_ref_label_returns_inner_string() {
1185 assert_eq!(AgentRef::Name("foo".to_string()).label(), "foo");
1186 assert_eq!(
1187 AgentRef::Path("bar/baz.md".to_string()).label(),
1188 "bar/baz.md"
1189 );
1190 }
1191
1192 #[test]
1195 fn workflow_trigger_serde_round_trip() {
1196 for (variant, expected_json) in [
1197 (WorkflowTrigger::Manual, r#""manual""#),
1198 (WorkflowTrigger::Pr, r#""pr""#),
1199 (WorkflowTrigger::Scheduled, r#""scheduled""#),
1200 ] {
1201 let json = serde_json::to_string(&variant).unwrap();
1202 assert_eq!(json, expected_json, "display mismatch for {variant:?}");
1203 let back: WorkflowTrigger = serde_json::from_str(&json).unwrap();
1204 assert_eq!(back, variant);
1205 }
1206 }
1207
1208 #[test]
1211 fn on_max_iter_serde_round_trip() {
1212 let json = serde_json::to_string(&OnMaxIter::Continue).unwrap();
1213 assert_eq!(json, r#""continue""#);
1214 let back: OnMaxIter = serde_json::from_str(&json).unwrap();
1215 assert_eq!(back, OnMaxIter::Continue);
1216 }
1217
1218 #[test]
1219 fn on_timeout_serde_round_trip() {
1220 let json = serde_json::to_string(&OnTimeout::Fail).unwrap();
1221 let back: OnTimeout = serde_json::from_str(&json).unwrap();
1222 assert_eq!(back, OnTimeout::Fail);
1223 }
1224
1225 #[test]
1226 fn on_child_fail_serde_all_variants() {
1227 for variant in [
1228 OnChildFail::Halt,
1229 OnChildFail::Continue,
1230 OnChildFail::SkipDependents,
1231 ] {
1232 let json = serde_json::to_string(&variant).unwrap();
1233 let back: OnChildFail = serde_json::from_str(&json).unwrap();
1234 assert_eq!(back, variant);
1235 }
1236 }
1237
1238 #[test]
1239 fn on_cycle_serde_all_variants() {
1240 for variant in [OnCycle::Fail, OnCycle::Warn] {
1241 let json = serde_json::to_string(&variant).unwrap();
1242 let back: OnCycle = serde_json::from_str(&json).unwrap();
1243 assert_eq!(back, variant);
1244 }
1245 }
1246
1247 #[test]
1248 fn approval_mode_serde_all_variants() {
1249 for variant in [ApprovalMode::MinApprovals, ApprovalMode::ReviewDecision] {
1250 let json = serde_json::to_string(&variant).unwrap();
1251 let back: ApprovalMode = serde_json::from_str(&json).unwrap();
1252 assert_eq!(back, variant);
1253 }
1254 }
1255
1256 #[test]
1257 fn on_fail_action_serde_all_variants() {
1258 for variant in [OnFailAction::Fail, OnFailAction::Continue] {
1259 let json = serde_json::to_string(&variant).unwrap();
1260 let back: OnFailAction = serde_json::from_str(&json).unwrap();
1261 assert_eq!(back, variant);
1262 }
1263 }
1264
1265 #[test]
1266 fn on_fail_agent_variant_serde() {
1267 let val = OnFail::Agent(AgentRef::Name("fallback".to_string()));
1268 let json = serde_json::to_string(&val).unwrap();
1269 assert!(json.contains("agent"), "got: {json}");
1270 let back: OnFail = serde_json::from_str(&json).unwrap();
1271 assert_eq!(back, OnFail::Agent(AgentRef::Name("fallback".to_string())));
1272 }
1273
1274 #[test]
1275 fn on_fail_continue_variant_serde() {
1276 let json = serde_json::to_string(&OnFail::Continue).unwrap();
1277 let back: OnFail = serde_json::from_str(&json).unwrap();
1278 assert_eq!(back, OnFail::Continue);
1279 }
1280
1281 #[test]
1282 fn input_type_serde_all_variants() {
1283 assert_eq!(
1284 serde_json::to_string(&InputType::String).unwrap(),
1285 r#""string""#
1286 );
1287 assert_eq!(
1288 serde_json::to_string(&InputType::Boolean).unwrap(),
1289 r#""boolean""#
1290 );
1291 }
1292
1293 #[test]
1296 fn script_node_collect_included_in_total() {
1297 let wf = simple_wf(vec![script_node("lint", "./scripts/lint.sh")]);
1298 assert_eq!(wf.total_nodes(), 1);
1299 }
1300}