use crate::ir::Program;
use crate::optimizer::{fingerprint_program, vyre_pass, PassAnalysis, PassResult};
#[vyre_pass(name = "region_inline", requires = [], invalidates = ["cse", "dce"])]
pub struct RegionInlinePass;
impl RegionInlinePass {
#[must_use]
pub fn analyze(program: &Program) -> PassAnalysis {
if program
.entry()
.iter()
.any(|n| matches!(n, crate::ir::Node::Region { .. }))
{
PassAnalysis::RUN
} else {
PassAnalysis::SKIP
}
}
#[must_use]
pub fn transform(program: Program) -> PassResult {
let before = fingerprint_program(&program);
let optimized = super::region_inline_engine::run(program);
PassResult {
changed: fingerprint_program(&optimized) != before,
program: optimized,
}
}
#[must_use]
pub fn fingerprint(program: &Program) -> u64 {
fingerprint_program(program)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{Expr, Node};
#[test]
fn region_inline_analyze_skips_without_regions() {
let p = Program::new_raw(vec![], [1, 1, 1], vec![Node::let_bind("x", Expr::u32(1))]);
assert_eq!(RegionInlinePass::analyze(&p), PassAnalysis::SKIP);
}
#[test]
fn region_inline_analyze_runs_with_regions() {
let p = Program::wrapped(
vec![],
[1, 1, 1],
vec![Node::Region {
generator: "test_gen".into(),
source_region: None,
body: vec![Node::let_bind("x", Expr::u32(1))].into(),
}],
);
assert_eq!(RegionInlinePass::analyze(&p), PassAnalysis::RUN);
}
#[test]
fn region_inline_transform_flattens_regions() {
let inner_let = Node::let_bind("x", Expr::u32(1));
let p = Program::wrapped(
vec![],
[1, 1, 1],
vec![Node::Region {
generator: "test_gen".into(),
source_region: None,
body: vec![inner_let.clone()].into(),
}],
);
let result = RegionInlinePass::transform(p);
assert!(result.changed, "Region inline failed to detect change");
assert!(
!result
.program
.entry()
.iter()
.any(|n| matches!(n, Node::Region { .. })),
"Region inline should have removed all Region nodes"
);
assert_eq!(result.program.entry().len(), 1);
assert!(matches!(result.program.entry()[0], Node::Let { .. }));
}
}