Skip to main content

axon/
parallel.rs

1//! Parallel step scheduler — depth-based wave execution with threads.
2//!
3//! Organizes steps into execution waves based on dependency depth (from D15's
4//! `DependencyGraph`). Steps at the same depth have no mutual dependencies
5//! and can safely execute in parallel.
6//!
7//! Execution model:
8//!   Wave 0: all root steps (no dependencies) — execute in parallel
9//!   Wave 1: steps that depend only on wave-0 steps — execute in parallel
10//!   Wave N: steps at depth N — execute after waves 0..N-1 complete
11//!
12//! Between waves, results synchronize back into the shared context so that
13//! the next wave can read them via `${StepName}` interpolation.
14//!
15//! Thread model: uses `std::thread::scope` for safe, borrow-friendly parallelism
16//! within each wave. No heap allocation for thread handles needed.
17
18use std::collections::HashMap;
19
20use crate::step_deps::{DependencyGraph, StepDependency};
21
22// ── Schedule structures ───────────────────────────────────────────────────
23
24/// A single execution wave — a group of steps that can run concurrently.
25#[derive(Debug, Clone)]
26pub struct Wave {
27    /// Depth level of this wave (0 = root steps).
28    pub depth: usize,
29    /// Step names in this wave, sorted alphabetically.
30    pub steps: Vec<String>,
31    /// Whether this wave has multiple steps (actual parallelism possible).
32    pub is_parallel: bool,
33}
34
35/// Execution schedule — ordered sequence of waves derived from dependency analysis.
36#[derive(Debug, Clone)]
37pub struct Schedule {
38    /// Waves in execution order (depth 0 first).
39    pub waves: Vec<Wave>,
40    /// Total number of steps across all waves.
41    pub total_steps: usize,
42    /// Number of waves with actual parallelism (more than 1 step).
43    pub parallel_waves: usize,
44    /// Maximum parallelism (largest wave size).
45    pub max_parallelism: usize,
46}
47
48impl Schedule {
49    /// Check if the schedule contains any parallelizable waves.
50    pub fn has_parallelism(&self) -> bool {
51        self.parallel_waves > 0
52    }
53
54    /// Get the wave index for a given step name.
55    pub fn wave_of(&self, step_name: &str) -> Option<usize> {
56        for (i, wave) in self.waves.iter().enumerate() {
57            if wave.steps.iter().any(|s| s == step_name) {
58                return Some(i);
59            }
60        }
61        None
62    }
63
64    /// Format the schedule as a compact summary string.
65    pub fn summary(&self) -> String {
66        if self.waves.is_empty() {
67            return "empty schedule".to_string();
68        }
69        let wave_desc: Vec<String> = self
70            .waves
71            .iter()
72            .map(|w| {
73                if w.is_parallel {
74                    format!("[{}]", w.steps.join(" | "))
75                } else {
76                    w.steps[0].clone()
77                }
78            })
79            .collect();
80        format!(
81            "{} → {} waves, {} parallel",
82            wave_desc.join(" → "),
83            self.waves.len(),
84            self.parallel_waves,
85        )
86    }
87}
88
89// ── Schedule builder ──────────────────────────────────────────────────────
90
91/// Build an execution schedule from a dependency graph.
92pub fn build_schedule(graph: &DependencyGraph) -> Schedule {
93    if graph.steps.is_empty() {
94        return Schedule {
95            waves: Vec::new(),
96            total_steps: 0,
97            parallel_waves: 0,
98            max_parallelism: 0,
99        };
100    }
101
102    // Calculate depth for each step
103    let depths = calculate_depths(&graph.steps);
104
105    // Group steps by depth level
106    let max_depth = depths.values().copied().max().unwrap_or(0);
107    let mut waves: Vec<Wave> = Vec::new();
108
109    for d in 0..=max_depth {
110        let mut steps: Vec<String> = depths
111            .iter()
112            .filter(|(_, &dep)| dep == d)
113            .map(|(name, _)| name.clone())
114            .collect();
115        if steps.is_empty() {
116            continue;
117        }
118        steps.sort();
119        let is_parallel = steps.len() > 1;
120        waves.push(Wave {
121            depth: d,
122            steps,
123            is_parallel,
124        });
125    }
126
127    let total_steps = graph.steps.len();
128    let parallel_waves = waves.iter().filter(|w| w.is_parallel).count();
129    let max_parallelism = waves.iter().map(|w| w.steps.len()).max().unwrap_or(0);
130
131    Schedule {
132        waves,
133        total_steps,
134        parallel_waves,
135        max_parallelism,
136    }
137}
138
139/// Calculate depth for each step via transitive dependency resolution.
140fn calculate_depths(deps: &[StepDependency]) -> HashMap<String, usize> {
141    let dep_map: HashMap<&str, &StepDependency> =
142        deps.iter().map(|d| (d.name.as_str(), d)).collect();
143    let mut cache: HashMap<String, usize> = HashMap::new();
144
145    fn step_depth(
146        name: &str,
147        dep_map: &HashMap<&str, &StepDependency>,
148        cache: &mut HashMap<String, usize>,
149    ) -> usize {
150        if let Some(&cached) = cache.get(name) {
151            return cached;
152        }
153        let d = match dep_map.get(name) {
154            Some(d) => d,
155            None => return 0,
156        };
157        if d.depends_on.is_empty() {
158            cache.insert(name.to_string(), 0);
159            return 0;
160        }
161        let max_child = d
162            .depends_on
163            .iter()
164            .map(|dep| step_depth(dep, dep_map, cache))
165            .max()
166            .unwrap_or(0);
167        let result = max_child + 1;
168        cache.insert(name.to_string(), result);
169        result
170    }
171
172    for d in deps {
173        step_depth(&d.name, &dep_map, &mut cache);
174    }
175
176    cache
177}
178
179// ── Wave executor ─────────────────────────────────────────────────────────
180
181/// Result of a single step execution within a wave.
182#[derive(Debug, Clone)]
183pub struct WaveStepResult {
184    pub step_name: String,
185    pub output: String,
186    pub success: bool,
187}
188
189/// Execute a wave of steps in parallel using scoped threads.
190///
191/// The `execute_fn` closure is called once per step, receiving the step name.
192/// It must be `Send + Sync` since it runs across threads.
193///
194/// Returns results for all steps in the wave (order not guaranteed for parallel).
195pub fn execute_wave<F>(wave: &Wave, execute_fn: F) -> Vec<WaveStepResult>
196where
197    F: Fn(&str) -> WaveStepResult + Send + Sync,
198{
199    if !wave.is_parallel || wave.steps.len() <= 1 {
200        // Sequential execution — no threads needed
201        return wave.steps.iter().map(|s| execute_fn(s)).collect();
202    }
203
204    // Parallel execution with scoped threads
205    let mut results: Vec<WaveStepResult> = Vec::with_capacity(wave.steps.len());
206
207    std::thread::scope(|scope| {
208        let handles: Vec<_> = wave
209            .steps
210            .iter()
211            .map(|step_name| {
212                let func = &execute_fn;
213                scope.spawn(move || func(step_name))
214            })
215            .collect();
216
217        for handle in handles {
218            match handle.join() {
219                Ok(result) => results.push(result),
220                Err(_) => results.push(WaveStepResult {
221                    step_name: "unknown".to_string(),
222                    output: "thread panicked".to_string(),
223                    success: false,
224                }),
225            }
226        }
227    });
228
229    results
230}
231
232// ── Tests ─────────────────────────────────────────────────────────────────
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237    use crate::step_deps::{analyze, StepInfo};
238
239    // ── Schedule building ─────────────────────────────────────────
240
241    #[test]
242    fn schedule_empty() {
243        let graph = analyze(&[]);
244        let sched = build_schedule(&graph);
245        assert!(sched.waves.is_empty());
246        assert_eq!(sched.total_steps, 0);
247        assert_eq!(sched.parallel_waves, 0);
248        assert!(!sched.has_parallelism());
249    }
250
251    #[test]
252    fn schedule_single_step() {
253        let steps = vec![StepInfo {
254            name: "A".into(),
255            step_type: "step".into(),
256            user_prompt: "do A".into(),
257            argument: String::new(),
258        }];
259        let graph = analyze(&steps);
260        let sched = build_schedule(&graph);
261
262        assert_eq!(sched.waves.len(), 1);
263        assert_eq!(sched.waves[0].steps, vec!["A"]);
264        assert!(!sched.waves[0].is_parallel);
265        assert_eq!(sched.parallel_waves, 0);
266        assert!(!sched.has_parallelism());
267    }
268
269    #[test]
270    fn schedule_linear_chain() {
271        // A → B → C (all sequential)
272        let steps = vec![
273            StepInfo { name: "A".into(), step_type: "step".into(), user_prompt: "do A".into(), argument: String::new() },
274            StepInfo { name: "B".into(), step_type: "step".into(), user_prompt: "use $A".into(), argument: String::new() },
275            StepInfo { name: "C".into(), step_type: "step".into(), user_prompt: "use $B".into(), argument: String::new() },
276        ];
277        let graph = analyze(&steps);
278        let sched = build_schedule(&graph);
279
280        assert_eq!(sched.waves.len(), 3);
281        assert_eq!(sched.waves[0].steps, vec!["A"]);
282        assert_eq!(sched.waves[1].steps, vec!["B"]);
283        assert_eq!(sched.waves[2].steps, vec!["C"]);
284        assert_eq!(sched.parallel_waves, 0);
285        assert!(!sched.has_parallelism());
286    }
287
288    #[test]
289    fn schedule_diamond_pattern() {
290        // A → B, A → C, B+C → D
291        let steps = vec![
292            StepInfo { name: "A".into(), step_type: "step".into(), user_prompt: "start".into(), argument: String::new() },
293            StepInfo { name: "B".into(), step_type: "step".into(), user_prompt: "use $A path1".into(), argument: String::new() },
294            StepInfo { name: "C".into(), step_type: "step".into(), user_prompt: "use $A path2".into(), argument: String::new() },
295            StepInfo { name: "D".into(), step_type: "step".into(), user_prompt: "combine $B and $C".into(), argument: String::new() },
296        ];
297        let graph = analyze(&steps);
298        let sched = build_schedule(&graph);
299
300        assert_eq!(sched.waves.len(), 3);
301        assert_eq!(sched.waves[0].steps, vec!["A"]);          // depth 0
302        assert_eq!(sched.waves[1].steps, vec!["B", "C"]);     // depth 1 — PARALLEL
303        assert_eq!(sched.waves[2].steps, vec!["D"]);           // depth 2
304        assert!(sched.waves[1].is_parallel);
305        assert_eq!(sched.parallel_waves, 1);
306        assert_eq!(sched.max_parallelism, 2);
307        assert!(sched.has_parallelism());
308    }
309
310    #[test]
311    fn schedule_all_independent() {
312        // A, B, C — no dependencies, all can run in parallel
313        let steps = vec![
314            StepInfo { name: "A".into(), step_type: "step".into(), user_prompt: "do A".into(), argument: String::new() },
315            StepInfo { name: "B".into(), step_type: "step".into(), user_prompt: "do B".into(), argument: String::new() },
316            StepInfo { name: "C".into(), step_type: "step".into(), user_prompt: "do C".into(), argument: String::new() },
317        ];
318        let graph = analyze(&steps);
319        let sched = build_schedule(&graph);
320
321        assert_eq!(sched.waves.len(), 1);
322        assert_eq!(sched.waves[0].steps, vec!["A", "B", "C"]);
323        assert!(sched.waves[0].is_parallel);
324        assert_eq!(sched.max_parallelism, 3);
325    }
326
327    #[test]
328    fn schedule_wide_diamond() {
329        // Root → B, C, D (parallel) → E
330        let steps = vec![
331            StepInfo { name: "Root".into(), step_type: "step".into(), user_prompt: "start".into(), argument: String::new() },
332            StepInfo { name: "B".into(), step_type: "step".into(), user_prompt: "$Root b".into(), argument: String::new() },
333            StepInfo { name: "C".into(), step_type: "step".into(), user_prompt: "$Root c".into(), argument: String::new() },
334            StepInfo { name: "D".into(), step_type: "step".into(), user_prompt: "$Root d".into(), argument: String::new() },
335            StepInfo { name: "E".into(), step_type: "step".into(), user_prompt: "$B $C $D".into(), argument: String::new() },
336        ];
337        let graph = analyze(&steps);
338        let sched = build_schedule(&graph);
339
340        assert_eq!(sched.waves.len(), 3);
341        assert_eq!(sched.waves[0].steps, vec!["Root"]);
342        assert_eq!(sched.waves[1].steps, vec!["B", "C", "D"]);
343        assert!(sched.waves[1].is_parallel);
344        assert_eq!(sched.waves[2].steps, vec!["E"]);
345        assert_eq!(sched.max_parallelism, 3);
346    }
347
348    // ── Wave of ───────────────────────────────────────────────────
349
350    #[test]
351    fn wave_of_lookup() {
352        let steps = vec![
353            StepInfo { name: "A".into(), step_type: "step".into(), user_prompt: "start".into(), argument: String::new() },
354            StepInfo { name: "B".into(), step_type: "step".into(), user_prompt: "$A".into(), argument: String::new() },
355        ];
356        let graph = analyze(&steps);
357        let sched = build_schedule(&graph);
358
359        assert_eq!(sched.wave_of("A"), Some(0));
360        assert_eq!(sched.wave_of("B"), Some(1));
361        assert_eq!(sched.wave_of("Z"), None);
362    }
363
364    // ── Summary ───────────────────────────────────────────────────
365
366    #[test]
367    fn schedule_summary_format() {
368        let steps = vec![
369            StepInfo { name: "A".into(), step_type: "step".into(), user_prompt: "start".into(), argument: String::new() },
370            StepInfo { name: "B".into(), step_type: "step".into(), user_prompt: "$A b".into(), argument: String::new() },
371            StepInfo { name: "C".into(), step_type: "step".into(), user_prompt: "$A c".into(), argument: String::new() },
372            StepInfo { name: "D".into(), step_type: "step".into(), user_prompt: "$B $C".into(), argument: String::new() },
373        ];
374        let graph = analyze(&steps);
375        let sched = build_schedule(&graph);
376        let summary = sched.summary();
377
378        assert!(summary.contains("A"));
379        assert!(summary.contains("B | C"));
380        assert!(summary.contains("D"));
381        assert!(summary.contains("3 waves"));
382        assert!(summary.contains("1 parallel"));
383    }
384
385    // ── Wave execution ────────────────────────────────────────────
386
387    #[test]
388    fn execute_wave_sequential() {
389        let wave = Wave {
390            depth: 0,
391            steps: vec!["A".into()],
392            is_parallel: false,
393        };
394
395        let results = execute_wave(&wave, |name| WaveStepResult {
396            step_name: name.to_string(),
397            output: format!("result_{name}"),
398            success: true,
399        });
400
401        assert_eq!(results.len(), 1);
402        assert_eq!(results[0].step_name, "A");
403        assert_eq!(results[0].output, "result_A");
404    }
405
406    #[test]
407    fn execute_wave_parallel() {
408        use std::sync::atomic::{AtomicUsize, Ordering};
409
410        let wave = Wave {
411            depth: 1,
412            steps: vec!["B".into(), "C".into(), "D".into()],
413            is_parallel: true,
414        };
415
416        let counter = AtomicUsize::new(0);
417
418        let results = execute_wave(&wave, |name| {
419            counter.fetch_add(1, Ordering::SeqCst);
420            // Simulate some work
421            std::thread::sleep(std::time::Duration::from_millis(10));
422            WaveStepResult {
423                step_name: name.to_string(),
424                output: format!("done_{name}"),
425                success: true,
426            }
427        });
428
429        // All 3 steps executed
430        assert_eq!(results.len(), 3);
431        assert_eq!(counter.load(Ordering::SeqCst), 3);
432
433        // All results present (order may vary)
434        let mut names: Vec<String> = results.iter().map(|r| r.step_name.clone()).collect();
435        names.sort();
436        assert_eq!(names, vec!["B", "C", "D"]);
437    }
438
439    #[test]
440    fn execute_wave_thread_safety() {
441        use std::sync::{Arc, Mutex};
442
443        let wave = Wave {
444            depth: 0,
445            steps: vec!["X".into(), "Y".into()],
446            is_parallel: true,
447        };
448
449        let log = Arc::new(Mutex::new(Vec::<String>::new()));
450
451        let results = execute_wave(&wave, |name| {
452            log.lock().unwrap().push(name.to_string());
453            WaveStepResult {
454                step_name: name.to_string(),
455                output: "ok".to_string(),
456                success: true,
457            }
458        });
459
460        assert_eq!(results.len(), 2);
461        let entries = log.lock().unwrap();
462        assert_eq!(entries.len(), 2);
463        assert!(entries.contains(&"X".to_string()));
464        assert!(entries.contains(&"Y".to_string()));
465    }
466}