1#![allow(clippy::needless_range_loop)]
2#![allow(dead_code)]
29
30use std::fmt;
31use std::time::Instant;
32
33#[derive(Debug, Clone, PartialEq, Eq)]
37pub enum OrchestratorError {
38 CycleDetected {
40 description: String,
42 },
43 InvalidStageIndex {
45 index: usize,
47 total_stages: usize,
49 },
50 ExecutionIndexOutOfRange {
52 index: usize,
54 stages_len: usize,
56 },
57}
58
59impl fmt::Display for OrchestratorError {
60 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61 match self {
62 OrchestratorError::CycleDetected { description } => {
63 write!(f, "dependency cycle detected: {description}")
64 }
65 OrchestratorError::InvalidStageIndex {
66 index,
67 total_stages,
68 } => {
69 write!(
70 f,
71 "invalid stage index {index} (total stages: {total_stages})"
72 )
73 }
74 OrchestratorError::ExecutionIndexOutOfRange { index, stages_len } => {
75 write!(
76 f,
77 "execution index {index} out of range (stages slice length: {stages_len})"
78 )
79 }
80 }
81 }
82}
83
84impl std::error::Error for OrchestratorError {}
85
86pub trait SolverStage: Send + Sync {
95 fn name(&self) -> &str;
97
98 fn step(&mut self, dt: f64);
100
101 fn estimated_cost(&self) -> f64;
106}
107
108#[derive(Debug, Clone, PartialEq, Eq)]
112pub struct StageDependency {
113 pub stage_idx: usize,
115 pub depends_on: Vec<usize>,
117}
118
119#[derive(Debug, Clone, PartialEq, Eq)]
127pub struct PipelineSchedule {
128 pub waves: Vec<Vec<usize>>,
130}
131
132impl PipelineSchedule {
133 pub fn num_waves(&self) -> usize {
135 self.waves.len()
136 }
137
138 pub fn num_stages(&self) -> usize {
140 self.waves.iter().map(|w| w.len()).sum()
141 }
142}
143
144#[derive(Debug, Clone)]
175pub struct ParallelOrchestrator {
176 stage_names: Vec<String>,
178 dependencies: Vec<StageDependency>,
180 timings: Vec<f64>,
182}
183
184impl ParallelOrchestrator {
185 pub fn new() -> Self {
187 Self {
188 stage_names: Vec::new(),
189 dependencies: Vec::new(),
190 timings: Vec::new(),
191 }
192 }
193
194 pub fn add_stage(&mut self, name: &str, depends_on: &[usize]) -> usize {
203 let idx = self.stage_names.len();
204 self.stage_names.push(name.to_string());
205 self.dependencies.push(StageDependency {
206 stage_idx: idx,
207 depends_on: depends_on.to_vec(),
208 });
209 self.timings.push(0.0);
210 idx
211 }
212
213 pub fn num_stages(&self) -> usize {
215 self.stage_names.len()
216 }
217
218 pub fn stage_names(&self) -> &[String] {
220 &self.stage_names
221 }
222
223 pub fn compute_schedule(&self) -> Result<PipelineSchedule, OrchestratorError> {
228 let waves = topological_sort(self.stage_names.len(), &self.dependencies)?;
229 Ok(PipelineSchedule { waves })
230 }
231
232 pub fn execute(
241 &mut self,
242 stages: &mut [Box<dyn SolverStage>],
243 dt: f64,
244 ) -> Result<(), OrchestratorError> {
245 let schedule = self.compute_schedule()?;
246
247 for wave in &schedule.waves {
248 for &stage_idx in wave {
252 if stage_idx >= stages.len() {
253 return Err(OrchestratorError::ExecutionIndexOutOfRange {
254 index: stage_idx,
255 stages_len: stages.len(),
256 });
257 }
258 let start = Instant::now();
259 stages[stage_idx].step(dt);
260 let elapsed = start.elapsed().as_secs_f64();
261 if stage_idx < self.timings.len() {
262 self.timings[stage_idx] += elapsed;
263 }
264 }
265 }
266
267 Ok(())
268 }
269
270 pub fn timings(&self) -> &[f64] {
272 &self.timings
273 }
274
275 pub fn total_time(&self) -> f64 {
277 self.timings.iter().sum()
278 }
279
280 pub fn reset_timings(&mut self) {
282 for t in &mut self.timings {
283 *t = 0.0;
284 }
285 }
286}
287
288impl Default for ParallelOrchestrator {
289 fn default() -> Self {
290 Self::new()
291 }
292}
293
294pub fn topological_sort(
306 n: usize,
307 deps: &[StageDependency],
308) -> Result<Vec<Vec<usize>>, OrchestratorError> {
309 if n == 0 {
310 return Ok(Vec::new());
311 }
312
313 for dep in deps {
315 if dep.stage_idx >= n {
316 return Err(OrchestratorError::InvalidStageIndex {
317 index: dep.stage_idx,
318 total_stages: n,
319 });
320 }
321 for &d in &dep.depends_on {
322 if d >= n {
323 return Err(OrchestratorError::InvalidStageIndex {
324 index: d,
325 total_stages: n,
326 });
327 }
328 }
329 }
330
331 let mut adjacency: Vec<Vec<usize>> = vec![Vec::new(); n];
334 let mut in_degree: Vec<usize> = vec![0; n];
335
336 for dep in deps {
337 for &d in &dep.depends_on {
338 adjacency[d].push(dep.stage_idx);
339 in_degree[dep.stage_idx] += 1;
340 }
341 }
342
343 let mut waves: Vec<Vec<usize>> = Vec::new();
345 let mut current_wave: Vec<usize> = Vec::new();
346
347 for i in 0..n {
349 if in_degree[i] == 0 {
350 current_wave.push(i);
351 }
352 }
353
354 let mut processed = 0usize;
355
356 while !current_wave.is_empty() {
357 current_wave.sort_unstable();
359 processed += current_wave.len();
360
361 let mut next_wave: Vec<usize> = Vec::new();
362 for &stage in ¤t_wave {
363 for &neighbor in &adjacency[stage] {
364 in_degree[neighbor] -= 1;
365 if in_degree[neighbor] == 0 {
366 next_wave.push(neighbor);
367 }
368 }
369 }
370
371 waves.push(std::mem::take(&mut current_wave));
372 current_wave = next_wave;
373 }
374
375 if processed != n {
376 let remaining: Vec<usize> = (0..n).filter(|&i| in_degree[i] > 0).collect();
378 let names: Vec<String> = remaining
379 .iter()
380 .filter_map(|&i| {
381 deps.iter()
382 .find(|d| d.stage_idx == i)
383 .map(|_| format!("stage {i}"))
384 })
385 .collect();
386 let description = if names.is_empty() {
387 format!("cycle involving {remaining:?}")
388 } else {
389 format!("cycle involving: {}", names.join(", "))
390 };
391 return Err(OrchestratorError::CycleDetected { description });
392 }
393
394 Ok(waves)
395}
396
397#[cfg(test)]
400mod tests {
401 use super::*;
402 use std::sync::Arc;
403
404 struct TestStage {
407 stage_name: String,
408 cost: f64,
409 call_count: u64,
410 execution_log: Arc<std::sync::Mutex<Vec<String>>>,
411 }
412
413 impl TestStage {
414 fn new(name: &str, cost: f64, log: Arc<std::sync::Mutex<Vec<String>>>) -> Self {
415 Self {
416 stage_name: name.to_string(),
417 cost,
418 call_count: 0,
419 execution_log: log,
420 }
421 }
422 }
423
424 impl SolverStage for TestStage {
425 fn name(&self) -> &str {
426 &self.stage_name
427 }
428
429 fn step(&mut self, _dt: f64) {
430 self.call_count += 1;
431 if let Ok(mut log) = self.execution_log.lock() {
432 log.push(self.stage_name.clone());
433 }
434 }
435
436 fn estimated_cost(&self) -> f64 {
437 self.cost
438 }
439 }
440
441 struct TimedStage {
443 stage_name: String,
444 spin_iters: u64,
445 }
446
447 impl TimedStage {
448 fn new(name: &str, spin_iters: u64) -> Self {
449 Self {
450 stage_name: name.to_string(),
451 spin_iters,
452 }
453 }
454 }
455
456 impl SolverStage for TimedStage {
457 fn name(&self) -> &str {
458 &self.stage_name
459 }
460
461 fn step(&mut self, _dt: f64) {
462 let mut acc = 0u64;
464 for i in 0..self.spin_iters {
465 acc = acc.wrapping_add(i);
466 }
467 std::hint::black_box(acc);
469 }
470
471 fn estimated_cost(&self) -> f64 {
472 self.spin_iters as f64
473 }
474 }
475
476 #[test]
479 fn test_linear_pipeline() {
480 let mut orch = ParallelOrchestrator::new();
481 let a = orch.add_stage("A", &[]);
482 let b = orch.add_stage("B", &[a]);
483 let _c = orch.add_stage("C", &[b]);
484
485 let schedule = orch.compute_schedule().expect("scheduling should succeed");
486 assert_eq!(schedule.waves.len(), 3, "linear pipeline needs 3 waves");
487 assert_eq!(schedule.waves[0], vec![0]);
488 assert_eq!(schedule.waves[1], vec![1]);
489 assert_eq!(schedule.waves[2], vec![2]);
490
491 let log = Arc::new(std::sync::Mutex::new(Vec::new()));
493 let mut stages: Vec<Box<dyn SolverStage>> = vec![
494 Box::new(TestStage::new("A", 1.0, Arc::clone(&log))),
495 Box::new(TestStage::new("B", 1.0, Arc::clone(&log))),
496 Box::new(TestStage::new("C", 1.0, Arc::clone(&log))),
497 ];
498
499 orch.execute(&mut stages, 0.016)
500 .expect("execute should succeed");
501
502 let recorded = log.lock().expect("lock should not be poisoned");
503 assert_eq!(&*recorded, &["A", "B", "C"]);
504 }
505
506 #[test]
509 fn test_diamond_dependency() {
510 let mut orch = ParallelOrchestrator::new();
511 let a = orch.add_stage("A", &[]);
512 let b = orch.add_stage("B", &[a]);
513 let c = orch.add_stage("C", &[a]);
514 let _d = orch.add_stage("D", &[b, c]);
515
516 let schedule = orch.compute_schedule().expect("scheduling should succeed");
517 assert_eq!(schedule.waves.len(), 3, "diamond needs 3 waves");
518 assert_eq!(schedule.waves[0], vec![0], "wave 0 has A");
519 let mut wave1 = schedule.waves[1].clone();
521 wave1.sort_unstable();
522 assert_eq!(wave1, vec![1, 2], "wave 1 has B and C");
523 assert_eq!(schedule.waves[2], vec![3], "wave 2 has D");
524 }
525
526 #[test]
529 fn test_cycle_detection() {
530 let deps = vec![
532 StageDependency {
533 stage_idx: 0,
534 depends_on: vec![2],
535 },
536 StageDependency {
537 stage_idx: 1,
538 depends_on: vec![0],
539 },
540 StageDependency {
541 stage_idx: 2,
542 depends_on: vec![1],
543 },
544 ];
545
546 let result = topological_sort(3, &deps);
547 assert!(result.is_err(), "cycle should produce an error");
548 match result {
549 Err(OrchestratorError::CycleDetected { description }) => {
550 assert!(
551 description.contains("cycle"),
552 "error should mention cycle: {description}"
553 );
554 }
555 other => panic!("expected CycleDetected, got {other:?}"),
556 }
557 }
558
559 #[test]
562 fn test_empty_pipeline() {
563 let orch = ParallelOrchestrator::new();
564 let schedule = orch
565 .compute_schedule()
566 .expect("empty schedule should succeed");
567 assert!(schedule.waves.is_empty(), "empty pipeline has no waves");
568 assert_eq!(schedule.num_waves(), 0);
569 assert_eq!(schedule.num_stages(), 0);
570 }
571
572 #[test]
575 fn test_single_stage() {
576 let mut orch = ParallelOrchestrator::new();
577 orch.add_stage("only", &[]);
578
579 let schedule = orch
580 .compute_schedule()
581 .expect("single stage should succeed");
582 assert_eq!(schedule.waves.len(), 1);
583 assert_eq!(schedule.waves[0], vec![0]);
584
585 let log = Arc::new(std::sync::Mutex::new(Vec::new()));
586 let mut stages: Vec<Box<dyn SolverStage>> =
587 vec![Box::new(TestStage::new("only", 5.0, Arc::clone(&log)))];
588
589 orch.execute(&mut stages, 1.0)
590 .expect("execute should succeed");
591
592 let recorded = log.lock().expect("lock should not be poisoned");
593 assert_eq!(&*recorded, &["only"]);
594 }
595
596 #[test]
599 fn test_timing_accumulation() {
600 let mut orch = ParallelOrchestrator::new();
601 orch.add_stage("fast", &[]);
602 orch.add_stage("slow", &[]);
603
604 let mut stages: Vec<Box<dyn SolverStage>> = vec![
605 Box::new(TimedStage::new("fast", 1_000)),
606 Box::new(TimedStage::new("slow", 1_000_000)),
607 ];
608
609 for _ in 0..3 {
611 orch.execute(&mut stages, 0.01)
612 .expect("execute should succeed");
613 }
614
615 let timings = orch.timings();
616 assert_eq!(timings.len(), 2);
617 assert!(
619 timings[0] > 0.0,
620 "fast stage should have positive timing: {}",
621 timings[0]
622 );
623 assert!(
624 timings[1] > 0.0,
625 "slow stage should have positive timing: {}",
626 timings[1]
627 );
628 let total = orch.total_time();
630 let sum = timings[0] + timings[1];
631 assert!(
632 (total - sum).abs() < 1e-15,
633 "total {total} should equal sum {sum}"
634 );
635 assert!(
639 timings[1] > timings[0],
640 "slow stage ({}) should take longer than fast stage ({})",
641 timings[1],
642 timings[0]
643 );
644 }
645
646 #[test]
649 fn test_invalid_stage_index() {
650 let deps = vec![StageDependency {
651 stage_idx: 0,
652 depends_on: vec![5], }];
654
655 let result = topological_sort(2, &deps);
656 assert!(result.is_err());
657 match result {
658 Err(OrchestratorError::InvalidStageIndex {
659 index: 5,
660 total_stages: 2,
661 }) => {} other => panic!("expected InvalidStageIndex, got {other:?}"),
663 }
664 }
665
666 #[test]
669 fn test_all_independent() {
670 let mut orch = ParallelOrchestrator::new();
671 orch.add_stage("X", &[]);
672 orch.add_stage("Y", &[]);
673 orch.add_stage("Z", &[]);
674
675 let schedule = orch.compute_schedule().expect("should succeed");
676 assert_eq!(
677 schedule.waves.len(),
678 1,
679 "all-independent stages fit in one wave"
680 );
681 assert_eq!(schedule.waves[0], vec![0, 1, 2]);
682 }
683
684 #[test]
687 fn test_wide_fan_out_fan_in() {
688 let mut orch = ParallelOrchestrator::new();
690 let root = orch.add_stage("root", &[]);
691 let mid: Vec<usize> = (0..4)
692 .map(|i| orch.add_stage(&format!("mid_{i}"), &[root]))
693 .collect();
694 let _sink = orch.add_stage("sink", &mid);
695
696 let schedule = orch.compute_schedule().expect("should succeed");
697 assert_eq!(schedule.waves.len(), 3);
698 assert_eq!(schedule.waves[0], vec![0]); assert_eq!(schedule.waves[1], vec![1, 2, 3, 4]); assert_eq!(schedule.waves[2], vec![5]); }
702
703 #[test]
706 fn test_diamond_execution_order() {
707 let mut orch = ParallelOrchestrator::new();
708 let a = orch.add_stage("A", &[]);
709 let b = orch.add_stage("B", &[a]);
710 let c = orch.add_stage("C", &[a]);
711 let _d = orch.add_stage("D", &[b, c]);
712
713 let log = Arc::new(std::sync::Mutex::new(Vec::new()));
714 let mut stages: Vec<Box<dyn SolverStage>> = vec![
715 Box::new(TestStage::new("A", 1.0, Arc::clone(&log))),
716 Box::new(TestStage::new("B", 2.0, Arc::clone(&log))),
717 Box::new(TestStage::new("C", 1.5, Arc::clone(&log))),
718 Box::new(TestStage::new("D", 3.0, Arc::clone(&log))),
719 ];
720
721 orch.execute(&mut stages, 0.01)
722 .expect("execute should succeed");
723
724 let recorded = log.lock().expect("lock should not be poisoned");
725 let pos_a = recorded
727 .iter()
728 .position(|s| s == "A")
729 .expect("A should be in log");
730 let pos_b = recorded
731 .iter()
732 .position(|s| s == "B")
733 .expect("B should be in log");
734 let pos_c = recorded
735 .iter()
736 .position(|s| s == "C")
737 .expect("C should be in log");
738 let pos_d = recorded
739 .iter()
740 .position(|s| s == "D")
741 .expect("D should be in log");
742
743 assert!(pos_a < pos_b, "A must run before B");
744 assert!(pos_a < pos_c, "A must run before C");
745 assert!(pos_b < pos_d, "B must run before D");
746 assert!(pos_c < pos_d, "C must run before D");
747 }
748
749 #[test]
752 fn test_reset_timings() {
753 let mut orch = ParallelOrchestrator::new();
754 orch.add_stage("A", &[]);
755
756 let mut stages: Vec<Box<dyn SolverStage>> = vec![Box::new(TimedStage::new("A", 100_000))];
757
758 orch.execute(&mut stages, 0.01)
759 .expect("execute should succeed");
760 assert!(orch.total_time() > 0.0);
761
762 orch.reset_timings();
763 assert!(
764 orch.total_time().abs() < 1e-15,
765 "timings should be zero after reset"
766 );
767 }
768
769 #[test]
772 fn test_self_dependency_cycle() {
773 let deps = vec![StageDependency {
774 stage_idx: 0,
775 depends_on: vec![0],
776 }];
777
778 let result = topological_sort(1, &deps);
779 assert!(result.is_err(), "self-dependency should be a cycle");
780 match result {
781 Err(OrchestratorError::CycleDetected { .. }) => {} other => panic!("expected CycleDetected, got {other:?}"),
783 }
784 }
785
786 #[test]
789 fn test_pipeline_schedule_helpers() {
790 let schedule = PipelineSchedule {
791 waves: vec![vec![0, 1], vec![2], vec![3, 4, 5]],
792 };
793 assert_eq!(schedule.num_waves(), 3);
794 assert_eq!(schedule.num_stages(), 6);
795 }
796
797 #[test]
800 fn test_error_display() {
801 let err = OrchestratorError::CycleDetected {
802 description: "A -> B -> A".to_string(),
803 };
804 let msg = format!("{err}");
805 assert!(msg.contains("cycle"));
806 assert!(msg.contains("A -> B -> A"));
807
808 let err2 = OrchestratorError::InvalidStageIndex {
809 index: 10,
810 total_stages: 3,
811 };
812 let msg2 = format!("{err2}");
813 assert!(msg2.contains("10"));
814 assert!(msg2.contains("3"));
815
816 let err3 = OrchestratorError::ExecutionIndexOutOfRange {
817 index: 5,
818 stages_len: 2,
819 };
820 let msg3 = format!("{err3}");
821 assert!(msg3.contains("5"));
822 assert!(msg3.contains("2"));
823 }
824
825 #[test]
828 fn test_default_orchestrator() {
829 let orch = ParallelOrchestrator::default();
830 assert_eq!(orch.num_stages(), 0);
831 assert!(orch.timings().is_empty());
832 assert!(orch.total_time().abs() < 1e-15);
833 }
834}