use super::engine;
use crate::ir::Program;
use crate::optimizer::{fingerprint_program, vyre_pass, PassAnalysis, PassResult};
#[vyre_pass(name = "cse", requires = ["canonicalize"], invalidates = ["fusion"])]
pub struct CsePass;
impl CsePass {
#[must_use]
pub fn analyze(program: &Program) -> PassAnalysis {
if program.entry().is_empty() {
PassAnalysis::SKIP
} else {
PassAnalysis::RUN
}
}
#[must_use]
pub fn transform(program: Program) -> PassResult {
let before = fingerprint_program(&program);
let optimized = engine::cse(program);
PassResult {
changed: fingerprint_program(&optimized) != before,
program: optimized,
}
}
#[must_use]
pub fn fingerprint(program: &Program) -> u64 {
fingerprint_program(program)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{Expr, Node, Program};
#[test]
fn cse_analyze_skips_empty() {
let empty = Program::new_raw(vec![], [1, 1, 1], vec![]);
assert_eq!(CsePass::analyze(&empty), PassAnalysis::SKIP);
}
#[test]
fn cse_transform_detects_changes() {
let heavy_expr = Expr::add(Expr::var("x"), Expr::var("y"));
let node1 = Node::let_bind("first", heavy_expr.clone());
let node2 = Node::let_bind("second", heavy_expr);
let p = Program::new_raw(vec![], [1, 1, 1], vec![node1, node2]);
let result = CsePass::transform(p);
assert!(
result.changed,
"CSE failed to detect change on redundant expressions"
);
let entry = result.program.entry();
assert_eq!(entry.len(), 2);
if let Node::Let { value, .. } = &entry[1] {
assert!(
matches!(value, Expr::Var(v) if v.as_ref() == "first"),
"CSE should have replaced the second binding with a reference to the first"
);
} else {
panic!("Expected Let node");
}
}
}