use crate::ir::{Expr, Node, Program};
use crate::optimizer::{vyre_pass, PassAnalysis, PassResult};
use crate::visit::node_map;
#[derive(Debug, Default)]
#[vyre_pass(
name = "branch_coalesce",
requires = [],
invalidates = [],
phase = "cleanup",
boundary_class = "abi_preserving",
cost_model_family = "fusion"
)]
pub struct BranchCoalesce;
impl BranchCoalesce {
#[must_use]
fn analyze_impl(program: &Program) -> PassAnalysis {
if !program
.stats()
.has_any_node_kind(crate::ir::stats::NODE_KIND_IF)
{
return PassAnalysis::SKIP;
}
if program
.entry()
.iter()
.any(|n| node_map::any_descendant(n, &mut is_coalesceable_if))
{
PassAnalysis::RUN
} else {
PassAnalysis::SKIP
}
}
#[must_use]
pub fn transform(program: Program) -> PassResult {
let mut changed = false;
let program = program.map_entry(|entry| {
entry
.into_iter()
.map(|n| rewrite_node(n, &mut changed))
.collect()
});
PassResult { program, changed }
}
}
fn rewrite_node(node: Node, changed: &mut bool) -> Node {
let recursed = node_map::map_children(node, &mut |child| rewrite_node(child, changed));
let recursed = node_map::map_body(recursed, &mut |body| {
body.into_iter().map(|n| rewrite_node(n, changed)).collect()
});
coalesce_if(recursed, changed)
}
fn coalesce_if(node: Node, changed: &mut bool) -> Node {
let Node::If {
cond: outer_cond,
then,
otherwise,
} = node
else {
return node_unchanged_helper(node);
};
if !otherwise.is_empty() || then.len() != 1 {
return Node::If {
cond: outer_cond,
then,
otherwise,
};
}
let mut then_iter = then.into_iter();
let inner = then_iter
.next()
.unwrap_or_else(|| unreachable!("then.len() == 1 by guard above"));
let Node::If {
cond: inner_cond,
then: inner_then,
otherwise: inner_otherwise,
} = inner
else {
return Node::If {
cond: outer_cond,
then: vec![inner],
otherwise,
};
};
if !inner_otherwise.is_empty() {
return Node::If {
cond: outer_cond,
then: vec![Node::If {
cond: inner_cond,
then: inner_then,
otherwise: inner_otherwise,
}],
otherwise,
};
}
if !is_pure_bool_expr(&outer_cond) || !is_pure_bool_expr(&inner_cond) {
return Node::If {
cond: outer_cond,
then: vec![Node::If {
cond: inner_cond,
then: inner_then,
otherwise: inner_otherwise,
}],
otherwise,
};
}
*changed = true;
Node::If {
cond: Expr::and(outer_cond, inner_cond),
then: inner_then,
otherwise,
}
}
fn node_unchanged_helper(node: Node) -> Node {
node
}
fn is_coalesceable_if(node: &Node) -> bool {
let Node::If {
cond: outer_cond,
then,
otherwise,
} = node
else {
return false;
};
if !otherwise.is_empty() || then.len() != 1 {
return false;
}
let Node::If {
cond: inner_cond,
otherwise: inner_otherwise,
..
} = &then[0]
else {
return false;
};
if !inner_otherwise.is_empty() {
return false;
}
is_pure_bool_expr(outer_cond) && is_pure_bool_expr(inner_cond)
}
fn is_pure_bool_expr(expr: &Expr) -> bool {
match expr {
Expr::BinOp { left, right, .. } => is_pure_bool_expr(left) && is_pure_bool_expr(right),
Expr::UnOp { operand, .. } => is_pure_bool_expr(operand),
Expr::Select {
cond,
true_val,
false_val,
} => is_pure_bool_expr(cond) && is_pure_bool_expr(true_val) && is_pure_bool_expr(false_val),
Expr::Cast { value, .. } => is_pure_bool_expr(value),
Expr::LitBool(_)
| Expr::Var(_)
| Expr::InvocationId { .. }
| Expr::WorkgroupId { .. }
| Expr::LocalId { .. }
| Expr::SubgroupLocalId
| Expr::SubgroupSize
| Expr::LitU32(_)
| Expr::LitI32(_)
| Expr::LitF32(_)
| Expr::BufLen { .. } => true,
Expr::Fma { a, b, c } => is_pure_bool_expr(a) && is_pure_bool_expr(b) && is_pure_bool_expr(c),
Expr::Load { .. }
| Expr::Atomic { .. }
| Expr::Call { .. }
| Expr::Opaque(_)
| Expr::SubgroupBallot { .. }
| Expr::SubgroupShuffle { .. }
| Expr::SubgroupAdd { .. } => false,
}
}
#[cfg(test)]
mod tests;