use crate::ir_inner::model::program::Program;
use crate::optimizer::{
registered_passes_for_profile, CostModelFamily, OptimizerError, OptimizerProfile, PassPhase,
PassScheduler, ProgramPassKind,
};
use std::sync::OnceLock;
static PHASE2_SCHEDULER: OnceLock<Result<PassScheduler, OptimizerError>> = OnceLock::new();
static PHASE4_SCHEDULER: OnceLock<Result<PassScheduler, OptimizerError>> = OnceLock::new();
const PHASE2_SELECTION: &[PassPhase] =
&[PassPhase::ScalarAlgebra, PassPhase::Loop, PassPhase::Sync];
const PHASE4_SELECTION: &[PassPhase] = &[
PassPhase::ScalarAlgebra,
PassPhase::Canonicalization,
PassPhase::Cleanup,
PassPhase::FusionCse,
PassPhase::Memory,
];
fn pre_lowering_scheduler(phases: &'static [PassPhase]) -> Result<PassScheduler, OptimizerError> {
let passes: Vec<ProgramPassKind> = registered_passes_for_profile(OptimizerProfile::Release)?
.into_iter()
.filter(|pass| {
let metadata = pass.metadata();
phases.contains(&metadata.phase)
&& metadata.cost_model_family != CostModelFamily::Megakernel
&& metadata.cost_model_family != CostModelFamily::Dataflow
&& !metadata.invalidates.contains(&"buffer_layout")
})
.collect();
Ok(PassScheduler::try_with_passes(passes)?
.with_cost_monotone_enforcement(true)
.with_effect_handler_enforcement(true)
.with_linear_type_enforcement(true)
.with_shape_predicate_enforcement(true))
}
#[must_use]
#[inline]
pub fn optimize(program: Program) -> Program {
use crate::optimizer::passes::algebraic::canonicalize_engine;
use crate::optimizer::passes::algebraic::const_fold::ConstFold;
use crate::optimizer::passes::cleanup::region_inline_engine;
use crate::optimizer::passes::cleanup::rematerialize_cheap_let::RematerializeCheapLetPass;
let prepared =
region_inline_engine::run(canonicalize_engine::run(program)).reconcile_runnable_top_level();
let phase2_output = {
let phase2_scheduler =
PHASE2_SCHEDULER.get_or_init(|| pre_lowering_scheduler(PHASE2_SELECTION));
let phase2_input = prepared;
match phase2_scheduler {
Ok(phase2_scheduler) => match phase2_scheduler.run(phase2_input.clone()) {
Ok(output) => output,
Err(error) => {
tracing::error!(
error = %error,
"pre-lowering phase 2 did not converge. Fix: inspect the pass set for oscillating rewrites."
);
phase2_input
}
},
Err(error) => {
tracing::error!(
error = %error,
"pre-lowering phase 2 scheduler construction failed. Fix: repair optimizer pass metadata."
);
phase2_input
}
}
};
let cleaned = canonicalize_engine::run(region_inline_engine::run(
crate::optimizer::passes::fusion_cse::dce::engine::dce(
crate::optimizer::passes::fusion_cse::cse::engine::cse(phase2_output),
),
));
let phase4 = {
let scheduler = PHASE4_SCHEDULER.get_or_init(|| pre_lowering_scheduler(PHASE4_SELECTION));
let phase4_input = cleaned;
match scheduler {
Ok(scheduler) => match scheduler.run(phase4_input.clone()) {
Ok(output) => output,
Err(error) => {
tracing::error!(
error = %error,
"pre-lowering phase 4 did not converge after 50 iterations. Fix: inspect the phase for oscillating rewrites or raise the cap only with a convergence certificate."
);
phase4_input
}
},
Err(error) => {
tracing::error!(
error = %error,
"pre-lowering phase 4 scheduler construction failed. Fix: repair optimizer pass metadata."
);
phase4_input
}
}
};
let rematerialized = RematerializeCheapLetPass::transform(phase4).program;
let folded = ConstFold::transform(canonicalize_engine::run(rematerialized)).program;
let cleaned = canonicalize_engine::run(region_inline_engine::run(
crate::optimizer::passes::fusion_cse::dce::engine::dce(
crate::optimizer::passes::fusion_cse::cse::engine::cse(folded),
),
));
let refolded = ConstFold::transform(cleaned).program;
let stabilized = canonicalize_engine::run(region_inline_engine::run(
crate::optimizer::passes::fusion_cse::dce::engine::dce(
crate::optimizer::passes::fusion_cse::cse::engine::cse(refolded),
),
));
stabilized.reconcile_runnable_top_level()
}
#[cfg(test)]
mod tests {
use super::{optimize, pre_lowering_scheduler, PHASE2_SELECTION, PHASE4_SELECTION};
use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
use crate::optimizer::{registered_passes_for_profile, OptimizerProfile};
#[test]
fn optimize_preserves_top_level_region_wrap_after_inline() {
let program = Program::wrapped(
vec![BufferDecl::output("out", 0, DataType::U32).with_count(1)],
[1, 1, 1],
vec![Node::store("out", Expr::u32(0), Expr::u32(7))],
);
assert!(program.is_top_level_region_wrapped());
let optimized = optimize(program);
assert!(
optimized.is_top_level_region_wrapped(),
"Fix: optimize() must preserve top-level region-wrap invariant after region_inline"
);
}
#[test]
fn pre_lowering_release_profile_exposes_hot_abi_preserving_passes() {
let names = registered_passes_for_profile(OptimizerProfile::Release)
.expect("Fix: release optimizer profile must schedule classified passes")
.into_iter()
.map(|pass| pass.metadata().name)
.collect::<std::collections::BTreeSet<_>>();
for required in [
"dead_store_elim",
"read_only_load_hoist",
"store_to_load_forward",
"loop_licm",
"loop_software_pipeline",
"branch_value_hoist",
"rematerialize_cheap_let",
] {
assert!(
names.contains(required),
"Fix: concrete optimization pass `{required}` exists but is not classified into the Release profile"
);
}
}
#[test]
fn pre_lowering_schedulers_enforce_cost_monotone_contract() {
for phases in [PHASE2_SELECTION, PHASE4_SELECTION] {
let scheduler = pre_lowering_scheduler(phases)
.expect("Fix: pre-lowering scheduler must build for release phases");
assert!(
scheduler.cost_monotone_enforcement(),
"Fix: backend-called pre_lowering::optimize must not land cost-up rewrites silently"
);
assert!(
scheduler.effect_handler_enforcement(),
"Fix: backend-called pre_lowering::optimize must not introduce new effects silently"
);
assert!(
scheduler.linear_type_enforcement(),
"Fix: backend-called pre_lowering::optimize must not introduce linear-type violations silently"
);
assert!(
scheduler.shape_predicate_enforcement(),
"Fix: backend-called pre_lowering::optimize must not introduce shape-predicate violations silently"
);
}
}
#[test]
fn optimize_preserves_var_snapshot_before_source_reassign_in_loop_branch() {
fn contains_tmp_snapshot(nodes: &[Node]) -> bool {
nodes.iter().any(|node| match node {
Node::Let {
name,
value: Expr::Var(source),
} => name.as_str() == "tmp" && source.as_str() == "s0",
Node::If {
then, otherwise, ..
} => contains_tmp_snapshot(then) || contains_tmp_snapshot(otherwise),
Node::Loop { body, .. } | Node::Block(body) => contains_tmp_snapshot(body),
Node::Region { body, .. } => contains_tmp_snapshot(body),
_ => false,
})
}
let program = Program::wrapped(
vec![BufferDecl::output("out", 0, DataType::U32).with_count(1)],
[1, 1, 1],
vec![
Node::let_bind("s0", Expr::u32(1)),
Node::let_bind("s1", Expr::u32(2)),
Node::Loop {
var: "pc".into(),
from: Expr::u32(0),
to: Expr::u32(1),
body: vec![
Node::let_bind("op", Expr::LitU32(4)),
Node::if_then(
Expr::eq(Expr::var("op"), Expr::u32(0)),
vec![
Node::assign("s1", Expr::var("s0")),
Node::assign("s0", Expr::u32(192)),
],
),
Node::if_then(
Expr::eq(Expr::var("op"), Expr::u32(1)),
vec![
Node::assign("s0", Expr::add(Expr::var("s0"), Expr::var("s1"))),
Node::assign("s1", Expr::u32(0)),
],
),
Node::if_then(
Expr::eq(Expr::var("op"), Expr::u32(2)),
vec![
Node::assign("s0", Expr::mul(Expr::var("s0"), Expr::var("s1"))),
Node::assign("s1", Expr::u32(0)),
],
),
Node::if_then(
Expr::eq(Expr::var("op"), Expr::u32(3)),
vec![Node::assign("s1", Expr::var("s0"))],
),
Node::if_then(
Expr::eq(Expr::var("op"), Expr::u32(4)),
vec![
Node::let_bind("tmp", Expr::var("s0")),
Node::assign("s0", Expr::var("s1")),
Node::assign("s1", Expr::var("tmp")),
],
),
],
},
Node::store("out", Expr::u32(0), Expr::var("s1")),
],
);
let optimized = optimize(program);
assert!(
contains_tmp_snapshot(optimized.entry()),
"Fix: pre-lowering optimize must preserve Var Let snapshot boundaries when the source is reassigned later in the same control-flow scope"
);
}
}