#![allow(clippy::expect_used)]
use crate::ir::{BufferDecl, Expr, Node};
use crate::ir_inner::model::program::Program;
use crate::optimizer::{registered_passes, OptimizerError, ProgramPassKind};
use rustc_hash::{FxHashMap, FxHashSet};
use std::sync::OnceLock;
pub(super) const DEFAULT_MAX_ITERATIONS: usize = 50;
pub struct PassScheduler {
passes: Vec<ProgramPassKind>,
pass_index: FxHashMap<&'static str, usize>,
execution_order: Vec<usize>,
requirements_prevalidated: bool,
max_iterations: usize,
invalidation_adjacency_cache: OnceLock<Vec<u32>>,
invalidation_closure_cache: OnceLock<FxHashMap<&'static str, FxHashSet<&'static str>>>,
dirty_trigger_index_cache: OnceLock<FxHashMap<&'static str, Vec<usize>>>,
initial_dirty_flags_cache: OnceLock<Vec<bool>>,
enforce_cost_monotone: bool,
enforce_effect_handlers: bool,
enforce_linear_types: bool,
enforce_shape_predicates: bool,
}
#[derive(Debug)]
pub struct OptimizerRunReport {
pub program: Program,
pub passes: Vec<PassRunMetric>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PassRunMetric {
pub iteration: usize,
pub pass: &'static str,
pub ran: bool,
pub changed: bool,
pub decision: PassRunDecision,
pub refusal_kind: Option<&'static str>,
pub effect_bits_before: u32,
pub effect_bits_after: u32,
pub linear_type_violations_before: usize,
pub linear_type_violations_after: usize,
pub shape_predicate_violations_before: usize,
pub shape_predicate_violations_after: usize,
pub runtime_ns: u128,
pub nodes_before: usize,
pub nodes_after: usize,
pub static_storage_bytes_before: u64,
pub static_storage_bytes_after: u64,
pub instruction_count_before: u64,
pub instruction_count_after: u64,
pub memory_op_count_before: u64,
pub memory_op_count_after: u64,
pub atomic_op_count_before: u64,
pub atomic_op_count_after: u64,
pub control_flow_count_before: u64,
pub control_flow_count_after: u64,
pub register_pressure_before: u32,
pub register_pressure_after: u32,
pub ir_heap_allocations_before: usize,
pub ir_heap_allocations_after: usize,
pub ir_heap_bytes_before: usize,
pub ir_heap_bytes_after: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PassRunDecision {
CleanSkipped,
AnalysisSkipped,
RanUnchanged,
Changed,
CostReverted,
EffectReverted,
LinearTypeReverted,
ShapePredicateReverted,
Refused,
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
struct IrAllocationEstimate {
allocations: usize,
bytes: usize,
}
impl IrAllocationEstimate {
fn add_container<T>(&mut self, len: usize) {
self.allocations = self.allocations.saturating_add(1);
self.bytes = self
.bytes
.saturating_add(len.saturating_mul(std::mem::size_of::<T>()));
}
fn add_box<T>(&mut self) {
self.allocations = self.allocations.saturating_add(1);
self.bytes = self.bytes.saturating_add(std::mem::size_of::<T>());
}
}
fn estimate_ir_allocations(program: &Program) -> IrAllocationEstimate {
let mut estimate = IrAllocationEstimate::default();
estimate.add_container::<BufferDecl>(program.buffers().len());
estimate.add_container::<Node>(program.entry().len());
estimate.allocations = estimate.allocations.saturating_add(2);
for node in program.entry() {
estimate_node_allocations(node, &mut estimate);
}
estimate
}
fn estimate_node_allocations(node: &Node, estimate: &mut IrAllocationEstimate) {
match node {
Node::Let { value, .. } | Node::Assign { value, .. } => {
estimate_expr_allocations(value, estimate);
}
Node::Store { index, value, .. } => {
estimate_expr_allocations(index, estimate);
estimate_expr_allocations(value, estimate);
}
Node::If {
cond,
then,
otherwise,
} => {
estimate_expr_allocations(cond, estimate);
estimate.add_container::<Node>(then.len());
estimate.add_container::<Node>(otherwise.len());
for node in then.iter().chain(otherwise.iter()) {
estimate_node_allocations(node, estimate);
}
}
Node::Loop { from, to, body, .. } => {
estimate_expr_allocations(from, estimate);
estimate_expr_allocations(to, estimate);
estimate.add_container::<Node>(body.len());
for node in body {
estimate_node_allocations(node, estimate);
}
}
Node::AsyncLoad { offset, size, .. } | Node::AsyncStore { offset, size, .. } => {
estimate.add_box::<Expr>();
estimate.add_box::<Expr>();
estimate_expr_allocations(offset, estimate);
estimate_expr_allocations(size, estimate);
}
Node::Trap { address, .. } => {
estimate.add_box::<Expr>();
estimate_expr_allocations(address, estimate);
}
Node::Block(body) => {
estimate.add_container::<Node>(body.len());
for node in body {
estimate_node_allocations(node, estimate);
}
}
Node::Region { body, .. } => {
estimate.add_container::<Node>(body.len());
for node in body.iter() {
estimate_node_allocations(node, estimate);
}
}
Node::Opaque(_) => {
estimate.allocations = estimate.allocations.saturating_add(1);
}
Node::IndirectDispatch { .. }
| Node::AllReduce { .. }
| Node::AllGather { .. }
| Node::ReduceScatter { .. }
| Node::Broadcast { .. }
| Node::AsyncWait { .. }
| Node::Resume { .. }
| Node::Return
| Node::Barrier { .. } => {}
}
}
fn estimate_expr_allocations(expr: &Expr, estimate: &mut IrAllocationEstimate) {
match expr {
Expr::Load { index, .. } => {
estimate.add_box::<Expr>();
estimate_expr_allocations(index, estimate);
}
Expr::BinOp { left, right, .. } => {
estimate.add_box::<Expr>();
estimate.add_box::<Expr>();
estimate_expr_allocations(left, estimate);
estimate_expr_allocations(right, estimate);
}
Expr::UnOp { operand, .. }
| Expr::Cast { value: operand, .. }
| Expr::SubgroupBallot { cond: operand }
| Expr::SubgroupAdd { value: operand } => {
estimate.add_box::<Expr>();
estimate_expr_allocations(operand, estimate);
}
Expr::Call { args, .. } => {
estimate.add_container::<Expr>(args.len());
for arg in args {
estimate_expr_allocations(arg, estimate);
}
}
Expr::Select {
cond,
true_val,
false_val,
} => {
estimate.add_box::<Expr>();
estimate.add_box::<Expr>();
estimate.add_box::<Expr>();
estimate_expr_allocations(cond, estimate);
estimate_expr_allocations(true_val, estimate);
estimate_expr_allocations(false_val, estimate);
}
Expr::Fma { a, b, c } => {
estimate.add_box::<Expr>();
estimate.add_box::<Expr>();
estimate.add_box::<Expr>();
estimate_expr_allocations(a, estimate);
estimate_expr_allocations(b, estimate);
estimate_expr_allocations(c, estimate);
}
Expr::Atomic {
index,
expected,
value,
..
} => {
estimate.add_box::<Expr>();
estimate.add_box::<Expr>();
estimate_expr_allocations(index, estimate);
if let Some(expected) = expected {
estimate.add_box::<Expr>();
estimate_expr_allocations(expected, estimate);
}
estimate_expr_allocations(value, estimate);
}
Expr::SubgroupShuffle { value, lane } => {
estimate.add_box::<Expr>();
estimate.add_box::<Expr>();
estimate_expr_allocations(value, estimate);
estimate_expr_allocations(lane, estimate);
}
Expr::Opaque(_) => {
estimate.allocations = estimate.allocations.saturating_add(1);
}
Expr::LitU32(_)
| Expr::LitI32(_)
| Expr::LitF32(_)
| Expr::LitBool(_)
| Expr::Var(_)
| Expr::BufLen { .. }
| Expr::InvocationId { .. }
| Expr::WorkgroupId { .. }
| Expr::LocalId { .. }
| Expr::SubgroupLocalId
| Expr::SubgroupSize => {}
}
}
impl PassScheduler {
pub fn try_default() -> Result<Self, OptimizerError> {
let passes = registered_passes()?;
let pass_index = passes
.iter()
.enumerate()
.map(|(i, pass)| (pass.metadata().name, i))
.collect();
let execution_order = (0..passes.len()).collect();
Ok(Self {
passes,
pass_index,
execution_order,
requirements_prevalidated: true,
max_iterations: DEFAULT_MAX_ITERATIONS,
invalidation_adjacency_cache: OnceLock::new(),
invalidation_closure_cache: OnceLock::new(),
dirty_trigger_index_cache: OnceLock::new(),
initial_dirty_flags_cache: OnceLock::new(),
enforce_cost_monotone: false,
enforce_effect_handlers: false,
enforce_linear_types: false,
enforce_shape_predicates: false,
})
}
#[must_use]
pub fn with_cost_monotone_enforcement(mut self, enforce: bool) -> Self {
self.enforce_cost_monotone = enforce;
self
}
#[must_use]
pub fn cost_monotone_enforcement(&self) -> bool {
self.enforce_cost_monotone
}
#[must_use]
pub fn with_effect_handler_enforcement(mut self, enforce: bool) -> Self {
self.enforce_effect_handlers = enforce;
self
}
#[must_use]
pub fn effect_handler_enforcement(&self) -> bool {
self.enforce_effect_handlers
}
#[must_use]
pub fn with_linear_type_enforcement(mut self, enforce: bool) -> Self {
self.enforce_linear_types = enforce;
self
}
#[must_use]
pub fn linear_type_enforcement(&self) -> bool {
self.enforce_linear_types
}
#[must_use]
pub fn with_shape_predicate_enforcement(mut self, enforce: bool) -> Self {
self.enforce_shape_predicates = enforce;
self
}
#[must_use]
pub fn shape_predicate_enforcement(&self) -> bool {
self.enforce_shape_predicates
}
}
impl Default for PassScheduler {
fn default() -> Self {
match Self::try_default() {
Ok(scheduler) => scheduler,
Err(error) => {
tracing::error!(
error = %error,
"Fix: built-in optimizer pass metadata is invalid; defaulting to an empty scheduler."
);
Self::with_passes(Vec::new())
}
}
}
}
mod topo;
mod queries;
mod run;
pub(crate) use topo::schedule_pass_metadata_indices;
pub use topo::{schedule_passes, PassSchedulingError};
#[cfg(test)]
mod tests;