Skip to main content

axon/
step_deps.rs

1//! Step dependency analysis — variable-based dependency graph between steps.
2//!
3//! Analyzes `$variable` / `${variable}` references in step prompts to build
4//! a dependency graph. This enables:
5//!   - Detection of which steps can potentially run in parallel
6//!   - Validation that referenced variables are actually produced
7//!   - Execution plan visualization with dependency chains
8//!
9//! Built-in variables ($result, $step_name, $flow_name, etc.) are excluded
10//! from dependency analysis as they are runtime-injected, not step-produced.
11
12use std::collections::{HashMap, HashSet};
13
14// ── Built-in variables (not produced by steps) ─────────────────────────────
15
16const BUILTIN_VARS: &[&str] = &[
17    "result",
18    "step_name",
19    "step_type",
20    "flow_name",
21    "persona_name",
22    "unit_index",
23    "step_index",
24];
25
26fn is_builtin(var: &str) -> bool {
27    BUILTIN_VARS.contains(&var)
28}
29
30// ── Variable extraction ────────────────────────────────────────────────────
31
32/// Extract all variable references from a string.
33/// Returns the set of variable names referenced via $name or ${name}.
34pub fn extract_refs(text: &str) -> HashSet<String> {
35    let mut refs = HashSet::new();
36    let bytes = text.as_bytes();
37    let mut i = 0;
38
39    while i < bytes.len() {
40        if bytes[i] == b'$' && i + 1 < bytes.len() {
41            if bytes[i + 1] == b'{' {
42                // ${name} form
43                if let Some(close) = text[i + 2..].find('}') {
44                    let var_name = &text[i + 2..i + 2 + close];
45                    if !var_name.is_empty() {
46                        refs.insert(var_name.to_string());
47                    }
48                    i += 3 + close;
49                    continue;
50                }
51            } else if bytes[i + 1].is_ascii_alphabetic() || bytes[i + 1] == b'_' {
52                // $name form
53                let start = i + 1;
54                let mut end = start;
55                while end < bytes.len()
56                    && (bytes[end].is_ascii_alphanumeric() || bytes[end] == b'_')
57                {
58                    end += 1;
59                }
60                let var_name = &text[start..end];
61                refs.insert(var_name.to_string());
62                i = end;
63                continue;
64            }
65        }
66        i += 1;
67    }
68
69    refs
70}
71
72/// §Fase 61 — augment a step's dependency-analysis `argument` with the step
73/// references carried by a `use Tool(k = v)` call's keyword arguments.
74///
75/// [`analyze`] only scans `user_prompt` + `argument` for `${name}` references.
76/// A keyword-form tool call carries its values in `named_args`, NOT in the
77/// single-`on <arg>` string, so without this its data-dependencies are
78/// invisible: the scheduler classifies the call as a root and co-schedules it
79/// in the SAME wave as the steps it consumes — whose results are therefore
80/// absent from the call's pre-wave context snapshot, so the value resolves to
81/// empty (§60 reference) or to the literal name (pre-§60). This was the unifying
82/// root cause behind a multi-arg `use Tool` whose source steps are independent.
83///
84/// Each `named_args` entry is `(name, value, value_kind)`:
85/// - `"reference"` — the value names the producing step directly (`Extract` or
86///   `Extract.output`; the `.output` suffix maps to the step-name key, mirroring
87///   the runtime resolver in `exec_context`).
88/// - any other kind (`"literal"`) — the value may interpolate one via `${…}`.
89///
90/// We append the referenced names as `${name}` tokens to `base`, GATED to names
91/// that are real steps (`step_names`) so a bare flow-param reference adds no
92/// spurious dependency. The result feeds [`analyze`] as the step's `argument`;
93/// `extract_refs` then rediscovers them as ordinary dependency edges.
94pub fn use_tool_analysis_argument(
95    base: &str,
96    named_args: &[(String, String, String)],
97    step_names: &HashSet<&str>,
98) -> String {
99    let mut arg = base.to_string();
100    for (_name, value, kind) in named_args {
101        if kind == "reference" {
102            let dep = value.strip_suffix(".output").unwrap_or(value);
103            if step_names.contains(dep) {
104                arg.push_str(" ${");
105                arg.push_str(dep);
106                arg.push('}');
107            }
108        } else {
109            for r in extract_refs(value) {
110                if step_names.contains(r.as_str()) {
111                    arg.push_str(" ${");
112                    arg.push_str(&r);
113                    arg.push('}');
114                }
115            }
116        }
117    }
118    arg
119}
120
121// ── Step info for analysis ─────────────────────────────────────────────────
122
123/// Minimal step representation for dependency analysis.
124#[derive(Debug, Clone)]
125pub struct StepInfo {
126    pub name: String,
127    pub step_type: String,
128    pub user_prompt: String,
129    /// For tool/memory steps: the argument expression.
130    pub argument: String,
131}
132
133// ── Dependency analysis result ─────────────────────────────────────────────
134
135/// Analysis result for a single step.
136#[derive(Debug, Clone)]
137pub struct StepDependency {
138    /// Step name.
139    pub name: String,
140    /// Step type.
141    pub step_type: String,
142    /// Steps this step depends on (via variable references).
143    pub depends_on: Vec<String>,
144    /// All variable references found (including builtins).
145    pub all_refs: Vec<String>,
146    /// Variable references that are step-produced (non-builtin).
147    pub step_refs: Vec<String>,
148    /// Whether this step has no step dependencies (can run first).
149    pub is_root: bool,
150}
151
152/// Full dependency graph for a unit's steps.
153#[derive(Debug)]
154pub struct DependencyGraph {
155    pub steps: Vec<StepDependency>,
156    /// Steps that can potentially run in parallel (no mutual dependencies).
157    pub parallel_groups: Vec<Vec<String>>,
158    /// Steps that reference undefined variables (not produced by any prior step).
159    pub unresolved_refs: Vec<(String, String)>,
160    /// Maximum depth of the dependency chain.
161    pub max_depth: usize,
162}
163
164// ── Analysis ───────────────────────────────────────────────────────────────
165
166/// Analyze dependencies between steps in a unit.
167pub fn analyze(steps: &[StepInfo]) -> DependencyGraph {
168    // 1. Build the set of step names (these are the "producers")
169    let step_names: HashSet<&str> = steps.iter().map(|s| s.name.as_str()).collect();
170
171    // 2. For each step, extract variable refs and resolve dependencies
172    let mut deps: Vec<StepDependency> = Vec::new();
173    let mut unresolved: Vec<(String, String)> = Vec::new();
174
175    for step in steps {
176        // Scan both user_prompt and argument for references
177        let mut all_refs: HashSet<String> = extract_refs(&step.user_prompt);
178        if !step.argument.is_empty() {
179            all_refs.extend(extract_refs(&step.argument));
180        }
181
182        let mut step_refs: Vec<String> = Vec::new();
183        let mut depends_on: Vec<String> = Vec::new();
184
185        for r in &all_refs {
186            if is_builtin(r) {
187                continue;
188            }
189            if step_names.contains(r.as_str()) {
190                step_refs.push(r.clone());
191                depends_on.push(r.clone());
192            } else {
193                unresolved.push((step.name.clone(), r.clone()));
194            }
195        }
196
197        depends_on.sort();
198        depends_on.dedup();
199        step_refs.sort();
200
201        let mut all_refs_sorted: Vec<String> = all_refs.into_iter().collect();
202        all_refs_sorted.sort();
203
204        deps.push(StepDependency {
205            name: step.name.clone(),
206            step_type: step.step_type.clone(),
207            is_root: depends_on.is_empty(),
208            depends_on,
209            all_refs: all_refs_sorted,
210            step_refs,
211        });
212    }
213
214    // 3. Detect parallel groups (steps with no mutual dependencies)
215    let parallel_groups = find_parallel_groups(&deps);
216
217    // 4. Calculate max depth
218    let max_depth = calculate_max_depth(&deps);
219
220    DependencyGraph {
221        steps: deps,
222        parallel_groups,
223        unresolved_refs: unresolved,
224        max_depth,
225    }
226}
227
228/// Find groups of steps that can potentially execute in parallel.
229/// Steps at the same depth level with no mutual dependencies form a group.
230fn find_parallel_groups(deps: &[StepDependency]) -> Vec<Vec<String>> {
231    // Build transitive dependency sets via depth calculation
232    let dep_map: HashMap<&str, &StepDependency> =
233        deps.iter().map(|d| (d.name.as_str(), d)).collect();
234
235    // Calculate depth for each step
236    let mut depth_cache: HashMap<String, usize> = HashMap::new();
237    fn step_depth(
238        name: &str,
239        dep_map: &HashMap<&str, &StepDependency>,
240        cache: &mut HashMap<String, usize>,
241    ) -> usize {
242        if let Some(&cached) = cache.get(name) {
243            return cached;
244        }
245        let d = match dep_map.get(name) {
246            Some(d) => d,
247            None => return 0,
248        };
249        if d.depends_on.is_empty() {
250            cache.insert(name.to_string(), 0);
251            return 0;
252        }
253        let max_child = d
254            .depends_on
255            .iter()
256            .map(|dep| step_depth(dep, dep_map, cache))
257            .max()
258            .unwrap_or(0);
259        let result = max_child + 1;
260        cache.insert(name.to_string(), result);
261        result
262    }
263
264    for d in deps {
265        step_depth(&d.name, &dep_map, &mut depth_cache);
266    }
267
268    // Group steps by depth level
269    let mut by_depth: HashMap<usize, Vec<String>> = HashMap::new();
270    for d in deps {
271        let depth = depth_cache.get(&d.name).copied().unwrap_or(0);
272        by_depth.entry(depth).or_default().push(d.name.clone());
273    }
274
275    // Return only groups with more than one step (actual parallelism)
276    let mut groups: Vec<Vec<String>> = by_depth
277        .into_values()
278        .filter(|g| g.len() > 1)
279        .collect();
280    groups.sort_by_key(|g| g[0].clone());
281    groups
282}
283
284/// Calculate the maximum dependency chain depth.
285fn calculate_max_depth(deps: &[StepDependency]) -> usize {
286    let dep_map: HashMap<&str, &StepDependency> =
287        deps.iter().map(|d| (d.name.as_str(), d)).collect();
288
289    fn depth(
290        name: &str,
291        dep_map: &HashMap<&str, &StepDependency>,
292        cache: &mut HashMap<String, usize>,
293    ) -> usize {
294        if let Some(&cached) = cache.get(name) {
295            return cached;
296        }
297        let d = match dep_map.get(name) {
298            Some(d) => d,
299            None => return 0,
300        };
301        if d.depends_on.is_empty() {
302            cache.insert(name.to_string(), 0);
303            return 0;
304        }
305        let max_child = d
306            .depends_on
307            .iter()
308            .map(|dep| depth(dep, dep_map, cache))
309            .max()
310            .unwrap_or(0);
311        let result = max_child + 1;
312        cache.insert(name.to_string(), result);
313        result
314    }
315
316    let mut cache = HashMap::new();
317    deps.iter()
318        .map(|d| depth(&d.name, &dep_map, &mut cache))
319        .max()
320        .unwrap_or(0)
321}
322
323// ── Tests ──────────────────────────────────────────────────────────────────
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328
329    #[test]
330    fn extract_refs_dollar_name() {
331        let refs = extract_refs("Use $result from $Analyze");
332        assert!(refs.contains("result"));
333        assert!(refs.contains("Analyze"));
334        assert_eq!(refs.len(), 2);
335    }
336
337    #[test]
338    fn extract_refs_braced() {
339        let refs = extract_refs("Given ${Extract} and ${Validate}");
340        assert!(refs.contains("Extract"));
341        assert!(refs.contains("Validate"));
342        assert_eq!(refs.len(), 2);
343    }
344
345    #[test]
346    fn extract_refs_mixed() {
347        let refs = extract_refs("$result is ${Analyze} plus $flow_name");
348        assert!(refs.contains("result"));
349        assert!(refs.contains("Analyze"));
350        assert!(refs.contains("flow_name"));
351        assert_eq!(refs.len(), 3);
352    }
353
354    #[test]
355    fn extract_refs_no_vars() {
356        let refs = extract_refs("plain text with no variables");
357        assert!(refs.is_empty());
358    }
359
360    #[test]
361    fn extract_refs_dollar_at_end() {
362        let refs = extract_refs("trailing $");
363        assert!(refs.is_empty());
364    }
365
366    #[test]
367    fn analyze_independent_steps() {
368        let steps = vec![
369            StepInfo {
370                name: "A".into(),
371                step_type: "step".into(),
372                user_prompt: "Do task A".into(),
373                argument: String::new(),
374            },
375            StepInfo {
376                name: "B".into(),
377                step_type: "step".into(),
378                user_prompt: "Do task B".into(),
379                argument: String::new(),
380            },
381        ];
382
383        let graph = analyze(&steps);
384        assert_eq!(graph.steps.len(), 2);
385        assert!(graph.steps[0].is_root);
386        assert!(graph.steps[1].is_root);
387        assert_eq!(graph.max_depth, 0);
388        // Both independent → one parallel group
389        assert_eq!(graph.parallel_groups.len(), 1);
390        assert_eq!(graph.parallel_groups[0].len(), 2);
391    }
392
393    #[test]
394    fn analyze_linear_chain() {
395        let steps = vec![
396            StepInfo {
397                name: "Extract".into(),
398                step_type: "step".into(),
399                user_prompt: "Extract entities".into(),
400                argument: String::new(),
401            },
402            StepInfo {
403                name: "Analyze".into(),
404                step_type: "step".into(),
405                user_prompt: "Analyze ${Extract}".into(),
406                argument: String::new(),
407            },
408            StepInfo {
409                name: "Report".into(),
410                step_type: "step".into(),
411                user_prompt: "Report on ${Analyze}".into(),
412                argument: String::new(),
413            },
414        ];
415
416        let graph = analyze(&steps);
417
418        // Extract is root
419        assert!(graph.steps[0].is_root);
420        assert!(graph.steps[0].depends_on.is_empty());
421
422        // Analyze depends on Extract
423        assert!(!graph.steps[1].is_root);
424        assert_eq!(graph.steps[1].depends_on, vec!["Extract"]);
425
426        // Report depends on Analyze
427        assert!(!graph.steps[2].is_root);
428        assert_eq!(graph.steps[2].depends_on, vec!["Analyze"]);
429
430        // Max depth is 2 (Extract→Analyze→Report)
431        assert_eq!(graph.max_depth, 2);
432
433        // No parallel groups (all sequential)
434        assert!(graph.parallel_groups.is_empty());
435    }
436
437    #[test]
438    fn analyze_diamond_pattern() {
439        // A → B, A → C, B+C → D
440        let steps = vec![
441            StepInfo {
442                name: "A".into(),
443                step_type: "step".into(),
444                user_prompt: "Start".into(),
445                argument: String::new(),
446            },
447            StepInfo {
448                name: "B".into(),
449                step_type: "step".into(),
450                user_prompt: "Process ${A} path B".into(),
451                argument: String::new(),
452            },
453            StepInfo {
454                name: "C".into(),
455                step_type: "step".into(),
456                user_prompt: "Process ${A} path C".into(),
457                argument: String::new(),
458            },
459            StepInfo {
460                name: "D".into(),
461                step_type: "step".into(),
462                user_prompt: "Merge ${B} and ${C}".into(),
463                argument: String::new(),
464            },
465        ];
466
467        let graph = analyze(&steps);
468
469        assert!(graph.steps[0].is_root); // A
470        assert_eq!(graph.steps[1].depends_on, vec!["A"]); // B→A
471        assert_eq!(graph.steps[2].depends_on, vec!["A"]); // C→A
472        assert_eq!(graph.steps[3].depends_on, vec!["B", "C"]); // D→B,C
473
474        // B and C can be parallel
475        assert!(!graph.parallel_groups.is_empty());
476        let has_bc_group = graph.parallel_groups.iter().any(|g| {
477            g.len() == 2 && g.contains(&"B".to_string()) && g.contains(&"C".to_string())
478        });
479        assert!(has_bc_group);
480
481        // Max depth: A→B→D or A→C→D = 2
482        assert_eq!(graph.max_depth, 2);
483    }
484
485    #[test]
486    fn analyze_builtin_vars_excluded() {
487        let steps = vec![
488            StepInfo {
489                name: "S1".into(),
490                step_type: "step".into(),
491                user_prompt: "Current step is $step_name in $flow_name".into(),
492                argument: String::new(),
493            },
494        ];
495
496        let graph = analyze(&steps);
497        assert!(graph.steps[0].is_root);
498        assert!(graph.steps[0].depends_on.is_empty());
499        // All refs include builtins
500        assert!(graph.steps[0].all_refs.contains(&"step_name".to_string()));
501        assert!(graph.steps[0].all_refs.contains(&"flow_name".to_string()));
502        // But step_refs is empty (no step-produced refs)
503        assert!(graph.steps[0].step_refs.is_empty());
504    }
505
506    #[test]
507    fn analyze_unresolved_refs() {
508        let steps = vec![
509            StepInfo {
510                name: "S1".into(),
511                step_type: "step".into(),
512                user_prompt: "Use ${NonExistent} data".into(),
513                argument: String::new(),
514            },
515        ];
516
517        let graph = analyze(&steps);
518        assert_eq!(graph.unresolved_refs.len(), 1);
519        assert_eq!(graph.unresolved_refs[0], ("S1".to_string(), "NonExistent".to_string()));
520    }
521
522    #[test]
523    fn analyze_argument_refs() {
524        let steps = vec![
525            StepInfo {
526                name: "Gather".into(),
527                step_type: "step".into(),
528                user_prompt: "Gather data".into(),
529                argument: String::new(),
530            },
531            StepInfo {
532                name: "Calc".into(),
533                step_type: "use_tool".into(),
534                user_prompt: "Calculate".into(),
535                argument: "${Gather}".into(),
536            },
537        ];
538
539        let graph = analyze(&steps);
540        assert_eq!(graph.steps[1].depends_on, vec!["Gather"]);
541    }
542
543    // ── §Fase 61 — use Tool(k = v) named-arg dependencies ───────────────────
544
545    #[test]
546    fn use_tool_reference_dotted_creates_dep() {
547        let names: HashSet<&str> = ["ExtractUrl"].into_iter().collect();
548        let na = vec![("url".into(), "ExtractUrl.output".into(), "reference".into())];
549        let arg = use_tool_analysis_argument("", &na, &names);
550        assert!(extract_refs(&arg).contains("ExtractUrl"));
551    }
552
553    #[test]
554    fn use_tool_reference_bare_creates_dep() {
555        let names: HashSet<&str> = ["ExtractUrl"].into_iter().collect();
556        let na = vec![("url".into(), "ExtractUrl".into(), "reference".into())];
557        let arg = use_tool_analysis_argument("", &na, &names);
558        assert!(extract_refs(&arg).contains("ExtractUrl"));
559    }
560
561    #[test]
562    fn use_tool_literal_interpolation_creates_dep() {
563        let names: HashSet<&str> = ["ExtractCompany"].into_iter().collect();
564        let na = vec![("c".into(), "${ExtractCompany}".into(), "literal".into())];
565        let arg = use_tool_analysis_argument("", &na, &names);
566        assert!(extract_refs(&arg).contains("ExtractCompany"));
567    }
568
569    #[test]
570    fn use_tool_flow_param_reference_is_not_a_step_dep() {
571        // `src = user_input` — a flow-param reference is valid (§60.c type-checks
572        // it) but is NOT a step, so it must add no synthetic dependency token.
573        let names: HashSet<&str> = ["ExtractUrl"].into_iter().collect();
574        let na = vec![("src".into(), "user_input".into(), "reference".into())];
575        let arg = use_tool_analysis_argument("base", &na, &names);
576        assert!(!arg.contains("${user_input}"));
577        assert!(!extract_refs(&arg).contains("user_input"));
578    }
579
580    #[test]
581    fn use_tool_literal_plain_value_no_dep() {
582        let names: HashSet<&str> = ["ExtractUrl"].into_iter().collect();
583        let na = vec![("mode".into(), "production".into(), "literal".into())];
584        let arg = use_tool_analysis_argument("", &na, &names);
585        assert!(extract_refs(&arg).is_empty());
586    }
587
588    #[test]
589    fn use_tool_multi_arg_orders_after_independent_sources() {
590        // The Kivi repro: two independent extraction steps (roots that depend
591        // only on the flow-param) + a `use Tool` consuming both via `.output`.
592        // Pre-§61 the call was a root → same wave as its sources → snapshot
593        // race. Post-§61 the dependency is visible → a strictly later wave.
594        let names: HashSet<&str> =
595            ["ExtractCompany", "ExtractDomain", "GenerateRadar"].into_iter().collect();
596        let na = vec![
597            ("company".into(), "ExtractCompany.output".into(), "reference".into()),
598            ("domain".into(), "ExtractDomain.output".into(), "reference".into()),
599        ];
600        let arg = use_tool_analysis_argument("", &na, &names);
601        let steps = vec![
602            StepInfo {
603                name: "ExtractCompany".into(),
604                step_type: "step".into(),
605                user_prompt: "extract the company from ${user_input}".into(),
606                argument: String::new(),
607            },
608            StepInfo {
609                name: "ExtractDomain".into(),
610                step_type: "step".into(),
611                user_prompt: "extract the domain from ${user_input}".into(),
612                argument: String::new(),
613            },
614            StepInfo {
615                name: "GenerateRadar".into(),
616                step_type: "use_tool".into(),
617                user_prompt: String::new(),
618                argument: arg,
619            },
620        ];
621        let graph = analyze(&steps);
622        let radar = graph.steps.iter().find(|s| s.name == "GenerateRadar").unwrap();
623        assert!(radar.depends_on.contains(&"ExtractCompany".to_string()));
624        assert!(radar.depends_on.contains(&"ExtractDomain".to_string()));
625        assert!(!radar.is_root);
626
627        // The schedule must NOT place GenerateRadar in wave 0 with the extractors.
628        let sched = crate::parallel::build_schedule(&graph);
629        assert!(!sched.waves[0].steps.contains(&"GenerateRadar".to_string()));
630    }
631
632    #[test]
633    fn analyze_empty_steps() {
634        let graph = analyze(&[]);
635        assert!(graph.steps.is_empty());
636        assert!(graph.parallel_groups.is_empty());
637        assert_eq!(graph.max_depth, 0);
638    }
639
640    #[test]
641    fn max_depth_flat() {
642        let steps = vec![
643            StepInfo { name: "A".into(), step_type: "step".into(), user_prompt: "a".into(), argument: String::new() },
644            StepInfo { name: "B".into(), step_type: "step".into(), user_prompt: "b".into(), argument: String::new() },
645            StepInfo { name: "C".into(), step_type: "step".into(), user_prompt: "c".into(), argument: String::new() },
646        ];
647        assert_eq!(analyze(&steps).max_depth, 0);
648    }
649}