use crate::ir::{Expr, Node, Program};
use crate::optimizer::{fingerprint_program, vyre_pass, PassAnalysis, PassResult};
use crate::visit::node_map;
#[derive(Debug, Default)]
#[vyre_pass(
name = "if_constant_branch_eliminate",
requires = ["const_fold"],
invalidates = []
)]
pub struct IfConstantBranchEliminatePass;
impl IfConstantBranchEliminatePass {
#[must_use]
pub fn analyze(program: &Program) -> PassAnalysis {
if program
.entry()
.iter()
.any(|n| node_map::any_descendant(n, &mut is_constant_if))
{
PassAnalysis::RUN
} else {
PassAnalysis::SKIP
}
}
#[must_use]
pub fn transform(program: Program) -> PassResult {
let scaffold = program.with_rewritten_entry(Vec::new());
let mut changed = false;
let entry: Vec<Node> = program
.into_entry_vec()
.into_iter()
.map(|node| eliminate_node(node, &mut changed))
.collect();
PassResult {
program: scaffold.with_rewritten_entry(entry),
changed,
}
}
#[must_use]
pub fn fingerprint(program: &Program) -> u64 {
fingerprint_program(program)
}
}
fn eliminate_node(node: Node, changed: &mut bool) -> Node {
let recursed = node_map::map_children(node, &mut |child| eliminate_node(child, changed));
match recursed {
Node::If {
cond: Expr::LitBool(true),
then,
otherwise: _,
} => {
*changed = true;
Node::Block(then)
}
Node::If {
cond: Expr::LitBool(false),
then: _,
otherwise,
} => {
*changed = true;
Node::Block(otherwise)
}
other => other,
}
}
fn is_constant_if(node: &Node) -> bool {
matches!(
node,
Node::If {
cond: Expr::LitBool(_),
..
}
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{BufferAccess, BufferDecl, DataType, Expr, Node};
fn buf() -> BufferDecl {
BufferDecl::storage("buf", 0, BufferAccess::ReadWrite, DataType::U32).with_count(4)
}
fn program_with_entry(entry: Vec<Node>) -> Program {
Program::wrapped(vec![buf()], [1, 1, 1], entry)
}
fn count_ifs(node: &Node) -> usize {
match node {
Node::If {
then, otherwise, ..
} => {
1 + then.iter().map(count_ifs).sum::<usize>()
+ otherwise.iter().map(count_ifs).sum::<usize>()
}
Node::Loop { body, .. } => body.iter().map(count_ifs).sum(),
Node::Block(body) => body.iter().map(count_ifs).sum(),
Node::Region { body, .. } => body.iter().map(count_ifs).sum(),
_ => 0,
}
}
#[test]
fn if_true_collapses_to_then_arm() {
let entry = vec![Node::if_then(
Expr::bool(true),
vec![Node::store("buf", Expr::u32(0), Expr::u32(7))],
)];
let program = program_with_entry(entry);
let result = IfConstantBranchEliminatePass::transform(program);
assert!(result.changed);
let total: usize = result.program.entry().iter().map(count_ifs).sum();
assert_eq!(total, 0, "if-true must collapse; got {total} If nodes");
}
#[test]
fn if_false_collapses_to_otherwise_arm() {
let entry = vec![Node::If {
cond: Expr::bool(false),
then: vec![Node::store("buf", Expr::u32(0), Expr::u32(7))],
otherwise: vec![Node::store("buf", Expr::u32(1), Expr::u32(8))],
}];
let program = program_with_entry(entry);
let result = IfConstantBranchEliminatePass::transform(program);
assert!(result.changed);
let total: usize = result.program.entry().iter().map(count_ifs).sum();
assert_eq!(total, 0);
}
#[test]
fn if_with_runtime_condition_kept() {
let entry = vec![Node::if_then(
Expr::var("c"),
vec![Node::store("buf", Expr::u32(0), Expr::u32(7))],
)];
let program = program_with_entry(entry);
let result = IfConstantBranchEliminatePass::transform(program);
assert!(
!result.changed,
"If with non-literal condition must be preserved"
);
}
#[test]
fn nested_constant_ifs_all_collapse() {
let inner = Node::If {
cond: Expr::bool(false),
then: vec![Node::store("buf", Expr::u32(0), Expr::u32(7))],
otherwise: vec![Node::store("buf", Expr::u32(1), Expr::u32(8))],
};
let entry = vec![Node::if_then(Expr::bool(true), vec![inner])];
let program = program_with_entry(entry);
let result = IfConstantBranchEliminatePass::transform(program);
assert!(result.changed);
let total: usize = result.program.entry().iter().map(count_ifs).sum();
assert_eq!(
total, 0,
"nested constant Ifs must all collapse; got {total} remaining"
);
}
#[test]
fn analyze_skips_program_with_no_constant_if() {
let entry = vec![Node::if_then(Expr::var("c"), vec![])];
let program = program_with_entry(entry);
assert_eq!(
IfConstantBranchEliminatePass::analyze(&program),
PassAnalysis::SKIP
);
}
#[test]
fn analyze_runs_when_constant_if_present() {
let entry = vec![Node::if_then(Expr::bool(true), vec![])];
let program = program_with_entry(entry);
assert_eq!(
IfConstantBranchEliminatePass::analyze(&program),
PassAnalysis::RUN
);
}
#[test]
fn if_u32_zero_is_not_matched() {
let entry = vec![Node::if_then(
Expr::u32(0),
vec![Node::store("buf", Expr::u32(0), Expr::u32(7))],
)];
let program = program_with_entry(entry);
let result = IfConstantBranchEliminatePass::transform(program);
assert!(
!result.changed,
"LitU32(0) is not LitBool; pass must not fire"
);
}
#[test]
fn if_u32_one_is_not_matched() {
let entry = vec![Node::if_then(
Expr::u32(1),
vec![Node::store("buf", Expr::u32(0), Expr::u32(7))],
)];
let program = program_with_entry(entry);
let result = IfConstantBranchEliminatePass::transform(program);
assert!(
!result.changed,
"LitU32(1) is not LitBool; pass must not fire"
);
}
#[test]
fn if_i32_zero_is_not_matched() {
let entry = vec![Node::if_then(
Expr::i32(0),
vec![Node::store("buf", Expr::u32(0), Expr::u32(7))],
)];
let program = program_with_entry(entry);
let result = IfConstantBranchEliminatePass::transform(program);
assert!(
!result.changed,
"LitI32(0) is not LitBool; pass must not fire"
);
}
#[test]
fn if_i32_neg1_is_not_matched() {
let entry = vec![Node::if_then(
Expr::i32(-1),
vec![Node::store("buf", Expr::u32(0), Expr::u32(7))],
)];
let program = program_with_entry(entry);
let result = IfConstantBranchEliminatePass::transform(program);
assert!(
!result.changed,
"LitI32(-1) is not LitBool; pass must not fire"
);
}
}