use super::*;
use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
use crate::optimizer::passes::const_fold::ConstFold;
use crate::optimizer::passes::fusion::Fusion;
use crate::optimizer::passes::normalize_atomics::NormalizeAtomicsPass;
use crate::optimizer::passes::strength_reduce::StrengthReduce;
use crate::optimizer::{PassAnalysis, PassMetadata, PassResult, ProgramPass};
use std::sync::Arc;
fn trivial_program() -> Program {
Program::wrapped(
vec![BufferDecl::read_write("out", 0, DataType::U32).with_count(1)],
[1, 1, 1],
vec![Node::store("out", Expr::u32(0), Expr::u32(42))],
)
}
#[test]
fn single_pass_converges() {
let scheduler = PassScheduler::with_passes(vec![ProgramPassKind::new(ConstFold)]);
let result = scheduler.run(trivial_program());
assert!(result.is_ok());
}
#[test]
fn run_with_metrics_reports_pass_runtime_and_ir_size() {
let scheduler = PassScheduler::with_passes(vec![ProgramPassKind::new(ConstFold)]);
let report = scheduler
.run_with_metrics(trivial_program())
.expect("Fix: metrics run should converge");
assert_eq!(report.passes.len(), 1);
let metric = &report.passes[0];
assert_eq!(metric.pass, "const_fold");
assert!(
metric.ran,
"const_fold should run on the first dirty iteration"
);
assert!(metric.nodes_before > 0);
assert!(metric.nodes_after > 0);
assert!(
metric.ir_heap_allocations_before > 0,
"metrics must include IR heap allocation pressure"
);
assert!(
metric.ir_heap_bytes_before > 0,
"metrics must include estimated IR heap bytes"
);
assert_eq!(
report.program.stats().node_count,
metric.nodes_after,
"metric after-count must describe the returned program"
);
}
#[test]
fn max_iterations_caps_execution() {
let program = Program::wrapped(
vec![BufferDecl::read_write("out", 0, DataType::U32).with_count(1)],
[1, 1, 1],
vec![Node::store(
"out",
Expr::u32(0),
Expr::add(Expr::u32(1), Expr::u32(2)),
)],
);
let scheduler =
PassScheduler::with_passes(vec![ProgramPassKind::new(ConstFold)]).with_max_iterations(0);
let result = scheduler.run(program);
assert!(
matches!(result, Err(OptimizerError::MaxIterations { .. })),
"zero iterations should immediately hit max: {:?}",
result
);
}
#[test]
fn idempotent_pass_converges_in_two_iterations() {
let program = Program::wrapped(
vec![BufferDecl::read_write("out", 0, DataType::U32).with_count(1)],
[1, 1, 1],
vec![Node::store(
"out",
Expr::u32(0),
Expr::add(Expr::u32(3), Expr::u32(4)),
)],
);
let scheduler =
PassScheduler::with_passes(vec![ProgramPassKind::new(ConstFold)]).with_max_iterations(2);
let result = scheduler.run(program);
assert!(result.is_ok(), "should converge within 2 iterations");
}
#[test]
fn multiple_passes_execute() {
let scheduler = PassScheduler::with_passes(vec![
ProgramPassKind::new(ConstFold),
ProgramPassKind::new(StrengthReduce),
]);
let result = scheduler.run(trivial_program());
assert!(result.is_ok());
}
#[test]
fn with_max_iterations_is_configurable() {
let scheduler =
PassScheduler::with_passes(vec![ProgramPassKind::new(ConstFold)]).with_max_iterations(100);
assert_eq!(scheduler.max_iterations, 100);
}
#[test]
fn default_scheduler_uses_registered_passes() {
let scheduler = PassScheduler::default();
assert!(
scheduler.passes.len() >= 9,
"must include at least 9 built-in passes, got {}",
scheduler.passes.len()
);
}
#[test]
fn transitive_dependents_unknown_pass_returns_empty() {
let scheduler = PassScheduler::with_passes(vec![ProgramPassKind::new(ConstFold)]);
assert!(scheduler.transitive_dependents("nonexistent").is_empty());
}
#[test]
fn reaches_unknown_pass_returns_false() {
let scheduler = PassScheduler::with_passes(vec![ProgramPassKind::new(ConstFold)]);
assert!(!scheduler.reaches("nonexistent", "const_fold"));
assert!(!scheduler.reaches("const_fold", "nonexistent"));
}
#[test]
fn pair_commutes_same_pass_is_true() {
let scheduler = PassScheduler::with_passes(vec![ProgramPassKind::new(ConstFold)]);
assert!(scheduler.pair_commutes("const_fold", "const_fold"));
}
#[test]
fn invalidation_marks_named_pass_and_requirement_dependents_dirty() {
let scheduler = PassScheduler::with_passes(vec![
ProgramPassKind::new(ConstFold),
ProgramPassKind::new(StrengthReduce),
ProgramPassKind::new(NormalizeAtomicsPass),
ProgramPassKind::new(Fusion),
]);
let mut dirty = FxHashSet::default();
scheduler.mark_invalidated_passes(&["fusion"], &mut dirty);
assert!(
dirty.contains("fusion"),
"pass-name invalidation must rerun that pass"
);
dirty.clear();
scheduler.mark_invalidated_passes(&["const_fold"], &mut dirty);
assert!(dirty.contains("const_fold"));
assert!(
dirty.contains("strength_reduce"),
"passes requiring an invalidated pass/capability must rerun"
);
}
#[derive(Debug)]
struct TestPass {
metadata: PassMetadata,
changes: bool,
}
impl crate::optimizer::private::Sealed for TestPass {}
impl ProgramPass for TestPass {
fn metadata(&self) -> PassMetadata {
self.metadata
}
fn analyze(&self, _program: &Program) -> PassAnalysis {
PassAnalysis::RUN
}
fn transform(&self, program: Program) -> PassResult {
if self.changes {
let mut entry = program.clone().into_entry_vec();
entry.push(Node::barrier());
PassResult {
program: program.with_rewritten_entry(entry),
changed: true,
}
} else {
PassResult::unchanged(program)
}
}
fn fingerprint(&self, _program: &Program) -> u64 {
0
}
}
#[derive(Debug)]
struct ExprOnlyPass {
metadata: PassMetadata,
}
impl crate::optimizer::private::Sealed for ExprOnlyPass {}
impl ProgramPass for ExprOnlyPass {
fn metadata(&self) -> PassMetadata {
self.metadata
}
fn analyze(&self, _program: &Program) -> PassAnalysis {
PassAnalysis::RUN
}
fn transform(&self, program: Program) -> PassResult {
let mut entry = program.clone().into_entry_vec();
if rewrite_first_store_value(&mut entry) {
return PassResult {
program: program.with_rewritten_entry(entry),
changed: true,
};
}
PassResult::unchanged(program)
}
fn fingerprint(&self, _program: &Program) -> u64 {
0
}
}
#[derive(Debug)]
struct SkipPass;
impl crate::optimizer::private::Sealed for SkipPass {}
impl ProgramPass for SkipPass {
fn metadata(&self) -> PassMetadata {
PassMetadata {
name: "skip_pass",
requires: &[],
invalidates: &[],
}
}
fn analyze(&self, _program: &Program) -> PassAnalysis {
PassAnalysis::SKIP
}
fn transform(&self, program: Program) -> PassResult {
PassResult::unchanged(program)
}
fn fingerprint(&self, _program: &Program) -> u64 {
0
}
}
fn rewrite_first_store_value(nodes: &mut [Node]) -> bool {
for node in nodes {
match node {
Node::Store { value, .. } => {
*value = Expr::u32(43);
return true;
}
Node::If {
then, otherwise, ..
} => {
if rewrite_first_store_value(then) || rewrite_first_store_value(otherwise) {
return true;
}
}
Node::Loop { body, .. } | Node::Block(body) => {
if rewrite_first_store_value(body) {
return true;
}
}
Node::Region { body, .. } => {
let body_vec: &mut Vec<Node> = Arc::make_mut(body);
if rewrite_first_store_value(body_vec.as_mut_slice()) {
return true;
}
}
_ => {}
}
}
false
}
#[test]
fn invalidating_prior_requirement_does_not_break_current_iteration() {
let scheduler = PassScheduler::with_passes(vec![
ProgramPassKind::new(TestPass {
metadata: PassMetadata {
name: "prepare",
requires: &[],
invalidates: &[],
},
changes: false,
}),
ProgramPassKind::new(TestPass {
metadata: PassMetadata {
name: "rewrite",
requires: &[],
invalidates: &["prepare"],
},
changes: true,
}),
ProgramPassKind::new(TestPass {
metadata: PassMetadata {
name: "consume",
requires: &["prepare"],
invalidates: &[],
},
changes: false,
}),
]);
let dirty = scheduler
.passes
.iter()
.map(|pass| pass.metadata().name)
.collect();
let (_program, changed, changed_by, next_dirty) = scheduler
.run_once(trivial_program(), &dirty)
.expect("Fix: invalidating a prior requirement must queue a rerun, not make later passes unschedulable");
assert!(changed);
assert_eq!(changed_by, Some("rewrite"));
assert!(next_dirty.contains("prepare"));
assert!(next_dirty.contains("consume"));
}
#[test]
fn run_with_metrics_tracks_expression_only_rewrites() {
let scheduler = PassScheduler::with_passes(vec![ProgramPassKind::new(ExprOnlyPass {
metadata: PassMetadata {
name: "expr_only",
requires: &[],
invalidates: &["value_numbering"],
},
})]);
let report = scheduler
.run_with_metrics(trivial_program())
.expect("Fix: metrics run must converge for expression-only rewrites");
assert_eq!(report.passes.len(), 2);
let first = &report.passes[0];
assert_eq!(first.pass, "expr_only");
assert!(
first.changed,
"expression-only rewrites keep node_count stable but still changed the program and must invalidate downstream facts"
);
assert_eq!(
first.nodes_before, first.nodes_after,
"the regression target is a same-node-count expression rewrite"
);
assert!(
!report.passes[1].changed,
"the second iteration must observe convergence after the expression rewrite landed"
);
}
#[test]
fn cost_monotone_disabled_by_default_keeps_cost_up_rewrites() {
let scheduler = PassScheduler::with_passes(vec![ProgramPassKind::new(TestPass {
metadata: PassMetadata {
name: "cost_up_default_off",
requires: &[],
invalidates: &[],
},
changes: true,
})]);
assert!(
!scheduler.cost_monotone_enforcement(),
"cost-monotone enforcement must default to OFF — flipping the default would change the \
optimizer's observable behavior on every consumer that constructs PassScheduler::default()"
);
let pre = trivial_program();
let pre_nodes = pre.stats().node_count;
let report = scheduler.run(pre).expect("Fix: scheduler must converge");
assert!(
report.stats().node_count > pre_nodes,
"with the gate disabled, the cost-up rewrite must land — got post_nodes={} pre_nodes={}",
report.stats().node_count,
pre_nodes
);
}
#[test]
fn cost_monotone_enabled_reverts_cost_up_rewrites() {
let scheduler = PassScheduler::with_passes(vec![ProgramPassKind::new(TestPass {
metadata: PassMetadata {
name: "cost_up_with_gate",
requires: &[],
invalidates: &[],
},
changes: true,
})])
.with_cost_monotone_enforcement(true);
assert!(scheduler.cost_monotone_enforcement());
let pre = trivial_program();
let pre_nodes = pre.stats().node_count;
let report = scheduler
.run(pre.clone())
.expect("Fix: scheduler must converge even when the gate reverts a cost-up rewrite");
assert_eq!(
report.stats().node_count,
pre_nodes,
"the gate must revert any pass that increases node_count without an explicit refusal — \
observed post_nodes={} pre_nodes={}",
report.stats().node_count,
pre_nodes
);
}
#[test]
fn cost_monotone_enabled_keeps_monotone_down_rewrites() {
let scheduler = PassScheduler::with_passes(vec![ProgramPassKind::new(TestPass {
metadata: PassMetadata {
name: "noop_with_gate",
requires: &[],
invalidates: &[],
},
changes: false,
})])
.with_cost_monotone_enforcement(true);
let pre = trivial_program();
let pre_nodes = pre.stats().node_count;
let report = scheduler.run(pre).expect("Fix: scheduler must converge");
assert_eq!(
report.stats().node_count,
pre_nodes,
"the gate must NOT mutate Programs that the pass left unchanged"
);
}
#[test]
fn scheduler_lookup_tables_use_static_str_keys() {
fn assert_static_str_map(_: &FxHashMap<&'static str, usize>) {}
let scheduler = PassScheduler::try_default().expect("Fix: built-in passes must be valid");
assert_static_str_map(&scheduler.pass_index);
let mut names: Vec<&'static str> = Vec::with_capacity(20);
for i in 0..20 {
let name: &'static str = Box::leak(format!("stress_pass_{i}").into_boxed_str());
names.push(name);
}
let passes: Vec<_> = names
.iter()
.map(|&name| {
ProgramPassKind::new(TestPass {
metadata: PassMetadata {
name,
requires: &[],
invalidates: &[],
},
changes: false,
})
})
.collect();
let scheduler20 = PassScheduler::with_passes(passes);
assert_static_str_map(&scheduler20.pass_index);
let dirty: FxHashSet<&'static str> = names.iter().copied().collect();
let (_program, _changed, _changed_by, _next_dirty) = scheduler20
.run_once(trivial_program(), &dirty)
.expect("Fix: stress scheduler must run");
for &name in &names {
assert!(
!scheduler20.reaches(name, name),
"a pass must not reach itself"
);
assert!(
scheduler20.pass_index.contains_key(name),
"pass_index must contain {name}"
);
}
}
#[test]
fn cost_monotone_enabled_metrics_reflect_post_revert_state() {
let scheduler = PassScheduler::with_passes(vec![ProgramPassKind::new(TestPass {
metadata: PassMetadata {
name: "cost_up_metric_check",
requires: &[],
invalidates: &[],
},
changes: true,
})])
.with_cost_monotone_enforcement(true);
let report = scheduler
.run_with_metrics(trivial_program())
.expect("Fix: metrics run must converge");
assert_eq!(report.passes.len(), 1);
let metric = &report.passes[0];
assert!(
metric.ran,
"the pass must have actually been called by the scheduler"
);
assert!(
!metric.changed,
"after gate-revert, the metric's `changed` flag must reflect that no change landed; \
got changed={}",
metric.changed
);
assert_eq!(
metric.nodes_after, metric.nodes_before,
"after gate-revert, nodes_before must equal nodes_after — the metric describes the \
post-gate Program shape, not the rejected rewrite"
);
}
#[test]
fn scheduler_preserves_program_identity_when_pass_skips() {
let program = trivial_program();
let original_entry = Arc::clone(program.entry_arc());
let scheduler = PassScheduler::with_passes(vec![ProgramPassKind::new(SkipPass)]);
let result = scheduler
.run(program)
.expect("Fix: scheduler must converge when all passes SKIP");
assert!(
Arc::ptr_eq(&original_entry, result.entry_arc()),
"scheduler must preserve entry Arc identity when a pass returns SKIP; \
reconcile_runnable_top_level must not allocate a fresh Vec or Arc"
);
}