use crate::ir::{Expr, Ident, Node, Program};
use crate::optimizer::{fingerprint_program, vyre_pass, PassAnalysis, PassResult};
use rustc_hash::FxHashSet;
#[derive(Debug, Default)]
#[vyre_pass(
name = "loop_licm",
requires = [],
invalidates = []
)]
pub struct LoopLicm;
impl LoopLicm {
#[must_use]
pub fn analyze(program: &Program) -> PassAnalysis {
if program.entry().iter().any(has_hoistable_let_in_any_loop) {
PassAnalysis::RUN
} else {
PassAnalysis::SKIP
}
}
#[must_use]
pub fn transform(program: Program) -> PassResult {
let scaffold = program.with_rewritten_entry(Vec::new());
let mut changed = false;
let entry: Vec<Node> = hoist_in_body(program.into_entry_vec(), &mut changed);
PassResult {
program: scaffold.with_rewritten_entry(entry),
changed,
}
}
#[must_use]
pub fn fingerprint(program: &Program) -> u64 {
fingerprint_program(program)
}
}
fn hoist_in_body(body: Vec<Node>, changed: &mut bool) -> Vec<Node> {
let mut out: Vec<Node> = Vec::with_capacity(body.len());
for node in body {
match node {
Node::Loop {
var,
from,
to,
body: loop_body,
} => {
let inner = hoist_in_body(loop_body, changed);
let (hoisted, kept) = split_invariant_lets(&var, inner, changed);
for h in hoisted {
out.push(h);
}
out.push(Node::Loop {
var,
from,
to,
body: kept,
});
}
Node::If {
cond,
then,
otherwise,
} => {
let then = hoist_in_body(then, changed);
let otherwise = hoist_in_body(otherwise, changed);
out.push(Node::If {
cond,
then,
otherwise,
});
}
Node::Block(inner) => {
out.push(Node::Block(hoist_in_body(inner, changed)));
}
Node::Region {
generator,
source_region,
body,
} => {
let body_vec =
std::sync::Arc::try_unwrap(body).unwrap_or_else(|arc| (*arc).clone());
let body_vec = hoist_in_body(body_vec, changed);
out.push(Node::Region {
generator,
source_region,
body: std::sync::Arc::new(body_vec),
});
}
other => out.push(other),
}
}
out
}
fn split_invariant_lets(
loop_var: &Ident,
body: Vec<Node>,
changed: &mut bool,
) -> (Vec<Node>, Vec<Node>) {
let mut mutated: FxHashSet<Ident> = FxHashSet::default();
mutated.insert(loop_var.clone());
collect_assigned_and_let_bound_names(&body, &mut mutated);
let mut hoisted: Vec<Node> = Vec::new();
let mut kept: Vec<Node> = Vec::with_capacity(body.len());
for node in body {
match node {
Node::Let { name, value } => {
let any_dependency_mutated = expr_references_any(&value, &mutated);
if !any_dependency_mutated && expr_is_observably_free(&value) {
*changed = true;
mutated.remove(&name);
hoisted.push(Node::let_bind(name.as_str(), *Box::new(value)));
} else {
kept.push(Node::Let { name, value });
}
}
other => kept.push(other),
}
}
(hoisted, kept)
}
fn collect_assigned_and_let_bound_names(nodes: &[Node], out: &mut FxHashSet<Ident>) {
for node in nodes {
match node {
Node::Let { name, .. } | Node::Assign { name, .. } => {
out.insert(name.clone());
}
Node::If {
then, otherwise, ..
} => {
collect_assigned_and_let_bound_names(then, out);
collect_assigned_and_let_bound_names(otherwise, out);
}
Node::Loop { body, .. } | Node::Block(body) => {
collect_assigned_and_let_bound_names(body, out);
}
Node::Region { body, .. } => {
collect_assigned_and_let_bound_names(body, out);
}
_ => {}
}
}
}
fn expr_references_any(expr: &Expr, mutated: &FxHashSet<Ident>) -> bool {
match expr {
Expr::Var(name) => mutated.contains(name),
Expr::Load { index, .. } => expr_references_any(index, mutated),
Expr::BinOp { left, right, .. } => {
expr_references_any(left, mutated) || expr_references_any(right, mutated)
}
Expr::UnOp { operand, .. } | Expr::Cast { value: operand, .. } => {
expr_references_any(operand, mutated)
}
Expr::Fma { a, b, c } => {
expr_references_any(a, mutated)
|| expr_references_any(b, mutated)
|| expr_references_any(c, mutated)
}
Expr::Select {
cond,
true_val,
false_val,
} => {
expr_references_any(cond, mutated)
|| expr_references_any(true_val, mutated)
|| expr_references_any(false_val, mutated)
}
Expr::Call { args, .. } => args.iter().any(|a| expr_references_any(a, mutated)),
Expr::Atomic {
index,
expected,
value,
..
} => {
expr_references_any(index, mutated)
|| expected
.as_deref()
.is_some_and(|e| expr_references_any(e, mutated))
|| expr_references_any(value, mutated)
}
Expr::SubgroupShuffle { value, .. } | Expr::SubgroupAdd { value } => {
expr_references_any(value, mutated)
}
Expr::SubgroupBallot { cond } => expr_references_any(cond, mutated),
Expr::LitU32(_)
| Expr::LitI32(_)
| Expr::LitF32(_)
| Expr::LitBool(_)
| Expr::BufLen { .. }
| Expr::InvocationId { .. }
| Expr::WorkgroupId { .. }
| Expr::LocalId { .. }
| Expr::SubgroupLocalId
| Expr::SubgroupSize
| Expr::Opaque(_) => false,
}
}
fn expr_is_observably_free(expr: &Expr) -> bool {
match expr {
Expr::LitU32(_)
| Expr::LitI32(_)
| Expr::LitF32(_)
| Expr::LitBool(_)
| Expr::Var(_)
| Expr::InvocationId { .. }
| Expr::WorkgroupId { .. }
| Expr::LocalId { .. } => true,
Expr::BinOp { left, right, .. } => {
expr_is_observably_free(left) && expr_is_observably_free(right)
}
Expr::UnOp { operand, .. } | Expr::Cast { value: operand, .. } => {
expr_is_observably_free(operand)
}
Expr::Fma { a, b, c } => {
expr_is_observably_free(a) && expr_is_observably_free(b) && expr_is_observably_free(c)
}
Expr::Select {
cond,
true_val,
false_val,
} => {
expr_is_observably_free(cond)
&& expr_is_observably_free(true_val)
&& expr_is_observably_free(false_val)
}
Expr::Load { .. }
| Expr::BufLen { .. }
| Expr::Atomic { .. }
| Expr::Call { .. }
| Expr::Opaque(_)
| Expr::SubgroupShuffle { .. }
| Expr::SubgroupAdd { .. }
| Expr::SubgroupBallot { .. }
| Expr::SubgroupLocalId
| Expr::SubgroupSize => false,
}
}
fn has_hoistable_let_in_any_loop(node: &Node) -> bool {
match node {
Node::Loop { var, body, .. } => {
let mut mutated: FxHashSet<Ident> = FxHashSet::default();
mutated.insert(var.clone());
collect_assigned_and_let_bound_names(body, &mut mutated);
for n in body {
if let Node::Let { name, value } = n {
let mut deps = mutated.clone();
deps.remove(name);
if !expr_references_any(value, &deps) && expr_is_observably_free(value) {
return true;
}
}
if has_hoistable_let_in_any_loop(n) {
return true;
}
}
false
}
Node::If {
then, otherwise, ..
} => {
then.iter().any(has_hoistable_let_in_any_loop)
|| otherwise.iter().any(has_hoistable_let_in_any_loop)
}
Node::Block(body) => body.iter().any(has_hoistable_let_in_any_loop),
Node::Region { body, .. } => body.iter().any(has_hoistable_let_in_any_loop),
_ => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{BufferAccess, BufferDecl, DataType, Expr, Node};
fn buf() -> BufferDecl {
BufferDecl::storage("buf", 0, BufferAccess::ReadWrite, DataType::U32).with_count(4)
}
fn program(entry: Vec<Node>) -> Program {
Program::wrapped(vec![buf()], [1, 1, 1], entry)
}
fn count_lets(node: &Node) -> usize {
match node {
Node::Let { .. } => 1,
Node::If {
then, otherwise, ..
} => {
then.iter().map(count_lets).sum::<usize>()
+ otherwise.iter().map(count_lets).sum::<usize>()
}
Node::Loop { body, .. } | Node::Block(body) => body.iter().map(count_lets).sum(),
Node::Region { body, .. } => body.iter().map(count_lets).sum(),
_ => 0,
}
}
fn count_lets_in_loop_body(entry: &[Node]) -> usize {
for n in entry {
if let Node::Loop { body, .. } = n {
return body.iter().map(count_lets).sum();
}
}
0
}
#[test]
fn hoists_pure_let_above_loop() {
let entry = vec![Node::loop_for(
"i",
Expr::u32(0),
Expr::u32(8),
vec![
Node::let_bind("k", Expr::u32(7)),
Node::store("buf", Expr::var("i"), Expr::var("k")),
],
)];
let result = LoopLicm::transform(program(entry));
assert!(result.changed);
let entry = result.program.entry();
assert_eq!(count_lets(&entry[0]), 1, "hoisted Let lives at outer scope");
assert_eq!(
count_lets_in_loop_body(entry),
0,
"loop body no longer holds the Let"
);
}
#[test]
fn does_not_hoist_let_that_references_loop_var() {
let entry = vec![Node::loop_for(
"i",
Expr::u32(0),
Expr::u32(8),
vec![
Node::let_bind("idx_plus_one", Expr::add(Expr::var("i"), Expr::u32(1))),
Node::store("buf", Expr::var("idx_plus_one"), Expr::u32(0)),
],
)];
let result = LoopLicm::transform(program(entry));
assert!(
!result.changed,
"Let depends on loop var; must stay in loop body"
);
}
#[test]
fn does_not_hoist_let_that_loads_buffer() {
let entry = vec![Node::loop_for(
"i",
Expr::u32(0),
Expr::u32(8),
vec![
Node::let_bind("snap", Expr::load("buf", Expr::u32(0))),
Node::store("buf", Expr::var("i"), Expr::var("snap")),
],
)];
let result = LoopLicm::transform(program(entry));
assert!(
!result.changed,
"Load must not be hoisted; ordering matters"
);
}
#[test]
fn does_not_hoist_let_whose_dependency_is_assigned_in_loop() {
let entry = vec![
Node::let_bind("acc", Expr::u32(0)),
Node::loop_for(
"i",
Expr::u32(0),
Expr::u32(8),
vec![
Node::let_bind("tmp", Expr::add(Expr::var("acc"), Expr::u32(1))),
Node::assign("acc", Expr::var("tmp")),
],
),
];
let result = LoopLicm::transform(program(entry));
assert!(
!result.changed,
"Let depends on a name Assign'd in the loop; cannot hoist"
);
}
#[test]
fn hoists_multiple_independent_pure_lets_in_order() {
let entry = vec![Node::loop_for(
"i",
Expr::u32(0),
Expr::u32(8),
vec![
Node::let_bind("a", Expr::u32(1)),
Node::let_bind("b", Expr::u32(2)),
Node::store(
"buf",
Expr::var("i"),
Expr::add(Expr::var("a"), Expr::var("b")),
),
],
)];
let result = LoopLicm::transform(program(entry));
assert!(result.changed);
let entry = result.program.entry();
let total_outer_lets: usize = entry.iter().take(2).map(count_lets).sum();
assert_eq!(total_outer_lets, 2, "both invariant Lets hoisted");
assert_eq!(count_lets_in_loop_body(entry), 0);
}
#[test]
fn analyze_skips_program_with_no_loops() {
let entry = vec![Node::let_bind("a", Expr::u32(1))];
assert_eq!(LoopLicm::analyze(&program(entry)), PassAnalysis::SKIP);
}
#[test]
fn analyze_runs_when_loop_has_hoistable_let() {
let entry = vec![Node::loop_for(
"i",
Expr::u32(0),
Expr::u32(8),
vec![Node::let_bind("k", Expr::u32(7))],
)];
assert_eq!(LoopLicm::analyze(&program(entry)), PassAnalysis::RUN);
}
#[test]
fn nested_loop_hoists_inner_invariant_all_the_way_out() {
let entry = vec![Node::loop_for(
"i",
Expr::u32(0),
Expr::u32(2),
vec![Node::loop_for(
"j",
Expr::u32(0),
Expr::u32(4),
vec![
Node::let_bind("k", Expr::u32(7)),
Node::store("buf", Expr::var("j"), Expr::var("k")),
],
)],
)];
let result = LoopLicm::transform(program(entry));
assert!(result.changed);
let entry = result.program.entry();
assert_eq!(
entry.len(),
1,
"Program::wrapped wraps the entry in a Region"
);
let Node::Region {
body: region_body, ..
} = &entry[0]
else {
panic!("Fix: entry must be the Region wrapper");
};
assert!(
region_body.len() >= 2,
"Region body holds the hoisted Let and the surviving outer Loop"
);
assert!(matches!(®ion_body[0], Node::Let { name, .. } if name == "k"));
let Node::Loop {
body: outer_body, ..
} = ®ion_body[1]
else {
panic!("Fix: second Region-body node must be the outer Loop");
};
assert_eq!(
outer_body.len(),
1,
"outer Loop body holds only the inner Loop"
);
let Node::Loop {
body: inner_body, ..
} = &outer_body[0]
else {
panic!("Fix: outer Loop body's child must be the inner Loop");
};
assert_eq!(inner_body.len(), 1, "inner Loop body holds only the Store");
assert!(matches!(&inner_body[0], Node::Store { .. }));
}
}