use crate::ir::Program;
use crate::optimizer::{fingerprint_program, vyre_pass, PassAnalysis, PassResult};
#[vyre_pass(
name = "canonicalize",
requires = [],
invalidates = ["fusion"],
phase = "canonicalization",
boundary_class = "abi_preserving",
cost_model_family = "scalar"
)]
pub struct Canonicalize;
impl Canonicalize {
fn analyze_impl(program: &Program) -> PassAnalysis {
if !program
.stats()
.has_any_node_kind(crate::ir::stats::NODE_KIND_EXPRESSION_BEARING_MASK)
{
return PassAnalysis::SKIP;
}
PassAnalysis::RUN
}
pub fn transform(program: Program) -> PassResult {
let before_fingerprint = fingerprint_program(&program);
let canonical = super::canonicalize_engine::run(program);
let after_fingerprint = fingerprint_program(&canonical);
let changed = before_fingerprint != after_fingerprint;
PassResult {
program: canonical,
changed,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
#[test]
fn analyze_skips_program_with_no_expression_bearing_nodes() {
let program = Program::wrapped(Vec::new(), [1, 1, 1], vec![Node::Return]);
match crate::optimizer::ProgramPass::analyze(&Canonicalize, &program) {
PassAnalysis::SKIP => {}
other => panic!("expected SKIP for expression-free program, got {other:?}"),
}
}
#[test]
fn canonicalize_pass_runs_idempotently() {
let program = Program::wrapped(
vec![
BufferDecl::read("a", 0, DataType::U32).with_count(64),
BufferDecl::output("out", 1, DataType::U32)
.with_count(64)
.with_output_byte_range(0..256),
],
[64, 1, 1],
vec![Node::store(
"out",
Expr::gid_x(),
Expr::add(Expr::u32(1), Expr::load("a", Expr::gid_x())),
)],
);
let first = Canonicalize::transform(program);
let second = Canonicalize::transform(first.program);
assert!(!second.changed, "Fix: canonicalize must be idempotent");
}
#[test]
fn canonicalize_pass_skips_already_canonical() {
let program = Program::wrapped(
vec![
BufferDecl::read("a", 0, DataType::U32).with_count(64),
BufferDecl::output("out", 1, DataType::U32)
.with_count(64)
.with_output_byte_range(0..256),
],
[64, 1, 1],
vec![Node::store(
"out",
Expr::gid_x(),
Expr::add(Expr::load("a", Expr::gid_x()), Expr::u32(1)),
)],
);
let result = Canonicalize::transform(program);
assert!(
!result.changed,
"Fix: canonical-form input must not flip the changed flag"
);
}
}