use crate::ir::{Expr, Node, Program};
use crate::optimizer::program_soa::ProgramFacts;
use crate::optimizer::{fingerprint_program, vyre_pass, PassAnalysis, PassResult};
use rustc_hash::FxHashMap;
#[derive(Debug, Default)]
#[vyre_pass(
name = "reaching_def_propagate",
requires = ["const_fold"],
invalidates = []
)]
pub struct ReachingDefPropagatePass;
impl ReachingDefPropagatePass {
#[must_use]
pub fn analyze(program: &Program) -> PassAnalysis {
let facts = ProgramFacts::build(program);
if collect_propagatable_lets_with_values(&facts, program).is_empty() {
PassAnalysis::SKIP
} else {
PassAnalysis::RUN
}
}
#[must_use]
pub fn transform(program: Program) -> PassResult {
let facts = ProgramFacts::build(&program);
let propagations = collect_propagatable_lets_with_values(&facts, &program);
if propagations.is_empty() {
return PassResult {
program,
changed: false,
};
}
let scaffold = program.with_rewritten_entry(Vec::new());
let mut changed = false;
let entry: Vec<Node> = program
.into_entry_vec()
.into_iter()
.map(|n| substitute_node(n, &propagations, &mut changed))
.collect();
PassResult {
program: scaffold.with_rewritten_entry(entry),
changed,
}
}
#[must_use]
pub fn fingerprint(program: &Program) -> u64 {
fingerprint_program(program)
}
}
fn substitute_node(node: Node, propagations: &FxHashMap<String, Expr>, changed: &mut bool) -> Node {
match node {
Node::Let { name, value } => {
let new_value = substitute_expr(value, propagations, changed);
Node::Let {
name,
value: new_value,
}
}
Node::Assign { name, value } => Node::Assign {
name,
value: substitute_expr(value, propagations, changed),
},
Node::Store {
buffer,
index,
value,
} => Node::Store {
buffer,
index: substitute_expr(index, propagations, changed),
value: substitute_expr(value, propagations, changed),
},
Node::If {
cond,
then,
otherwise,
} => Node::If {
cond: substitute_expr(cond, propagations, changed),
then: then
.into_iter()
.map(|n| substitute_node(n, propagations, changed))
.collect(),
otherwise: otherwise
.into_iter()
.map(|n| substitute_node(n, propagations, changed))
.collect(),
},
Node::Loop {
var,
from,
to,
body,
} => Node::Loop {
var,
from: substitute_expr(from, propagations, changed),
to: substitute_expr(to, propagations, changed),
body: body
.into_iter()
.map(|n| substitute_node(n, propagations, changed))
.collect(),
},
Node::Block(body) => Node::Block(
body.into_iter()
.map(|n| substitute_node(n, propagations, changed))
.collect(),
),
Node::Region {
generator,
source_region,
body,
} => {
let body_vec: Vec<Node> = match std::sync::Arc::try_unwrap(body) {
Ok(v) => v,
Err(arc) => (*arc).clone(),
};
Node::Region {
generator,
source_region,
body: std::sync::Arc::new(
body_vec
.into_iter()
.map(|n| substitute_node(n, propagations, changed))
.collect(),
),
}
}
Node::AsyncLoad {
source,
destination,
offset,
size,
tag,
} => Node::AsyncLoad {
source,
destination,
tag,
offset: Box::new(substitute_expr(*offset, propagations, changed)),
size: Box::new(substitute_expr(*size, propagations, changed)),
},
Node::AsyncStore {
source,
destination,
offset,
size,
tag,
} => Node::AsyncStore {
source,
destination,
tag,
offset: Box::new(substitute_expr(*offset, propagations, changed)),
size: Box::new(substitute_expr(*size, propagations, changed)),
},
Node::Trap { address, tag } => Node::Trap {
address: Box::new(substitute_expr(*address, propagations, changed)),
tag,
},
other => other,
}
}
fn substitute_expr(expr: Expr, propagations: &FxHashMap<String, Expr>, changed: &mut bool) -> Expr {
match expr {
Expr::Var(ref name) => {
if let Some(literal) = propagations.get(name.as_str()) {
*changed = true;
literal.clone()
} else {
expr
}
}
Expr::Load { buffer, index } => Expr::Load {
buffer,
index: Box::new(substitute_expr(*index, propagations, changed)),
},
Expr::BinOp { op, left, right } => Expr::BinOp {
op,
left: Box::new(substitute_expr(*left, propagations, changed)),
right: Box::new(substitute_expr(*right, propagations, changed)),
},
Expr::UnOp { op, operand } => Expr::UnOp {
op,
operand: Box::new(substitute_expr(*operand, propagations, changed)),
},
Expr::Call { op_id, args } => Expr::Call {
op_id,
args: args
.into_iter()
.map(|a| substitute_expr(a, propagations, changed))
.collect(),
},
Expr::Select {
cond,
true_val,
false_val,
} => Expr::Select {
cond: Box::new(substitute_expr(*cond, propagations, changed)),
true_val: Box::new(substitute_expr(*true_val, propagations, changed)),
false_val: Box::new(substitute_expr(*false_val, propagations, changed)),
},
Expr::Cast { target, value } => Expr::Cast {
target,
value: Box::new(substitute_expr(*value, propagations, changed)),
},
Expr::Fma { a, b, c } => Expr::Fma {
a: Box::new(substitute_expr(*a, propagations, changed)),
b: Box::new(substitute_expr(*b, propagations, changed)),
c: Box::new(substitute_expr(*c, propagations, changed)),
},
Expr::Atomic {
op,
buffer,
index,
expected,
value,
ordering,
} => Expr::Atomic {
op,
buffer,
index: Box::new(substitute_expr(*index, propagations, changed)),
expected: expected.map(|e| Box::new(substitute_expr(*e, propagations, changed))),
value: Box::new(substitute_expr(*value, propagations, changed)),
ordering,
},
Expr::SubgroupBallot { cond } => Expr::SubgroupBallot {
cond: Box::new(substitute_expr(*cond, propagations, changed)),
},
Expr::SubgroupShuffle { value, lane } => Expr::SubgroupShuffle {
value: Box::new(substitute_expr(*value, propagations, changed)),
lane: Box::new(substitute_expr(*lane, propagations, changed)),
},
Expr::SubgroupAdd { value } => Expr::SubgroupAdd {
value: Box::new(substitute_expr(*value, propagations, changed)),
},
other => other,
}
}
fn collect_propagatable_lets_with_values(
facts: &ProgramFacts,
program: &Program,
) -> FxHashMap<String, Expr> {
let mut candidates: rustc_hash::FxHashSet<String> = rustc_hash::FxHashSet::default();
for (_, name) in facts.lets() {
if !facts.is_name_rebound(name.as_str()) {
candidates.insert(name.as_str().to_owned());
}
}
if candidates.is_empty() {
return FxHashMap::default();
}
let mut out: FxHashMap<String, Expr> = FxHashMap::default();
for node in program.entry() {
scan_for_literal_lets(node, &candidates, &mut out);
}
out
}
fn scan_for_literal_lets(
node: &Node,
candidates: &rustc_hash::FxHashSet<String>,
out: &mut FxHashMap<String, Expr>,
) {
match node {
Node::Let { name, value } => {
if candidates.contains(name.as_str()) && is_literal(value) {
out.insert(name.as_str().to_owned(), value.clone());
}
}
Node::If {
then, otherwise, ..
} => {
for n in then {
scan_for_literal_lets(n, candidates, out);
}
for n in otherwise {
scan_for_literal_lets(n, candidates, out);
}
}
Node::Loop { body, .. } => {
for n in body {
scan_for_literal_lets(n, candidates, out);
}
}
Node::Block(body) => {
for n in body {
scan_for_literal_lets(n, candidates, out);
}
}
Node::Region { body, .. } => {
for n in body.iter() {
scan_for_literal_lets(n, candidates, out);
}
}
_ => {}
}
}
fn is_literal(expr: &Expr) -> bool {
matches!(
expr,
Expr::LitU32(_) | Expr::LitI32(_) | Expr::LitF32(_) | Expr::LitBool(_)
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{BufferAccess, BufferDecl, DataType, Expr, Ident, Node};
fn buf() -> BufferDecl {
BufferDecl::storage("buf", 0, BufferAccess::ReadWrite, DataType::U32).with_count(4)
}
fn program(entry: Vec<Node>) -> Program {
Program::wrapped(vec![buf()], [1, 1, 1], entry)
}
fn run(program: Program) -> PassResult {
let facts = ProgramFacts::build(&program);
let propagations = collect_propagatable_lets_with_values(&facts, &program);
if propagations.is_empty() {
return PassResult {
program,
changed: false,
};
}
let scaffold = program.with_rewritten_entry(Vec::new());
let mut changed = false;
let entry: Vec<Node> = program
.into_entry_vec()
.into_iter()
.map(|n| substitute_node(n, &propagations, &mut changed))
.collect();
PassResult {
program: scaffold.with_rewritten_entry(entry),
changed,
}
}
fn count_var_reads(nodes: &[Node], target: &str) -> usize {
let facts = ProgramFacts::build(&Program::wrapped(vec![buf()], [1, 1, 1], nodes.to_vec()));
facts
.var_reads()
.iter()
.filter(|(_, n)| n.as_str() == target)
.count()
}
#[test]
fn propagates_literal_across_sibling_arms() {
let entry = vec![
Node::let_bind("x", Expr::u32(7)),
Node::If {
cond: Expr::var("c"),
then: vec![Node::store("buf", Expr::u32(0), Expr::var("x"))],
otherwise: vec![Node::store("buf", Expr::u32(1), Expr::var("x"))],
},
];
let result = run(program(entry));
assert!(result.changed, "literal must propagate to both arms");
let entry = result.program.entry().to_vec();
assert_eq!(
count_var_reads(&entry, "x"),
0,
"no Var(x) reads remain after propagation"
);
}
#[test]
fn keeps_literal_when_name_is_assigned() {
let entry = vec![
Node::let_bind("x", Expr::u32(7)),
Node::Assign {
name: Ident::from("x"),
value: Expr::u32(99),
},
Node::store("buf", Expr::u32(0), Expr::var("x")),
];
let result = run(program(entry));
assert!(!result.changed);
}
#[test]
fn keeps_literal_when_name_is_shadowed() {
let entry = vec![
Node::let_bind("x", Expr::u32(7)),
Node::Block(vec![
Node::let_bind("x", Expr::u32(99)),
Node::store("buf", Expr::u32(0), Expr::var("x")),
]),
];
let result = run(program(entry));
assert!(!result.changed);
}
#[test]
fn keeps_literal_when_name_is_loop_var() {
let entry = vec![
Node::let_bind("i", Expr::u32(7)),
Node::Loop {
var: Ident::from("i"),
from: Expr::u32(0),
to: Expr::u32(4),
body: vec![Node::store("buf", Expr::var("i"), Expr::u32(1))],
},
];
let result = run(program(entry));
assert!(!result.changed);
}
#[test]
fn keeps_let_with_non_literal_value() {
let entry = vec![
Node::let_bind(
"x",
Expr::BinOp {
op: crate::ir::BinOp::Add,
left: Box::new(Expr::u32(1)),
right: Box::new(Expr::u32(2)),
},
),
Node::store("buf", Expr::u32(0), Expr::var("x")),
];
let result = run(program(entry));
assert!(!result.changed);
}
#[test]
fn propagates_into_loop_body() {
let entry = vec![
Node::let_bind("x", Expr::u32(7)),
Node::Loop {
var: Ident::from("i"),
from: Expr::u32(0),
to: Expr::u32(4),
body: vec![Node::store("buf", Expr::var("i"), Expr::var("x"))],
},
];
let result = run(program(entry));
assert!(result.changed);
let entry = result.program.entry().to_vec();
assert_eq!(count_var_reads(&entry, "x"), 0);
}
#[test]
fn analyze_skips_program_with_no_eligible_lets() {
let entry = vec![Node::store("buf", Expr::u32(0), Expr::u32(1))];
match ReachingDefPropagatePass::analyze(&program(entry)) {
PassAnalysis::SKIP => {}
other => panic!("expected SKIP, got {other:?}"),
}
}
#[test]
fn transform_matches_helper_api() {
let entry = vec![
Node::let_bind("x", Expr::u32(13)),
Node::store("buf", Expr::u32(0), Expr::var("x")),
];
let p1 = run(program(entry.clone()));
let p2 = ReachingDefPropagatePass::transform(program(entry));
assert_eq!(p1.changed, p2.changed);
}
}