use crate::ir::{Expr, Node, Program};
use crate::optimizer::{fingerprint_program, vyre_pass, PassAnalysis, PassResult};
use crate::visit::node_map;
#[derive(Debug, Default)]
#[vyre_pass(
name = "branch_value_hoist",
requires = [],
invalidates = []
)]
pub struct BranchValueHoistPass;
impl BranchValueHoistPass {
#[must_use]
pub fn analyze(program: &Program) -> PassAnalysis {
if program
.entry()
.iter()
.any(|n| node_map::any_descendant(n, &mut is_prefix_candidate))
{
PassAnalysis::RUN
} else {
PassAnalysis::SKIP
}
}
#[must_use]
pub fn transform(program: Program) -> PassResult {
let scaffold = program.with_rewritten_entry(Vec::new());
let mut changed = false;
let entry: Vec<Node> = program
.into_entry_vec()
.into_iter()
.flat_map(|node| hoist_prefix(node, &mut changed))
.collect();
PassResult {
program: scaffold.with_rewritten_entry(entry),
changed,
}
}
#[must_use]
pub fn fingerprint(program: &Program) -> u64 {
fingerprint_program(program)
}
}
fn hoist_prefix(node: Node, changed: &mut bool) -> Vec<Node> {
let recursed = node_map::map_children(node, &mut |child| {
let hoisted = hoist_prefix(child, changed);
if hoisted.len() == 1 {
hoisted
.into_iter()
.next()
.unwrap_or(Node::Block(Vec::new()))
} else {
Node::Block(hoisted)
}
});
if let Node::If {
cond,
then,
otherwise,
} = recursed
{
let (prefix, new_then, new_otherwise) = extract_common_prefix(then, otherwise);
if !prefix.is_empty() {
*changed = true;
let mut out = prefix;
out.push(Node::If {
cond,
then: new_then,
otherwise: new_otherwise,
});
return out;
}
return vec![Node::If {
cond,
then: new_then,
otherwise: new_otherwise,
}];
}
vec![recursed]
}
fn extract_common_prefix(
mut then: Vec<Node>,
mut otherwise: Vec<Node>,
) -> (Vec<Node>, Vec<Node>, Vec<Node>) {
let mut prefix = Vec::new();
while let (Some(t), Some(o)) = (then.first(), otherwise.first()) {
if !is_hoistable_let_pair(t, o) {
break;
}
prefix.push(then.remove(0));
otherwise.remove(0);
}
(prefix, then, otherwise)
}
fn is_hoistable_let_pair(a: &Node, b: &Node) -> bool {
match (a, b) {
(
Node::Let {
name: name_a,
value: value_a,
},
Node::Let {
name: name_b,
value: value_b,
},
) => name_a == name_b && value_a == value_b && expr_is_observably_free(value_a),
_ => false,
}
}
fn expr_is_observably_free(expr: &Expr) -> bool {
match expr {
Expr::Load { .. }
| Expr::Atomic { .. }
| Expr::Call { .. }
| Expr::Opaque(_)
| Expr::SubgroupBallot { .. }
| Expr::SubgroupShuffle { .. }
| Expr::SubgroupAdd { .. }
| Expr::SubgroupLocalId
| Expr::SubgroupSize => false,
Expr::LitU32(_)
| Expr::LitI32(_)
| Expr::LitF32(_)
| Expr::LitBool(_)
| Expr::Var(_)
| Expr::BufLen { .. }
| Expr::InvocationId { .. }
| Expr::WorkgroupId { .. }
| Expr::LocalId { .. } => true,
Expr::BinOp { left, right, .. } => {
expr_is_observably_free(left) && expr_is_observably_free(right)
}
Expr::UnOp { operand, .. } => expr_is_observably_free(operand),
Expr::Select {
cond,
true_val,
false_val,
} => {
expr_is_observably_free(cond)
&& expr_is_observably_free(true_val)
&& expr_is_observably_free(false_val)
}
Expr::Cast { value, .. } => expr_is_observably_free(value),
Expr::Fma { a, b, c } => {
expr_is_observably_free(a) && expr_is_observably_free(b) && expr_is_observably_free(c)
}
}
}
fn is_prefix_candidate(node: &Node) -> bool {
if let Node::If {
then, otherwise, ..
} = node
{
match (then.first(), otherwise.first()) {
(Some(t), Some(o)) => is_hoistable_let_pair(t, o),
_ => false,
}
} else {
false
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{BufferAccess, BufferDecl, DataType, Expr, Ident, Node};
fn buf() -> BufferDecl {
BufferDecl::storage("buf", 0, BufferAccess::ReadWrite, DataType::U32).with_count(4)
}
fn program_with_entry(entry: Vec<Node>) -> Program {
Program::wrapped(vec![buf()], [1, 1, 1], entry)
}
fn find_if_with_siblings(nodes: &[Node]) -> Option<&[Node]> {
if nodes.iter().any(|n| matches!(n, Node::If { .. })) {
return Some(nodes);
}
for node in nodes {
let body = match node {
Node::Block(body) => body.as_slice(),
Node::Region { body, .. } => body.as_ref().as_slice(),
_ => continue,
};
if let Some(found) = find_if_with_siblings(body) {
return Some(found);
}
}
None
}
#[test]
fn hoists_single_common_let_prefix() {
let common = Node::let_bind("x", Expr::u32(42));
let entry = vec![Node::If {
cond: Expr::var("c"),
then: vec![
common.clone(),
Node::store("buf", Expr::u32(0), Expr::var("x")),
],
otherwise: vec![common, Node::store("buf", Expr::u32(0), Expr::var("x"))],
}];
let program = program_with_entry(entry);
let result = BranchValueHoistPass::transform(program);
assert!(result.changed, "common Let prefix must be hoisted");
let siblings = find_if_with_siblings(result.program.entry())
.expect("hoisted Let + If must live as siblings somewhere in the entry tree");
assert_eq!(siblings.len(), 2, "prefix Let then surviving If");
assert!(matches!(&siblings[0], Node::Let { name, .. } if name.as_str() == "x"));
assert!(matches!(&siblings[1], Node::If { .. }));
}
#[test]
fn hoists_chain_of_common_lets() {
let a = Node::let_bind("x", Expr::u32(1));
let b = Node::let_bind(
"y",
Expr::BinOp {
op: crate::ir::BinOp::Add,
left: Box::new(Expr::var("x")),
right: Box::new(Expr::u32(2)),
},
);
let entry = vec![Node::If {
cond: Expr::var("c"),
then: vec![
a.clone(),
b.clone(),
Node::store("buf", Expr::u32(0), Expr::var("y")),
],
otherwise: vec![a, b, Node::store("buf", Expr::u32(1), Expr::var("y"))],
}];
let program = program_with_entry(entry);
let result = BranchValueHoistPass::transform(program);
assert!(result.changed, "two-Let prefix must be hoisted in one pass");
let siblings = find_if_with_siblings(result.program.entry())
.expect("hoisted Lets + If must live as siblings somewhere in the entry tree");
assert_eq!(siblings.len(), 3, "two Let prefix nodes then surviving If");
assert!(matches!(&siblings[0], Node::Let { name, .. } if name.as_str() == "x"));
assert!(matches!(&siblings[1], Node::Let { name, .. } if name.as_str() == "y"));
assert!(matches!(&siblings[2], Node::If { .. }));
}
#[test]
fn keeps_when_names_differ() {
let entry = vec![Node::If {
cond: Expr::var("c"),
then: vec![Node::let_bind("x", Expr::u32(1))],
otherwise: vec![Node::let_bind("y", Expr::u32(1))],
}];
let program = program_with_entry(entry);
let result = BranchValueHoistPass::transform(program);
assert!(!result.changed, "differing names must not hoist");
}
#[test]
fn keeps_when_values_differ() {
let entry = vec![Node::If {
cond: Expr::var("c"),
then: vec![Node::let_bind("x", Expr::u32(1))],
otherwise: vec![Node::let_bind("x", Expr::u32(2))],
}];
let program = program_with_entry(entry);
let result = BranchValueHoistPass::transform(program);
assert!(!result.changed, "differing values must not hoist");
}
#[test]
fn keeps_when_value_reads_memory() {
let common = Node::let_bind(
"x",
Expr::Load {
buffer: Ident::from("buf"),
index: Box::new(Expr::u32(0)),
},
);
let entry = vec![Node::If {
cond: Expr::var("c"),
then: vec![common.clone()],
otherwise: vec![common],
}];
let program = program_with_entry(entry);
let result = BranchValueHoistPass::transform(program);
assert!(!result.changed, "Load-bearing prefix must not be hoisted");
}
#[test]
fn keeps_when_value_is_atomic() {
let common = Node::let_bind(
"x",
Expr::Atomic {
op: crate::ir::AtomicOp::Add,
buffer: Ident::from("buf"),
index: Box::new(Expr::u32(0)),
expected: None,
value: Box::new(Expr::u32(1)),
ordering: crate::ir::MemoryOrdering::Relaxed,
},
);
let entry = vec![Node::If {
cond: Expr::var("c"),
then: vec![common.clone()],
otherwise: vec![common],
}];
let program = program_with_entry(entry);
let result = BranchValueHoistPass::transform(program);
assert!(!result.changed, "Atomic prefix must not be hoisted");
}
#[test]
fn keeps_when_prefix_is_store() {
let common = Node::store("buf", Expr::u32(0), Expr::u32(7));
let entry = vec![Node::If {
cond: Expr::var("c"),
then: vec![common.clone()],
otherwise: vec![common],
}];
let program = program_with_entry(entry);
let result = BranchValueHoistPass::transform(program);
assert!(!result.changed, "Store prefix must not be hoisted");
}
#[test]
fn extracts_only_the_common_prefix() {
let common = Node::let_bind("x", Expr::u32(7));
let entry = vec![Node::If {
cond: Expr::var("c"),
then: vec![
common.clone(),
Node::store("buf", Expr::u32(0), Expr::u32(1)),
],
otherwise: vec![common, Node::store("buf", Expr::u32(0), Expr::u32(2))],
}];
let program = program_with_entry(entry);
let result = BranchValueHoistPass::transform(program);
assert!(result.changed, "leading common prefix must be hoisted");
let siblings = find_if_with_siblings(result.program.entry())
.expect("hoisted Let + If must live as siblings somewhere in the entry tree");
let surviving_if = siblings
.iter()
.find(|n| matches!(n, Node::If { .. }))
.expect("surviving If must remain after the hoist");
match surviving_if {
Node::If {
then, otherwise, ..
} => {
assert_eq!(then.len(), 1, "non-prefix tail stays in then");
assert_eq!(otherwise.len(), 1, "non-prefix tail stays in otherwise");
assert!(matches!(&then[0], Node::Store { .. }));
assert!(matches!(&otherwise[0], Node::Store { .. }));
}
other => panic!("expected If, got {other:?}"),
}
}
#[test]
fn keeps_when_one_arm_is_empty() {
let entry = vec![Node::If {
cond: Expr::var("c"),
then: vec![Node::let_bind("x", Expr::u32(1))],
otherwise: vec![],
}];
let program = program_with_entry(entry);
let result = BranchValueHoistPass::transform(program);
assert!(!result.changed, "empty otherwise has nothing to share");
}
#[test]
fn analyze_skips_programs_with_no_branch() {
let entry = vec![Node::store("buf", Expr::u32(0), Expr::u32(1))];
let program = program_with_entry(entry);
match BranchValueHoistPass::analyze(&program) {
PassAnalysis::SKIP => {}
other => panic!("expected SKIP, got {other:?}"),
}
}
}