use std::sync::Arc;
use vyre_foundation::ir::{Expr, Node, Program};
pub fn apply_dead_branch(program: &Program) -> Program {
let body: Vec<Node> = match program.entry() {
[Node::Region { body, .. }] => body.as_ref().clone(),
entry => entry.to_vec(),
};
let new_body = rewrite_scope(&body);
let new_entry = match program.entry() {
[Node::Region {
generator,
source_region,
..
}] => vec![Node::Region {
generator: generator.clone(),
source_region: source_region.clone(),
body: Arc::new(new_body),
}],
_ => new_body,
};
program.with_rewritten_entry(new_entry)
}
fn rewrite_scope(body: &[Node]) -> Vec<Node> {
let prefix_len = super::encode::reachable_prefix_len(body);
let mut out: Vec<Node> = Vec::with_capacity(prefix_len);
for node in &body[..prefix_len] {
match node {
Node::If {
cond,
then,
otherwise,
} => {
if let Some(taken_branch) = const_branch(cond, then, otherwise) {
out.extend(rewrite_scope(taken_branch));
} else {
let new_then = rewrite_scope(then);
let new_otherwise = rewrite_scope(otherwise);
if new_then.is_empty() && new_otherwise.is_empty() && expr_no_atomic(cond) {
continue;
}
if new_then == new_otherwise && expr_no_atomic(cond) {
out.extend(new_then);
continue;
}
out.push(Node::if_then_else(cond.clone(), new_then, new_otherwise));
}
}
Node::Loop {
var,
from,
to,
body,
} => {
if let (Expr::LitU32(f), Expr::LitU32(t)) = (from, to) {
if f >= t {
continue;
}
}
if let (Expr::LitI32(f), Expr::LitI32(t)) = (from, to) {
if f >= t {
continue;
}
}
if from_to_structurally_equal(from, to) {
continue;
}
let inner = rewrite_scope(body);
if inner.is_empty() && expr_no_atomic(from) && expr_no_atomic(to) {
continue;
}
out.push(Node::loop_for(var.clone(), from.clone(), to.clone(), inner));
}
Node::Block(body) => {
let inner = rewrite_scope(body);
if inner.is_empty() {
continue;
}
out.extend(inner);
}
Node::Region {
generator,
source_region,
body,
} => {
out.push(Node::Region {
generator: generator.clone(),
source_region: source_region.clone(),
body: Arc::new(rewrite_scope(body.as_slice())),
});
}
other => out.push(other.clone()),
}
}
out
}
fn from_to_structurally_equal(from: &Expr, to: &Expr) -> bool {
match (from, to) {
(Expr::LitU32(a), Expr::LitU32(b)) => a == b,
(Expr::LitI32(a), Expr::LitI32(b)) => a == b,
(Expr::Var(a), Expr::Var(b)) => a == b,
(Expr::BufLen { buffer: a }, Expr::BufLen { buffer: b }) => a == b,
(Expr::InvocationId { axis: a }, Expr::InvocationId { axis: b }) => a == b,
(Expr::WorkgroupId { axis: a }, Expr::WorkgroupId { axis: b }) => a == b,
(Expr::LocalId { axis: a }, Expr::LocalId { axis: b }) => a == b,
_ => false,
}
}
fn expr_no_atomic(expr: &Expr) -> bool {
match expr {
Expr::Atomic { .. } => false,
Expr::BinOp { left, right, .. } => expr_no_atomic(left) && expr_no_atomic(right),
Expr::UnOp { operand, .. } => expr_no_atomic(operand),
Expr::Select {
cond,
true_val,
false_val,
} => expr_no_atomic(cond) && expr_no_atomic(true_val) && expr_no_atomic(false_val),
Expr::Fma { a, b, c } => expr_no_atomic(a) && expr_no_atomic(b) && expr_no_atomic(c),
Expr::Load { index, .. } => expr_no_atomic(index),
Expr::Cast { value, .. } => expr_no_atomic(value),
Expr::Call { args, .. } => args.iter().all(expr_no_atomic),
Expr::SubgroupBallot { cond } => expr_no_atomic(cond),
Expr::SubgroupShuffle { value, lane } => expr_no_atomic(value) && expr_no_atomic(lane),
Expr::SubgroupAdd { value } => expr_no_atomic(value),
Expr::Opaque(_) => false,
_ => true,
}
}
fn const_branch<'a>(cond: &Expr, then: &'a [Node], otherwise: &'a [Node]) -> Option<&'a [Node]> {
match cond {
Expr::LitBool(true) => Some(then),
Expr::LitBool(false) => Some(otherwise),
Expr::LitU32(0) => Some(otherwise),
Expr::LitU32(_) => Some(then),
Expr::LitI32(0) => Some(otherwise),
Expr::LitI32(_) => Some(then),
_ => None,
}
}