use std::sync::Arc;
use rustc_hash::FxHashSet;
use vyre_foundation::ir::{BufferAccess, Expr, Ident, Node, Program};
pub fn apply_licm(program: &Program) -> Program {
let read_only: FxHashSet<Ident> = program
.buffers()
.iter()
.filter(|b| matches!(b.access, BufferAccess::ReadOnly))
.map(|b| Ident::new(b.name.clone()))
.collect();
let body: Vec<Node> = match program.entry() {
[Node::Region { body, .. }] => body.as_ref().clone(),
entry => entry.to_vec(),
};
let new_body = rewrite_scope(&body, &read_only);
let new_entry = match program.entry() {
[Node::Region {
generator,
source_region,
..
}] => vec![Node::Region {
generator: generator.clone(),
source_region: source_region.clone(),
body: Arc::new(new_body),
}],
_ => new_body,
};
program.with_rewritten_entry(new_entry)
}
fn rewrite_scope(body: &[Node], read_only: &FxHashSet<Ident>) -> Vec<Node> {
let prefix_len = super::encode::reachable_prefix_len(body);
let mut out: Vec<Node> = Vec::with_capacity(prefix_len);
for node in &body[..prefix_len] {
match node {
Node::Loop {
var,
from,
to,
body: loop_body,
} => {
let inner = rewrite_scope(loop_body, read_only);
let (hoisted, kept) = split_invariants(var, &inner, read_only);
out.extend(hoisted);
out.push(Node::loop_for(var.clone(), from.clone(), to.clone(), kept));
}
Node::If {
cond,
then,
otherwise,
} => {
out.push(Node::if_then_else(
cond.clone(),
rewrite_scope(then, read_only),
rewrite_scope(otherwise, read_only),
));
}
Node::Block(b) => out.push(Node::Block(rewrite_scope(b, read_only))),
Node::Region {
generator,
source_region,
body,
} => out.push(Node::Region {
generator: generator.clone(),
source_region: source_region.clone(),
body: Arc::new(rewrite_scope(body.as_slice(), read_only)),
}),
other => out.push(other.clone()),
}
}
out
}
fn split_invariants(
iter_var: &Ident,
body: &[Node],
read_only: &FxHashSet<Ident>,
) -> (Vec<Node>, Vec<Node>) {
let mut hoisted: Vec<Node> = Vec::new();
let mut kept: Vec<Node> = Vec::new();
let mut hoisted_names: FxHashSet<Ident> = FxHashSet::default();
let mut local_unhoisted: FxHashSet<Ident> = FxHashSet::default();
let mut still_safe = true;
for node in body {
if !still_safe {
kept.push(node.clone());
continue;
}
match node {
Node::Let { name, value } => {
if name == iter_var {
kept.push(node.clone());
local_unhoisted.insert(name.clone());
continue;
}
if expr_is_invariant(value, iter_var, &hoisted_names, &local_unhoisted, read_only) {
hoisted.push(Node::let_bind(name.clone(), value.clone()));
hoisted_names.insert(name.clone());
} else {
kept.push(node.clone());
local_unhoisted.insert(name.clone());
}
}
Node::Store { .. }
| Node::Assign { .. }
| Node::Trap { .. }
| Node::AsyncLoad { .. }
| Node::AsyncStore { .. }
| Node::AsyncWait { .. }
| Node::Barrier { .. }
| Node::IndirectDispatch { .. }
| Node::Resume { .. }
| Node::Opaque(_) => {
still_safe = false;
kept.push(node.clone());
}
_ => kept.push(node.clone()),
}
}
(hoisted, kept)
}
#[allow(clippy::only_used_in_recursion)]
fn expr_is_invariant(
expr: &Expr,
iter_var: &Ident,
hoisted: &FxHashSet<Ident>,
local_unhoisted: &FxHashSet<Ident>,
read_only: &FxHashSet<Ident>,
) -> bool {
match expr {
Expr::LitU32(_)
| Expr::LitI32(_)
| Expr::LitF32(_)
| Expr::LitBool(_)
| Expr::SubgroupLocalId
| Expr::SubgroupSize
| Expr::InvocationId { .. }
| Expr::WorkgroupId { .. }
| Expr::LocalId { .. }
| Expr::BufLen { .. } => true,
Expr::Var(name) => name != iter_var && !local_unhoisted.contains(name),
Expr::BinOp { left, right, .. } => {
expr_is_invariant(left, iter_var, hoisted, local_unhoisted, read_only)
&& expr_is_invariant(right, iter_var, hoisted, local_unhoisted, read_only)
}
Expr::UnOp { operand, .. } => {
expr_is_invariant(operand, iter_var, hoisted, local_unhoisted, read_only)
}
Expr::Select {
cond,
true_val,
false_val,
} => {
expr_is_invariant(cond, iter_var, hoisted, local_unhoisted, read_only)
&& expr_is_invariant(true_val, iter_var, hoisted, local_unhoisted, read_only)
&& expr_is_invariant(false_val, iter_var, hoisted, local_unhoisted, read_only)
}
Expr::Fma { a, b, c } => {
expr_is_invariant(a, iter_var, hoisted, local_unhoisted, read_only)
&& expr_is_invariant(b, iter_var, hoisted, local_unhoisted, read_only)
&& expr_is_invariant(c, iter_var, hoisted, local_unhoisted, read_only)
}
Expr::Load { buffer, index } => {
read_only.contains(buffer)
&& expr_is_invariant(index, iter_var, hoisted, local_unhoisted, read_only)
}
_ => false,
}
}