1use std::collections::{HashMap, HashSet};
13
14const BUILTIN_VARS: &[&str] = &[
17 "result",
18 "step_name",
19 "step_type",
20 "flow_name",
21 "persona_name",
22 "unit_index",
23 "step_index",
24];
25
26fn is_builtin(var: &str) -> bool {
27 BUILTIN_VARS.contains(&var)
28}
29
30pub fn extract_refs(text: &str) -> HashSet<String> {
35 let mut refs = HashSet::new();
36 let bytes = text.as_bytes();
37 let mut i = 0;
38
39 while i < bytes.len() {
40 if bytes[i] == b'$' && i + 1 < bytes.len() {
41 if bytes[i + 1] == b'{' {
42 if let Some(close) = text[i + 2..].find('}') {
44 let var_name = &text[i + 2..i + 2 + close];
45 if !var_name.is_empty() {
46 refs.insert(var_name.to_string());
47 }
48 i += 3 + close;
49 continue;
50 }
51 } else if bytes[i + 1].is_ascii_alphabetic() || bytes[i + 1] == b'_' {
52 let start = i + 1;
54 let mut end = start;
55 while end < bytes.len()
56 && (bytes[end].is_ascii_alphanumeric() || bytes[end] == b'_')
57 {
58 end += 1;
59 }
60 let var_name = &text[start..end];
61 refs.insert(var_name.to_string());
62 i = end;
63 continue;
64 }
65 }
66 i += 1;
67 }
68
69 refs
70}
71
72pub fn use_tool_analysis_argument(
95 base: &str,
96 named_args: &[(String, String, String)],
97 step_names: &HashSet<&str>,
98) -> String {
99 let mut arg = base.to_string();
100 for (_name, value, kind) in named_args {
101 if kind == "reference" {
102 let dep = value.strip_suffix(".output").unwrap_or(value);
103 if step_names.contains(dep) {
104 arg.push_str(" ${");
105 arg.push_str(dep);
106 arg.push('}');
107 }
108 } else {
109 for r in extract_refs(value) {
110 if step_names.contains(r.as_str()) {
111 arg.push_str(" ${");
112 arg.push_str(&r);
113 arg.push('}');
114 }
115 }
116 }
117 }
118 arg
119}
120
121#[derive(Debug, Clone)]
125pub struct StepInfo {
126 pub name: String,
127 pub step_type: String,
128 pub user_prompt: String,
129 pub argument: String,
131}
132
133#[derive(Debug, Clone)]
137pub struct StepDependency {
138 pub name: String,
140 pub step_type: String,
142 pub depends_on: Vec<String>,
144 pub all_refs: Vec<String>,
146 pub step_refs: Vec<String>,
148 pub is_root: bool,
150}
151
152#[derive(Debug)]
154pub struct DependencyGraph {
155 pub steps: Vec<StepDependency>,
156 pub parallel_groups: Vec<Vec<String>>,
158 pub unresolved_refs: Vec<(String, String)>,
160 pub max_depth: usize,
162}
163
164pub fn analyze(steps: &[StepInfo]) -> DependencyGraph {
168 let step_names: HashSet<&str> = steps.iter().map(|s| s.name.as_str()).collect();
170
171 let mut deps: Vec<StepDependency> = Vec::new();
173 let mut unresolved: Vec<(String, String)> = Vec::new();
174
175 for step in steps {
176 let mut all_refs: HashSet<String> = extract_refs(&step.user_prompt);
178 if !step.argument.is_empty() {
179 all_refs.extend(extract_refs(&step.argument));
180 }
181
182 let mut step_refs: Vec<String> = Vec::new();
183 let mut depends_on: Vec<String> = Vec::new();
184
185 for r in &all_refs {
186 if is_builtin(r) {
187 continue;
188 }
189 if step_names.contains(r.as_str()) {
190 step_refs.push(r.clone());
191 depends_on.push(r.clone());
192 } else {
193 unresolved.push((step.name.clone(), r.clone()));
194 }
195 }
196
197 depends_on.sort();
198 depends_on.dedup();
199 step_refs.sort();
200
201 let mut all_refs_sorted: Vec<String> = all_refs.into_iter().collect();
202 all_refs_sorted.sort();
203
204 deps.push(StepDependency {
205 name: step.name.clone(),
206 step_type: step.step_type.clone(),
207 is_root: depends_on.is_empty(),
208 depends_on,
209 all_refs: all_refs_sorted,
210 step_refs,
211 });
212 }
213
214 let parallel_groups = find_parallel_groups(&deps);
216
217 let max_depth = calculate_max_depth(&deps);
219
220 DependencyGraph {
221 steps: deps,
222 parallel_groups,
223 unresolved_refs: unresolved,
224 max_depth,
225 }
226}
227
228fn find_parallel_groups(deps: &[StepDependency]) -> Vec<Vec<String>> {
231 let dep_map: HashMap<&str, &StepDependency> =
233 deps.iter().map(|d| (d.name.as_str(), d)).collect();
234
235 let mut depth_cache: HashMap<String, usize> = HashMap::new();
237 fn step_depth(
238 name: &str,
239 dep_map: &HashMap<&str, &StepDependency>,
240 cache: &mut HashMap<String, usize>,
241 ) -> usize {
242 if let Some(&cached) = cache.get(name) {
243 return cached;
244 }
245 let d = match dep_map.get(name) {
246 Some(d) => d,
247 None => return 0,
248 };
249 if d.depends_on.is_empty() {
250 cache.insert(name.to_string(), 0);
251 return 0;
252 }
253 let max_child = d
254 .depends_on
255 .iter()
256 .map(|dep| step_depth(dep, dep_map, cache))
257 .max()
258 .unwrap_or(0);
259 let result = max_child + 1;
260 cache.insert(name.to_string(), result);
261 result
262 }
263
264 for d in deps {
265 step_depth(&d.name, &dep_map, &mut depth_cache);
266 }
267
268 let mut by_depth: HashMap<usize, Vec<String>> = HashMap::new();
270 for d in deps {
271 let depth = depth_cache.get(&d.name).copied().unwrap_or(0);
272 by_depth.entry(depth).or_default().push(d.name.clone());
273 }
274
275 let mut groups: Vec<Vec<String>> = by_depth
277 .into_values()
278 .filter(|g| g.len() > 1)
279 .collect();
280 groups.sort_by_key(|g| g[0].clone());
281 groups
282}
283
284fn calculate_max_depth(deps: &[StepDependency]) -> usize {
286 let dep_map: HashMap<&str, &StepDependency> =
287 deps.iter().map(|d| (d.name.as_str(), d)).collect();
288
289 fn depth(
290 name: &str,
291 dep_map: &HashMap<&str, &StepDependency>,
292 cache: &mut HashMap<String, usize>,
293 ) -> usize {
294 if let Some(&cached) = cache.get(name) {
295 return cached;
296 }
297 let d = match dep_map.get(name) {
298 Some(d) => d,
299 None => return 0,
300 };
301 if d.depends_on.is_empty() {
302 cache.insert(name.to_string(), 0);
303 return 0;
304 }
305 let max_child = d
306 .depends_on
307 .iter()
308 .map(|dep| depth(dep, dep_map, cache))
309 .max()
310 .unwrap_or(0);
311 let result = max_child + 1;
312 cache.insert(name.to_string(), result);
313 result
314 }
315
316 let mut cache = HashMap::new();
317 deps.iter()
318 .map(|d| depth(&d.name, &dep_map, &mut cache))
319 .max()
320 .unwrap_or(0)
321}
322
323#[cfg(test)]
326mod tests {
327 use super::*;
328
329 #[test]
330 fn extract_refs_dollar_name() {
331 let refs = extract_refs("Use $result from $Analyze");
332 assert!(refs.contains("result"));
333 assert!(refs.contains("Analyze"));
334 assert_eq!(refs.len(), 2);
335 }
336
337 #[test]
338 fn extract_refs_braced() {
339 let refs = extract_refs("Given ${Extract} and ${Validate}");
340 assert!(refs.contains("Extract"));
341 assert!(refs.contains("Validate"));
342 assert_eq!(refs.len(), 2);
343 }
344
345 #[test]
346 fn extract_refs_mixed() {
347 let refs = extract_refs("$result is ${Analyze} plus $flow_name");
348 assert!(refs.contains("result"));
349 assert!(refs.contains("Analyze"));
350 assert!(refs.contains("flow_name"));
351 assert_eq!(refs.len(), 3);
352 }
353
354 #[test]
355 fn extract_refs_no_vars() {
356 let refs = extract_refs("plain text with no variables");
357 assert!(refs.is_empty());
358 }
359
360 #[test]
361 fn extract_refs_dollar_at_end() {
362 let refs = extract_refs("trailing $");
363 assert!(refs.is_empty());
364 }
365
366 #[test]
367 fn analyze_independent_steps() {
368 let steps = vec![
369 StepInfo {
370 name: "A".into(),
371 step_type: "step".into(),
372 user_prompt: "Do task A".into(),
373 argument: String::new(),
374 },
375 StepInfo {
376 name: "B".into(),
377 step_type: "step".into(),
378 user_prompt: "Do task B".into(),
379 argument: String::new(),
380 },
381 ];
382
383 let graph = analyze(&steps);
384 assert_eq!(graph.steps.len(), 2);
385 assert!(graph.steps[0].is_root);
386 assert!(graph.steps[1].is_root);
387 assert_eq!(graph.max_depth, 0);
388 assert_eq!(graph.parallel_groups.len(), 1);
390 assert_eq!(graph.parallel_groups[0].len(), 2);
391 }
392
393 #[test]
394 fn analyze_linear_chain() {
395 let steps = vec![
396 StepInfo {
397 name: "Extract".into(),
398 step_type: "step".into(),
399 user_prompt: "Extract entities".into(),
400 argument: String::new(),
401 },
402 StepInfo {
403 name: "Analyze".into(),
404 step_type: "step".into(),
405 user_prompt: "Analyze ${Extract}".into(),
406 argument: String::new(),
407 },
408 StepInfo {
409 name: "Report".into(),
410 step_type: "step".into(),
411 user_prompt: "Report on ${Analyze}".into(),
412 argument: String::new(),
413 },
414 ];
415
416 let graph = analyze(&steps);
417
418 assert!(graph.steps[0].is_root);
420 assert!(graph.steps[0].depends_on.is_empty());
421
422 assert!(!graph.steps[1].is_root);
424 assert_eq!(graph.steps[1].depends_on, vec!["Extract"]);
425
426 assert!(!graph.steps[2].is_root);
428 assert_eq!(graph.steps[2].depends_on, vec!["Analyze"]);
429
430 assert_eq!(graph.max_depth, 2);
432
433 assert!(graph.parallel_groups.is_empty());
435 }
436
437 #[test]
438 fn analyze_diamond_pattern() {
439 let steps = vec![
441 StepInfo {
442 name: "A".into(),
443 step_type: "step".into(),
444 user_prompt: "Start".into(),
445 argument: String::new(),
446 },
447 StepInfo {
448 name: "B".into(),
449 step_type: "step".into(),
450 user_prompt: "Process ${A} path B".into(),
451 argument: String::new(),
452 },
453 StepInfo {
454 name: "C".into(),
455 step_type: "step".into(),
456 user_prompt: "Process ${A} path C".into(),
457 argument: String::new(),
458 },
459 StepInfo {
460 name: "D".into(),
461 step_type: "step".into(),
462 user_prompt: "Merge ${B} and ${C}".into(),
463 argument: String::new(),
464 },
465 ];
466
467 let graph = analyze(&steps);
468
469 assert!(graph.steps[0].is_root); assert_eq!(graph.steps[1].depends_on, vec!["A"]); assert_eq!(graph.steps[2].depends_on, vec!["A"]); assert_eq!(graph.steps[3].depends_on, vec!["B", "C"]); assert!(!graph.parallel_groups.is_empty());
476 let has_bc_group = graph.parallel_groups.iter().any(|g| {
477 g.len() == 2 && g.contains(&"B".to_string()) && g.contains(&"C".to_string())
478 });
479 assert!(has_bc_group);
480
481 assert_eq!(graph.max_depth, 2);
483 }
484
485 #[test]
486 fn analyze_builtin_vars_excluded() {
487 let steps = vec![
488 StepInfo {
489 name: "S1".into(),
490 step_type: "step".into(),
491 user_prompt: "Current step is $step_name in $flow_name".into(),
492 argument: String::new(),
493 },
494 ];
495
496 let graph = analyze(&steps);
497 assert!(graph.steps[0].is_root);
498 assert!(graph.steps[0].depends_on.is_empty());
499 assert!(graph.steps[0].all_refs.contains(&"step_name".to_string()));
501 assert!(graph.steps[0].all_refs.contains(&"flow_name".to_string()));
502 assert!(graph.steps[0].step_refs.is_empty());
504 }
505
506 #[test]
507 fn analyze_unresolved_refs() {
508 let steps = vec![
509 StepInfo {
510 name: "S1".into(),
511 step_type: "step".into(),
512 user_prompt: "Use ${NonExistent} data".into(),
513 argument: String::new(),
514 },
515 ];
516
517 let graph = analyze(&steps);
518 assert_eq!(graph.unresolved_refs.len(), 1);
519 assert_eq!(graph.unresolved_refs[0], ("S1".to_string(), "NonExistent".to_string()));
520 }
521
522 #[test]
523 fn analyze_argument_refs() {
524 let steps = vec![
525 StepInfo {
526 name: "Gather".into(),
527 step_type: "step".into(),
528 user_prompt: "Gather data".into(),
529 argument: String::new(),
530 },
531 StepInfo {
532 name: "Calc".into(),
533 step_type: "use_tool".into(),
534 user_prompt: "Calculate".into(),
535 argument: "${Gather}".into(),
536 },
537 ];
538
539 let graph = analyze(&steps);
540 assert_eq!(graph.steps[1].depends_on, vec!["Gather"]);
541 }
542
543 #[test]
546 fn use_tool_reference_dotted_creates_dep() {
547 let names: HashSet<&str> = ["ExtractUrl"].into_iter().collect();
548 let na = vec![("url".into(), "ExtractUrl.output".into(), "reference".into())];
549 let arg = use_tool_analysis_argument("", &na, &names);
550 assert!(extract_refs(&arg).contains("ExtractUrl"));
551 }
552
553 #[test]
554 fn use_tool_reference_bare_creates_dep() {
555 let names: HashSet<&str> = ["ExtractUrl"].into_iter().collect();
556 let na = vec![("url".into(), "ExtractUrl".into(), "reference".into())];
557 let arg = use_tool_analysis_argument("", &na, &names);
558 assert!(extract_refs(&arg).contains("ExtractUrl"));
559 }
560
561 #[test]
562 fn use_tool_literal_interpolation_creates_dep() {
563 let names: HashSet<&str> = ["ExtractCompany"].into_iter().collect();
564 let na = vec![("c".into(), "${ExtractCompany}".into(), "literal".into())];
565 let arg = use_tool_analysis_argument("", &na, &names);
566 assert!(extract_refs(&arg).contains("ExtractCompany"));
567 }
568
569 #[test]
570 fn use_tool_flow_param_reference_is_not_a_step_dep() {
571 let names: HashSet<&str> = ["ExtractUrl"].into_iter().collect();
574 let na = vec![("src".into(), "user_input".into(), "reference".into())];
575 let arg = use_tool_analysis_argument("base", &na, &names);
576 assert!(!arg.contains("${user_input}"));
577 assert!(!extract_refs(&arg).contains("user_input"));
578 }
579
580 #[test]
581 fn use_tool_literal_plain_value_no_dep() {
582 let names: HashSet<&str> = ["ExtractUrl"].into_iter().collect();
583 let na = vec![("mode".into(), "production".into(), "literal".into())];
584 let arg = use_tool_analysis_argument("", &na, &names);
585 assert!(extract_refs(&arg).is_empty());
586 }
587
588 #[test]
589 fn use_tool_multi_arg_orders_after_independent_sources() {
590 let names: HashSet<&str> =
595 ["ExtractCompany", "ExtractDomain", "GenerateRadar"].into_iter().collect();
596 let na = vec![
597 ("company".into(), "ExtractCompany.output".into(), "reference".into()),
598 ("domain".into(), "ExtractDomain.output".into(), "reference".into()),
599 ];
600 let arg = use_tool_analysis_argument("", &na, &names);
601 let steps = vec![
602 StepInfo {
603 name: "ExtractCompany".into(),
604 step_type: "step".into(),
605 user_prompt: "extract the company from ${user_input}".into(),
606 argument: String::new(),
607 },
608 StepInfo {
609 name: "ExtractDomain".into(),
610 step_type: "step".into(),
611 user_prompt: "extract the domain from ${user_input}".into(),
612 argument: String::new(),
613 },
614 StepInfo {
615 name: "GenerateRadar".into(),
616 step_type: "use_tool".into(),
617 user_prompt: String::new(),
618 argument: arg,
619 },
620 ];
621 let graph = analyze(&steps);
622 let radar = graph.steps.iter().find(|s| s.name == "GenerateRadar").unwrap();
623 assert!(radar.depends_on.contains(&"ExtractCompany".to_string()));
624 assert!(radar.depends_on.contains(&"ExtractDomain".to_string()));
625 assert!(!radar.is_root);
626
627 let sched = crate::parallel::build_schedule(&graph);
629 assert!(!sched.waves[0].steps.contains(&"GenerateRadar".to_string()));
630 }
631
632 #[test]
633 fn analyze_empty_steps() {
634 let graph = analyze(&[]);
635 assert!(graph.steps.is_empty());
636 assert!(graph.parallel_groups.is_empty());
637 assert_eq!(graph.max_depth, 0);
638 }
639
640 #[test]
641 fn max_depth_flat() {
642 let steps = vec![
643 StepInfo { name: "A".into(), step_type: "step".into(), user_prompt: "a".into(), argument: String::new() },
644 StepInfo { name: "B".into(), step_type: "step".into(), user_prompt: "b".into(), argument: String::new() },
645 StepInfo { name: "C".into(), step_type: "step".into(), user_prompt: "c".into(), argument: String::new() },
646 ];
647 assert_eq!(analyze(&steps).max_depth, 0);
648 }
649}