use std::sync::Arc;
use rustc_hash::FxHashMap;
use vyre_foundation::ir::{Expr, Ident, Node, Program};
use super::cse_via_encoded::CanonicalLookup;
use super::expr_arena::ExprArenaEncoding;
pub fn apply_cross_scope_cse(
program: &Program,
arena: &ExprArenaEncoding,
canonical: &[u32],
) -> Program {
apply_cross_scope_cse_with_lookup(program, arena, canonical)
}
pub fn apply_cross_scope_cse_with_lookup<C: CanonicalLookup + ?Sized>(
program: &Program,
arena: &ExprArenaEncoding,
canonical: &C,
) -> Program {
if canonical.is_empty() || arena.expr_count == 0 {
return program.clone();
}
let body: Vec<Node> = match program.entry() {
[Node::Region { body, .. }] => body.as_ref().clone(),
entry => entry.to_vec(),
};
let mut walker = CseWalker {
arena,
canonical,
node_index: 1, next_let_id: 0,
};
let new_body = walker.rewrite_scope(&body);
let new_entry = match program.entry() {
[Node::Region {
generator,
source_region,
..
}] => vec![Node::Region {
generator: generator.clone(),
source_region: source_region.clone(),
body: Arc::new(new_body),
}],
_ => new_body,
};
program.with_rewritten_entry(new_entry)
}
struct CseWalker<'a, C: CanonicalLookup + ?Sized> {
arena: &'a ExprArenaEncoding,
canonical: &'a C,
node_index: usize,
next_let_id: u32,
}
struct Occurrence {
canon: u32,
expr: Expr,
}
impl<C: CanonicalLookup + ?Sized> CseWalker<'_, C> {
fn rewrite_scope(&mut self, body: &[Node]) -> Vec<Node> {
let prefix_len = super::encode::reachable_prefix_len(body);
let saved_index = self.node_index;
let mut occurrences: Vec<Occurrence> = Vec::new();
for node in &body[..prefix_len] {
self.collect_node(node, &mut occurrences);
}
let mut counts: FxHashMap<u32, (u32, Expr)> = FxHashMap::default();
let mut order: Vec<u32> = Vec::new();
for occ in occurrences {
counts
.entry(occ.canon)
.and_modify(|(c, _)| *c += 1)
.or_insert_with(|| {
order.push(occ.canon);
(1, occ.expr)
});
}
let mut hoist_plan: FxHashMap<u32, Ident> = FxHashMap::default();
let mut hoist_lets: Vec<Node> = Vec::new();
for canon in &order {
let (count, expr) = match counts.get(canon).cloned() {
Some(p) => p,
None => continue,
};
if count < 2 {
continue;
}
if !is_hoist_worthy(&expr) {
continue;
}
if !expr_no_atomic(&expr) {
continue;
}
let name = self.fresh_name();
hoist_lets.push(Node::let_bind(name.clone(), expr));
hoist_plan.insert(*canon, name);
}
self.node_index = saved_index;
let mut out: Vec<Node> = Vec::with_capacity(prefix_len + hoist_lets.len());
out.extend(hoist_lets);
for node in &body[..prefix_len] {
out.push(self.rewrite_node(node, &hoist_plan));
}
out
}
fn collect_node(&mut self, node: &Node, occs: &mut Vec<Occurrence>) {
let idx = self.node_index;
self.node_index += 1;
let top_ids = self
.arena
.node_top_level_exprs
.get(idx)
.cloned()
.unwrap_or_default();
match node {
Node::Let { value, .. } | Node::Assign { value, .. } => {
self.record(&top_ids, 0, value, occs);
}
Node::Store { index, value, .. } => {
self.record(&top_ids, 0, index, occs);
self.record(&top_ids, 1, value, occs);
}
Node::If {
cond,
then,
otherwise,
} => {
self.record(&top_ids, 0, cond, occs);
self.advance_through_scope(then);
self.advance_through_scope(otherwise);
}
Node::Loop { from, to, body, .. } => {
self.record(&top_ids, 0, from, occs);
self.record(&top_ids, 1, to, occs);
self.advance_through_scope(body);
}
Node::Block(body) => self.advance_through_scope(body),
Node::Region { body, .. } => self.advance_through_scope(body.as_slice()),
Node::AsyncLoad { offset, size, .. } | Node::AsyncStore { offset, size, .. } => {
self.record(&top_ids, 0, offset, occs);
self.record(&top_ids, 1, size, occs);
}
Node::Trap { address, .. } => self.record(&top_ids, 0, address, occs),
_ => {}
}
}
fn record(&self, top_ids: &[u32], slot: usize, expr: &Expr, occs: &mut Vec<Occurrence>) {
let arena_id = match top_ids.get(slot).copied() {
Some(id) => id,
None => return,
};
let canon = self.canonical.canonical_of(arena_id);
occs.push(Occurrence {
canon,
expr: expr.clone(),
});
}
fn advance_through_scope(&mut self, body: &[Node]) {
let prefix_len = super::encode::reachable_prefix_len(body);
for node in &body[..prefix_len] {
self.advance_through_node(node);
}
}
fn advance_through_node(&mut self, node: &Node) {
self.node_index += 1;
match node {
Node::If {
then, otherwise, ..
} => {
self.advance_through_scope(then);
self.advance_through_scope(otherwise);
}
Node::Loop { body, .. } => self.advance_through_scope(body),
Node::Block(body) => self.advance_through_scope(body),
Node::Region { body, .. } => self.advance_through_scope(body.as_slice()),
_ => {}
}
}
fn fresh_name(&mut self) -> Ident {
let id = self.next_let_id;
self.next_let_id += 1;
Ident::new(Arc::from(format!("__cse_{id}")))
}
fn rewrite_node(&mut self, node: &Node, plan: &FxHashMap<u32, Ident>) -> Node {
let idx = self.node_index;
self.node_index += 1;
let top_ids: Vec<u32> = self
.arena
.node_top_level_exprs
.get(idx)
.cloned()
.unwrap_or_default();
match node {
Node::Let { name, value } => {
let new_value = self.rewrite_top(&top_ids, 0, value, plan);
Node::let_bind(name.clone(), new_value)
}
Node::Assign { name, value } => {
let new_value = self.rewrite_top(&top_ids, 0, value, plan);
Node::assign(name.clone(), new_value)
}
Node::Store {
buffer,
index,
value,
} => {
let new_index = self.rewrite_top(&top_ids, 0, index, plan);
let new_value = self.rewrite_top(&top_ids, 1, value, plan);
Node::store(buffer.clone(), new_index, new_value)
}
Node::If {
cond,
then,
otherwise,
} => {
let new_cond = self.rewrite_top(&top_ids, 0, cond, plan);
let new_then = self.rewrite_scope(then);
let new_otherwise = self.rewrite_scope(otherwise);
Node::if_then_else(new_cond, new_then, new_otherwise)
}
Node::Loop {
var,
from,
to,
body,
} => {
let new_from = self.rewrite_top(&top_ids, 0, from, plan);
let new_to = self.rewrite_top(&top_ids, 1, to, plan);
let new_body = self.rewrite_scope(body);
Node::loop_for(var.clone(), new_from, new_to, new_body)
}
Node::Block(body) => Node::Block(self.rewrite_scope(body)),
Node::Region {
generator,
source_region,
body,
} => Node::Region {
generator: generator.clone(),
source_region: source_region.clone(),
body: Arc::new(self.rewrite_scope(body.as_slice())),
},
other => other.clone(),
}
}
fn rewrite_top(
&self,
top_ids: &[u32],
slot: usize,
expr: &Expr,
plan: &FxHashMap<u32, Ident>,
) -> Expr {
let arena_id = match top_ids.get(slot).copied() {
Some(id) => id,
None => return expr.clone(),
};
let canon = self.canonical.canonical_of(arena_id);
if let Some(name) = plan.get(&canon) {
return Expr::var(name.clone());
}
expr.clone()
}
}
fn is_hoist_worthy(expr: &Expr) -> bool {
!matches!(
expr,
Expr::LitU32(_)
| Expr::LitI32(_)
| Expr::LitF32(_)
| Expr::LitBool(_)
| Expr::Var(_)
| Expr::BufLen { .. }
| Expr::InvocationId { .. }
| Expr::WorkgroupId { .. }
| Expr::LocalId { .. }
| Expr::SubgroupLocalId
| Expr::SubgroupSize
)
}
fn expr_no_atomic(expr: &Expr) -> bool {
match expr {
Expr::Atomic { .. } => false,
Expr::BinOp { left, right, .. } => expr_no_atomic(left) && expr_no_atomic(right),
Expr::UnOp { operand, .. } => expr_no_atomic(operand),
Expr::Select {
cond,
true_val,
false_val,
} => expr_no_atomic(cond) && expr_no_atomic(true_val) && expr_no_atomic(false_val),
Expr::Fma { a, b, c } => expr_no_atomic(a) && expr_no_atomic(b) && expr_no_atomic(c),
Expr::Load { index, .. } => expr_no_atomic(index),
Expr::Cast { value, .. } => expr_no_atomic(value),
Expr::Call { args, .. } => args.iter().all(expr_no_atomic),
Expr::SubgroupBallot { cond } => expr_no_atomic(cond),
Expr::SubgroupShuffle { value, lane } => expr_no_atomic(value) && expr_no_atomic(lane),
Expr::SubgroupAdd { value } => expr_no_atomic(value),
Expr::Opaque(_) => false,
_ => true,
}
}