#![allow(clippy::expect_used)]
use crate::ir::{BufferDecl, Expr, Node};
use crate::ir_inner::model::program::Program;
use crate::optimizer::{registered_passes, OptimizerError, ProgramPassKind};
use rustc_hash::FxHashMap;
#[cfg(test)]
use rustc_hash::FxHashSet;
use std::sync::OnceLock;
pub(super) const DEFAULT_MAX_ITERATIONS: usize = 50;
pub struct PassScheduler {
passes: Vec<ProgramPassKind>,
pass_index: FxHashMap<&'static str, usize>,
execution_order: Vec<usize>,
max_iterations: usize,
invalidation_adjacency_cache: OnceLock<Vec<u32>>,
invalidation_closure_cache: OnceLock<Vec<u32>>,
enforce_cost_monotone: bool,
}
#[derive(Debug)]
pub struct OptimizerRunReport {
pub program: Program,
pub passes: Vec<PassRunMetric>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PassRunMetric {
pub iteration: usize,
pub pass: &'static str,
pub ran: bool,
pub changed: bool,
pub runtime_ns: u128,
pub nodes_before: usize,
pub nodes_after: usize,
pub static_storage_bytes_before: u64,
pub static_storage_bytes_after: u64,
pub instruction_count_before: u64,
pub instruction_count_after: u64,
pub memory_op_count_before: u64,
pub memory_op_count_after: u64,
pub atomic_op_count_before: u64,
pub atomic_op_count_after: u64,
pub control_flow_count_before: u64,
pub control_flow_count_after: u64,
pub register_pressure_before: u32,
pub register_pressure_after: u32,
pub ir_heap_allocations_before: usize,
pub ir_heap_allocations_after: usize,
pub ir_heap_bytes_before: usize,
pub ir_heap_bytes_after: usize,
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
struct IrAllocationEstimate {
allocations: usize,
bytes: usize,
}
impl IrAllocationEstimate {
fn add_container<T>(&mut self, len: usize) {
self.allocations = self.allocations.saturating_add(1);
self.bytes = self
.bytes
.saturating_add(len.saturating_mul(std::mem::size_of::<T>()));
}
fn add_box<T>(&mut self) {
self.allocations = self.allocations.saturating_add(1);
self.bytes = self.bytes.saturating_add(std::mem::size_of::<T>());
}
}
fn estimate_ir_allocations(program: &Program) -> IrAllocationEstimate {
let mut estimate = IrAllocationEstimate::default();
estimate.add_container::<BufferDecl>(program.buffers().len());
estimate.add_container::<Node>(program.entry().len());
estimate.allocations = estimate.allocations.saturating_add(2);
for node in program.entry() {
estimate_node_allocations(node, &mut estimate);
}
estimate
}
fn estimate_node_allocations(node: &Node, estimate: &mut IrAllocationEstimate) {
match node {
Node::Let { value, .. } | Node::Assign { value, .. } => {
estimate_expr_allocations(value, estimate);
}
Node::Store { index, value, .. } => {
estimate_expr_allocations(index, estimate);
estimate_expr_allocations(value, estimate);
}
Node::If {
cond,
then,
otherwise,
} => {
estimate_expr_allocations(cond, estimate);
estimate.add_container::<Node>(then.len());
estimate.add_container::<Node>(otherwise.len());
for node in then.iter().chain(otherwise.iter()) {
estimate_node_allocations(node, estimate);
}
}
Node::Loop { from, to, body, .. } => {
estimate_expr_allocations(from, estimate);
estimate_expr_allocations(to, estimate);
estimate.add_container::<Node>(body.len());
for node in body {
estimate_node_allocations(node, estimate);
}
}
Node::AsyncLoad { offset, size, .. } | Node::AsyncStore { offset, size, .. } => {
estimate.add_box::<Expr>();
estimate.add_box::<Expr>();
estimate_expr_allocations(offset, estimate);
estimate_expr_allocations(size, estimate);
}
Node::Trap { address, .. } => {
estimate.add_box::<Expr>();
estimate_expr_allocations(address, estimate);
}
Node::Block(body) => {
estimate.add_container::<Node>(body.len());
for node in body {
estimate_node_allocations(node, estimate);
}
}
Node::Region { body, .. } => {
estimate.add_container::<Node>(body.len());
for node in body.iter() {
estimate_node_allocations(node, estimate);
}
}
Node::Opaque(_) => {
estimate.allocations = estimate.allocations.saturating_add(1);
}
Node::IndirectDispatch { .. }
| Node::AsyncWait { .. }
| Node::Resume { .. }
| Node::Return
| Node::Barrier { .. } => {}
}
}
fn estimate_expr_allocations(expr: &Expr, estimate: &mut IrAllocationEstimate) {
match expr {
Expr::Load { index, .. } => {
estimate.add_box::<Expr>();
estimate_expr_allocations(index, estimate);
}
Expr::BinOp { left, right, .. } => {
estimate.add_box::<Expr>();
estimate.add_box::<Expr>();
estimate_expr_allocations(left, estimate);
estimate_expr_allocations(right, estimate);
}
Expr::UnOp { operand, .. }
| Expr::Cast { value: operand, .. }
| Expr::SubgroupBallot { cond: operand }
| Expr::SubgroupAdd { value: operand } => {
estimate.add_box::<Expr>();
estimate_expr_allocations(operand, estimate);
}
Expr::Call { args, .. } => {
estimate.add_container::<Expr>(args.len());
for arg in args {
estimate_expr_allocations(arg, estimate);
}
}
Expr::Select {
cond,
true_val,
false_val,
} => {
estimate.add_box::<Expr>();
estimate.add_box::<Expr>();
estimate.add_box::<Expr>();
estimate_expr_allocations(cond, estimate);
estimate_expr_allocations(true_val, estimate);
estimate_expr_allocations(false_val, estimate);
}
Expr::Fma { a, b, c } => {
estimate.add_box::<Expr>();
estimate.add_box::<Expr>();
estimate.add_box::<Expr>();
estimate_expr_allocations(a, estimate);
estimate_expr_allocations(b, estimate);
estimate_expr_allocations(c, estimate);
}
Expr::Atomic {
index,
expected,
value,
..
} => {
estimate.add_box::<Expr>();
estimate.add_box::<Expr>();
estimate_expr_allocations(index, estimate);
if let Some(expected) = expected {
estimate.add_box::<Expr>();
estimate_expr_allocations(expected, estimate);
}
estimate_expr_allocations(value, estimate);
}
Expr::SubgroupShuffle { value, lane } => {
estimate.add_box::<Expr>();
estimate.add_box::<Expr>();
estimate_expr_allocations(value, estimate);
estimate_expr_allocations(lane, estimate);
}
Expr::Opaque(_) => {
estimate.allocations = estimate.allocations.saturating_add(1);
}
Expr::LitU32(_)
| Expr::LitI32(_)
| Expr::LitF32(_)
| Expr::LitBool(_)
| Expr::Var(_)
| Expr::BufLen { .. }
| Expr::InvocationId { .. }
| Expr::WorkgroupId { .. }
| Expr::LocalId { .. }
| Expr::SubgroupLocalId
| Expr::SubgroupSize => {}
}
}
impl PassScheduler {
pub fn try_default() -> Result<Self, OptimizerError> {
let passes = registered_passes()?;
let pass_index = passes
.iter()
.enumerate()
.map(|(i, pass)| (pass.metadata().name, i))
.collect();
let execution_order = (0..passes.len()).collect();
Ok(Self {
passes,
pass_index,
execution_order,
max_iterations: DEFAULT_MAX_ITERATIONS,
invalidation_adjacency_cache: OnceLock::new(),
invalidation_closure_cache: OnceLock::new(),
enforce_cost_monotone: false,
})
}
#[must_use]
pub fn with_cost_monotone_enforcement(mut self, enforce: bool) -> Self {
self.enforce_cost_monotone = enforce;
self
}
#[must_use]
pub fn cost_monotone_enforcement(&self) -> bool {
self.enforce_cost_monotone
}
}
impl Default for PassScheduler {
fn default() -> Self {
Self::try_default().unwrap_or_else(|error| {
panic!(
"Fix: built-in optimizer pass metadata is invalid; this is a vyre-foundation bug: {error}"
)
})
}
}
mod topo;
mod queries;
mod run;
pub use topo::{schedule_passes, PassSchedulingError};
#[cfg(test)]
mod tests;