Skip to main content

runkon_flow/dsl/
types.rs

1use std::collections::HashMap;
2use std::path::Path;
3
4use serde::{Deserialize, Serialize};
5
6// ---------------------------------------------------------------------------
7// AST types
8// ---------------------------------------------------------------------------
9
10/// A complete workflow definition parsed from a `.wf` file.
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct WorkflowDef {
13    pub name: String,
14    #[serde(default)]
15    pub title: Option<String>,
16    pub description: String,
17    pub trigger: WorkflowTrigger,
18    #[serde(default)]
19    pub targets: Vec<String>,
20    #[serde(default)]
21    pub group: Option<String>,
22    pub inputs: Vec<InputDecl>,
23    pub body: Vec<WorkflowNode>,
24    pub always: Vec<WorkflowNode>,
25    pub source_path: String,
26}
27
28impl WorkflowDef {
29    /// Returns the human-readable display name for this workflow.
30    /// Falls back to `name` if no `title` is set.
31    pub fn display_name(&self) -> &str {
32        self.title.as_deref().unwrap_or(&self.name)
33    }
34
35    /// Total number of nodes across body and always blocks.
36    pub fn total_nodes(&self) -> usize {
37        count_nodes(&self.body) + count_nodes(&self.always)
38    }
39
40    /// Number of top-level steps (body + always, non-recursive).
41    /// Better for user-facing progress display than `total_nodes()`.
42    pub fn top_level_steps(&self) -> usize {
43        self.body.len() + self.always.len()
44    }
45
46    /// Find the `max_iterations` of the do-while or while loop that owns
47    /// the step with the given name. Returns `None` if the step is not
48    /// inside a loop or the step name is not found.
49    pub fn max_iterations_for_step(&self, step_name: &str) -> Option<u32> {
50        fn search(nodes: &[WorkflowNode], name: &str) -> Option<u32> {
51            for node in nodes {
52                match node {
53                    WorkflowNode::DoWhile(n) => {
54                        if n.step == name {
55                            return Some(n.max_iterations);
56                        }
57                        if let Some(v) = search(&n.body, name) {
58                            return Some(v);
59                        }
60                    }
61                    WorkflowNode::While(n) => {
62                        if n.step == name {
63                            return Some(n.max_iterations);
64                        }
65                        if let Some(v) = search(&n.body, name) {
66                            return Some(v);
67                        }
68                    }
69                    _ => {
70                        if let Some(body) = node.body() {
71                            if let Some(v) = search(body, name) {
72                                return Some(v);
73                            }
74                        }
75                    }
76                }
77            }
78            None
79        }
80        search(&self.body, step_name).or_else(|| search(&self.always, step_name))
81    }
82
83    /// Collect all prompt snippet references across body and always blocks, sorted and deduplicated.
84    pub fn collect_all_snippet_refs(&self) -> Vec<String> {
85        let mut refs = collect_snippet_refs(&self.body);
86        refs.extend(collect_snippet_refs(&self.always));
87        refs.sort();
88        refs.dedup();
89        refs
90    }
91
92    /// Collect all output schema references across body and always blocks, sorted and deduplicated.
93    pub fn collect_all_schema_refs(&self) -> Vec<String> {
94        let mut refs = collect_schema_refs(&self.body);
95        refs.extend(collect_schema_refs(&self.always));
96        refs.sort();
97        refs.dedup();
98        refs
99    }
100
101    /// Collect all agent references across body and always blocks, sorted and deduplicated.
102    pub fn collect_all_agent_refs(&self) -> Vec<AgentRef> {
103        let mut refs = collect_agent_names(&self.body);
104        refs.extend(collect_agent_names(&self.always));
105        refs.sort();
106        refs.dedup();
107        refs
108    }
109
110    /// Collect all as_identity values referenced across body and always blocks, sorted and deduplicated.
111    pub fn collect_all_as_identities(&self) -> Vec<String> {
112        let mut names = collect_as_identities(&self.body);
113        names.extend(collect_as_identities(&self.always));
114        names.sort();
115        names.dedup();
116        names
117    }
118
119    /// Collect all plugin_dirs from call nodes across body and always blocks, sorted and deduplicated.
120    pub fn collect_all_plugin_dirs(&self) -> Vec<String> {
121        let mut dirs = collect_plugin_dirs(&self.body);
122        dirs.extend(collect_plugin_dirs(&self.always));
123        dirs.sort();
124        dirs.dedup();
125        dirs
126    }
127}
128
129/// A structured parse warning produced when a `.wf` file fails to load.
130#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct WorkflowWarning {
132    /// The filename (e.g. `bad.wf`) that failed to parse.
133    pub file: String,
134    /// Human-readable description of the parse error.
135    pub message: String,
136}
137
138/// Trigger type for when a workflow should run.
139#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
140#[serde(rename_all = "snake_case")]
141pub enum WorkflowTrigger {
142    Manual,
143    Pr,
144    Scheduled,
145}
146
147impl std::fmt::Display for WorkflowTrigger {
148    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149        match self {
150            Self::Manual => write!(f, "manual"),
151            Self::Pr => write!(f, "pr"),
152            Self::Scheduled => write!(f, "scheduled"),
153        }
154    }
155}
156
157impl std::str::FromStr for WorkflowTrigger {
158    type Err = String;
159    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
160        match s {
161            "manual" => Ok(Self::Manual),
162            "pr" => Ok(Self::Pr),
163            "scheduled" => Ok(Self::Scheduled),
164            _ => Err(format!("unknown trigger: {s}")),
165        }
166    }
167}
168
169/// The type of a workflow input.
170#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
171#[serde(rename_all = "snake_case")]
172pub enum InputType {
173    #[default]
174    String,
175    Boolean,
176}
177
178/// An input declaration for a workflow.
179#[derive(Debug, Clone, Serialize, Deserialize)]
180pub struct InputDecl {
181    pub name: String,
182    pub required: bool,
183    pub default: Option<String>,
184    pub description: Option<String>,
185    #[serde(default)]
186    pub input_type: InputType,
187}
188
189/// A node in the workflow execution graph.
190#[derive(Debug, Clone, Serialize, Deserialize)]
191#[serde(tag = "type", rename_all = "snake_case")]
192pub enum WorkflowNode {
193    Call(CallNode),
194    CallWorkflow(CallWorkflowNode),
195    If(IfNode),
196    Unless(UnlessNode),
197    While(WhileNode),
198    DoWhile(DoWhileNode),
199    Do(DoNode),
200    Parallel(ParallelNode),
201    Gate(GateNode),
202    Always(AlwaysNode),
203    Script(ScriptNode),
204    ForEach(ForEachNode),
205}
206
207impl WorkflowNode {
208    /// Returns the child body slice for block-node variants, or `None` for leaf nodes.
209    pub fn body(&self) -> Option<&[WorkflowNode]> {
210        match self {
211            WorkflowNode::If(n) => Some(&n.body),
212            WorkflowNode::Unless(n) => Some(&n.body),
213            WorkflowNode::While(n) => Some(&n.body),
214            WorkflowNode::DoWhile(n) => Some(&n.body),
215            WorkflowNode::Do(n) => Some(&n.body),
216            WorkflowNode::Always(n) => Some(&n.body),
217            _ => None,
218        }
219    }
220}
221
222/// A foreach step node — fans out a child workflow over a collection of items.
223#[derive(Debug, Clone, Serialize, Deserialize)]
224pub struct ForEachNode {
225    /// Step name used as the key in step_results and resume skip sets.
226    pub name: String,
227    /// The collection type to fan out over.
228    pub over: ForeachOver,
229    /// Raw scope key-value map passed to the provider's `parse_scope` method.
230    pub scope: Option<HashMap<String, String>>,
231    /// Generic filter map (required for workflow_run fan-outs, reserved for repos).
232    #[serde(default)]
233    pub filter: HashMap<String, String>,
234    /// Whether to use dependency-ordered dispatch (tickets only).
235    pub ordered: bool,
236    /// What to do when a ticket cycle is detected (tickets + ordered only).
237    pub on_cycle: OnCycle,
238    /// Maximum number of child workflows to run concurrently.
239    pub max_parallel: u32,
240    /// Name of the child workflow to invoke for each item.
241    pub workflow: String,
242    /// Input map passed to each child workflow invocation.
243    /// Values may contain `{{item.*}}` template references.
244    #[serde(default)]
245    pub inputs: HashMap<String, String>,
246    /// How to handle a child workflow failure.
247    pub on_child_fail: OnChildFail,
248}
249
250/// The collection type for a foreach step — the registered provider name.
251pub type ForeachOver = String;
252
253/// What to do when a child workflow fails.
254#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
255#[serde(rename_all = "snake_case")]
256pub enum OnChildFail {
257    /// Cancel in-flight runs and fail the step immediately.
258    Halt,
259    /// Log the failure and keep dispatching remaining items.
260    Continue,
261    /// Mark the failed item's transitive dependents as skipped (ordered tickets only).
262    SkipDependents,
263}
264
265/// What to do when a ticket dependency cycle is detected.
266#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
267#[serde(rename_all = "snake_case")]
268pub enum OnCycle {
269    /// Abort with an error naming the cycle.
270    Fail,
271    /// Log a warning, break the back-edge, and continue.
272    Warn,
273}
274
275/// A script step node — runs a shell script directly (no agent/LLM).
276#[derive(Debug, Clone, Serialize, Deserialize)]
277pub struct ScriptNode {
278    /// Step name used as the step key in step_results and resume skip sets.
279    pub name: String,
280    /// Path to the script to run (supports `{{variable}}` substitution).
281    /// Resolved in order: worktree dir → repo dir → `~/.claude/skills/`.
282    pub run: String,
283    /// Environment variable overrides (values support `{{variable}}` substitution).
284    #[serde(default)]
285    pub env: HashMap<String, String>,
286    /// Optional timeout in seconds. If the script does not complete within this
287    /// duration it is killed and the step is marked `TimedOut`.
288    pub timeout: Option<u64>,
289    /// Number of retry attempts after the first failure (0 = no retries).
290    #[serde(default)]
291    pub retries: u32,
292    /// Action to take if all attempts fail.
293    pub on_fail: Option<OnFail>,
294    /// Named GitHub App bot identity to use for this script (matches `[github.apps.<name>]`).
295    /// When set, the resolved installation token is injected as `GH_TOKEN` so the script
296    /// uses that bot identity for all `gh` CLI calls.
297    pub as_identity: Option<String>,
298}
299
300/// The action to take when all retries for a `call`, `script`, or `call workflow` step exhaust.
301///
302/// - `Agent`: invoke a fallback agent (existing behaviour).
303/// - `Continue`: skip the step without marking the workflow failed.
304#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
305#[serde(tag = "kind", content = "value", rename_all = "snake_case")]
306pub enum OnFail {
307    Agent(AgentRef),
308    Continue,
309}
310
311/// Reference to an agent — either a short name or an explicit file path.
312///
313/// - `Name`: bare identifier (e.g. `plan`) resolved via the search order.
314/// - `Path`: quoted string (e.g. `".claude/agents/plan.md"`) resolved directly
315///   relative to the repository root.
316#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
317#[serde(tag = "kind", content = "value", rename_all = "snake_case")]
318pub enum AgentRef {
319    Name(String),
320    Path(String),
321}
322
323impl AgentRef {
324    /// Human-readable label for display and logging (the inner string value).
325    pub fn label(&self) -> &str {
326        match self {
327            Self::Name(s) | Self::Path(s) => s.as_str(),
328        }
329    }
330
331    /// Key used to store and look up results in `step_results`.
332    ///
333    /// - `Name` variants return the name as-is.
334    /// - `Path` variants return the file stem without extension
335    ///   (e.g. `"plan"` from `".claude/agents/plan.md"`), so that `if`/`while`
336    ///   conditions can reference path-based agents by their short name.
337    pub fn step_key(&self) -> String {
338        match self {
339            Self::Name(s) => s.clone(),
340            Self::Path(s) => Path::new(s)
341                .file_stem()
342                .and_then(|stem| stem.to_str())
343                .unwrap_or(s.as_str())
344                .to_string(),
345        }
346    }
347}
348
349impl std::fmt::Display for AgentRef {
350    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
351        write!(f, "{}", self.label())
352    }
353}
354
355#[derive(Debug, Clone, Serialize, Deserialize)]
356pub struct CallNode {
357    pub agent: AgentRef,
358    #[serde(default)]
359    pub retries: u32,
360    pub on_fail: Option<OnFail>,
361    /// Optional output schema reference for structured output.
362    pub output: Option<String>,
363    /// Prompt snippet references to append to the agent prompt.
364    #[serde(default)]
365    pub with: Vec<String>,
366    /// Named GitHub App bot identity to use for this call (matches `[github.apps.<name>]`).
367    pub as_identity: Option<String>,
368    /// Per-step plugin directories from the `.wf` file. Merged with repo-level
369    /// `extra_plugin_dirs` at execution time to give this agent access to
370    /// specialist plugins (e.g. `/usr/local/bsg/agent-architecture/planner`).
371    #[serde(default)]
372    pub plugin_dirs: Vec<String>,
373    /// Optional per-step timeout (e.g. "5m", "30s", "1h"). If the step does not
374    /// complete within this duration it is cancelled with `CancellationReason::Timeout`.
375    pub timeout: Option<String>,
376    /// Optional per-step host-enforced turn cap. Overrides the workflow-level default.
377    /// `None` defers to `DEFAULT_MAX_TURNS` applied by the executor.
378    #[serde(default)]
379    pub max_turns: Option<u32>,
380}
381
382/// A sub-workflow invocation node.
383#[derive(Debug, Clone, Serialize, Deserialize)]
384pub struct CallWorkflowNode {
385    pub workflow: String,
386    #[serde(default)]
387    pub inputs: HashMap<String, String>,
388    #[serde(default)]
389    pub retries: u32,
390    pub on_fail: Option<OnFail>,
391    /// Named GitHub App bot identity inherited by child call nodes.
392    pub as_identity: Option<String>,
393}
394
395/// A condition in an `if`/`unless` block.
396#[derive(Debug, Clone, Serialize, Deserialize)]
397#[serde(tag = "kind", rename_all = "snake_case")]
398pub enum Condition {
399    /// References a marker produced by a prior step: `step.marker`.
400    StepMarker { step: String, marker: String },
401    /// References a boolean input directly: `input_name`.
402    BoolInput { input: String },
403}
404
405#[derive(Debug, Clone, Serialize, Deserialize)]
406pub struct IfNode {
407    pub condition: Condition,
408    pub body: Vec<WorkflowNode>,
409}
410
411#[derive(Debug, Clone, Serialize, Deserialize)]
412pub struct UnlessNode {
413    pub condition: Condition,
414    pub body: Vec<WorkflowNode>,
415}
416
417#[derive(Debug, Clone, Serialize, Deserialize)]
418pub struct WhileNode {
419    pub step: String,
420    pub marker: String,
421    pub max_iterations: u32,
422    pub stuck_after: Option<u32>,
423    pub on_max_iter: OnMaxIter,
424    pub body: Vec<WorkflowNode>,
425}
426
427#[derive(Debug, Clone, Serialize, Deserialize)]
428pub struct DoWhileNode {
429    pub step: String,
430    pub marker: String,
431    pub max_iterations: u32,
432    pub stuck_after: Option<u32>,
433    pub on_max_iter: OnMaxIter,
434    pub body: Vec<WorkflowNode>,
435}
436
437/// A plain sequential grouping block (`do { ... }`), with optional `output` and `with`.
438#[derive(Debug, Clone, Serialize, Deserialize)]
439pub struct DoNode {
440    /// Optional output schema reference for structured output.
441    pub output: Option<String>,
442    /// Prompt snippet references applied to all calls inside the block.
443    #[serde(default)]
444    pub with: Vec<String>,
445    pub body: Vec<WorkflowNode>,
446}
447
448#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
449#[serde(rename_all = "snake_case")]
450pub enum OnMaxIter {
451    Fail,
452    Continue,
453}
454
455#[derive(Debug, Clone, Serialize, Deserialize)]
456pub struct ParallelNode {
457    #[serde(default = "default_true")]
458    pub fail_fast: bool,
459    pub min_success: Option<u32>,
460    pub calls: Vec<AgentRef>,
461    /// Block-level output schema reference (applies to all calls unless overridden).
462    pub output: Option<String>,
463    /// Per-call output schema overrides, keyed by index (as string) in `calls`.
464    /// String keys are used because JSON object keys are always strings and serde_json
465    /// cannot coerce them back to integer types on deserialization.
466    #[serde(default)]
467    pub call_outputs: HashMap<String, String>,
468    /// Block-level prompt snippet references (applied to all calls).
469    #[serde(default)]
470    pub with: Vec<String>,
471    /// Per-call prompt snippet additions, keyed by index (as string) in `calls`.
472    #[serde(default)]
473    pub call_with: HashMap<String, Vec<String>>,
474    /// Per-call `if` conditions keyed by index (as string) in `calls`.
475    /// Value is (step_name, marker_name). Run the call only if that marker is present.
476    #[serde(default)]
477    pub call_if: HashMap<String, (String, String)>,
478    /// Per-call retry counts keyed by index (as string) in `calls`. 0 = no retries.
479    #[serde(default)]
480    pub call_retries: HashMap<String, u32>,
481}
482
483fn default_true() -> bool {
484    true
485}
486
487#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
488#[serde(rename_all = "snake_case")]
489pub enum ApprovalMode {
490    #[default]
491    MinApprovals,
492    ReviewDecision,
493}
494
495#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
496#[serde(rename_all = "snake_case")]
497pub enum OnFailAction {
498    Fail,
499    Continue,
500}
501
502/// Configuration specific to `quality_gate` nodes.
503///
504/// Grouped into a single struct so non-quality-gate construction sites need
505/// only `quality_gate: None` instead of three separate optional fields.
506#[derive(Debug, Clone, Serialize, Deserialize)]
507pub struct QualityGateConfig {
508    /// Step key whose structured output is evaluated.
509    pub source: String,
510    /// Minimum confidence score (0-100) required to pass.
511    pub threshold: u32,
512    /// Action when the gate fails (score below threshold).
513    #[serde(default = "default_on_fail")]
514    pub on_fail_action: OnFailAction,
515}
516
517fn default_on_fail() -> OnFailAction {
518    OnFailAction::Fail
519}
520
521/// Specifies the set of options for a multi-select gate.
522///
523/// - `Static`: a literal key-value map of option strings defined in the workflow file.
524/// - `StepRef`: a `"step.field"` reference resolved at runtime from a prior step's
525///   structured output (the field must be a JSON object with string values).
526#[derive(Debug, Clone, Serialize, Deserialize)]
527#[serde(untagged)]
528pub enum GateOptions {
529    Static(HashMap<String, String>),
530    /// Raw `"step.field"` dotted reference — resolved at execution time.
531    StepRef(String),
532}
533
534#[derive(Debug, Clone, Serialize, Deserialize)]
535pub struct GateNode {
536    pub name: String,
537    pub gate_type: String,
538    pub prompt: Option<String>,
539    #[serde(default = "default_one")]
540    pub min_approvals: u32,
541    #[serde(default)]
542    pub approval_mode: ApprovalMode,
543    pub timeout_secs: u64,
544    pub on_timeout: OnTimeout,
545    /// Named GitHub App bot identity used for `gh` calls inside this gate.
546    pub as_identity: Option<String>,
547    /// Quality gate-specific configuration. Present only when `gate_type == QUALITY_GATE_TYPE`.
548    #[serde(flatten)]
549    pub quality_gate: Option<QualityGateConfig>,
550    /// Optional multi-select options for human_approval / human_review gates.
551    pub options: Option<GateOptions>,
552}
553
554fn default_one() -> u32 {
555    1
556}
557
558pub const QUALITY_GATE_TYPE: &str = "quality_gate";
559
560#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
561#[serde(rename_all = "snake_case")]
562pub enum OnTimeout {
563    Fail,
564    Continue,
565}
566
567#[derive(Debug, Clone, Serialize, Deserialize)]
568pub struct AlwaysNode {
569    pub body: Vec<WorkflowNode>,
570}
571
572// ---------------------------------------------------------------------------
573// Tree-walking helpers
574// ---------------------------------------------------------------------------
575
576/// Count the total number of nodes in a node list (for display).
577pub(crate) fn count_nodes(nodes: &[WorkflowNode]) -> usize {
578    let mut count = 0;
579    for node in nodes {
580        count += 1;
581        match node {
582            WorkflowNode::Parallel(n) => count += n.calls.len(),
583            _ => {
584                if let Some(body) = node.body() {
585                    count += count_nodes(body);
586                }
587            }
588        }
589    }
590    count
591}
592
593/// Collect all agent references in a node tree (for validation before execution).
594pub fn collect_agent_names(nodes: &[WorkflowNode]) -> Vec<AgentRef> {
595    let mut refs = Vec::new();
596    for node in nodes {
597        match node {
598            WorkflowNode::Call(n) => {
599                refs.push(n.agent.clone());
600                if let Some(OnFail::Agent(ref a)) = n.on_fail {
601                    refs.push(a.clone());
602                }
603            }
604            WorkflowNode::CallWorkflow(n) => {
605                if let Some(OnFail::Agent(ref a)) = n.on_fail {
606                    refs.push(a.clone());
607                }
608            }
609            WorkflowNode::Script(n) => {
610                if let Some(OnFail::Agent(ref a)) = n.on_fail {
611                    refs.push(a.clone());
612                }
613            }
614            WorkflowNode::Parallel(n) => refs.extend(n.calls.iter().cloned()),
615            _ => {
616                if let Some(body) = node.body() {
617                    refs.extend(collect_agent_names(body));
618                }
619            }
620        }
621    }
622    refs
623}
624
625/// Collect all prompt snippet references (`with` values) from a node tree.
626pub(crate) fn collect_snippet_refs(nodes: &[WorkflowNode]) -> Vec<String> {
627    let mut refs = Vec::new();
628    for node in nodes {
629        match node {
630            WorkflowNode::Call(n) => refs.extend(n.with.iter().cloned()),
631            WorkflowNode::Parallel(n) => {
632                refs.extend(n.with.iter().cloned());
633                for extra in n.call_with.values() {
634                    refs.extend(extra.iter().cloned());
635                }
636            }
637            WorkflowNode::Do(n) => {
638                refs.extend(n.with.iter().cloned());
639                refs.extend(collect_snippet_refs(&n.body));
640            }
641            _ => {
642                if let Some(body) = node.body() {
643                    refs.extend(collect_snippet_refs(body));
644                }
645            }
646        }
647    }
648    refs
649}
650
651/// Collect all `call workflow` references in a node tree (for cycle detection).
652pub fn collect_workflow_refs(nodes: &[WorkflowNode]) -> Vec<String> {
653    let mut refs = Vec::new();
654    for node in nodes {
655        match node {
656            WorkflowNode::Call(_) | WorkflowNode::Gate(_) | WorkflowNode::Script(_) => {}
657            WorkflowNode::CallWorkflow(n) => refs.push(n.workflow.clone()),
658            WorkflowNode::If(n) => refs.extend(collect_workflow_refs(&n.body)),
659            WorkflowNode::Unless(n) => refs.extend(collect_workflow_refs(&n.body)),
660            WorkflowNode::While(n) => refs.extend(collect_workflow_refs(&n.body)),
661            WorkflowNode::DoWhile(n) => refs.extend(collect_workflow_refs(&n.body)),
662            WorkflowNode::Do(n) => refs.extend(collect_workflow_refs(&n.body)),
663            WorkflowNode::Parallel(_) => {} // parallel only contains agent calls
664            WorkflowNode::Always(n) => refs.extend(collect_workflow_refs(&n.body)),
665            // ForEach references a child workflow — include it for cycle detection
666            WorkflowNode::ForEach(n) => refs.push(n.workflow.clone()),
667        }
668    }
669    refs
670}
671
672/// Collect all output schema references (`output =` values) from a node tree.
673pub(crate) fn collect_schema_refs(nodes: &[WorkflowNode]) -> Vec<String> {
674    let mut refs = Vec::new();
675    for node in nodes {
676        match node {
677            WorkflowNode::Call(n) => {
678                if let Some(ref s) = n.output {
679                    refs.push(s.clone());
680                }
681            }
682            WorkflowNode::Do(n) => {
683                if let Some(ref s) = n.output {
684                    refs.push(s.clone());
685                }
686                refs.extend(collect_schema_refs(&n.body));
687            }
688            WorkflowNode::Parallel(n) => {
689                if let Some(ref s) = n.output {
690                    refs.push(s.clone());
691                }
692                refs.extend(n.call_outputs.values().cloned());
693            }
694            _ => {
695                if let Some(body) = node.body() {
696                    refs.extend(collect_schema_refs(body));
697                }
698            }
699        }
700    }
701    refs
702}
703
704/// Collect all as_identity values from a node tree.
705pub(crate) fn collect_as_identities(nodes: &[WorkflowNode]) -> Vec<String> {
706    let mut names = Vec::new();
707    for node in nodes {
708        match node {
709            WorkflowNode::Call(n) => {
710                if let Some(ref b) = n.as_identity {
711                    names.push(b.clone());
712                }
713            }
714            WorkflowNode::CallWorkflow(n) => {
715                if let Some(ref b) = n.as_identity {
716                    names.push(b.clone());
717                }
718            }
719            WorkflowNode::Gate(n) => {
720                if let Some(ref b) = n.as_identity {
721                    names.push(b.clone());
722                }
723            }
724            WorkflowNode::Script(n) => {
725                if let Some(ref b) = n.as_identity {
726                    names.push(b.clone());
727                }
728            }
729            _ => {
730                if let Some(body) = node.body() {
731                    names.extend(collect_as_identities(body));
732                }
733            }
734        }
735    }
736    names
737}
738
739/// Collect all per-step plugin_dirs from call nodes in a node tree.
740pub(crate) fn collect_plugin_dirs(nodes: &[WorkflowNode]) -> Vec<String> {
741    let mut dirs = Vec::new();
742    for node in nodes {
743        match node {
744            WorkflowNode::Call(n) => dirs.extend(n.plugin_dirs.iter().cloned()),
745            _ => {
746                if let Some(body) = node.body() {
747                    dirs.extend(collect_plugin_dirs(body));
748                }
749            }
750        }
751    }
752    dirs
753}
754
755#[cfg(test)]
756mod tests {
757    use std::collections::HashMap;
758
759    use super::*;
760
761    // ── helpers ──────────────────────────────────────────────────────────────
762
763    fn simple_wf(body: Vec<WorkflowNode>) -> WorkflowDef {
764        WorkflowDef {
765            name: "test_wf".to_string(),
766            title: None,
767            description: String::new(),
768            trigger: WorkflowTrigger::Manual,
769            targets: vec![],
770            group: None,
771            inputs: vec![],
772            body,
773            always: vec![],
774            source_path: "test.wf".to_string(),
775        }
776    }
777
778    fn call(agent: &str) -> WorkflowNode {
779        WorkflowNode::Call(CallNode {
780            agent: AgentRef::Name(agent.to_string()),
781            retries: 0,
782            on_fail: None,
783            output: None,
784            with: vec![],
785            as_identity: None,
786            plugin_dirs: vec![],
787            timeout: None,
788            max_turns: None,
789        })
790    }
791
792    fn call_with_output(agent: &str, output: &str) -> WorkflowNode {
793        WorkflowNode::Call(CallNode {
794            agent: AgentRef::Name(agent.to_string()),
795            output: Some(output.to_string()),
796            retries: 0,
797            on_fail: None,
798            with: vec![],
799            as_identity: None,
800            plugin_dirs: vec![],
801            timeout: None,
802            max_turns: None,
803        })
804    }
805
806    fn call_with_snippets(agent: &str, snippets: &[&str]) -> WorkflowNode {
807        WorkflowNode::Call(CallNode {
808            agent: AgentRef::Name(agent.to_string()),
809            with: snippets.iter().map(|s| s.to_string()).collect(),
810            retries: 0,
811            on_fail: None,
812            output: None,
813            as_identity: None,
814            plugin_dirs: vec![],
815            timeout: None,
816            max_turns: None,
817        })
818    }
819
820    fn call_with_plugin_dirs(agent: &str, dirs: &[&str]) -> WorkflowNode {
821        WorkflowNode::Call(CallNode {
822            agent: AgentRef::Name(agent.to_string()),
823            plugin_dirs: dirs.iter().map(|s| s.to_string()).collect(),
824            retries: 0,
825            on_fail: None,
826            output: None,
827            with: vec![],
828            as_identity: None,
829            timeout: None,
830            max_turns: None,
831        })
832    }
833
834    fn call_with_identity(agent: &str, identity: &str) -> WorkflowNode {
835        WorkflowNode::Call(CallNode {
836            agent: AgentRef::Name(agent.to_string()),
837            as_identity: Some(identity.to_string()),
838            retries: 0,
839            on_fail: None,
840            output: None,
841            with: vec![],
842            plugin_dirs: vec![],
843            timeout: None,
844            max_turns: None,
845        })
846    }
847
848    fn do_while_node(step: &str, max_iter: u32, body: Vec<WorkflowNode>) -> WorkflowNode {
849        WorkflowNode::DoWhile(DoWhileNode {
850            step: step.to_string(),
851            marker: "done".to_string(),
852            max_iterations: max_iter,
853            stuck_after: None,
854            on_max_iter: OnMaxIter::Fail,
855            body,
856        })
857    }
858
859    fn while_node(step: &str, max_iter: u32, body: Vec<WorkflowNode>) -> WorkflowNode {
860        WorkflowNode::While(WhileNode {
861            step: step.to_string(),
862            marker: "needs_revision".to_string(),
863            max_iterations: max_iter,
864            stuck_after: None,
865            on_max_iter: OnMaxIter::Fail,
866            body,
867        })
868    }
869
870    fn if_node(step: &str, marker: &str, body: Vec<WorkflowNode>) -> WorkflowNode {
871        WorkflowNode::If(IfNode {
872            condition: Condition::StepMarker {
873                step: step.to_string(),
874                marker: marker.to_string(),
875            },
876            body,
877        })
878    }
879
880    fn call_workflow(name: &str) -> WorkflowNode {
881        WorkflowNode::CallWorkflow(CallWorkflowNode {
882            workflow: name.to_string(),
883            inputs: HashMap::new(),
884            retries: 0,
885            on_fail: None,
886            as_identity: None,
887        })
888    }
889
890    fn script_node(name: &str, run: &str) -> WorkflowNode {
891        WorkflowNode::Script(ScriptNode {
892            name: name.to_string(),
893            run: run.to_string(),
894            env: HashMap::new(),
895            timeout: None,
896            retries: 0,
897            on_fail: None,
898            as_identity: None,
899        })
900    }
901
902    // ── WorkflowDef::display_name ─────────────────────────────────────────────
903
904    #[test]
905    fn display_name_returns_title_when_set() {
906        let mut wf = simple_wf(vec![]);
907        wf.title = Some("My Workflow".to_string());
908        assert_eq!(wf.display_name(), "My Workflow");
909    }
910
911    #[test]
912    fn display_name_falls_back_to_name_when_no_title() {
913        let wf = simple_wf(vec![]);
914        assert_eq!(wf.display_name(), "test_wf");
915    }
916
917    // ── WorkflowDef::total_nodes ──────────────────────────────────────────────
918
919    #[test]
920    fn total_nodes_flat_list() {
921        let wf = simple_wf(vec![call("a"), call("b"), call("c")]);
922        assert_eq!(wf.total_nodes(), 3);
923    }
924
925    #[test]
926    fn total_nodes_includes_nested_nodes() {
927        let nested = if_node("a", "done", vec![call("b"), call("c")]);
928        let wf = simple_wf(vec![call("a"), nested]);
929        assert_eq!(wf.total_nodes(), 4);
930    }
931
932    #[test]
933    fn total_nodes_includes_always_block() {
934        let mut wf = simple_wf(vec![call("a")]);
935        wf.always = vec![call("cleanup")];
936        assert_eq!(wf.total_nodes(), 2);
937    }
938
939    // ── WorkflowDef::top_level_steps ─────────────────────────────────────────
940
941    #[test]
942    fn top_level_steps_returns_only_direct_children() {
943        let nested = if_node("a", "done", vec![call("b"), call("c")]);
944        let wf = simple_wf(vec![call("a"), nested]);
945        assert_eq!(wf.top_level_steps(), 2);
946    }
947
948    #[test]
949    fn top_level_steps_includes_always_block() {
950        let mut wf = simple_wf(vec![call("a"), call("b")]);
951        wf.always = vec![call("cleanup")];
952        assert_eq!(wf.top_level_steps(), 3);
953    }
954
955    // ── WorkflowDef::max_iterations_for_step ─────────────────────────────────
956
957    #[test]
958    fn max_iterations_for_step_found_in_do_while() {
959        let wf = simple_wf(vec![do_while_node("reviewer", 5, vec![call("reviewer")])]);
960        assert_eq!(wf.max_iterations_for_step("reviewer"), Some(5));
961    }
962
963    #[test]
964    fn max_iterations_for_step_found_in_while() {
965        let wf = simple_wf(vec![
966            call("reviewer"),
967            while_node("reviewer", 3, vec![call("fix")]),
968        ]);
969        assert_eq!(wf.max_iterations_for_step("reviewer"), Some(3));
970    }
971
972    #[test]
973    fn max_iterations_for_step_not_found_returns_none() {
974        let wf = simple_wf(vec![call("a"), call("b")]);
975        assert_eq!(wf.max_iterations_for_step("a"), None);
976    }
977
978    #[test]
979    fn max_iterations_for_step_nested_loop() {
980        let inner = do_while_node("inner", 2, vec![call("inner")]);
981        let outer = while_node("outer", 10, vec![call("outer"), inner]);
982        let wf = simple_wf(vec![outer]);
983        assert_eq!(wf.max_iterations_for_step("inner"), Some(2));
984        assert_eq!(wf.max_iterations_for_step("outer"), Some(10));
985    }
986
987    // ── count_nodes ──────────────────────────────────────────────────────────
988
989    #[test]
990    fn count_nodes_flat_list() {
991        let nodes = vec![call("a"), call("b")];
992        assert_eq!(count_nodes(&nodes), 2);
993    }
994
995    #[test]
996    fn count_nodes_parallel_counts_calls() {
997        let parallel = WorkflowNode::Parallel(ParallelNode {
998            fail_fast: true,
999            min_success: None,
1000            calls: vec![
1001                AgentRef::Name("a".to_string()),
1002                AgentRef::Name("b".to_string()),
1003            ],
1004            output: None,
1005            call_outputs: HashMap::new(),
1006            with: vec![],
1007            call_with: HashMap::new(),
1008            call_if: HashMap::new(),
1009            call_retries: HashMap::new(),
1010        });
1011        let nodes = vec![parallel];
1012        assert_eq!(count_nodes(&nodes), 3); // 1 parallel node + 2 calls
1013    }
1014
1015    #[test]
1016    fn count_nodes_recursive_into_if_body() {
1017        let nested = if_node("a", "done", vec![call("b"), call("c")]);
1018        assert_eq!(count_nodes(&[nested]), 3); // if + 2 body
1019    }
1020
1021    // ── collect_agent_names ───────────────────────────────────────────────────
1022
1023    #[test]
1024    fn collect_agent_names_flat_call_nodes() {
1025        let nodes = vec![call("agent_a"), call("agent_b")];
1026        let refs = collect_agent_names(&nodes);
1027        let names: Vec<&str> = refs.iter().map(|r| r.label()).collect();
1028        assert!(names.contains(&"agent_a"));
1029        assert!(names.contains(&"agent_b"));
1030    }
1031
1032    #[test]
1033    fn collect_agent_names_deduplication_when_sorted() {
1034        let nodes = vec![call("agent_a"), call("agent_a"), call("agent_b")];
1035        let mut refs = collect_agent_names(&nodes);
1036        refs.sort();
1037        refs.dedup();
1038        assert_eq!(refs.len(), 2);
1039    }
1040
1041    #[test]
1042    fn collect_agent_names_parallel_node() {
1043        let parallel = WorkflowNode::Parallel(ParallelNode {
1044            fail_fast: true,
1045            min_success: None,
1046            calls: vec![
1047                AgentRef::Name("par_a".to_string()),
1048                AgentRef::Name("par_b".to_string()),
1049            ],
1050            output: None,
1051            call_outputs: HashMap::new(),
1052            with: vec![],
1053            call_with: HashMap::new(),
1054            call_if: HashMap::new(),
1055            call_retries: HashMap::new(),
1056        });
1057        let refs = collect_agent_names(&[parallel]);
1058        let names: Vec<&str> = refs.iter().map(|r| r.label()).collect();
1059        assert!(names.contains(&"par_a"));
1060        assert!(names.contains(&"par_b"));
1061    }
1062
1063    #[test]
1064    fn collect_all_agent_refs_deduplicates_and_sorts() {
1065        let wf = simple_wf(vec![call("z_agent"), call("a_agent"), call("z_agent")]);
1066        let refs = wf.collect_all_agent_refs();
1067        assert_eq!(refs.len(), 2);
1068        assert_eq!(refs[0].label(), "a_agent");
1069        assert_eq!(refs[1].label(), "z_agent");
1070    }
1071
1072    // ── collect_snippet_refs ──────────────────────────────────────────────────
1073
1074    #[test]
1075    fn collect_snippet_refs_from_call_with() {
1076        let nodes = vec![call_with_snippets("agent", &["ctx_a", "ctx_b"])];
1077        let refs = collect_snippet_refs(&nodes);
1078        assert!(refs.contains(&"ctx_a".to_string()));
1079        assert!(refs.contains(&"ctx_b".to_string()));
1080    }
1081
1082    #[test]
1083    fn collect_all_snippet_refs_deduplicates() {
1084        let wf = simple_wf(vec![
1085            call_with_snippets("a", &["shared"]),
1086            call_with_snippets("b", &["shared", "unique"]),
1087        ]);
1088        let refs = wf.collect_all_snippet_refs();
1089        assert_eq!(refs.iter().filter(|s| *s == "shared").count(), 1);
1090        assert_eq!(refs.len(), 2);
1091    }
1092
1093    // ── collect_workflow_refs ─────────────────────────────────────────────────
1094
1095    #[test]
1096    fn collect_workflow_refs_from_call_workflow() {
1097        let nodes = vec![call_workflow("child_wf"), call_workflow("other_wf")];
1098        let refs = collect_workflow_refs(&nodes);
1099        assert!(refs.contains(&"child_wf".to_string()));
1100        assert!(refs.contains(&"other_wf".to_string()));
1101    }
1102
1103    #[test]
1104    fn collect_workflow_refs_skips_call_nodes() {
1105        let nodes = vec![call("agent"), call_workflow("child_wf")];
1106        let refs = collect_workflow_refs(&nodes);
1107        assert_eq!(refs.len(), 1);
1108        assert_eq!(refs[0], "child_wf");
1109    }
1110
1111    // ── collect_schema_refs ───────────────────────────────────────────────────
1112
1113    #[test]
1114    fn collect_schema_refs_from_call_output() {
1115        let nodes = vec![call_with_output("agent", "my_schema")];
1116        let refs = collect_schema_refs(&nodes);
1117        assert!(refs.contains(&"my_schema".to_string()));
1118    }
1119
1120    #[test]
1121    fn collect_all_schema_refs_deduplicates() {
1122        let wf = simple_wf(vec![
1123            call_with_output("a", "schema"),
1124            call_with_output("b", "schema"),
1125        ]);
1126        let refs = wf.collect_all_schema_refs();
1127        assert_eq!(refs.iter().filter(|s| *s == "schema").count(), 1);
1128    }
1129
1130    // ── collect_as_identities ─────────────────────────────────────────────────
1131
1132    #[test]
1133    fn collect_as_identities_from_call_nodes() {
1134        let nodes = vec![call_with_identity("agent", "bot-app")];
1135        let names = collect_as_identities(&nodes);
1136        assert!(names.contains(&"bot-app".to_string()));
1137    }
1138
1139    #[test]
1140    fn collect_all_as_identities_deduplicates() {
1141        let wf = simple_wf(vec![
1142            call_with_identity("a", "bot"),
1143            call_with_identity("b", "bot"),
1144        ]);
1145        let names = wf.collect_all_as_identities();
1146        assert_eq!(names.iter().filter(|n| *n == "bot").count(), 1);
1147    }
1148
1149    // ── collect_plugin_dirs ───────────────────────────────────────────────────
1150
1151    #[test]
1152    fn collect_plugin_dirs_from_call_nodes() {
1153        let nodes = vec![call_with_plugin_dirs("agent", &["/opt/plugins"])];
1154        let dirs = collect_plugin_dirs(&nodes);
1155        assert!(dirs.contains(&"/opt/plugins".to_string()));
1156    }
1157
1158    #[test]
1159    fn collect_all_plugin_dirs_deduplicates() {
1160        let wf = simple_wf(vec![
1161            call_with_plugin_dirs("a", &["/opt/shared"]),
1162            call_with_plugin_dirs("b", &["/opt/shared", "/opt/unique"]),
1163        ]);
1164        let dirs = wf.collect_all_plugin_dirs();
1165        assert_eq!(dirs.iter().filter(|d| *d == "/opt/shared").count(), 1);
1166        assert_eq!(dirs.len(), 2);
1167    }
1168
1169    // ── AgentRef::step_key ────────────────────────────────────────────────────
1170
1171    #[test]
1172    fn agent_ref_name_step_key_returns_name() {
1173        let r = AgentRef::Name("my_agent".to_string());
1174        assert_eq!(r.step_key(), "my_agent");
1175    }
1176
1177    #[test]
1178    fn agent_ref_path_step_key_returns_file_stem() {
1179        let r = AgentRef::Path(".claude/agents/plan.md".to_string());
1180        assert_eq!(r.step_key(), "plan");
1181    }
1182
1183    #[test]
1184    fn agent_ref_label_returns_inner_string() {
1185        assert_eq!(AgentRef::Name("foo".to_string()).label(), "foo");
1186        assert_eq!(
1187            AgentRef::Path("bar/baz.md".to_string()).label(),
1188            "bar/baz.md"
1189        );
1190    }
1191
1192    // ── WorkflowTrigger serde ─────────────────────────────────────────────────
1193
1194    #[test]
1195    fn workflow_trigger_serde_round_trip() {
1196        for (variant, expected_json) in [
1197            (WorkflowTrigger::Manual, r#""manual""#),
1198            (WorkflowTrigger::Pr, r#""pr""#),
1199            (WorkflowTrigger::Scheduled, r#""scheduled""#),
1200        ] {
1201            let json = serde_json::to_string(&variant).unwrap();
1202            assert_eq!(json, expected_json, "display mismatch for {variant:?}");
1203            let back: WorkflowTrigger = serde_json::from_str(&json).unwrap();
1204            assert_eq!(back, variant);
1205        }
1206    }
1207
1208    // ── Enum serde round-trips ────────────────────────────────────────────────
1209
1210    #[test]
1211    fn on_max_iter_serde_round_trip() {
1212        let json = serde_json::to_string(&OnMaxIter::Continue).unwrap();
1213        assert_eq!(json, r#""continue""#);
1214        let back: OnMaxIter = serde_json::from_str(&json).unwrap();
1215        assert_eq!(back, OnMaxIter::Continue);
1216    }
1217
1218    #[test]
1219    fn on_timeout_serde_round_trip() {
1220        let json = serde_json::to_string(&OnTimeout::Fail).unwrap();
1221        let back: OnTimeout = serde_json::from_str(&json).unwrap();
1222        assert_eq!(back, OnTimeout::Fail);
1223    }
1224
1225    #[test]
1226    fn on_child_fail_serde_all_variants() {
1227        for variant in [
1228            OnChildFail::Halt,
1229            OnChildFail::Continue,
1230            OnChildFail::SkipDependents,
1231        ] {
1232            let json = serde_json::to_string(&variant).unwrap();
1233            let back: OnChildFail = serde_json::from_str(&json).unwrap();
1234            assert_eq!(back, variant);
1235        }
1236    }
1237
1238    #[test]
1239    fn on_cycle_serde_all_variants() {
1240        for variant in [OnCycle::Fail, OnCycle::Warn] {
1241            let json = serde_json::to_string(&variant).unwrap();
1242            let back: OnCycle = serde_json::from_str(&json).unwrap();
1243            assert_eq!(back, variant);
1244        }
1245    }
1246
1247    #[test]
1248    fn approval_mode_serde_all_variants() {
1249        for variant in [ApprovalMode::MinApprovals, ApprovalMode::ReviewDecision] {
1250            let json = serde_json::to_string(&variant).unwrap();
1251            let back: ApprovalMode = serde_json::from_str(&json).unwrap();
1252            assert_eq!(back, variant);
1253        }
1254    }
1255
1256    #[test]
1257    fn on_fail_action_serde_all_variants() {
1258        for variant in [OnFailAction::Fail, OnFailAction::Continue] {
1259            let json = serde_json::to_string(&variant).unwrap();
1260            let back: OnFailAction = serde_json::from_str(&json).unwrap();
1261            assert_eq!(back, variant);
1262        }
1263    }
1264
1265    #[test]
1266    fn on_fail_agent_variant_serde() {
1267        let val = OnFail::Agent(AgentRef::Name("fallback".to_string()));
1268        let json = serde_json::to_string(&val).unwrap();
1269        assert!(json.contains("agent"), "got: {json}");
1270        let back: OnFail = serde_json::from_str(&json).unwrap();
1271        assert_eq!(back, OnFail::Agent(AgentRef::Name("fallback".to_string())));
1272    }
1273
1274    #[test]
1275    fn on_fail_continue_variant_serde() {
1276        let json = serde_json::to_string(&OnFail::Continue).unwrap();
1277        let back: OnFail = serde_json::from_str(&json).unwrap();
1278        assert_eq!(back, OnFail::Continue);
1279    }
1280
1281    #[test]
1282    fn input_type_serde_all_variants() {
1283        assert_eq!(
1284            serde_json::to_string(&InputType::String).unwrap(),
1285            r#""string""#
1286        );
1287        assert_eq!(
1288            serde_json::to_string(&InputType::Boolean).unwrap(),
1289            r#""boolean""#
1290        );
1291    }
1292
1293    // ── Script node helper ────────────────────────────────────────────────────
1294
1295    #[test]
1296    fn script_node_collect_included_in_total() {
1297        let wf = simple_wf(vec![script_node("lint", "./scripts/lint.sh")]);
1298        assert_eq!(wf.total_nodes(), 1);
1299    }
1300}