use crate::ir::{Expr, Node, Program};
use crate::optimizer::{fingerprint_program, vyre_pass, PassAnalysis, PassResult};
#[derive(Debug, Default)]
#[vyre_pass(
name = "rematerialize_cheap_let",
requires = [],
invalidates = []
)]
pub struct RematerializeCheapLetPass;
impl RematerializeCheapLetPass {
#[must_use]
pub fn analyze(program: &Program) -> PassAnalysis {
let mut found = false;
scan_for_candidate(program.entry(), &mut found);
if found {
PassAnalysis::RUN
} else {
PassAnalysis::SKIP
}
}
#[must_use]
pub fn transform(program: Program) -> PassResult {
let scaffold = program.with_rewritten_entry(Vec::new());
let mut changed = false;
let entry = rewrite_sequence(program.into_entry_vec(), &mut changed);
PassResult {
program: scaffold.with_rewritten_entry(entry),
changed,
}
}
#[must_use]
pub fn fingerprint(program: &Program) -> u64 {
fingerprint_program(program)
}
}
fn rewrite_sequence(nodes: Vec<Node>, changed: &mut bool) -> Vec<Node> {
let mut nodes: Vec<Node> = nodes
.into_iter()
.map(|n| recurse_into_children(n, changed))
.collect();
let mut i = 0;
while i < nodes.len() {
let take_value = match &nodes[i] {
Node::Let { name, value } if is_cheap_leaf(value) => {
let name = name.clone();
let tail = &nodes[i + 1..];
if tail.iter().any(|n| node_reassigns(n, &name)) {
None
} else {
Some((name, value.clone()))
}
}
_ => None,
};
if let Some((name, value)) = take_value {
nodes.remove(i);
for node in &mut nodes[i..] {
substitute_var_in_node(node, &name, &value);
}
*changed = true;
} else {
i += 1;
}
}
nodes
}
fn recurse_into_children(node: Node, changed: &mut bool) -> Node {
match node {
Node::If {
cond,
then,
otherwise,
} => Node::If {
cond,
then: rewrite_sequence(then, changed),
otherwise: rewrite_sequence(otherwise, changed),
},
Node::Loop {
var,
from,
to,
body,
} => Node::Loop {
var,
from,
to,
body: rewrite_sequence(body, changed),
},
Node::Block(body) => Node::Block(rewrite_sequence(body, changed)),
Node::Region {
generator,
source_region,
body,
} => {
let body_vec: Vec<Node> = match std::sync::Arc::try_unwrap(body) {
Ok(v) => v,
Err(arc) => (*arc).clone(),
};
Node::Region {
generator,
source_region,
body: std::sync::Arc::new(rewrite_sequence(body_vec, changed)),
}
}
other => other,
}
}
fn substitute_var_in_node(node: &mut Node, name: &str, value: &Expr) {
match node {
Node::Let { value: v, .. } | Node::Assign { value: v, .. } => {
substitute_var_in_expr(v, name, value);
}
Node::Store {
index, value: v, ..
} => {
substitute_var_in_expr(index, name, value);
substitute_var_in_expr(v, name, value);
}
Node::If {
cond,
then,
otherwise,
} => {
substitute_var_in_expr(cond, name, value);
for n in then {
substitute_var_in_node(n, name, value);
}
for n in otherwise {
substitute_var_in_node(n, name, value);
}
}
Node::Loop { from, to, body, .. } => {
substitute_var_in_expr(from, name, value);
substitute_var_in_expr(to, name, value);
for n in body {
substitute_var_in_node(n, name, value);
}
}
Node::Block(body) => {
for n in body {
substitute_var_in_node(n, name, value);
}
}
Node::Region { body, .. } => {
let body_vec: Vec<Node> = match std::sync::Arc::try_unwrap(std::mem::take(body)) {
Ok(v) => v,
Err(arc) => (*arc).clone(),
};
let mut owned = body_vec;
for n in &mut owned {
substitute_var_in_node(n, name, value);
}
*body = std::sync::Arc::new(owned);
}
Node::AsyncLoad { offset, size, .. } | Node::AsyncStore { offset, size, .. } => {
substitute_var_in_expr(offset, name, value);
substitute_var_in_expr(size, name, value);
}
Node::Trap { address, .. } => {
substitute_var_in_expr(address, name, value);
}
Node::Return
| Node::Barrier { .. }
| Node::IndirectDispatch { .. }
| Node::AsyncWait { .. }
| Node::Resume { .. }
| Node::Opaque(_) => {}
}
}
fn substitute_var_in_expr(expr: &mut Expr, name: &str, value: &Expr) {
match expr {
Expr::Var(ident) if ident.as_str() == name => {
*expr = value.clone();
}
Expr::Var(_)
| Expr::LitU32(_)
| Expr::LitI32(_)
| Expr::LitF32(_)
| Expr::LitBool(_)
| Expr::BufLen { .. }
| Expr::InvocationId { .. }
| Expr::WorkgroupId { .. }
| Expr::LocalId { .. }
| Expr::SubgroupLocalId
| Expr::SubgroupSize
| Expr::Opaque(_) => {}
Expr::Load { index, .. } => substitute_var_in_expr(index, name, value),
Expr::BinOp { left, right, .. } => {
substitute_var_in_expr(left, name, value);
substitute_var_in_expr(right, name, value);
}
Expr::UnOp { operand, .. } => substitute_var_in_expr(operand, name, value),
Expr::Select {
cond,
true_val,
false_val,
} => {
substitute_var_in_expr(cond, name, value);
substitute_var_in_expr(true_val, name, value);
substitute_var_in_expr(false_val, name, value);
}
Expr::Cast { value: v, .. } => substitute_var_in_expr(v, name, value),
Expr::Fma { a, b, c } => {
substitute_var_in_expr(a, name, value);
substitute_var_in_expr(b, name, value);
substitute_var_in_expr(c, name, value);
}
Expr::Call { args, .. } => {
for arg in args {
substitute_var_in_expr(arg, name, value);
}
}
Expr::Atomic {
index,
expected,
value: v,
..
} => {
substitute_var_in_expr(index, name, value);
if let Some(e) = expected.as_deref_mut() {
substitute_var_in_expr(e, name, value);
}
substitute_var_in_expr(v, name, value);
}
Expr::SubgroupBallot { cond } => substitute_var_in_expr(cond, name, value),
Expr::SubgroupShuffle { value: v, lane } => {
substitute_var_in_expr(v, name, value);
substitute_var_in_expr(lane, name, value);
}
Expr::SubgroupAdd { value: v } => substitute_var_in_expr(v, name, value),
}
}
fn is_cheap_leaf(expr: &Expr) -> bool {
matches!(
expr,
Expr::LitU32(_)
| Expr::LitI32(_)
| Expr::LitF32(_)
| Expr::LitBool(_)
| Expr::Var(_)
| Expr::BufLen { .. }
| Expr::InvocationId { .. }
| Expr::WorkgroupId { .. }
| Expr::LocalId { .. }
| Expr::SubgroupLocalId
| Expr::SubgroupSize
)
}
fn node_reassigns(node: &Node, name: &str) -> bool {
match node {
Node::Assign { name: n, .. } if n.as_str() == name => true,
Node::Let { name: n, .. } if n.as_str() == name => true,
Node::Assign { .. } | Node::Let { .. } | Node::Store { .. } => false,
Node::If {
then, otherwise, ..
} => {
then.iter().any(|n| node_reassigns(n, name))
|| otherwise.iter().any(|n| node_reassigns(n, name))
}
Node::Loop { var, body, .. } => {
if var.as_str() == name {
return true;
}
body.iter().any(|n| node_reassigns(n, name))
}
Node::Block(body) => body.iter().any(|n| node_reassigns(n, name)),
Node::Region { body, .. } => body.iter().any(|n| node_reassigns(n, name)),
Node::Return
| Node::Barrier { .. }
| Node::IndirectDispatch { .. }
| Node::AsyncLoad { .. }
| Node::AsyncStore { .. }
| Node::AsyncWait { .. }
| Node::Trap { .. }
| Node::Resume { .. }
| Node::Opaque(_) => false,
}
}
fn scan_for_candidate(nodes: &[Node], found: &mut bool) {
for node in nodes {
if *found {
return;
}
match node {
Node::Let { value, .. } if is_cheap_leaf(value) => *found = true,
Node::If {
then, otherwise, ..
} => {
scan_for_candidate(then, found);
scan_for_candidate(otherwise, found);
}
Node::Loop { body, .. } => scan_for_candidate(body, found),
Node::Block(body) => scan_for_candidate(body, found),
Node::Region { body, .. } => scan_for_candidate(body, found),
_ => {}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{BufferAccess, BufferDecl, DataType, Expr, Node};
fn buf() -> BufferDecl {
BufferDecl::storage("buf", 0, BufferAccess::ReadWrite, DataType::U32).with_count(4)
}
fn program(entry: Vec<Node>) -> Program {
Program::wrapped(vec![buf()], [1, 1, 1], entry)
}
fn find_user_siblings(nodes: &[Node]) -> Option<&[Node]> {
if nodes.iter().any(|n| {
matches!(
n,
Node::Store { .. } | Node::Let { .. } | Node::If { .. } | Node::Loop { .. }
)
}) {
return Some(nodes);
}
for node in nodes {
let body = match node {
Node::Block(body) => body.as_slice(),
Node::Region { body, .. } => body.as_ref().as_slice(),
_ => continue,
};
if let Some(found) = find_user_siblings(body) {
return Some(found);
}
}
None
}
#[test]
fn inlines_literal_into_single_use() {
let entry = vec![
Node::let_bind("z", Expr::u32(0)),
Node::store("buf", Expr::u32(0), Expr::var("z")),
];
let result = RematerializeCheapLetPass::transform(program(entry));
assert!(result.changed, "single-use literal Let must inline");
let siblings = find_user_siblings(result.program.entry()).expect("user body present");
assert_eq!(siblings.len(), 1, "Let dropped, only Store remains");
match &siblings[0] {
Node::Store { value, .. } => {
assert_eq!(*value, Expr::LitU32(0), "literal substituted at use site");
}
other => panic!("expected Store, got {other:?}"),
}
}
#[test]
fn inlines_literal_into_many_uses() {
let entry = vec![
Node::let_bind("z", Expr::u32(7)),
Node::store("buf", Expr::u32(0), Expr::var("z")),
Node::store("buf", Expr::u32(1), Expr::var("z")),
Node::store("buf", Expr::u32(2), Expr::var("z")),
];
let result = RematerializeCheapLetPass::transform(program(entry));
assert!(
result.changed,
"literal Let must inline regardless of use count"
);
let siblings = find_user_siblings(result.program.entry()).expect("user body present");
assert_eq!(siblings.len(), 3, "Let dropped, three Stores remain");
for n in siblings {
match n {
Node::Store { value, .. } => assert_eq!(*value, Expr::LitU32(7)),
other => panic!("expected Store, got {other:?}"),
}
}
}
#[test]
fn inlines_invocation_id() {
let entry = vec![
Node::let_bind("gid", Expr::InvocationId { axis: 0 }),
Node::store("buf", Expr::var("gid"), Expr::u32(1)),
];
let result = RematerializeCheapLetPass::transform(program(entry));
assert!(result.changed, "InvocationId Let must inline");
let siblings = find_user_siblings(result.program.entry()).expect("user body present");
assert_eq!(siblings.len(), 1, "Let dropped");
match &siblings[0] {
Node::Store { index, .. } => {
assert_eq!(*index, Expr::InvocationId { axis: 0 });
}
other => panic!("expected Store, got {other:?}"),
}
}
#[test]
fn keeps_load_let() {
let entry = vec![
Node::let_bind(
"v",
Expr::Load {
buffer: crate::ir::Ident::from("buf"),
index: Box::new(Expr::u32(0)),
},
),
Node::store("buf", Expr::u32(1), Expr::var("v")),
Node::store("buf", Expr::u32(2), Expr::var("v")),
];
let result = RematerializeCheapLetPass::transform(program(entry));
assert!(!result.changed, "Load value must not inline");
}
#[test]
fn keeps_binop_let() {
let entry = vec![
Node::let_bind(
"v",
Expr::BinOp {
op: crate::ir::BinOp::Add,
left: Box::new(Expr::u32(1)),
right: Box::new(Expr::u32(2)),
},
),
Node::store("buf", Expr::u32(0), Expr::var("v")),
];
let result = RematerializeCheapLetPass::transform(program(entry));
assert!(!result.changed, "BinOp value must not inline");
}
#[test]
fn keeps_let_when_name_is_reassigned() {
let entry = vec![
Node::let_bind("z", Expr::u32(0)),
Node::Assign {
name: crate::ir::Ident::from("z"),
value: Expr::u32(99),
},
Node::store("buf", Expr::u32(0), Expr::var("z")),
];
let result = RematerializeCheapLetPass::transform(program(entry));
assert!(!result.changed, "reassigned Let must not be rematerialized");
}
#[test]
fn keeps_let_when_loop_rebinds_name() {
let entry = vec![
Node::let_bind("i", Expr::u32(99)),
Node::Loop {
var: crate::ir::Ident::from("i"),
from: Expr::u32(0),
to: Expr::u32(4),
body: vec![Node::store("buf", Expr::var("i"), Expr::u32(1))],
},
];
let result = RematerializeCheapLetPass::transform(program(entry));
assert!(
!result.changed,
"loop-shadowed Let must not be rematerialized"
);
}
#[test]
fn inlines_into_nested_if() {
let entry = vec![
Node::let_bind("z", Expr::u32(5)),
Node::If {
cond: Expr::var("c"),
then: vec![Node::store("buf", Expr::u32(0), Expr::var("z"))],
otherwise: vec![Node::store("buf", Expr::u32(1), Expr::var("z"))],
},
];
let result = RematerializeCheapLetPass::transform(program(entry));
assert!(result.changed, "nested-If use must be inlined");
let siblings = find_user_siblings(result.program.entry()).expect("user body present");
match &siblings[0] {
Node::If {
then, otherwise, ..
} => {
match &then[0] {
Node::Store { value, .. } => {
assert_eq!(*value, Expr::LitU32(5));
}
other => panic!("expected Store, got {other:?}"),
}
match &otherwise[0] {
Node::Store { value, .. } => {
assert_eq!(*value, Expr::LitU32(5));
}
other => panic!("expected Store, got {other:?}"),
}
}
other => panic!("expected If, got {other:?}"),
}
}
#[test]
fn analyze_skips_program_with_no_cheap_let() {
let entry = vec![Node::store("buf", Expr::u32(0), Expr::u32(1))];
match RematerializeCheapLetPass::analyze(&program(entry)) {
PassAnalysis::SKIP => {}
other => panic!("expected SKIP, got {other:?}"),
}
}
}