use crate::ir::{Expr, Node, Program};
use crate::optimizer::{fingerprint_program, vyre_pass, PassAnalysis, PassResult};
use crate::visit::node_map;
#[derive(Debug, Default)]
#[vyre_pass(
name = "branch_coalesce",
requires = [],
invalidates = []
)]
pub struct BranchCoalesce;
impl BranchCoalesce {
#[must_use]
pub fn analyze(program: &Program) -> PassAnalysis {
if program
.entry()
.iter()
.any(|n| node_map::any_descendant(n, &mut is_coalesceable_if))
{
PassAnalysis::RUN
} else {
PassAnalysis::SKIP
}
}
#[must_use]
pub fn transform(program: Program) -> PassResult {
let scaffold = program.with_rewritten_entry(Vec::new());
let mut changed = false;
let entry: Vec<Node> = program
.into_entry_vec()
.into_iter()
.map(|n| rewrite_node(n, &mut changed))
.collect();
PassResult {
program: scaffold.with_rewritten_entry(entry),
changed,
}
}
#[must_use]
pub fn fingerprint(program: &Program) -> u64 {
fingerprint_program(program)
}
}
fn rewrite_node(node: Node, changed: &mut bool) -> Node {
let recursed = node_map::map_children(node, &mut |child| rewrite_node(child, changed));
let recursed = node_map::map_body(recursed, &mut |body| {
body.into_iter().map(|n| rewrite_node(n, changed)).collect()
});
coalesce_if(recursed, changed)
}
fn coalesce_if(node: Node, changed: &mut bool) -> Node {
let Node::If {
cond: outer_cond,
then,
otherwise,
} = node
else {
return node_unchanged_helper(node);
};
if !otherwise.is_empty() || then.len() != 1 {
return Node::If {
cond: outer_cond,
then,
otherwise,
};
}
let mut then_iter = then.into_iter();
let inner = then_iter.next().expect("then.len() == 1 by guard above");
let Node::If {
cond: inner_cond,
then: inner_then,
otherwise: inner_otherwise,
} = inner
else {
return Node::If {
cond: outer_cond,
then: vec![inner],
otherwise,
};
};
if !inner_otherwise.is_empty() {
return Node::If {
cond: outer_cond,
then: vec![Node::If {
cond: inner_cond,
then: inner_then,
otherwise: inner_otherwise,
}],
otherwise,
};
}
if !is_pure_bool_expr(&outer_cond) || !is_pure_bool_expr(&inner_cond) {
return Node::If {
cond: outer_cond,
then: vec![Node::If {
cond: inner_cond,
then: inner_then,
otherwise: inner_otherwise,
}],
otherwise,
};
}
*changed = true;
Node::If {
cond: Expr::and(outer_cond, inner_cond),
then: inner_then,
otherwise,
}
}
fn node_unchanged_helper(node: Node) -> Node {
node
}
fn is_coalesceable_if(node: &Node) -> bool {
let Node::If {
cond: outer_cond,
then,
otherwise,
} = node
else {
return false;
};
if !otherwise.is_empty() || then.len() != 1 {
return false;
}
let Node::If {
cond: inner_cond,
otherwise: inner_otherwise,
..
} = &then[0]
else {
return false;
};
if !inner_otherwise.is_empty() {
return false;
}
is_pure_bool_expr(outer_cond) && is_pure_bool_expr(inner_cond)
}
fn is_pure_bool_expr(expr: &Expr) -> bool {
match expr {
Expr::LitBool(_) => true,
Expr::Var(_) => true,
Expr::BinOp { left, right, .. } => is_pure_bool_expr(left) && is_pure_bool_expr(right),
Expr::UnOp { operand, .. } => is_pure_bool_expr(operand),
Expr::Select {
cond,
true_val,
false_val,
} => is_pure_bool_expr(cond) && is_pure_bool_expr(true_val) && is_pure_bool_expr(false_val),
Expr::Cast { value, .. } => is_pure_bool_expr(value),
Expr::InvocationId { .. }
| Expr::WorkgroupId { .. }
| Expr::LocalId { .. }
| Expr::SubgroupLocalId
| Expr::SubgroupSize => true,
Expr::LitU32(_) | Expr::LitI32(_) | Expr::LitF32(_) => true,
Expr::Load { .. }
| Expr::BufLen { .. }
| Expr::Atomic { .. }
| Expr::Call { .. }
| Expr::Opaque(_)
| Expr::Fma { .. }
| Expr::SubgroupBallot { .. }
| Expr::SubgroupShuffle { .. }
| Expr::SubgroupAdd { .. } => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{BufferAccess, BufferDecl, DataType, Expr, Ident, Node};
fn buf() -> BufferDecl {
BufferDecl::storage("buf", 0, BufferAccess::ReadWrite, DataType::U32).with_count(4)
}
fn program_with_entry(entry: Vec<Node>) -> Program {
Program::wrapped(vec![buf()], [1, 1, 1], entry)
}
fn count_ifs(node: &Node) -> usize {
match node {
Node::If {
then, otherwise, ..
} => {
1 + then.iter().map(count_ifs).sum::<usize>()
+ otherwise.iter().map(count_ifs).sum::<usize>()
}
Node::Loop { body, .. } | Node::Block(body) => body.iter().map(count_ifs).sum(),
Node::Region { body, .. } => body.iter().map(count_ifs).sum(),
_ => 0,
}
}
fn first_if_cond(entry: &[Node]) -> Option<&Expr> {
for node in entry {
match node {
Node::If { cond, .. } => return Some(cond),
Node::Region { body, .. } => {
if let Some(c) = first_if_cond(body.as_ref()) {
return Some(c);
}
}
Node::Block(body) | Node::Loop { body, .. } => {
if let Some(c) = first_if_cond(body) {
return Some(c);
}
}
_ => {}
}
}
None
}
#[test]
fn coalesces_nested_if_with_two_pure_conds() {
let entry = vec![Node::if_then(
Expr::var("c1"),
vec![Node::if_then(
Expr::var("c2"),
vec![Node::store("buf", Expr::u32(0), Expr::u32(7))],
)],
)];
let program = program_with_entry(entry);
let result = BranchCoalesce::transform(program);
assert!(result.changed);
let entry: Vec<&Node> = result.program.entry().iter().collect();
let total: usize = entry.iter().map(|n| count_ifs(n)).sum();
assert_eq!(total, 1, "two nested Ifs collapse into one");
let cond = first_if_cond(result.program.entry()).expect("Fix: must have an If");
assert_eq!(cond, &Expr::and(Expr::var("c1"), Expr::var("c2")));
}
#[test]
fn does_not_coalesce_when_outer_has_sibling() {
let entry = vec![Node::if_then(
Expr::var("c1"),
vec![
Node::store("buf", Expr::u32(0), Expr::u32(7)),
Node::if_then(
Expr::var("c2"),
vec![Node::store("buf", Expr::u32(1), Expr::u32(8))],
),
],
)];
let program = program_with_entry(entry);
let result = BranchCoalesce::transform(program);
assert!(
!result.changed,
"must not hoist sibling Store into combined branch"
);
}
#[test]
fn does_not_coalesce_when_outer_has_otherwise() {
let entry = vec![Node::if_then_else(
Expr::var("c1"),
vec![Node::if_then(
Expr::var("c2"),
vec![Node::store("buf", Expr::u32(0), Expr::u32(7))],
)],
vec![Node::store("buf", Expr::u32(0), Expr::u32(9))],
)];
let program = program_with_entry(entry);
let result = BranchCoalesce::transform(program);
assert!(!result.changed, "outer else-arm must be preserved");
}
#[test]
fn does_not_coalesce_when_inner_has_otherwise() {
let entry = vec![Node::if_then(
Expr::var("c1"),
vec![Node::if_then_else(
Expr::var("c2"),
vec![Node::store("buf", Expr::u32(0), Expr::u32(7))],
vec![Node::store("buf", Expr::u32(0), Expr::u32(9))],
)],
)];
let program = program_with_entry(entry);
let result = BranchCoalesce::transform(program);
assert!(!result.changed, "inner else-arm must be preserved");
}
#[test]
fn does_not_coalesce_when_outer_cond_loads_memory() {
let entry = vec![Node::if_then(
Expr::eq(Expr::load("buf", Expr::u32(0)), Expr::u32(0)),
vec![Node::if_then(
Expr::var("c2"),
vec![Node::store("buf", Expr::u32(0), Expr::u32(7))],
)],
)];
let program = program_with_entry(entry);
let result = BranchCoalesce::transform(program);
assert!(
!result.changed,
"outer cond reads memory; conjoining could change ordering"
);
}
#[test]
fn does_not_coalesce_when_inner_cond_loads_memory() {
let entry = vec![Node::if_then(
Expr::var("c1"),
vec![Node::if_then(
Expr::eq(Expr::load("buf", Expr::u32(0)), Expr::u32(0)),
vec![Node::store("buf", Expr::u32(0), Expr::u32(7))],
)],
)];
let program = program_with_entry(entry);
let result = BranchCoalesce::transform(program);
assert!(
!result.changed,
"inner cond reads memory; conjoining could change ordering"
);
}
#[test]
fn coalesces_three_level_nesting_in_one_pass() {
let entry = vec![Node::if_then(
Expr::var("c1"),
vec![Node::if_then(
Expr::var("c2"),
vec![Node::if_then(
Expr::var("c3"),
vec![Node::store("buf", Expr::u32(0), Expr::u32(1))],
)],
)],
)];
let program = program_with_entry(entry);
let result = BranchCoalesce::transform(program);
assert!(result.changed);
let total: usize = result.program.entry().iter().map(|n| count_ifs(n)).sum();
assert_eq!(total, 1, "three nested Ifs collapse into one");
let cond = first_if_cond(result.program.entry()).expect("Fix: must have an If");
let expected = Expr::and(Expr::var("c1"), Expr::and(Expr::var("c2"), Expr::var("c3")));
assert_eq!(cond, &expected);
}
#[test]
fn analyze_skips_program_with_no_coalesceable_pair() {
let entry = vec![Node::if_then(
Expr::var("c1"),
vec![Node::store("buf", Expr::u32(0), Expr::u32(1))],
)];
let program = program_with_entry(entry);
assert_eq!(BranchCoalesce::analyze(&program), PassAnalysis::SKIP);
}
#[test]
fn analyze_runs_when_coalesceable_pair_present() {
let entry = vec![Node::if_then(
Expr::var("c1"),
vec![Node::if_then(
Expr::var("c2"),
vec![Node::store("buf", Expr::u32(0), Expr::u32(7))],
)],
)];
let program = program_with_entry(entry);
assert_eq!(BranchCoalesce::analyze(&program), PassAnalysis::RUN);
}
#[test]
fn coalesces_inside_loop_body() {
let _ = Ident::from("i"); let entry = vec![Node::loop_for(
"i",
Expr::u32(0),
Expr::u32(8),
vec![Node::if_then(
Expr::var("c1"),
vec![Node::if_then(
Expr::var("c2"),
vec![Node::store("buf", Expr::var("i"), Expr::u32(7))],
)],
)],
)];
let program = program_with_entry(entry);
let result = BranchCoalesce::transform(program);
assert!(result.changed);
let total: usize = result.program.entry().iter().map(|n| count_ifs(n)).sum();
assert_eq!(total, 1, "nested If inside Loop coalesces");
}
}