#![allow(unused_imports)]
use rustc_hash::FxHashMap;
use std::collections::VecDeque;
use crate::optimizer::{PassMetadata, ProgramPassRegistration};
#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
#[non_exhaustive]
pub enum PassSchedulingError {
#[error("optimizer pass `{pass}` requires unknown pass `{missing}`.")]
UnknownRequire {
pass: &'static str,
missing: &'static str,
},
#[error("optimizer pass dependency cycle among {pass_ids:?}. Fix: {fix}")]
Cycle {
pass_ids: Vec<&'static str>,
fix: &'static str,
},
#[error("duplicate pass id `{id}`.")]
DuplicateId {
id: &'static str,
},
}
pub fn schedule_passes(
passes: &[&'static ProgramPassRegistration],
) -> Result<Vec<&'static ProgramPassRegistration>, PassSchedulingError> {
let mut metadata = Vec::with_capacity(passes.len());
metadata.extend(passes.iter().map(|pass| pass.metadata));
let order = schedule_pass_metadata_indices(&metadata)?;
let mut scheduled = Vec::with_capacity(order.len());
scheduled.extend(order.into_iter().map(|index| passes[index]));
Ok(scheduled)
}
pub(super) fn schedule_pass_metadata_indices(
passes: &[PassMetadata],
) -> Result<Vec<usize>, PassSchedulingError> {
let n = passes.len();
let mut by_id = FxHashMap::with_capacity_and_hasher(n, Default::default());
for (i, pass) in passes.iter().enumerate() {
if by_id.insert(pass.name, i).is_some() {
return Err(PassSchedulingError::DuplicateId { id: pass.name });
}
}
let mut indegree = vec![0usize; n];
let mut dependents = Vec::with_capacity(n);
dependents.resize_with(n, Vec::new);
for (i, pass) in passes.iter().enumerate() {
for required in pass.requires {
if let Some(&req_i) = by_id.get(required) {
if !dependents[req_i].contains(&i) {
dependents[req_i].push(i);
indegree[i] += 1;
}
} else {
return Err(PassSchedulingError::UnknownRequire {
pass: pass.name,
missing: required,
});
}
}
}
for children in &mut dependents {
children.sort_unstable_by_key(|&child| passes[child].name);
}
let mut initial_ready = Vec::with_capacity(n);
initial_ready.extend(
indegree
.iter()
.enumerate()
.filter_map(|(id, °ree)| (degree == 0).then_some(id)),
);
initial_ready.sort_unstable_by_key(|&id| passes[id].name);
let mut ready = VecDeque::from(initial_ready);
let mut ordered = Vec::with_capacity(n);
while let Some(id) = ready.pop_front() {
ordered.push(id);
for &child in &dependents[id] {
indegree[child] -= 1;
if indegree[child] == 0 {
let child_name = passes[child].name;
let pos = ready
.iter()
.position(|&existing| child_name < passes[existing].name)
.unwrap_or(ready.len());
ready.insert(pos, child);
}
}
}
if ordered.len() != n {
let mut pass_ids = Vec::with_capacity(n - ordered.len());
pass_ids.extend(
indegree
.into_iter()
.enumerate()
.filter_map(|(id, degree)| (degree > 0).then_some(passes[id].name)),
);
pass_ids.sort_unstable();
return Err(PassSchedulingError::Cycle {
pass_ids,
fix: "Break the cycle by removing one of these `requires` entries.",
});
}
Ok(ordered)
}