use crate::ir::Program;
use crate::optimizer::{fingerprint_program, vyre_pass, PassAnalysis, PassResult};
#[vyre_pass(name = "canonicalize", requires = [], invalidates = ["fusion", "const_fold"])]
pub struct Canonicalize;
impl Canonicalize {
pub fn analyze(_program: &Program) -> PassAnalysis {
PassAnalysis::RUN
}
pub fn transform(program: Program) -> PassResult {
let before_fingerprint = fingerprint_program(&program);
let canonical = super::canonicalize_engine::run(program);
let after_fingerprint = fingerprint_program(&canonical);
let changed = before_fingerprint != after_fingerprint;
PassResult {
program: canonical,
changed,
}
}
#[must_use]
pub fn fingerprint(program: &Program) -> u64 {
fingerprint_program(program)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
#[test]
fn canonicalize_pass_runs_idempotently() {
let program = Program::wrapped(
vec![
BufferDecl::read("a", 0, DataType::U32).with_count(64),
BufferDecl::output("out", 1, DataType::U32)
.with_count(64)
.with_output_byte_range(0..256),
],
[64, 1, 1],
vec![Node::store(
"out",
Expr::gid_x(),
Expr::add(Expr::u32(1), Expr::load("a", Expr::gid_x())),
)],
);
let first = Canonicalize::transform(program);
let second = Canonicalize::transform(first.program);
assert!(!second.changed, "Fix: canonicalize must be idempotent");
}
#[test]
fn canonicalize_pass_skips_already_canonical() {
let program = Program::wrapped(
vec![
BufferDecl::read("a", 0, DataType::U32).with_count(64),
BufferDecl::output("out", 1, DataType::U32)
.with_count(64)
.with_output_byte_range(0..256),
],
[64, 1, 1],
vec![Node::store(
"out",
Expr::gid_x(),
Expr::add(Expr::load("a", Expr::gid_x()), Expr::u32(1)),
)],
);
let result = Canonicalize::transform(program);
assert!(
!result.changed,
"Fix: canonical-form input must not flip the changed flag"
);
}
}