#![allow(clippy::expect_used)]
use crate::ir::{BufferDecl, Expr, Node};
use crate::ir_inner::model::program::Program;
use crate::optimizer::{registered_passes, OptimizerError, ProgramPassKind};
use rustc_hash::{FxHashMap, FxHashSet};
use std::sync::OnceLock;
pub(super) const DEFAULT_MAX_ITERATIONS: usize = 50;
pub struct PassScheduler {
passes: Vec<ProgramPassKind>,
pass_index: FxHashMap<&'static str, usize>,
research_traces: FxHashMap<&'static str, PassResearchTrace>,
execution_order: Vec<usize>,
requirements_prevalidated: bool,
max_iterations: usize,
invalidation_adjacency_cache: OnceLock<Vec<u32>>,
invalidation_closure_cache: OnceLock<FxHashMap<&'static str, FxHashSet<&'static str>>>,
dirty_trigger_index_cache: OnceLock<FxHashMap<&'static str, Vec<usize>>>,
initial_dirty_flags_cache: OnceLock<Vec<bool>>,
enforce_cost_monotone: bool,
enforce_effect_handlers: bool,
enforce_linear_types: bool,
enforce_shape_predicates: bool,
}
pub(crate) const PASS_RESEARCH_TRACE_SCHEMA_VERSION: u32 = 1;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct PassResearchTrace {
pub schema_version: u32,
pub research_basis_key: &'static str,
pub baseline_id: &'static str,
pub proof_artifact_id: &'static str,
}
impl PassResearchTrace {
pub fn try_new(
research_basis_key: &'static str,
baseline_id: &'static str,
proof_artifact_id: &'static str,
) -> Result<Self, PassResearchTraceError> {
if research_basis_key.trim().is_empty() {
return Err(PassResearchTraceError::MissingResearchBasisKey);
}
if baseline_id.trim().is_empty() {
return Err(PassResearchTraceError::MissingBaselineId);
}
if proof_artifact_id.trim().is_empty() {
return Err(PassResearchTraceError::MissingProofArtifactId);
}
Ok(Self {
schema_version: PASS_RESEARCH_TRACE_SCHEMA_VERSION,
research_basis_key,
baseline_id,
proof_artifact_id,
})
}
#[must_use]
pub fn is_complete(&self) -> bool {
self.schema_version == PASS_RESEARCH_TRACE_SCHEMA_VERSION
&& !self.research_basis_key.trim().is_empty()
&& !self.baseline_id.trim().is_empty()
&& !self.proof_artifact_id.trim().is_empty()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum PassResearchTraceError {
MissingResearchBasisKey,
MissingBaselineId,
MissingProofArtifactId,
}
impl std::fmt::Display for PassResearchTraceError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::MissingResearchBasisKey => {
f.write_str("pass research trace is missing research_basis_key")
}
Self::MissingBaselineId => f.write_str("pass research trace is missing baseline_id"),
Self::MissingProofArtifactId => {
f.write_str("pass research trace is missing proof_artifact_id")
}
}
}
}
impl std::error::Error for PassResearchTraceError {}
#[derive(Debug)]
pub struct OptimizerRunReport {
pub program: Program,
pub passes: Vec<PassRunMetric>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PassRunMetric {
pub iteration: usize,
pub pass: &'static str,
pub research_trace: Option<PassResearchTrace>,
pub ran: bool,
pub changed: bool,
pub decision: PassRunDecision,
pub refusal_kind: Option<&'static str>,
pub required_analyses: &'static [&'static str],
pub declared_invalidations: &'static [&'static str],
pub fact_substrate_reused: bool,
pub fact_substrate_recomputed: bool,
pub fact_substrate_invalidated: bool,
pub effect_bits_before: u32,
pub effect_bits_after: u32,
pub linear_type_violations_before: usize,
pub linear_type_violations_after: usize,
pub shape_predicate_violations_before: usize,
pub shape_predicate_violations_after: usize,
pub runtime_ns: u128,
pub nodes_before: usize,
pub nodes_after: usize,
pub static_storage_bytes_before: u64,
pub static_storage_bytes_after: u64,
pub instruction_count_before: u64,
pub instruction_count_after: u64,
pub memory_op_count_before: u64,
pub memory_op_count_after: u64,
pub atomic_op_count_before: u64,
pub atomic_op_count_after: u64,
pub control_flow_count_before: u64,
pub control_flow_count_after: u64,
pub register_pressure_before: u32,
pub register_pressure_after: u32,
pub ir_heap_allocations_before: usize,
pub ir_heap_allocations_after: usize,
pub ir_heap_bytes_before: usize,
pub ir_heap_bytes_after: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PassRunDecision {
CleanSkipped,
AnalysisSkipped,
RanUnchanged,
Changed,
CostReverted,
EffectReverted,
LinearTypeReverted,
ShapePredicateReverted,
Refused,
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
struct IrAllocationEstimate {
allocations: usize,
bytes: usize,
}
impl IrAllocationEstimate {
fn add_container<T>(&mut self, len: usize) {
self.allocations = self.allocations.saturating_add(1);
self.bytes = self
.bytes
.saturating_add(len.saturating_mul(std::mem::size_of::<T>()));
}
fn add_box<T>(&mut self) {
self.allocations = self.allocations.saturating_add(1);
self.bytes = self.bytes.saturating_add(std::mem::size_of::<T>());
}
}
fn estimate_ir_allocations(program: &Program) -> IrAllocationEstimate {
let mut estimate = IrAllocationEstimate::default();
estimate.add_container::<BufferDecl>(program.buffers().len());
estimate.add_container::<Node>(program.entry().len());
estimate.allocations = estimate.allocations.saturating_add(2);
for node in program.entry() {
estimate_node_allocations(node, &mut estimate);
}
estimate
}
fn estimate_node_allocations(node: &Node, estimate: &mut IrAllocationEstimate) {
match node {
Node::Let { value, .. } | Node::Assign { value, .. } => {
estimate_expr_allocations(value, estimate);
}
Node::Store { index, value, .. } => {
estimate_expr_allocations(index, estimate);
estimate_expr_allocations(value, estimate);
}
Node::If {
cond,
then,
otherwise,
} => {
estimate_expr_allocations(cond, estimate);
estimate.add_container::<Node>(then.len());
estimate.add_container::<Node>(otherwise.len());
for node in then.iter().chain(otherwise.iter()) {
estimate_node_allocations(node, estimate);
}
}
Node::Loop { from, to, body, .. } => {
estimate_expr_allocations(from, estimate);
estimate_expr_allocations(to, estimate);
estimate.add_container::<Node>(body.len());
for node in body {
estimate_node_allocations(node, estimate);
}
}
Node::AsyncLoad { offset, size, .. } | Node::AsyncStore { offset, size, .. } => {
estimate.add_box::<Expr>();
estimate.add_box::<Expr>();
estimate_expr_allocations(offset, estimate);
estimate_expr_allocations(size, estimate);
}
Node::Trap { address, .. } => {
estimate.add_box::<Expr>();
estimate_expr_allocations(address, estimate);
}
Node::Block(body) => {
estimate.add_container::<Node>(body.len());
for node in body {
estimate_node_allocations(node, estimate);
}
}
Node::Region { body, .. } => {
estimate.add_container::<Node>(body.len());
for node in body.iter() {
estimate_node_allocations(node, estimate);
}
}
Node::Opaque(_) => {
estimate.allocations = estimate.allocations.saturating_add(1);
}
Node::IndirectDispatch { .. }
| Node::AllReduce { .. }
| Node::AllGather { .. }
| Node::ReduceScatter { .. }
| Node::Broadcast { .. }
| Node::AsyncWait { .. }
| Node::Resume { .. }
| Node::Return
| Node::Barrier { .. } => {}
}
}
fn estimate_expr_allocations(expr: &Expr, estimate: &mut IrAllocationEstimate) {
match expr {
Expr::Load { index, .. } => {
estimate.add_box::<Expr>();
estimate_expr_allocations(index, estimate);
}
Expr::BinOp { left, right, .. } => {
estimate.add_box::<Expr>();
estimate.add_box::<Expr>();
estimate_expr_allocations(left, estimate);
estimate_expr_allocations(right, estimate);
}
Expr::UnOp { operand, .. }
| Expr::Cast { value: operand, .. }
| Expr::SubgroupBallot { cond: operand }
| Expr::SubgroupAdd { value: operand } => {
estimate.add_box::<Expr>();
estimate_expr_allocations(operand, estimate);
}
Expr::Call { args, .. } => {
estimate.add_container::<Expr>(args.len());
for arg in args {
estimate_expr_allocations(arg, estimate);
}
}
Expr::Select {
cond,
true_val,
false_val,
} => {
estimate.add_box::<Expr>();
estimate.add_box::<Expr>();
estimate.add_box::<Expr>();
estimate_expr_allocations(cond, estimate);
estimate_expr_allocations(true_val, estimate);
estimate_expr_allocations(false_val, estimate);
}
Expr::Fma { a, b, c } => {
estimate.add_box::<Expr>();
estimate.add_box::<Expr>();
estimate.add_box::<Expr>();
estimate_expr_allocations(a, estimate);
estimate_expr_allocations(b, estimate);
estimate_expr_allocations(c, estimate);
}
Expr::Atomic {
index,
expected,
value,
..
} => {
estimate.add_box::<Expr>();
estimate.add_box::<Expr>();
estimate_expr_allocations(index, estimate);
if let Some(expected) = expected {
estimate.add_box::<Expr>();
estimate_expr_allocations(expected, estimate);
}
estimate_expr_allocations(value, estimate);
}
Expr::SubgroupShuffle { value, lane } => {
estimate.add_box::<Expr>();
estimate.add_box::<Expr>();
estimate_expr_allocations(value, estimate);
estimate_expr_allocations(lane, estimate);
}
Expr::Opaque(_) => {
estimate.allocations = estimate.allocations.saturating_add(1);
}
Expr::LitU32(_)
| Expr::LitI32(_)
| Expr::LitF32(_)
| Expr::LitBool(_)
| Expr::Var(_)
| Expr::BufLen { .. }
| Expr::InvocationId { .. }
| Expr::WorkgroupId { .. }
| Expr::LocalId { .. }
| Expr::SubgroupLocalId
| Expr::SubgroupSize => {}
}
}
impl PassScheduler {
pub fn try_default() -> Result<Self, OptimizerError> {
let passes = registered_passes()?;
let pass_index = passes
.iter()
.enumerate()
.map(|(i, pass)| (pass.metadata().name, i))
.collect();
let execution_order = (0..passes.len()).collect();
Ok(Self {
passes,
pass_index,
research_traces: FxHashMap::default(),
execution_order,
requirements_prevalidated: true,
max_iterations: DEFAULT_MAX_ITERATIONS,
invalidation_adjacency_cache: OnceLock::new(),
invalidation_closure_cache: OnceLock::new(),
dirty_trigger_index_cache: OnceLock::new(),
initial_dirty_flags_cache: OnceLock::new(),
enforce_cost_monotone: false,
enforce_effect_handlers: false,
enforce_linear_types: false,
enforce_shape_predicates: false,
})
}
#[must_use]
pub fn with_cost_monotone_enforcement(mut self, enforce: bool) -> Self {
self.enforce_cost_monotone = enforce;
self
}
#[must_use]
pub fn cost_monotone_enforcement(&self) -> bool {
self.enforce_cost_monotone
}
#[must_use]
pub fn with_effect_handler_enforcement(mut self, enforce: bool) -> Self {
self.enforce_effect_handlers = enforce;
self
}
#[must_use]
pub fn effect_handler_enforcement(&self) -> bool {
self.enforce_effect_handlers
}
#[must_use]
pub fn with_linear_type_enforcement(mut self, enforce: bool) -> Self {
self.enforce_linear_types = enforce;
self
}
#[must_use]
pub fn linear_type_enforcement(&self) -> bool {
self.enforce_linear_types
}
#[must_use]
pub fn with_shape_predicate_enforcement(mut self, enforce: bool) -> Self {
self.enforce_shape_predicates = enforce;
self
}
#[must_use]
pub fn shape_predicate_enforcement(&self) -> bool {
self.enforce_shape_predicates
}
}
impl Default for PassScheduler {
fn default() -> Self {
match Self::try_default() {
Ok(scheduler) => scheduler,
Err(error) => {
tracing::error!(
error = %error,
"Fix: built-in optimizer pass metadata is invalid; defaulting to an empty scheduler."
);
Self::with_passes(Vec::new())
}
}
}
}
mod topo;
mod queries;
mod run;
pub(crate) use topo::schedule_pass_metadata_indices;
pub use topo::{schedule_passes, PassSchedulingError};
#[cfg(test)]
mod research_trace_contract_tests {
use super::{PassResearchTrace, PassResearchTraceError, PASS_RESEARCH_TRACE_SCHEMA_VERSION};
#[test]
fn pass_research_trace_requires_all_identifiers() {
assert_eq!(
PassResearchTrace::try_new("", "baseline/vyre-default", "artifact/optimizer-proof")
.unwrap_err(),
PassResearchTraceError::MissingResearchBasisKey
);
assert_eq!(
PassResearchTrace::try_new(
"research/mlir-pass-replay",
"",
"artifact/optimizer-proof"
)
.unwrap_err(),
PassResearchTraceError::MissingBaselineId
);
assert_eq!(
PassResearchTrace::try_new("research/mlir-pass-replay", "baseline/vyre-default", "")
.unwrap_err(),
PassResearchTraceError::MissingProofArtifactId
);
}
#[test]
fn scheduler_metric_construction_carries_research_trace() {
let trace = PassResearchTrace::try_new(
"research/mlir-pass-replay",
"baseline/vyre-default",
"artifact/vx-336-optimizer-proof",
)
.unwrap();
assert_eq!(trace.schema_version, PASS_RESEARCH_TRACE_SCHEMA_VERSION);
assert!(trace.is_complete());
let run_source = include_str!("run.rs");
assert!(run_source.contains("research_trace: self.research_trace_for(metadata.name)"));
}
}
#[cfg(test)]
mod tests;