vyre-foundation 0.4.1

Foundation layer: IR, type system, memory model, wire format. Zero application semantics. Part of the vyre GPU compiler.
Documentation
//! Common-subexpression elimination — registered ProgramPass.
//!
//! The engine itself lives at `super::engine`; this module hooks it
//! into the scheduler's fixpoint loop and invalidation tracking.

use super::engine;
use crate::ir::Program;
use crate::optimizer::{fingerprint_program, vyre_pass, PassAnalysis, PassResult};

#[vyre_pass(name = "cse", requires = ["canonicalize"], invalidates = ["fusion"])]
/// Built-in CSE pass.
pub struct CsePass;

impl CsePass {
    /// Run only when the program has at least one expression node.
    #[must_use]
    pub fn analyze(program: &Program) -> PassAnalysis {
        if program.entry().is_empty() {
            PassAnalysis::SKIP
        } else {
            PassAnalysis::RUN
        }
    }

    /// Run CSE over the program entry.
    #[must_use]
    pub fn transform(program: Program) -> PassResult {
        let before = fingerprint_program(&program);
        let optimized = engine::cse(program);
        PassResult {
            changed: fingerprint_program(&optimized) != before,
            program: optimized,
        }
    }

    /// Fingerprint this pass's visible input.
    #[must_use]
    pub fn fingerprint(program: &Program) -> u64 {
        fingerprint_program(program)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::ir::{Expr, Node, Program};

    #[test]
    fn cse_analyze_skips_empty() {
        let empty = Program::new_raw(vec![], [1, 1, 1], vec![]);
        assert_eq!(CsePass::analyze(&empty), PassAnalysis::SKIP);
    }

    #[test]
    fn cse_transform_detects_changes() {
        // Create an IR where the same heavy expression is bound twice.
        let heavy_expr = Expr::add(Expr::var("x"), Expr::var("y"));
        let node1 = Node::let_bind("first", heavy_expr.clone());
        let node2 = Node::let_bind("second", heavy_expr);
        let p = Program::new_raw(vec![], [1, 1, 1], vec![node1, node2]);
        let result = CsePass::transform(p);

        assert!(
            result.changed,
            "CSE failed to detect change on redundant expressions"
        );

        let entry = result.program.entry();
        assert_eq!(entry.len(), 2);
        if let Node::Let { value, .. } = &entry[1] {
            assert!(
                matches!(value, Expr::Var(v) if v.as_ref() == "first"),
                "CSE should have replaced the second binding with a reference to the first"
            );
        } else {
            panic!("Expected Let node");
        }
    }
}