#![allow(clippy::needless_range_loop)]
#![allow(dead_code)]
use std::fmt;
use std::time::Instant;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum OrchestratorError {
CycleDetected {
description: String,
},
InvalidStageIndex {
index: usize,
total_stages: usize,
},
ExecutionIndexOutOfRange {
index: usize,
stages_len: usize,
},
}
impl fmt::Display for OrchestratorError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
OrchestratorError::CycleDetected { description } => {
write!(f, "dependency cycle detected: {description}")
}
OrchestratorError::InvalidStageIndex {
index,
total_stages,
} => {
write!(
f,
"invalid stage index {index} (total stages: {total_stages})"
)
}
OrchestratorError::ExecutionIndexOutOfRange { index, stages_len } => {
write!(
f,
"execution index {index} out of range (stages slice length: {stages_len})"
)
}
}
}
}
impl std::error::Error for OrchestratorError {}
pub trait SolverStage: Send + Sync {
fn name(&self) -> &str;
fn step(&mut self, dt: f64);
fn estimated_cost(&self) -> f64;
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct StageDependency {
pub stage_idx: usize,
pub depends_on: Vec<usize>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PipelineSchedule {
pub waves: Vec<Vec<usize>>,
}
impl PipelineSchedule {
pub fn num_waves(&self) -> usize {
self.waves.len()
}
pub fn num_stages(&self) -> usize {
self.waves.iter().map(|w| w.len()).sum()
}
}
#[derive(Debug, Clone)]
pub struct ParallelOrchestrator {
stage_names: Vec<String>,
dependencies: Vec<StageDependency>,
timings: Vec<f64>,
}
impl ParallelOrchestrator {
pub fn new() -> Self {
Self {
stage_names: Vec::new(),
dependencies: Vec::new(),
timings: Vec::new(),
}
}
pub fn add_stage(&mut self, name: &str, depends_on: &[usize]) -> usize {
let idx = self.stage_names.len();
self.stage_names.push(name.to_string());
self.dependencies.push(StageDependency {
stage_idx: idx,
depends_on: depends_on.to_vec(),
});
self.timings.push(0.0);
idx
}
pub fn num_stages(&self) -> usize {
self.stage_names.len()
}
pub fn stage_names(&self) -> &[String] {
&self.stage_names
}
pub fn compute_schedule(&self) -> Result<PipelineSchedule, OrchestratorError> {
let waves = topological_sort(self.stage_names.len(), &self.dependencies)?;
Ok(PipelineSchedule { waves })
}
pub fn execute(
&mut self,
stages: &mut [Box<dyn SolverStage>],
dt: f64,
) -> Result<(), OrchestratorError> {
let schedule = self.compute_schedule()?;
for wave in &schedule.waves {
for &stage_idx in wave {
if stage_idx >= stages.len() {
return Err(OrchestratorError::ExecutionIndexOutOfRange {
index: stage_idx,
stages_len: stages.len(),
});
}
let start = Instant::now();
stages[stage_idx].step(dt);
let elapsed = start.elapsed().as_secs_f64();
if stage_idx < self.timings.len() {
self.timings[stage_idx] += elapsed;
}
}
}
Ok(())
}
pub fn timings(&self) -> &[f64] {
&self.timings
}
pub fn total_time(&self) -> f64 {
self.timings.iter().sum()
}
pub fn reset_timings(&mut self) {
for t in &mut self.timings {
*t = 0.0;
}
}
}
impl Default for ParallelOrchestrator {
fn default() -> Self {
Self::new()
}
}
pub fn topological_sort(
n: usize,
deps: &[StageDependency],
) -> Result<Vec<Vec<usize>>, OrchestratorError> {
if n == 0 {
return Ok(Vec::new());
}
for dep in deps {
if dep.stage_idx >= n {
return Err(OrchestratorError::InvalidStageIndex {
index: dep.stage_idx,
total_stages: n,
});
}
for &d in &dep.depends_on {
if d >= n {
return Err(OrchestratorError::InvalidStageIndex {
index: d,
total_stages: n,
});
}
}
}
let mut adjacency: Vec<Vec<usize>> = vec![Vec::new(); n];
let mut in_degree: Vec<usize> = vec![0; n];
for dep in deps {
for &d in &dep.depends_on {
adjacency[d].push(dep.stage_idx);
in_degree[dep.stage_idx] += 1;
}
}
let mut waves: Vec<Vec<usize>> = Vec::new();
let mut current_wave: Vec<usize> = Vec::new();
for i in 0..n {
if in_degree[i] == 0 {
current_wave.push(i);
}
}
let mut processed = 0usize;
while !current_wave.is_empty() {
current_wave.sort_unstable();
processed += current_wave.len();
let mut next_wave: Vec<usize> = Vec::new();
for &stage in ¤t_wave {
for &neighbor in &adjacency[stage] {
in_degree[neighbor] -= 1;
if in_degree[neighbor] == 0 {
next_wave.push(neighbor);
}
}
}
waves.push(std::mem::take(&mut current_wave));
current_wave = next_wave;
}
if processed != n {
let remaining: Vec<usize> = (0..n).filter(|&i| in_degree[i] > 0).collect();
let names: Vec<String> = remaining
.iter()
.filter_map(|&i| {
deps.iter()
.find(|d| d.stage_idx == i)
.map(|_| format!("stage {i}"))
})
.collect();
let description = if names.is_empty() {
format!("cycle involving {remaining:?}")
} else {
format!("cycle involving: {}", names.join(", "))
};
return Err(OrchestratorError::CycleDetected { description });
}
Ok(waves)
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
struct TestStage {
stage_name: String,
cost: f64,
call_count: u64,
execution_log: Arc<std::sync::Mutex<Vec<String>>>,
}
impl TestStage {
fn new(name: &str, cost: f64, log: Arc<std::sync::Mutex<Vec<String>>>) -> Self {
Self {
stage_name: name.to_string(),
cost,
call_count: 0,
execution_log: log,
}
}
}
impl SolverStage for TestStage {
fn name(&self) -> &str {
&self.stage_name
}
fn step(&mut self, _dt: f64) {
self.call_count += 1;
if let Ok(mut log) = self.execution_log.lock() {
log.push(self.stage_name.clone());
}
}
fn estimated_cost(&self) -> f64 {
self.cost
}
}
struct TimedStage {
stage_name: String,
spin_iters: u64,
}
impl TimedStage {
fn new(name: &str, spin_iters: u64) -> Self {
Self {
stage_name: name.to_string(),
spin_iters,
}
}
}
impl SolverStage for TimedStage {
fn name(&self) -> &str {
&self.stage_name
}
fn step(&mut self, _dt: f64) {
let mut acc = 0u64;
for i in 0..self.spin_iters {
acc = acc.wrapping_add(i);
}
std::hint::black_box(acc);
}
fn estimated_cost(&self) -> f64 {
self.spin_iters as f64
}
}
#[test]
fn test_linear_pipeline() {
let mut orch = ParallelOrchestrator::new();
let a = orch.add_stage("A", &[]);
let b = orch.add_stage("B", &[a]);
let _c = orch.add_stage("C", &[b]);
let schedule = orch.compute_schedule().expect("scheduling should succeed");
assert_eq!(schedule.waves.len(), 3, "linear pipeline needs 3 waves");
assert_eq!(schedule.waves[0], vec![0]);
assert_eq!(schedule.waves[1], vec![1]);
assert_eq!(schedule.waves[2], vec![2]);
let log = Arc::new(std::sync::Mutex::new(Vec::new()));
let mut stages: Vec<Box<dyn SolverStage>> = vec![
Box::new(TestStage::new("A", 1.0, Arc::clone(&log))),
Box::new(TestStage::new("B", 1.0, Arc::clone(&log))),
Box::new(TestStage::new("C", 1.0, Arc::clone(&log))),
];
orch.execute(&mut stages, 0.016)
.expect("execute should succeed");
let recorded = log.lock().expect("lock should not be poisoned");
assert_eq!(&*recorded, &["A", "B", "C"]);
}
#[test]
fn test_diamond_dependency() {
let mut orch = ParallelOrchestrator::new();
let a = orch.add_stage("A", &[]);
let b = orch.add_stage("B", &[a]);
let c = orch.add_stage("C", &[a]);
let _d = orch.add_stage("D", &[b, c]);
let schedule = orch.compute_schedule().expect("scheduling should succeed");
assert_eq!(schedule.waves.len(), 3, "diamond needs 3 waves");
assert_eq!(schedule.waves[0], vec![0], "wave 0 has A");
let mut wave1 = schedule.waves[1].clone();
wave1.sort_unstable();
assert_eq!(wave1, vec![1, 2], "wave 1 has B and C");
assert_eq!(schedule.waves[2], vec![3], "wave 2 has D");
}
#[test]
fn test_cycle_detection() {
let deps = vec![
StageDependency {
stage_idx: 0,
depends_on: vec![2],
},
StageDependency {
stage_idx: 1,
depends_on: vec![0],
},
StageDependency {
stage_idx: 2,
depends_on: vec![1],
},
];
let result = topological_sort(3, &deps);
assert!(result.is_err(), "cycle should produce an error");
match result {
Err(OrchestratorError::CycleDetected { description }) => {
assert!(
description.contains("cycle"),
"error should mention cycle: {description}"
);
}
other => panic!("expected CycleDetected, got {other:?}"),
}
}
#[test]
fn test_empty_pipeline() {
let orch = ParallelOrchestrator::new();
let schedule = orch
.compute_schedule()
.expect("empty schedule should succeed");
assert!(schedule.waves.is_empty(), "empty pipeline has no waves");
assert_eq!(schedule.num_waves(), 0);
assert_eq!(schedule.num_stages(), 0);
}
#[test]
fn test_single_stage() {
let mut orch = ParallelOrchestrator::new();
orch.add_stage("only", &[]);
let schedule = orch
.compute_schedule()
.expect("single stage should succeed");
assert_eq!(schedule.waves.len(), 1);
assert_eq!(schedule.waves[0], vec![0]);
let log = Arc::new(std::sync::Mutex::new(Vec::new()));
let mut stages: Vec<Box<dyn SolverStage>> =
vec![Box::new(TestStage::new("only", 5.0, Arc::clone(&log)))];
orch.execute(&mut stages, 1.0)
.expect("execute should succeed");
let recorded = log.lock().expect("lock should not be poisoned");
assert_eq!(&*recorded, &["only"]);
}
#[test]
fn test_timing_accumulation() {
let mut orch = ParallelOrchestrator::new();
orch.add_stage("fast", &[]);
orch.add_stage("slow", &[]);
let mut stages: Vec<Box<dyn SolverStage>> = vec![
Box::new(TimedStage::new("fast", 1_000)),
Box::new(TimedStage::new("slow", 1_000_000)),
];
for _ in 0..3 {
orch.execute(&mut stages, 0.01)
.expect("execute should succeed");
}
let timings = orch.timings();
assert_eq!(timings.len(), 2);
assert!(
timings[0] > 0.0,
"fast stage should have positive timing: {}",
timings[0]
);
assert!(
timings[1] > 0.0,
"slow stage should have positive timing: {}",
timings[1]
);
let total = orch.total_time();
let sum = timings[0] + timings[1];
assert!(
(total - sum).abs() < 1e-15,
"total {total} should equal sum {sum}"
);
assert!(
timings[1] > timings[0],
"slow stage ({}) should take longer than fast stage ({})",
timings[1],
timings[0]
);
}
#[test]
fn test_invalid_stage_index() {
let deps = vec![StageDependency {
stage_idx: 0,
depends_on: vec![5], }];
let result = topological_sort(2, &deps);
assert!(result.is_err());
match result {
Err(OrchestratorError::InvalidStageIndex {
index: 5,
total_stages: 2,
}) => {} other => panic!("expected InvalidStageIndex, got {other:?}"),
}
}
#[test]
fn test_all_independent() {
let mut orch = ParallelOrchestrator::new();
orch.add_stage("X", &[]);
orch.add_stage("Y", &[]);
orch.add_stage("Z", &[]);
let schedule = orch.compute_schedule().expect("should succeed");
assert_eq!(
schedule.waves.len(),
1,
"all-independent stages fit in one wave"
);
assert_eq!(schedule.waves[0], vec![0, 1, 2]);
}
#[test]
fn test_wide_fan_out_fan_in() {
let mut orch = ParallelOrchestrator::new();
let root = orch.add_stage("root", &[]);
let mid: Vec<usize> = (0..4)
.map(|i| orch.add_stage(&format!("mid_{i}"), &[root]))
.collect();
let _sink = orch.add_stage("sink", &mid);
let schedule = orch.compute_schedule().expect("should succeed");
assert_eq!(schedule.waves.len(), 3);
assert_eq!(schedule.waves[0], vec![0]); assert_eq!(schedule.waves[1], vec![1, 2, 3, 4]); assert_eq!(schedule.waves[2], vec![5]); }
#[test]
fn test_diamond_execution_order() {
let mut orch = ParallelOrchestrator::new();
let a = orch.add_stage("A", &[]);
let b = orch.add_stage("B", &[a]);
let c = orch.add_stage("C", &[a]);
let _d = orch.add_stage("D", &[b, c]);
let log = Arc::new(std::sync::Mutex::new(Vec::new()));
let mut stages: Vec<Box<dyn SolverStage>> = vec![
Box::new(TestStage::new("A", 1.0, Arc::clone(&log))),
Box::new(TestStage::new("B", 2.0, Arc::clone(&log))),
Box::new(TestStage::new("C", 1.5, Arc::clone(&log))),
Box::new(TestStage::new("D", 3.0, Arc::clone(&log))),
];
orch.execute(&mut stages, 0.01)
.expect("execute should succeed");
let recorded = log.lock().expect("lock should not be poisoned");
let pos_a = recorded
.iter()
.position(|s| s == "A")
.expect("A should be in log");
let pos_b = recorded
.iter()
.position(|s| s == "B")
.expect("B should be in log");
let pos_c = recorded
.iter()
.position(|s| s == "C")
.expect("C should be in log");
let pos_d = recorded
.iter()
.position(|s| s == "D")
.expect("D should be in log");
assert!(pos_a < pos_b, "A must run before B");
assert!(pos_a < pos_c, "A must run before C");
assert!(pos_b < pos_d, "B must run before D");
assert!(pos_c < pos_d, "C must run before D");
}
#[test]
fn test_reset_timings() {
let mut orch = ParallelOrchestrator::new();
orch.add_stage("A", &[]);
let mut stages: Vec<Box<dyn SolverStage>> = vec![Box::new(TimedStage::new("A", 100_000))];
orch.execute(&mut stages, 0.01)
.expect("execute should succeed");
assert!(orch.total_time() > 0.0);
orch.reset_timings();
assert!(
orch.total_time().abs() < 1e-15,
"timings should be zero after reset"
);
}
#[test]
fn test_self_dependency_cycle() {
let deps = vec![StageDependency {
stage_idx: 0,
depends_on: vec![0],
}];
let result = topological_sort(1, &deps);
assert!(result.is_err(), "self-dependency should be a cycle");
match result {
Err(OrchestratorError::CycleDetected { .. }) => {} other => panic!("expected CycleDetected, got {other:?}"),
}
}
#[test]
fn test_pipeline_schedule_helpers() {
let schedule = PipelineSchedule {
waves: vec![vec![0, 1], vec![2], vec![3, 4, 5]],
};
assert_eq!(schedule.num_waves(), 3);
assert_eq!(schedule.num_stages(), 6);
}
#[test]
fn test_error_display() {
let err = OrchestratorError::CycleDetected {
description: "A -> B -> A".to_string(),
};
let msg = format!("{err}");
assert!(msg.contains("cycle"));
assert!(msg.contains("A -> B -> A"));
let err2 = OrchestratorError::InvalidStageIndex {
index: 10,
total_stages: 3,
};
let msg2 = format!("{err2}");
assert!(msg2.contains("10"));
assert!(msg2.contains("3"));
let err3 = OrchestratorError::ExecutionIndexOutOfRange {
index: 5,
stages_len: 2,
};
let msg3 = format!("{err3}");
assert!(msg3.contains("5"));
assert!(msg3.contains("2"));
}
#[test]
fn test_default_orchestrator() {
let orch = ParallelOrchestrator::default();
assert_eq!(orch.num_stages(), 0);
assert!(orch.timings().is_empty());
assert!(orch.total_time().abs() < 1e-15);
}
}