use std::sync::Arc;
use rustc_hash::FxHashSet;
use crate::ir::{Expr, Ident, Node, Program};
const FUSION_ARM_PREFIX: &str = "__vyre_fuse_a";
#[derive(Clone, Copy)]
pub(super) enum RenameScope<'a> {
All,
MultiplyDeclared(&'a FxHashSet<Ident>),
}
pub(super) struct ArmRenamer<'a> {
arm_idx: usize,
scope: RenameScope<'a>,
}
impl<'a> ArmRenamer<'a> {
pub(super) fn isolated(arm_idx: usize) -> Self {
Self {
arm_idx,
scope: RenameScope::All,
}
}
pub(super) fn shared(arm_idx: usize, multiply_declared: &'a FxHashSet<Ident>) -> Self {
Self {
arm_idx,
scope: RenameScope::MultiplyDeclared(multiply_declared),
}
}
pub(super) fn push_entry_node(&self, out: &mut Vec<Node>, node: &Node) {
if let RenameScope::MultiplyDeclared(_) = self.scope {
if let Node::Region {
generator, body, ..
} = node
{
if generator.as_str() == Program::ROOT_REGION_GENERATOR {
for child in body.iter() {
out.push(self.node(child));
}
return;
}
}
}
out.push(self.node(node));
}
fn ident(&self, name: &Ident) -> Ident {
let rename = match self.scope {
RenameScope::All => true,
RenameScope::MultiplyDeclared(set) => set.contains(name),
};
if !rename || name.as_str().starts_with(FUSION_ARM_PREFIX) {
return name.clone();
}
Ident::from(format!("{FUSION_ARM_PREFIX}{}_{}", self.arm_idx, name.as_str()))
}
fn nodes(&self, nodes: &[Node]) -> Vec<Node> {
nodes.iter().map(|node| self.node(node)).collect()
}
fn node(&self, node: &Node) -> Node {
match node {
Node::Let { name, value } => Node::Let {
name: self.ident(name),
value: self.expr(value),
},
Node::Assign { name, value } => Node::Assign {
name: self.ident(name),
value: self.expr(value),
},
Node::Store {
buffer,
index,
value,
} => Node::Store {
buffer: buffer.clone(),
index: self.expr(index),
value: self.expr(value),
},
Node::If {
cond,
then,
otherwise,
} => Node::If {
cond: self.expr(cond),
then: self.nodes(then),
otherwise: self.nodes(otherwise),
},
Node::Loop {
var,
from,
to,
body,
} => Node::Loop {
var: self.ident(var),
from: self.expr(from),
to: self.expr(to),
body: self.nodes(body),
},
Node::Block(body) => Node::Block(self.nodes(body)),
Node::Region {
generator,
source_region,
body,
} => Node::Region {
generator: generator.clone(),
source_region: source_region.clone(),
body: Arc::new(self.nodes(body)),
},
Node::AsyncLoad {
source,
destination,
offset,
size,
tag,
} => Node::AsyncLoad {
source: source.clone(),
destination: destination.clone(),
offset: Box::new(self.expr(offset)),
size: Box::new(self.expr(size)),
tag: self.ident(tag),
},
Node::AsyncStore {
source,
destination,
offset,
size,
tag,
} => Node::AsyncStore {
source: source.clone(),
destination: destination.clone(),
offset: Box::new(self.expr(offset)),
size: Box::new(self.expr(size)),
tag: self.ident(tag),
},
Node::AsyncWait { tag } => Node::AsyncWait {
tag: self.ident(tag),
},
Node::Trap { address, tag } => Node::Trap {
address: Box::new(self.expr(address)),
tag: self.ident(tag),
},
Node::Resume { tag } => Node::Resume {
tag: self.ident(tag),
},
Node::IndirectDispatch {
count_buffer,
count_offset,
} => Node::IndirectDispatch {
count_buffer: count_buffer.clone(),
count_offset: *count_offset,
},
Node::AllReduce { buffer, op, group } => Node::AllReduce {
buffer: buffer.clone(),
op: *op,
group: *group,
},
Node::AllGather {
input,
output,
group,
} => Node::AllGather {
input: input.clone(),
output: output.clone(),
group: *group,
},
Node::ReduceScatter {
input,
output,
op,
group,
} => Node::ReduceScatter {
input: input.clone(),
output: output.clone(),
op: *op,
group: *group,
},
Node::Broadcast {
buffer,
root,
group,
} => Node::Broadcast {
buffer: buffer.clone(),
root: *root,
group: *group,
},
Node::Return => Node::Return,
Node::Barrier { ordering } => Node::barrier_with_ordering(*ordering),
Node::Opaque(extension) => Node::Opaque(Arc::clone(extension)),
}
}
fn expr(&self, expr: &Expr) -> Expr {
match expr {
Expr::Var(name) => Expr::Var(self.ident(name)),
Expr::Load { buffer, index } => Expr::Load {
buffer: buffer.clone(),
index: Box::new(self.expr(index)),
},
Expr::BufLen { buffer } => Expr::BufLen {
buffer: buffer.clone(),
},
Expr::BinOp { op, left, right } => Expr::BinOp {
op: *op,
left: Box::new(self.expr(left)),
right: Box::new(self.expr(right)),
},
Expr::UnOp { op, operand } => Expr::UnOp {
op: op.clone(),
operand: Box::new(self.expr(operand)),
},
Expr::Call { op_id, args } => Expr::Call {
op_id: op_id.clone(),
args: args.iter().map(|arg| self.expr(arg)).collect(),
},
Expr::Select {
cond,
true_val,
false_val,
} => Expr::Select {
cond: Box::new(self.expr(cond)),
true_val: Box::new(self.expr(true_val)),
false_val: Box::new(self.expr(false_val)),
},
Expr::Cast { target, value } => Expr::Cast {
target: target.clone(),
value: Box::new(self.expr(value)),
},
Expr::Fma { a, b, c } => Expr::Fma {
a: Box::new(self.expr(a)),
b: Box::new(self.expr(b)),
c: Box::new(self.expr(c)),
},
Expr::Atomic {
op,
buffer,
index,
expected,
value,
ordering,
} => Expr::Atomic {
op: *op,
buffer: buffer.clone(),
index: Box::new(self.expr(index)),
expected: expected.as_ref().map(|expr| Box::new(self.expr(expr))),
value: Box::new(self.expr(value)),
ordering: *ordering,
},
Expr::SubgroupBallot { cond } => Expr::SubgroupBallot {
cond: Box::new(self.expr(cond)),
},
Expr::SubgroupShuffle { value, lane } => Expr::SubgroupShuffle {
value: Box::new(self.expr(value)),
lane: Box::new(self.expr(lane)),
},
Expr::SubgroupAdd { value } => Expr::SubgroupAdd {
value: Box::new(self.expr(value)),
},
Expr::LitU32(_)
| Expr::LitI32(_)
| Expr::LitF32(_)
| Expr::LitBool(_)
| Expr::InvocationId { .. }
| Expr::WorkgroupId { .. }
| Expr::LocalId { .. }
| Expr::SubgroupLocalId
| Expr::SubgroupSize
| Expr::Opaque(_) => expr.clone(),
}
}
}
pub(super) fn push_alpha_renamed_arm_entry_node(out: &mut Vec<Node>, node: &Node, arm_idx: usize) {
ArmRenamer::isolated(arm_idx).push_entry_node(out, node);
}
pub(super) fn multiply_declared_names(arm_entries: &[&[Node]]) -> FxHashSet<Ident> {
let mut decl_arms: rustc_hash::FxHashMap<Ident, usize> = rustc_hash::FxHashMap::default();
for entry in arm_entries {
let mut declared = FxHashSet::default();
for node in *entry {
collect_declared_names(node, &mut declared);
}
for name in declared {
*decl_arms.entry(name).or_insert(0) += 1;
}
}
decl_arms
.into_iter()
.filter_map(|(name, arms)| (arms >= 2).then_some(name))
.collect()
}
fn collect_declared_names(node: &Node, out: &mut FxHashSet<Ident>) {
match node {
Node::Let { name, .. } => {
out.insert(name.clone());
}
Node::Loop { var, body, .. } => {
out.insert(var.clone());
for n in body {
collect_declared_names(n, out);
}
}
Node::If {
then, otherwise, ..
} => {
for n in then.iter().chain(otherwise.iter()) {
collect_declared_names(n, out);
}
}
Node::Block(body) => {
for n in body {
collect_declared_names(n, out);
}
}
Node::Region { body, .. } => {
for n in body.iter() {
collect_declared_names(n, out);
}
}
_ => {}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::Ident;
fn names<const N: usize>(items: [&str; N]) -> FxHashSet<Ident> {
items.into_iter().map(Ident::from).collect()
}
#[test]
fn bare_temp_is_prefixed_once() {
let out = ArmRenamer::isolated(0).ident(&Ident::from("cmp_5"));
assert_eq!(out.as_str(), "__vyre_fuse_a0_cmp_5");
assert!(out.as_str().starts_with(FUSION_ARM_PREFIX));
}
#[test]
fn already_qualified_temp_is_not_re_prefixed() {
let inner = ArmRenamer::isolated(0).ident(&Ident::from("cmp_5"));
let outer = ArmRenamer::isolated(1).ident(&inner);
assert_eq!(
outer.as_str(),
inner.as_str(),
"fusion temp must not accumulate a second arm prefix"
);
let outer2 = ArmRenamer::isolated(2).ident(&outer);
assert_eq!(outer2.as_str(), inner.as_str());
}
#[test]
fn distinct_bare_temps_in_distinct_arms_stay_distinct() {
let a0 = ArmRenamer::isolated(0).ident(&Ident::from("x"));
let a1 = ArmRenamer::isolated(1).ident(&Ident::from("x"));
assert_ne!(a0.as_str(), a1.as_str());
}
#[test]
fn single_declared_name_is_left_unrenamed_in_every_arm() {
let multiply_declared = names([]); let producer = ArmRenamer::shared(1, &multiply_declared).ident(&Ident::from("__cmp_5"));
let consumer = ArmRenamer::shared(0, &multiply_declared).ident(&Ident::from("__cmp_5"));
assert_eq!(producer.as_str(), "__cmp_5");
assert_eq!(consumer.as_str(), producer.as_str());
}
#[test]
fn multiply_declared_name_is_renamed_in_shared_mode() {
let multiply_declared = names(["acc"]);
let a0 = ArmRenamer::shared(0, &multiply_declared).ident(&Ident::from("acc"));
let a1 = ArmRenamer::shared(1, &multiply_declared).ident(&Ident::from("acc"));
assert_eq!(a0.as_str(), "__vyre_fuse_a0_acc");
assert_ne!(a0.as_str(), a1.as_str());
}
#[test]
fn multiply_declared_names_counts_arms_not_occurrences() {
let arm0 = vec![
Node::Let {
name: Ident::from("acc"),
value: Expr::u32(0),
},
Node::Let {
name: Ident::from("__use"),
value: Expr::Var(Ident::from("__cmp_5")),
},
];
let arm1 = vec![
Node::Let {
name: Ident::from("acc"),
value: Expr::u32(1),
},
Node::Let {
name: Ident::from("__cmp_5"),
value: Expr::u32(2),
},
];
let entries: Vec<&[Node]> = vec![&arm0, &arm1];
let multi = multiply_declared_names(&entries);
assert!(
multi.contains(&Ident::from("acc")),
"a name declared in two arms must be renamed to avoid collision"
);
assert!(
!multi.contains(&Ident::from("__cmp_5")),
"a name declared in exactly one arm is unique and must stay linked"
);
assert!(!multi.contains(&Ident::from("__use")));
}
}