1use std::collections::HashMap;
19
20use crate::step_deps::{DependencyGraph, StepDependency};
21
22#[derive(Debug, Clone)]
26pub struct Wave {
27 pub depth: usize,
29 pub steps: Vec<String>,
31 pub is_parallel: bool,
33}
34
35#[derive(Debug, Clone)]
37pub struct Schedule {
38 pub waves: Vec<Wave>,
40 pub total_steps: usize,
42 pub parallel_waves: usize,
44 pub max_parallelism: usize,
46}
47
48impl Schedule {
49 pub fn has_parallelism(&self) -> bool {
51 self.parallel_waves > 0
52 }
53
54 pub fn wave_of(&self, step_name: &str) -> Option<usize> {
56 for (i, wave) in self.waves.iter().enumerate() {
57 if wave.steps.iter().any(|s| s == step_name) {
58 return Some(i);
59 }
60 }
61 None
62 }
63
64 pub fn summary(&self) -> String {
66 if self.waves.is_empty() {
67 return "empty schedule".to_string();
68 }
69 let wave_desc: Vec<String> = self
70 .waves
71 .iter()
72 .map(|w| {
73 if w.is_parallel {
74 format!("[{}]", w.steps.join(" | "))
75 } else {
76 w.steps[0].clone()
77 }
78 })
79 .collect();
80 format!(
81 "{} → {} waves, {} parallel",
82 wave_desc.join(" → "),
83 self.waves.len(),
84 self.parallel_waves,
85 )
86 }
87}
88
89pub fn build_schedule(graph: &DependencyGraph) -> Schedule {
93 if graph.steps.is_empty() {
94 return Schedule {
95 waves: Vec::new(),
96 total_steps: 0,
97 parallel_waves: 0,
98 max_parallelism: 0,
99 };
100 }
101
102 let depths = calculate_depths(&graph.steps);
104
105 let max_depth = depths.values().copied().max().unwrap_or(0);
107 let mut waves: Vec<Wave> = Vec::new();
108
109 for d in 0..=max_depth {
110 let mut steps: Vec<String> = depths
111 .iter()
112 .filter(|(_, &dep)| dep == d)
113 .map(|(name, _)| name.clone())
114 .collect();
115 if steps.is_empty() {
116 continue;
117 }
118 steps.sort();
119 let is_parallel = steps.len() > 1;
120 waves.push(Wave {
121 depth: d,
122 steps,
123 is_parallel,
124 });
125 }
126
127 let total_steps = graph.steps.len();
128 let parallel_waves = waves.iter().filter(|w| w.is_parallel).count();
129 let max_parallelism = waves.iter().map(|w| w.steps.len()).max().unwrap_or(0);
130
131 Schedule {
132 waves,
133 total_steps,
134 parallel_waves,
135 max_parallelism,
136 }
137}
138
139fn calculate_depths(deps: &[StepDependency]) -> HashMap<String, usize> {
141 let dep_map: HashMap<&str, &StepDependency> =
142 deps.iter().map(|d| (d.name.as_str(), d)).collect();
143 let mut cache: HashMap<String, usize> = HashMap::new();
144
145 fn step_depth(
146 name: &str,
147 dep_map: &HashMap<&str, &StepDependency>,
148 cache: &mut HashMap<String, usize>,
149 ) -> usize {
150 if let Some(&cached) = cache.get(name) {
151 return cached;
152 }
153 let d = match dep_map.get(name) {
154 Some(d) => d,
155 None => return 0,
156 };
157 if d.depends_on.is_empty() {
158 cache.insert(name.to_string(), 0);
159 return 0;
160 }
161 let max_child = d
162 .depends_on
163 .iter()
164 .map(|dep| step_depth(dep, dep_map, cache))
165 .max()
166 .unwrap_or(0);
167 let result = max_child + 1;
168 cache.insert(name.to_string(), result);
169 result
170 }
171
172 for d in deps {
173 step_depth(&d.name, &dep_map, &mut cache);
174 }
175
176 cache
177}
178
179#[derive(Debug, Clone)]
183pub struct WaveStepResult {
184 pub step_name: String,
185 pub output: String,
186 pub success: bool,
187}
188
189pub fn execute_wave<F>(wave: &Wave, execute_fn: F) -> Vec<WaveStepResult>
196where
197 F: Fn(&str) -> WaveStepResult + Send + Sync,
198{
199 if !wave.is_parallel || wave.steps.len() <= 1 {
200 return wave.steps.iter().map(|s| execute_fn(s)).collect();
202 }
203
204 let mut results: Vec<WaveStepResult> = Vec::with_capacity(wave.steps.len());
206
207 std::thread::scope(|scope| {
208 let handles: Vec<_> = wave
209 .steps
210 .iter()
211 .map(|step_name| {
212 let func = &execute_fn;
213 scope.spawn(move || func(step_name))
214 })
215 .collect();
216
217 for handle in handles {
218 match handle.join() {
219 Ok(result) => results.push(result),
220 Err(_) => results.push(WaveStepResult {
221 step_name: "unknown".to_string(),
222 output: "thread panicked".to_string(),
223 success: false,
224 }),
225 }
226 }
227 });
228
229 results
230}
231
232#[cfg(test)]
235mod tests {
236 use super::*;
237 use crate::step_deps::{analyze, StepInfo};
238
239 #[test]
242 fn schedule_empty() {
243 let graph = analyze(&[]);
244 let sched = build_schedule(&graph);
245 assert!(sched.waves.is_empty());
246 assert_eq!(sched.total_steps, 0);
247 assert_eq!(sched.parallel_waves, 0);
248 assert!(!sched.has_parallelism());
249 }
250
251 #[test]
252 fn schedule_single_step() {
253 let steps = vec![StepInfo {
254 name: "A".into(),
255 step_type: "step".into(),
256 user_prompt: "do A".into(),
257 argument: String::new(),
258 }];
259 let graph = analyze(&steps);
260 let sched = build_schedule(&graph);
261
262 assert_eq!(sched.waves.len(), 1);
263 assert_eq!(sched.waves[0].steps, vec!["A"]);
264 assert!(!sched.waves[0].is_parallel);
265 assert_eq!(sched.parallel_waves, 0);
266 assert!(!sched.has_parallelism());
267 }
268
269 #[test]
270 fn schedule_linear_chain() {
271 let steps = vec![
273 StepInfo { name: "A".into(), step_type: "step".into(), user_prompt: "do A".into(), argument: String::new() },
274 StepInfo { name: "B".into(), step_type: "step".into(), user_prompt: "use $A".into(), argument: String::new() },
275 StepInfo { name: "C".into(), step_type: "step".into(), user_prompt: "use $B".into(), argument: String::new() },
276 ];
277 let graph = analyze(&steps);
278 let sched = build_schedule(&graph);
279
280 assert_eq!(sched.waves.len(), 3);
281 assert_eq!(sched.waves[0].steps, vec!["A"]);
282 assert_eq!(sched.waves[1].steps, vec!["B"]);
283 assert_eq!(sched.waves[2].steps, vec!["C"]);
284 assert_eq!(sched.parallel_waves, 0);
285 assert!(!sched.has_parallelism());
286 }
287
288 #[test]
289 fn schedule_diamond_pattern() {
290 let steps = vec![
292 StepInfo { name: "A".into(), step_type: "step".into(), user_prompt: "start".into(), argument: String::new() },
293 StepInfo { name: "B".into(), step_type: "step".into(), user_prompt: "use $A path1".into(), argument: String::new() },
294 StepInfo { name: "C".into(), step_type: "step".into(), user_prompt: "use $A path2".into(), argument: String::new() },
295 StepInfo { name: "D".into(), step_type: "step".into(), user_prompt: "combine $B and $C".into(), argument: String::new() },
296 ];
297 let graph = analyze(&steps);
298 let sched = build_schedule(&graph);
299
300 assert_eq!(sched.waves.len(), 3);
301 assert_eq!(sched.waves[0].steps, vec!["A"]); assert_eq!(sched.waves[1].steps, vec!["B", "C"]); assert_eq!(sched.waves[2].steps, vec!["D"]); assert!(sched.waves[1].is_parallel);
305 assert_eq!(sched.parallel_waves, 1);
306 assert_eq!(sched.max_parallelism, 2);
307 assert!(sched.has_parallelism());
308 }
309
310 #[test]
311 fn schedule_all_independent() {
312 let steps = vec![
314 StepInfo { name: "A".into(), step_type: "step".into(), user_prompt: "do A".into(), argument: String::new() },
315 StepInfo { name: "B".into(), step_type: "step".into(), user_prompt: "do B".into(), argument: String::new() },
316 StepInfo { name: "C".into(), step_type: "step".into(), user_prompt: "do C".into(), argument: String::new() },
317 ];
318 let graph = analyze(&steps);
319 let sched = build_schedule(&graph);
320
321 assert_eq!(sched.waves.len(), 1);
322 assert_eq!(sched.waves[0].steps, vec!["A", "B", "C"]);
323 assert!(sched.waves[0].is_parallel);
324 assert_eq!(sched.max_parallelism, 3);
325 }
326
327 #[test]
328 fn schedule_wide_diamond() {
329 let steps = vec![
331 StepInfo { name: "Root".into(), step_type: "step".into(), user_prompt: "start".into(), argument: String::new() },
332 StepInfo { name: "B".into(), step_type: "step".into(), user_prompt: "$Root b".into(), argument: String::new() },
333 StepInfo { name: "C".into(), step_type: "step".into(), user_prompt: "$Root c".into(), argument: String::new() },
334 StepInfo { name: "D".into(), step_type: "step".into(), user_prompt: "$Root d".into(), argument: String::new() },
335 StepInfo { name: "E".into(), step_type: "step".into(), user_prompt: "$B $C $D".into(), argument: String::new() },
336 ];
337 let graph = analyze(&steps);
338 let sched = build_schedule(&graph);
339
340 assert_eq!(sched.waves.len(), 3);
341 assert_eq!(sched.waves[0].steps, vec!["Root"]);
342 assert_eq!(sched.waves[1].steps, vec!["B", "C", "D"]);
343 assert!(sched.waves[1].is_parallel);
344 assert_eq!(sched.waves[2].steps, vec!["E"]);
345 assert_eq!(sched.max_parallelism, 3);
346 }
347
348 #[test]
351 fn wave_of_lookup() {
352 let steps = vec![
353 StepInfo { name: "A".into(), step_type: "step".into(), user_prompt: "start".into(), argument: String::new() },
354 StepInfo { name: "B".into(), step_type: "step".into(), user_prompt: "$A".into(), argument: String::new() },
355 ];
356 let graph = analyze(&steps);
357 let sched = build_schedule(&graph);
358
359 assert_eq!(sched.wave_of("A"), Some(0));
360 assert_eq!(sched.wave_of("B"), Some(1));
361 assert_eq!(sched.wave_of("Z"), None);
362 }
363
364 #[test]
367 fn schedule_summary_format() {
368 let steps = vec![
369 StepInfo { name: "A".into(), step_type: "step".into(), user_prompt: "start".into(), argument: String::new() },
370 StepInfo { name: "B".into(), step_type: "step".into(), user_prompt: "$A b".into(), argument: String::new() },
371 StepInfo { name: "C".into(), step_type: "step".into(), user_prompt: "$A c".into(), argument: String::new() },
372 StepInfo { name: "D".into(), step_type: "step".into(), user_prompt: "$B $C".into(), argument: String::new() },
373 ];
374 let graph = analyze(&steps);
375 let sched = build_schedule(&graph);
376 let summary = sched.summary();
377
378 assert!(summary.contains("A"));
379 assert!(summary.contains("B | C"));
380 assert!(summary.contains("D"));
381 assert!(summary.contains("3 waves"));
382 assert!(summary.contains("1 parallel"));
383 }
384
385 #[test]
388 fn execute_wave_sequential() {
389 let wave = Wave {
390 depth: 0,
391 steps: vec!["A".into()],
392 is_parallel: false,
393 };
394
395 let results = execute_wave(&wave, |name| WaveStepResult {
396 step_name: name.to_string(),
397 output: format!("result_{name}"),
398 success: true,
399 });
400
401 assert_eq!(results.len(), 1);
402 assert_eq!(results[0].step_name, "A");
403 assert_eq!(results[0].output, "result_A");
404 }
405
406 #[test]
407 fn execute_wave_parallel() {
408 use std::sync::atomic::{AtomicUsize, Ordering};
409
410 let wave = Wave {
411 depth: 1,
412 steps: vec!["B".into(), "C".into(), "D".into()],
413 is_parallel: true,
414 };
415
416 let counter = AtomicUsize::new(0);
417
418 let results = execute_wave(&wave, |name| {
419 counter.fetch_add(1, Ordering::SeqCst);
420 std::thread::sleep(std::time::Duration::from_millis(10));
422 WaveStepResult {
423 step_name: name.to_string(),
424 output: format!("done_{name}"),
425 success: true,
426 }
427 });
428
429 assert_eq!(results.len(), 3);
431 assert_eq!(counter.load(Ordering::SeqCst), 3);
432
433 let mut names: Vec<String> = results.iter().map(|r| r.step_name.clone()).collect();
435 names.sort();
436 assert_eq!(names, vec!["B", "C", "D"]);
437 }
438
439 #[test]
440 fn execute_wave_thread_safety() {
441 use std::sync::{Arc, Mutex};
442
443 let wave = Wave {
444 depth: 0,
445 steps: vec!["X".into(), "Y".into()],
446 is_parallel: true,
447 };
448
449 let log = Arc::new(Mutex::new(Vec::<String>::new()));
450
451 let results = execute_wave(&wave, |name| {
452 log.lock().unwrap().push(name.to_string());
453 WaveStepResult {
454 step_name: name.to_string(),
455 output: "ok".to_string(),
456 success: true,
457 }
458 });
459
460 assert_eq!(results.len(), 2);
461 let entries = log.lock().unwrap();
462 assert_eq!(entries.len(), 2);
463 assert!(entries.contains(&"X".to_string()));
464 assert!(entries.contains(&"Y".to_string()));
465 }
466}