Skip to main content

oxiphysics_core/
parallel_orchestrator.rs

1#![allow(clippy::needless_range_loop)]
2// Copyright 2026 COOLJAPAN OU (Team KitaSan)
3// SPDX-License-Identifier: Apache-2.0
4
5//! Parallel solver orchestration for multi-stage physics pipelines.
6//!
7//! This module provides a dependency-aware scheduler that organizes solver stages
8//! into parallel waves via topological sorting. Stages within the same wave have
9//! no mutual dependencies and can conceptually execute in parallel, while stages
10//! in later waves depend on earlier ones.
11//!
12//! # Architecture
13//!
14//! The orchestrator works in three phases:
15//! 1. **Registration** - stages and their dependencies are added via [`ParallelOrchestrator::add_stage`].
16//! 2. **Scheduling** - [`ParallelOrchestrator::compute_schedule`] performs a topological sort
17//!    to group independent stages into waves (Kahn's algorithm with cycle detection).
18//! 3. **Execution** - [`ParallelOrchestrator::execute`] runs each wave sequentially,
19//!    executing stages within each wave. Per-stage wall-clock timings are accumulated.
20//!
21//! # Future: rayon-based parallelism
22//!
23//! Currently stages within a wave run sequentially. When rayon is added as a
24//! workspace dependency, intra-wave parallelism can be enabled behind a feature
25//! flag (e.g. `parallel-rayon`) by replacing the sequential loop with
26//! `rayon::scope` or `rayon::join`.
27
28#![allow(dead_code)]
29
30use std::fmt;
31use std::time::Instant;
32
33// ─── Error ──────────────────────────────────────────────────────────────────
34
35/// Errors that can occur during orchestration scheduling or execution.
36#[derive(Debug, Clone, PartialEq, Eq)]
37pub enum OrchestratorError {
38    /// A dependency cycle was detected among stages.
39    CycleDetected {
40        /// Human-readable description of the cycle.
41        description: String,
42    },
43    /// A stage index referenced in a dependency is out of range.
44    InvalidStageIndex {
45        /// The invalid index that was referenced.
46        index: usize,
47        /// The total number of registered stages.
48        total_stages: usize,
49    },
50    /// A stage index used during execution is out of range.
51    ExecutionIndexOutOfRange {
52        /// The invalid index.
53        index: usize,
54        /// The length of the stages slice provided.
55        stages_len: usize,
56    },
57}
58
59impl fmt::Display for OrchestratorError {
60    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61        match self {
62            OrchestratorError::CycleDetected { description } => {
63                write!(f, "dependency cycle detected: {description}")
64            }
65            OrchestratorError::InvalidStageIndex {
66                index,
67                total_stages,
68            } => {
69                write!(
70                    f,
71                    "invalid stage index {index} (total stages: {total_stages})"
72                )
73            }
74            OrchestratorError::ExecutionIndexOutOfRange { index, stages_len } => {
75                write!(
76                    f,
77                    "execution index {index} out of range (stages slice length: {stages_len})"
78                )
79            }
80        }
81    }
82}
83
84impl std::error::Error for OrchestratorError {}
85
86// ─── SolverStage trait ──────────────────────────────────────────────────────
87
88/// A single stage in a physics solver pipeline.
89///
90/// Each stage performs a portion of the simulation step (e.g. broadphase collision,
91/// constraint solving, integration). Stages declare an estimated cost for
92/// load-balancing purposes and must be `Send + Sync` to support future
93/// parallel execution.
94pub trait SolverStage: Send + Sync {
95    /// Returns the human-readable name of this stage.
96    fn name(&self) -> &str;
97
98    /// Advances the stage by `dt` seconds of simulation time.
99    fn step(&mut self, dt: f64);
100
101    /// Returns an estimated computational cost (arbitrary units) for load balancing.
102    ///
103    /// Higher values indicate more expensive stages. The orchestrator may use this
104    /// to reorder stages within a wave for better load distribution.
105    fn estimated_cost(&self) -> f64;
106}
107
108// ─── StageDependency ────────────────────────────────────────────────────────
109
110/// Describes ordering constraints for a single stage.
111#[derive(Debug, Clone, PartialEq, Eq)]
112pub struct StageDependency {
113    /// The index of the stage this dependency record belongs to.
114    pub stage_idx: usize,
115    /// Indices of stages that must complete before this stage can run.
116    pub depends_on: Vec<usize>,
117}
118
119// ─── PipelineSchedule ───────────────────────────────────────────────────────
120
121/// A computed execution schedule produced by topological sorting.
122///
123/// Stages are grouped into *waves*. All stages within the same wave are
124/// independent of each other and can conceptually run in parallel. Waves
125/// themselves execute in order.
126#[derive(Debug, Clone, PartialEq, Eq)]
127pub struct PipelineSchedule {
128    /// Groups of stage indices that can run in parallel within each wave.
129    pub waves: Vec<Vec<usize>>,
130}
131
132impl PipelineSchedule {
133    /// Returns the total number of waves.
134    pub fn num_waves(&self) -> usize {
135        self.waves.len()
136    }
137
138    /// Returns the total number of stages across all waves.
139    pub fn num_stages(&self) -> usize {
140        self.waves.iter().map(|w| w.len()).sum()
141    }
142}
143
144// ─── ParallelOrchestrator ───────────────────────────────────────────────────
145
146/// Orchestrates the execution of multiple solver stages respecting dependencies.
147///
148/// # Example
149///
150/// ```no_run
151/// use oxiphysics_core::parallel_orchestrator::{ParallelOrchestrator, SolverStage};
152///
153/// struct SimpleStage { name: String, cost: f64, step_count: u64 }
154///
155/// impl SolverStage for SimpleStage {
156///     fn name(&self) -> &str { &self.name }
157///     fn step(&mut self, _dt: f64) { self.step_count += 1; }
158///     fn estimated_cost(&self) -> f64 { self.cost }
159/// }
160///
161/// let mut orch = ParallelOrchestrator::new();
162/// let a = orch.add_stage("broadphase", &[]);
163/// let b = orch.add_stage("narrowphase", &[a]);
164/// let c = orch.add_stage("solver", &[b]);
165///
166/// let mut stages: Vec<Box<dyn SolverStage>> = vec![
167///     Box::new(SimpleStage { name: "broadphase".into(), cost: 1.0, step_count: 0 }),
168///     Box::new(SimpleStage { name: "narrowphase".into(), cost: 2.0, step_count: 0 }),
169///     Box::new(SimpleStage { name: "solver".into(), cost: 3.0, step_count: 0 }),
170/// ];
171///
172/// orch.execute(&mut stages, 0.016).expect("execution failed");
173/// ```
174#[derive(Debug, Clone)]
175pub struct ParallelOrchestrator {
176    /// Registered stage names (index = stage id).
177    stage_names: Vec<String>,
178    /// Dependency records for each stage.
179    dependencies: Vec<StageDependency>,
180    /// Accumulated wall-clock timings per stage (seconds).
181    timings: Vec<f64>,
182}
183
184impl ParallelOrchestrator {
185    /// Creates a new empty orchestrator.
186    pub fn new() -> Self {
187        Self {
188            stage_names: Vec::new(),
189            dependencies: Vec::new(),
190            timings: Vec::new(),
191        }
192    }
193
194    /// Registers a new stage with the given name and dependency list.
195    ///
196    /// Returns the index of the newly added stage, which can be used as a
197    /// dependency for later stages.
198    ///
199    /// # Arguments
200    /// * `name` - Human-readable stage name.
201    /// * `depends_on` - Indices of stages that must run before this one.
202    pub fn add_stage(&mut self, name: &str, depends_on: &[usize]) -> usize {
203        let idx = self.stage_names.len();
204        self.stage_names.push(name.to_string());
205        self.dependencies.push(StageDependency {
206            stage_idx: idx,
207            depends_on: depends_on.to_vec(),
208        });
209        self.timings.push(0.0);
210        idx
211    }
212
213    /// Returns the number of registered stages.
214    pub fn num_stages(&self) -> usize {
215        self.stage_names.len()
216    }
217
218    /// Returns the registered stage names.
219    pub fn stage_names(&self) -> &[String] {
220        &self.stage_names
221    }
222
223    /// Computes a [`PipelineSchedule`] by topologically sorting stages into waves.
224    ///
225    /// Returns an error if the dependency graph contains a cycle or references
226    /// an invalid stage index.
227    pub fn compute_schedule(&self) -> Result<PipelineSchedule, OrchestratorError> {
228        let waves = topological_sort(self.stage_names.len(), &self.dependencies)?;
229        Ok(PipelineSchedule { waves })
230    }
231
232    /// Executes all registered stages in dependency order.
233    ///
234    /// Stages within the same wave are currently run sequentially. The timings
235    /// for each stage are accumulated across repeated calls to `execute`.
236    ///
237    /// # Errors
238    /// Returns an error if scheduling fails (cycle or invalid index) or if a
239    /// scheduled stage index is out of range for the provided `stages` slice.
240    pub fn execute(
241        &mut self,
242        stages: &mut [Box<dyn SolverStage>],
243        dt: f64,
244    ) -> Result<(), OrchestratorError> {
245        let schedule = self.compute_schedule()?;
246
247        for wave in &schedule.waves {
248            // NOTE: future rayon parallelism would replace this sequential loop
249            // with rayon::scope or par_iter over the wave indices, using unsafe
250            // cell or split_at_mut to grant mutable access to disjoint elements.
251            for &stage_idx in wave {
252                if stage_idx >= stages.len() {
253                    return Err(OrchestratorError::ExecutionIndexOutOfRange {
254                        index: stage_idx,
255                        stages_len: stages.len(),
256                    });
257                }
258                let start = Instant::now();
259                stages[stage_idx].step(dt);
260                let elapsed = start.elapsed().as_secs_f64();
261                if stage_idx < self.timings.len() {
262                    self.timings[stage_idx] += elapsed;
263                }
264            }
265        }
266
267        Ok(())
268    }
269
270    /// Returns per-stage accumulated wall-clock timings in seconds.
271    pub fn timings(&self) -> &[f64] {
272        &self.timings
273    }
274
275    /// Returns the total accumulated wall-clock time across all stages.
276    pub fn total_time(&self) -> f64 {
277        self.timings.iter().sum()
278    }
279
280    /// Resets all accumulated timings to zero.
281    pub fn reset_timings(&mut self) {
282        for t in &mut self.timings {
283            *t = 0.0;
284        }
285    }
286}
287
288impl Default for ParallelOrchestrator {
289    fn default() -> Self {
290        Self::new()
291    }
292}
293
294// ─── Topological sort ───────────────────────────────────────────────────────
295
296/// Performs a topological sort of `n` stages using Kahn's algorithm.
297///
298/// Stages are grouped into waves: each wave contains stages whose dependencies
299/// have all been satisfied by previous waves. This is a BFS-layered variant of
300/// Kahn's algorithm.
301///
302/// # Errors
303/// - [`OrchestratorError::InvalidStageIndex`] if any dependency index is >= `n`.
304/// - [`OrchestratorError::CycleDetected`] if the graph contains a cycle.
305pub fn topological_sort(
306    n: usize,
307    deps: &[StageDependency],
308) -> Result<Vec<Vec<usize>>, OrchestratorError> {
309    if n == 0 {
310        return Ok(Vec::new());
311    }
312
313    // Validate all dependency indices.
314    for dep in deps {
315        if dep.stage_idx >= n {
316            return Err(OrchestratorError::InvalidStageIndex {
317                index: dep.stage_idx,
318                total_stages: n,
319            });
320        }
321        for &d in &dep.depends_on {
322            if d >= n {
323                return Err(OrchestratorError::InvalidStageIndex {
324                    index: d,
325                    total_stages: n,
326                });
327            }
328        }
329    }
330
331    // Build adjacency list and in-degree count.
332    // Edge: dependency -> stage (dependency must come first).
333    let mut adjacency: Vec<Vec<usize>> = vec![Vec::new(); n];
334    let mut in_degree: Vec<usize> = vec![0; n];
335
336    for dep in deps {
337        for &d in &dep.depends_on {
338            adjacency[d].push(dep.stage_idx);
339            in_degree[dep.stage_idx] += 1;
340        }
341    }
342
343    // BFS-layered Kahn's algorithm.
344    let mut waves: Vec<Vec<usize>> = Vec::new();
345    let mut current_wave: Vec<usize> = Vec::new();
346
347    // Seed with all stages that have no dependencies.
348    for i in 0..n {
349        if in_degree[i] == 0 {
350            current_wave.push(i);
351        }
352    }
353
354    let mut processed = 0usize;
355
356    while !current_wave.is_empty() {
357        // Sort the wave for deterministic output.
358        current_wave.sort_unstable();
359        processed += current_wave.len();
360
361        let mut next_wave: Vec<usize> = Vec::new();
362        for &stage in &current_wave {
363            for &neighbor in &adjacency[stage] {
364                in_degree[neighbor] -= 1;
365                if in_degree[neighbor] == 0 {
366                    next_wave.push(neighbor);
367                }
368            }
369        }
370
371        waves.push(std::mem::take(&mut current_wave));
372        current_wave = next_wave;
373    }
374
375    if processed != n {
376        // Some stages were not processed => cycle exists.
377        let remaining: Vec<usize> = (0..n).filter(|&i| in_degree[i] > 0).collect();
378        let names: Vec<String> = remaining
379            .iter()
380            .filter_map(|&i| {
381                deps.iter()
382                    .find(|d| d.stage_idx == i)
383                    .map(|_| format!("stage {i}"))
384            })
385            .collect();
386        let description = if names.is_empty() {
387            format!("cycle involving {remaining:?}")
388        } else {
389            format!("cycle involving: {}", names.join(", "))
390        };
391        return Err(OrchestratorError::CycleDetected { description });
392    }
393
394    Ok(waves)
395}
396
397// ─── Tests ──────────────────────────────────────────────────────────────────
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402    use std::sync::Arc;
403
404    /// A simple test stage that tracks how many times `step` was called
405    /// and records the order of execution into a shared log.
406    struct TestStage {
407        stage_name: String,
408        cost: f64,
409        call_count: u64,
410        execution_log: Arc<std::sync::Mutex<Vec<String>>>,
411    }
412
413    impl TestStage {
414        fn new(name: &str, cost: f64, log: Arc<std::sync::Mutex<Vec<String>>>) -> Self {
415            Self {
416                stage_name: name.to_string(),
417                cost,
418                call_count: 0,
419                execution_log: log,
420            }
421        }
422    }
423
424    impl SolverStage for TestStage {
425        fn name(&self) -> &str {
426            &self.stage_name
427        }
428
429        fn step(&mut self, _dt: f64) {
430            self.call_count += 1;
431            if let Ok(mut log) = self.execution_log.lock() {
432                log.push(self.stage_name.clone());
433            }
434        }
435
436        fn estimated_cost(&self) -> f64 {
437            self.cost
438        }
439    }
440
441    /// A stage that does a small busy-wait to produce measurable timing.
442    struct TimedStage {
443        stage_name: String,
444        spin_iters: u64,
445    }
446
447    impl TimedStage {
448        fn new(name: &str, spin_iters: u64) -> Self {
449            Self {
450                stage_name: name.to_string(),
451                spin_iters,
452            }
453        }
454    }
455
456    impl SolverStage for TimedStage {
457        fn name(&self) -> &str {
458            &self.stage_name
459        }
460
461        fn step(&mut self, _dt: f64) {
462            // Busy-spin to burn some CPU time.
463            let mut acc = 0u64;
464            for i in 0..self.spin_iters {
465                acc = acc.wrapping_add(i);
466            }
467            // Prevent the optimizer from eliding the loop.
468            std::hint::black_box(acc);
469        }
470
471        fn estimated_cost(&self) -> f64 {
472            self.spin_iters as f64
473        }
474    }
475
476    // ── Linear pipeline: A -> B -> C ────────────────────────────────────
477
478    #[test]
479    fn test_linear_pipeline() {
480        let mut orch = ParallelOrchestrator::new();
481        let a = orch.add_stage("A", &[]);
482        let b = orch.add_stage("B", &[a]);
483        let _c = orch.add_stage("C", &[b]);
484
485        let schedule = orch.compute_schedule().expect("scheduling should succeed");
486        assert_eq!(schedule.waves.len(), 3, "linear pipeline needs 3 waves");
487        assert_eq!(schedule.waves[0], vec![0]);
488        assert_eq!(schedule.waves[1], vec![1]);
489        assert_eq!(schedule.waves[2], vec![2]);
490
491        // Verify execution order.
492        let log = Arc::new(std::sync::Mutex::new(Vec::new()));
493        let mut stages: Vec<Box<dyn SolverStage>> = vec![
494            Box::new(TestStage::new("A", 1.0, Arc::clone(&log))),
495            Box::new(TestStage::new("B", 1.0, Arc::clone(&log))),
496            Box::new(TestStage::new("C", 1.0, Arc::clone(&log))),
497        ];
498
499        orch.execute(&mut stages, 0.016)
500            .expect("execute should succeed");
501
502        let recorded = log.lock().expect("lock should not be poisoned");
503        assert_eq!(&*recorded, &["A", "B", "C"]);
504    }
505
506    // ── Diamond dependency: A -> B, A -> C, B -> D, C -> D ──────────────
507
508    #[test]
509    fn test_diamond_dependency() {
510        let mut orch = ParallelOrchestrator::new();
511        let a = orch.add_stage("A", &[]);
512        let b = orch.add_stage("B", &[a]);
513        let c = orch.add_stage("C", &[a]);
514        let _d = orch.add_stage("D", &[b, c]);
515
516        let schedule = orch.compute_schedule().expect("scheduling should succeed");
517        assert_eq!(schedule.waves.len(), 3, "diamond needs 3 waves");
518        assert_eq!(schedule.waves[0], vec![0], "wave 0 has A");
519        // B and C should be in the same wave (sorted).
520        let mut wave1 = schedule.waves[1].clone();
521        wave1.sort_unstable();
522        assert_eq!(wave1, vec![1, 2], "wave 1 has B and C");
523        assert_eq!(schedule.waves[2], vec![3], "wave 2 has D");
524    }
525
526    // ── Cycle detection ─────────────────────────────────────────────────
527
528    #[test]
529    fn test_cycle_detection() {
530        // A -> B -> C -> A (cycle)
531        let deps = vec![
532            StageDependency {
533                stage_idx: 0,
534                depends_on: vec![2],
535            },
536            StageDependency {
537                stage_idx: 1,
538                depends_on: vec![0],
539            },
540            StageDependency {
541                stage_idx: 2,
542                depends_on: vec![1],
543            },
544        ];
545
546        let result = topological_sort(3, &deps);
547        assert!(result.is_err(), "cycle should produce an error");
548        match result {
549            Err(OrchestratorError::CycleDetected { description }) => {
550                assert!(
551                    description.contains("cycle"),
552                    "error should mention cycle: {description}"
553                );
554            }
555            other => panic!("expected CycleDetected, got {other:?}"),
556        }
557    }
558
559    // ── Empty pipeline ──────────────────────────────────────────────────
560
561    #[test]
562    fn test_empty_pipeline() {
563        let orch = ParallelOrchestrator::new();
564        let schedule = orch
565            .compute_schedule()
566            .expect("empty schedule should succeed");
567        assert!(schedule.waves.is_empty(), "empty pipeline has no waves");
568        assert_eq!(schedule.num_waves(), 0);
569        assert_eq!(schedule.num_stages(), 0);
570    }
571
572    // ── Single stage ────────────────────────────────────────────────────
573
574    #[test]
575    fn test_single_stage() {
576        let mut orch = ParallelOrchestrator::new();
577        orch.add_stage("only", &[]);
578
579        let schedule = orch
580            .compute_schedule()
581            .expect("single stage should succeed");
582        assert_eq!(schedule.waves.len(), 1);
583        assert_eq!(schedule.waves[0], vec![0]);
584
585        let log = Arc::new(std::sync::Mutex::new(Vec::new()));
586        let mut stages: Vec<Box<dyn SolverStage>> =
587            vec![Box::new(TestStage::new("only", 5.0, Arc::clone(&log)))];
588
589        orch.execute(&mut stages, 1.0)
590            .expect("execute should succeed");
591
592        let recorded = log.lock().expect("lock should not be poisoned");
593        assert_eq!(&*recorded, &["only"]);
594    }
595
596    // ── Timing accumulation ─────────────────────────────────────────────
597
598    #[test]
599    fn test_timing_accumulation() {
600        let mut orch = ParallelOrchestrator::new();
601        orch.add_stage("fast", &[]);
602        orch.add_stage("slow", &[]);
603
604        let mut stages: Vec<Box<dyn SolverStage>> = vec![
605            Box::new(TimedStage::new("fast", 1_000)),
606            Box::new(TimedStage::new("slow", 1_000_000)),
607        ];
608
609        // Run multiple times to accumulate.
610        for _ in 0..3 {
611            orch.execute(&mut stages, 0.01)
612                .expect("execute should succeed");
613        }
614
615        let timings = orch.timings();
616        assert_eq!(timings.len(), 2);
617        // Both stages should have recorded some positive time.
618        assert!(
619            timings[0] > 0.0,
620            "fast stage should have positive timing: {}",
621            timings[0]
622        );
623        assert!(
624            timings[1] > 0.0,
625            "slow stage should have positive timing: {}",
626            timings[1]
627        );
628        // Total should equal sum.
629        let total = orch.total_time();
630        let sum = timings[0] + timings[1];
631        assert!(
632            (total - sum).abs() < 1e-15,
633            "total {total} should equal sum {sum}"
634        );
635        // The slow stage should generally take more time than the fast one.
636        // (Not a strict assertion since OS scheduling can vary, but 1000x
637        // difference in iterations should be enough.)
638        assert!(
639            timings[1] > timings[0],
640            "slow stage ({}) should take longer than fast stage ({})",
641            timings[1],
642            timings[0]
643        );
644    }
645
646    // ── Invalid stage index ─────────────────────────────────────────────
647
648    #[test]
649    fn test_invalid_stage_index() {
650        let deps = vec![StageDependency {
651            stage_idx: 0,
652            depends_on: vec![5], // 5 is invalid for n=2
653        }];
654
655        let result = topological_sort(2, &deps);
656        assert!(result.is_err());
657        match result {
658            Err(OrchestratorError::InvalidStageIndex {
659                index: 5,
660                total_stages: 2,
661            }) => {} // expected
662            other => panic!("expected InvalidStageIndex, got {other:?}"),
663        }
664    }
665
666    // ── Multiple independent stages in one wave ─────────────────────────
667
668    #[test]
669    fn test_all_independent() {
670        let mut orch = ParallelOrchestrator::new();
671        orch.add_stage("X", &[]);
672        orch.add_stage("Y", &[]);
673        orch.add_stage("Z", &[]);
674
675        let schedule = orch.compute_schedule().expect("should succeed");
676        assert_eq!(
677            schedule.waves.len(),
678            1,
679            "all-independent stages fit in one wave"
680        );
681        assert_eq!(schedule.waves[0], vec![0, 1, 2]);
682    }
683
684    // ── Wide diamond (fan-out + fan-in) ─────────────────────────────────
685
686    #[test]
687    fn test_wide_fan_out_fan_in() {
688        // Root -> 4 parallel stages -> sink
689        let mut orch = ParallelOrchestrator::new();
690        let root = orch.add_stage("root", &[]);
691        let mid: Vec<usize> = (0..4)
692            .map(|i| orch.add_stage(&format!("mid_{i}"), &[root]))
693            .collect();
694        let _sink = orch.add_stage("sink", &mid);
695
696        let schedule = orch.compute_schedule().expect("should succeed");
697        assert_eq!(schedule.waves.len(), 3);
698        assert_eq!(schedule.waves[0], vec![0]); // root
699        assert_eq!(schedule.waves[1], vec![1, 2, 3, 4]); // mid_0..mid_3
700        assert_eq!(schedule.waves[2], vec![5]); // sink
701    }
702
703    // ── Execution with diamond verifies correct ordering ────────────────
704
705    #[test]
706    fn test_diamond_execution_order() {
707        let mut orch = ParallelOrchestrator::new();
708        let a = orch.add_stage("A", &[]);
709        let b = orch.add_stage("B", &[a]);
710        let c = orch.add_stage("C", &[a]);
711        let _d = orch.add_stage("D", &[b, c]);
712
713        let log = Arc::new(std::sync::Mutex::new(Vec::new()));
714        let mut stages: Vec<Box<dyn SolverStage>> = vec![
715            Box::new(TestStage::new("A", 1.0, Arc::clone(&log))),
716            Box::new(TestStage::new("B", 2.0, Arc::clone(&log))),
717            Box::new(TestStage::new("C", 1.5, Arc::clone(&log))),
718            Box::new(TestStage::new("D", 3.0, Arc::clone(&log))),
719        ];
720
721        orch.execute(&mut stages, 0.01)
722            .expect("execute should succeed");
723
724        let recorded = log.lock().expect("lock should not be poisoned");
725        // A must come before B, C. D must come after both B and C.
726        let pos_a = recorded
727            .iter()
728            .position(|s| s == "A")
729            .expect("A should be in log");
730        let pos_b = recorded
731            .iter()
732            .position(|s| s == "B")
733            .expect("B should be in log");
734        let pos_c = recorded
735            .iter()
736            .position(|s| s == "C")
737            .expect("C should be in log");
738        let pos_d = recorded
739            .iter()
740            .position(|s| s == "D")
741            .expect("D should be in log");
742
743        assert!(pos_a < pos_b, "A must run before B");
744        assert!(pos_a < pos_c, "A must run before C");
745        assert!(pos_b < pos_d, "B must run before D");
746        assert!(pos_c < pos_d, "C must run before D");
747    }
748
749    // ── Reset timings ───────────────────────────────────────────────────
750
751    #[test]
752    fn test_reset_timings() {
753        let mut orch = ParallelOrchestrator::new();
754        orch.add_stage("A", &[]);
755
756        let mut stages: Vec<Box<dyn SolverStage>> = vec![Box::new(TimedStage::new("A", 100_000))];
757
758        orch.execute(&mut stages, 0.01)
759            .expect("execute should succeed");
760        assert!(orch.total_time() > 0.0);
761
762        orch.reset_timings();
763        assert!(
764            orch.total_time().abs() < 1e-15,
765            "timings should be zero after reset"
766        );
767    }
768
769    // ── Self-dependency is a cycle ──────────────────────────────────────
770
771    #[test]
772    fn test_self_dependency_cycle() {
773        let deps = vec![StageDependency {
774            stage_idx: 0,
775            depends_on: vec![0],
776        }];
777
778        let result = topological_sort(1, &deps);
779        assert!(result.is_err(), "self-dependency should be a cycle");
780        match result {
781            Err(OrchestratorError::CycleDetected { .. }) => {} // expected
782            other => panic!("expected CycleDetected, got {other:?}"),
783        }
784    }
785
786    // ── PipelineSchedule helpers ────────────────────────────────────────
787
788    #[test]
789    fn test_pipeline_schedule_helpers() {
790        let schedule = PipelineSchedule {
791            waves: vec![vec![0, 1], vec![2], vec![3, 4, 5]],
792        };
793        assert_eq!(schedule.num_waves(), 3);
794        assert_eq!(schedule.num_stages(), 6);
795    }
796
797    // ── OrchestratorError display ───────────────────────────────────────
798
799    #[test]
800    fn test_error_display() {
801        let err = OrchestratorError::CycleDetected {
802            description: "A -> B -> A".to_string(),
803        };
804        let msg = format!("{err}");
805        assert!(msg.contains("cycle"));
806        assert!(msg.contains("A -> B -> A"));
807
808        let err2 = OrchestratorError::InvalidStageIndex {
809            index: 10,
810            total_stages: 3,
811        };
812        let msg2 = format!("{err2}");
813        assert!(msg2.contains("10"));
814        assert!(msg2.contains("3"));
815
816        let err3 = OrchestratorError::ExecutionIndexOutOfRange {
817            index: 5,
818            stages_len: 2,
819        };
820        let msg3 = format!("{err3}");
821        assert!(msg3.contains("5"));
822        assert!(msg3.contains("2"));
823    }
824
825    // ── Default trait impl ──────────────────────────────────────────────
826
827    #[test]
828    fn test_default_orchestrator() {
829        let orch = ParallelOrchestrator::default();
830        assert_eq!(orch.num_stages(), 0);
831        assert!(orch.timings().is_empty());
832        assert!(orch.total_time().abs() < 1e-15);
833    }
834}