use crate::pass::Pass;
use rlx_ir::op::*;
use rlx_ir::*;
use std::collections::HashMap;
use crate::graph_rewrite::Rewriter;
pub struct FuseMatMulBiasAct;
fn fusible_mm_bias_epilogue_activation(act: Activation) -> bool {
matches!(act, Activation::Gelu | Activation::Silu)
}
impl Pass for FuseMatMulBiasAct {
fn name(&self) -> &str {
"fuse_matmul_bias_act"
}
fn run(&self, graph: Graph) -> Graph {
let mut rw = Rewriter::new(&graph.name);
let mut fused_away: HashMap<NodeId, ()> = HashMap::new();
for node in graph.nodes() {
if fused_away.contains_key(&node.id) {
continue;
}
if matches!(node.op, Op::MatMul) {
let mm_id = node.id;
let mm_users: Vec<_> = graph.users(mm_id);
if mm_users.len() == 1 {
let add_node = graph.node(mm_users[0]);
if let Op::Binary(BinaryOp::Add) = &add_node.op {
let (bias_id, _mm_input) = if add_node.inputs[0] == mm_id {
(add_node.inputs[1], add_node.inputs[0])
} else {
(add_node.inputs[0], add_node.inputs[1])
};
let bias_shape = graph.shape(bias_id);
if bias_shape.rank() <= 1 {
let add_id = add_node.id;
let add_users = graph.users(add_id);
let mut activation = None;
let mut act_id = None;
if add_users.len() == 1 {
let act_node = graph.node(add_users[0]);
if let Op::Activation(a) = &act_node.op
&& fusible_mm_bias_epilogue_activation(*a)
{
activation = Some(*a);
act_id = Some(act_node.id);
}
}
let out_shape = if let Some(aid) = act_id {
graph.shape(aid).clone()
} else {
add_node.shape.clone()
};
rw.ensure_mapped(&graph, &[node.inputs[0], node.inputs[1], bias_id]);
let fused_id = rw.add_fused(
Op::FusedMatMulBiasAct { activation },
&[node.inputs[0], node.inputs[1], bias_id],
out_shape,
);
rw.replace(mm_id, fused_id);
rw.replace(add_id, fused_id);
fused_away.insert(add_id, ());
if let Some(aid) = act_id {
rw.replace(aid, fused_id);
fused_away.insert(aid, ());
}
continue;
}
}
}
}
rw.copy_node(node);
}
rw.finish(&graph.outputs)
}
}
pub struct FuseResidualLN;
impl Pass for FuseResidualLN {
fn name(&self) -> &str {
"fuse_residual_ln"
}
fn run(&self, graph: Graph) -> Graph {
let mut is_output: HashMap<NodeId, ()> = HashMap::new();
for &oid in &graph.outputs {
is_output.insert(oid, ());
}
let mut fused_away: HashMap<NodeId, ()> = HashMap::new();
for node in graph.nodes() {
if let Op::LayerNorm { .. } = &node.op {
let ln_input_id = node.inputs[0];
let ln_input = graph.node(ln_input_id);
if matches!(ln_input.op, Op::Binary(BinaryOp::Add))
&& graph.use_count(ln_input_id) == 1
&& !is_output.contains_key(&ln_input_id)
{
fused_away.insert(ln_input_id, ());
}
}
}
let mut rw = Rewriter::new(&graph.name);
for node in graph.nodes() {
if fused_away.contains_key(&node.id) {
continue;
}
if let Op::LayerNorm { eps, .. } = &node.op {
let ln_input_id = node.inputs[0];
let ln_input = graph.node(ln_input_id);
if matches!(ln_input.op, Op::Binary(BinaryOp::Add))
&& fused_away.contains_key(&ln_input_id)
{
let (x_id, residual_id) = (ln_input.inputs[0], ln_input.inputs[1]);
let gamma_id = node.inputs[1];
let beta_id = node.inputs[2];
let fused_id = rw.add_fused(
Op::FusedResidualLN {
has_bias: false,
eps: *eps,
},
&[x_id, residual_id, gamma_id, beta_id],
node.shape.clone(),
);
rw.replace(ln_input_id, fused_id);
rw.replace(node.id, fused_id);
continue;
}
}
rw.copy_node(node);
}
rw.finish(&graph.outputs)
}
}
pub struct FuseResidualRmsNorm;
impl Pass for FuseResidualRmsNorm {
fn name(&self) -> &str {
"fuse_residual_rms_norm"
}
fn run(&self, graph: Graph) -> Graph {
let mut is_output: HashMap<NodeId, ()> = HashMap::new();
for &oid in &graph.outputs {
is_output.insert(oid, ());
}
let mut fused_away: HashMap<NodeId, ()> = HashMap::new();
for node in graph.nodes() {
if let Op::RmsNorm { .. } = &node.op {
let rn_input_id = node.inputs[0];
let rn_input = graph.node(rn_input_id);
if matches!(rn_input.op, Op::Binary(BinaryOp::Add))
&& graph.use_count(rn_input_id) == 1
&& !is_output.contains_key(&rn_input_id)
{
fused_away.insert(rn_input_id, ());
}
}
}
let mut rw = Rewriter::new(&graph.name);
for node in graph.nodes() {
if fused_away.contains_key(&node.id) {
continue;
}
if let Op::RmsNorm { eps, .. } = &node.op {
let rn_input_id = node.inputs[0];
let rn_input = graph.node(rn_input_id);
if matches!(rn_input.op, Op::Binary(BinaryOp::Add))
&& fused_away.contains_key(&rn_input_id)
{
let (x_id, residual_id) = (rn_input.inputs[0], rn_input.inputs[1]);
let gamma_id = node.inputs[1];
let beta_id = node.inputs[2];
let fused_id = rw.add_fused(
Op::FusedResidualRmsNorm {
has_bias: false,
eps: *eps,
},
&[x_id, residual_id, gamma_id, beta_id],
node.shape.clone(),
);
rw.replace(rn_input_id, fused_id);
rw.replace(node.id, fused_id);
continue;
}
}
rw.copy_node(node);
}
rw.finish(&graph.outputs)
}
}
pub struct FuseRmsNormReshape;
fn leading_flatten_shape(in_shape: &Shape, new_shape: &[i64]) -> Option<Shape> {
rlx_ir::shape::leading_flatten_shape(in_shape, new_shape)
}
fn sole_consumer(graph: &Graph, id: NodeId) -> Option<NodeId> {
graph
.nodes()
.iter()
.find(|n| n.inputs.contains(&id))
.map(|n| n.id)
}
impl Pass for FuseRmsNormReshape {
fn name(&self) -> &str {
"fuse_rms_norm_reshape"
}
fn run(&self, graph: Graph) -> Graph {
let mut is_output: HashMap<NodeId, ()> = HashMap::new();
for &oid in &graph.outputs {
is_output.insert(oid, ());
}
let mut flat_shape: HashMap<NodeId, Shape> = HashMap::new();
let mut fused_away: HashMap<NodeId, ()> = HashMap::new();
for node in graph.nodes() {
if let Op::RmsNorm { .. } = &node.op {
if graph.use_count(node.id) != 1 || is_output.contains_key(&node.id) {
continue;
}
let Some(reshape_id) = sole_consumer(&graph, node.id) else {
continue;
};
if is_output.contains_key(&reshape_id) {
continue;
}
let reshape = graph.node(reshape_id);
if let Op::Reshape { new_shape } = &reshape.op {
if let Some(flat) = leading_flatten_shape(&node.shape, new_shape) {
flat_shape.insert(node.id, flat);
fused_away.insert(reshape_id, ());
}
}
}
}
let mut rw = Rewriter::new(&graph.name);
for node in graph.nodes() {
if fused_away.contains_key(&node.id) {
continue;
}
if let Op::RmsNorm { axis, eps, .. } = &node.op {
if let Some(flat) = flat_shape.get(&node.id) {
let Some(reshape_id) = sole_consumer(&graph, node.id) else {
rw.copy_node(node);
continue;
};
let fused_id = rw.add_fused(
Op::RmsNorm {
axis: *axis,
eps: *eps,
},
&node.inputs,
flat.clone(),
);
rw.replace(node.id, fused_id);
rw.replace(reshape_id, fused_id);
continue;
}
}
rw.copy_node(node);
}
rw.finish(&graph.outputs)
}
}
pub struct FuseSwiGLUDualMatmul;
impl FuseSwiGLUDualMatmul {
fn match_dual_swiglu(
graph: &Graph,
mul_node: &Node,
) -> Option<(NodeId, NodeId, NodeId, NodeId, NodeId)> {
if !matches!(mul_node.op, Op::Binary(BinaryOp::Mul)) {
return None;
}
let lhs = graph.node(mul_node.inputs[0]);
let rhs = graph.node(mul_node.inputs[1]);
let (up_mm, silu_id, silu_node) = if matches!(rhs.op, Op::Activation(Activation::Silu)) {
(lhs, mul_node.inputs[1], rhs)
} else if matches!(lhs.op, Op::Activation(Activation::Silu)) {
(rhs, mul_node.inputs[0], lhs)
} else {
return None;
};
if !matches!(up_mm.op, Op::MatMul) {
return None;
}
let gate_mm = graph.node(silu_node.inputs[0]);
if !matches!(gate_mm.op, Op::MatMul) {
return None;
}
if up_mm.inputs[0] != gate_mm.inputs[0] {
return None;
}
if graph.use_count(silu_id) != 1 {
return None;
}
Some((mul_node.id, gate_mm.id, up_mm.id, up_mm.inputs[0], silu_id))
}
}
impl Pass for FuseSwiGLUDualMatmul {
fn name(&self) -> &str {
"fuse_swiglu_dual_matmul"
}
fn run(&self, graph: Graph) -> Graph {
let mut matches: Vec<(NodeId, NodeId, NodeId, NodeId, NodeId)> = Vec::new();
let mut consumed: HashMap<NodeId, ()> = HashMap::new();
for node in graph.nodes() {
if let Some((mul_id, gate_mm, up_mm, _, silu_id)) =
Self::match_dual_swiglu(&graph, node)
{
matches.push((mul_id, gate_mm, up_mm, graph.node(up_mm).inputs[0], silu_id));
consumed.insert(gate_mm, ());
consumed.insert(up_mm, ());
consumed.insert(silu_id, ());
}
}
if matches.is_empty() {
return graph;
}
let match_by_mul: HashMap<NodeId, (NodeId, NodeId, NodeId)> = matches
.into_iter()
.map(|(mul, gate, up, input, _silu)| (mul, (gate, up, input)))
.collect();
let mut rw = Rewriter::new(&graph.name);
for node in graph.nodes() {
if consumed.contains_key(&node.id) {
continue;
}
if let Some(&(gate_mm, up_mm, input_id)) = match_by_mul.get(&node.id) {
let gate = graph.node(gate_mm);
let up = graph.node(up_mm);
let wg = gate.inputs[1];
let wu = up.inputs[1];
rw.ensure_mapped(&graph, &[input_id, wg, wu]);
let wu_shape = graph.shape(wu);
let wg_shape = graph.shape(wg);
let k = wu_shape.dim(0).unwrap_static();
let n_up = wu_shape.dim(1).unwrap_static();
let n_gate = wg_shape.dim(1).unwrap_static();
debug_assert_eq!(wu_shape.dim(0), wg_shape.dim(0));
let concat_shape = Shape::new(&[k, n_up + n_gate], wu_shape.dtype());
let concat_w = rw.add_fused(Op::Concat { axis: 1 }, &[wu, wg], concat_shape);
let out_rank = up.shape.rank();
let mut mm_dims: Vec<Dim> = (0..out_rank).map(|i| up.shape.dim(i)).collect();
mm_dims[out_rank - 1] = Dim::Static(n_up + n_gate);
let cat_shape = Shape::from_dims(&mm_dims, up.shape.dtype());
let cat_id =
rw.new_graph
.add_node(Op::MatMul, vec![rw.map(input_id), concat_w], cat_shape);
let fused_id = rw.new_graph.add_node(
Op::FusedSwiGLU {
cast_to: None,
gate_first: false,
},
vec![cat_id],
node.shape.clone(),
);
rw.replace(node.id, fused_id);
continue;
}
rw.copy_node(node);
}
rw.finish(&graph.outputs)
}
}
pub struct FuseSharedInputMatMul;
impl Pass for FuseSharedInputMatMul {
fn name(&self) -> &str {
"fuse_shared_input_matmul"
}
fn run(&self, graph: Graph) -> Graph {
struct FuseGroup {
input_id: NodeId,
matmul_ids: Vec<NodeId>,
}
let mut input_to_matmuls: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
for node in graph.nodes() {
if matches!(node.op, Op::MatMul) {
input_to_matmuls
.entry(node.inputs[0])
.or_default()
.push(node.id);
}
}
let mut groups: Vec<FuseGroup> = Vec::new();
for (input_id, matmul_ids) in input_to_matmuls {
if matmul_ids.len() < 2 {
continue;
}
let first = graph.node(matmul_ids[0]);
let w0 = graph.shape(first.inputs[1]);
if w0.rank() != 2 {
continue;
}
let compatible = matmul_ids.iter().all(|&id| {
let m = graph.node(id);
matches!(m.op, Op::MatMul)
&& graph.shape(m.inputs[1]).rank() == 2
&& graph.shape(m.inputs[1]).dim(0) == w0.dim(0)
});
if compatible {
groups.push(FuseGroup {
input_id,
matmul_ids,
});
}
}
if groups.is_empty() {
return graph;
}
let group_by_first: HashMap<NodeId, &FuseGroup> =
groups.iter().map(|g| (g.matmul_ids[0], g)).collect();
let mut fused_away: HashMap<NodeId, ()> = HashMap::new();
for g in &groups {
for &id in &g.matmul_ids[1..] {
fused_away.insert(id, ());
}
}
let mut rw = Rewriter::new(&graph.name);
for node in graph.nodes() {
if fused_away.contains_key(&node.id) {
continue;
}
if let Some(group) = group_by_first.get(&node.id) {
let matmuls: Vec<_> = group.matmul_ids.iter().map(|&id| graph.node(id)).collect();
let weight_ids: Vec<NodeId> = matmuls.iter().map(|m| m.inputs[1]).collect();
rw.ensure_mapped(&graph, std::slice::from_ref(&group.input_id));
rw.ensure_mapped(&graph, &weight_ids);
let w0_shape = graph.shape(weight_ids[0]);
let k = w0_shape.dim(0).unwrap_static();
let ns: Vec<usize> = weight_ids
.iter()
.map(|&w| graph.shape(w).dim(1).unwrap_static())
.collect();
let combined_n: usize = ns.iter().sum();
let concat_shape = Shape::new(&[k, combined_n], w0_shape.dtype());
let concat_id = rw.add_fused(Op::Concat { axis: 1 }, &weight_ids, concat_shape);
let out_rank = matmuls[0].shape.rank();
let mut mm_dims: Vec<Dim> =
(0..out_rank).map(|i| matmuls[0].shape.dim(i)).collect();
mm_dims[out_rank - 1] = Dim::Static(combined_n);
let mm_shape = Shape::from_dims(&mm_dims, matmuls[0].shape.dtype());
let mm_id = rw.new_graph.add_node(
Op::MatMul,
vec![rw.map(group.input_id), concat_id],
mm_shape,
);
let mut start = 0usize;
for (mm, &n) in matmuls.iter().zip(&ns) {
let narrow = rw.new_graph.add_node(
Op::Narrow {
axis: out_rank - 1,
start,
len: n,
},
vec![mm_id],
mm.shape.clone(),
);
rw.replace(mm.id, narrow);
start += n;
}
continue;
}
rw.copy_node(node);
}
rw.finish(&graph.outputs)
}
}
pub struct FuseSwiGLU;
impl Pass for FuseSwiGLU {
fn name(&self) -> &str {
"fuse_swiglu"
}
fn run(&self, graph: Graph) -> Graph {
#[allow(dead_code)]
struct Match {
mul_id: NodeId,
up_narrow_id: NodeId,
silu_id: NodeId,
gate_narrow_id: NodeId,
cat_id: NodeId,
out_n: usize,
gate_first: bool,
}
let mut matches: Vec<Match> = Vec::new();
let mut consumed: HashMap<NodeId, ()> = HashMap::new();
for node in graph.nodes() {
if !matches!(node.op, Op::Binary(BinaryOp::Mul)) {
continue;
}
let lhs_id = node.inputs[0];
let rhs_id = node.inputs[1];
let lhs = graph.node(lhs_id);
let rhs = graph.node(rhs_id);
let (up_narrow, silu_id, silu_node) =
if matches!(rhs.op, Op::Activation(Activation::Silu)) {
(lhs, rhs_id, rhs)
} else if matches!(lhs.op, Op::Activation(Activation::Silu)) {
(rhs, lhs_id, lhs)
} else {
continue;
};
let (up_axis, up_start, up_len) = match &up_narrow.op {
Op::Narrow { axis, start, len } => (*axis, *start, *len),
_ => continue,
};
let gate_narrow_id = silu_node.inputs[0];
let gate_narrow = graph.node(gate_narrow_id);
let (g_axis, g_start, g_len) = match &gate_narrow.op {
Op::Narrow { axis, start, len } => (*axis, *start, *len),
_ => continue,
};
if up_narrow.inputs[0] != gate_narrow.inputs[0] {
continue;
}
if up_axis != g_axis {
continue;
}
if up_len != g_len {
continue;
}
let n = up_len;
let gate_first = up_start == n && g_start == 0;
if !(gate_first || (up_start == 0 && g_start == n)) {
continue;
}
if graph.use_count(up_narrow.id) != 1 {
continue;
}
if graph.use_count(gate_narrow_id) != 1 {
continue;
}
if graph.use_count(silu_id) != 1 {
continue;
}
matches.push(Match {
mul_id: node.id,
up_narrow_id: up_narrow.id,
silu_id,
gate_narrow_id,
cat_id: up_narrow.inputs[0],
out_n: n,
gate_first,
});
consumed.insert(up_narrow.id, ());
consumed.insert(gate_narrow_id, ());
consumed.insert(silu_id, ());
}
if matches.is_empty() {
return graph;
}
let mut rw = Rewriter::new(&graph.name);
let match_by_mul: HashMap<NodeId, &Match> = matches.iter().map(|m| (m.mul_id, m)).collect();
for node in graph.nodes() {
if consumed.contains_key(&node.id) {
continue;
}
if let Some(m) = match_by_mul.get(&node.id) {
let out_shape = node.shape.clone();
debug_assert_eq!(
out_shape.dim(out_shape.rank() - 1).unwrap_static(),
m.out_n,
"FuseSwiGLU: output last dim should be N"
);
let fused_id = rw.add_fused(
Op::FusedSwiGLU {
cast_to: None,
gate_first: m.gate_first,
},
&[m.cat_id],
out_shape,
);
rw.replace(node.id, fused_id);
continue;
}
rw.copy_node(node);
}
rw.finish(&graph.outputs)
}
}
pub struct FuseAttentionBlock;
impl FuseAttentionBlock {
fn should_fuse(graph: &Graph) -> bool {
let threshold: usize = rlx_ir::env::var("RLX_FUSE_ATTN_THRESHOLD")
.and_then(|v| v.parse().ok())
.unwrap_or(64);
for node in graph.nodes() {
if let Op::Input { .. } = &node.op
&& node.shape.rank() >= 2
{
let d0 = node.shape.dim(0);
let d1 = node.shape.dim(1);
if d0.is_static() && d1.is_static() {
let b = d0.unwrap_static();
let s = d1.unwrap_static();
if b * s <= threshold {
return true;
}
}
}
}
false
}
}
fn narrow_parent(node: &Node) -> Option<(NodeId, usize, usize, usize)> {
match &node.op {
Op::Narrow { axis, start, len } => Some((node.inputs[0], *axis, *start, *len)),
_ => None,
}
}
fn fused_mm_bias_none(node: &Node) -> Option<(NodeId, NodeId, NodeId)> {
if let Op::FusedMatMulBiasAct { activation: None } = &node.op
&& node.inputs.len() == 3
{
return Some((node.inputs[0], node.inputs[1], node.inputs[2]));
}
None
}
impl Pass for FuseAttentionBlock {
fn name(&self) -> &str {
"fuse_attention_block"
}
fn run(&self, graph: Graph) -> Graph {
if !Self::should_fuse(&graph) {
return graph;
}
let mut is_output: HashMap<NodeId, ()> = HashMap::new();
for &oid in &graph.outputs {
is_output.insert(oid, ());
}
struct Match {
attn_id: NodeId,
qkv_mm_id: NodeId,
out_mm_id: NodeId,
narrows: [NodeId; 3],
hidden_id: NodeId,
qkv_w: NodeId,
qkv_b: NodeId,
out_w: NodeId,
out_b: NodeId,
mask: NodeId,
num_heads: usize,
head_dim: usize,
out_shape: Shape,
}
let mut matches: Vec<Match> = Vec::new();
let mut fused_away: HashMap<NodeId, ()> = HashMap::new();
for node in graph.nodes() {
let Op::Attention {
num_heads,
head_dim,
mask_kind,
score_scale,
attn_logit_softcap,
} = &node.op
else {
continue;
};
if !matches!(mask_kind, MaskKind::Custom)
|| score_scale.is_some()
|| attn_logit_softcap.is_some()
|| node.inputs.len() != 4
{
continue;
}
let (q, k, v, mask) = (
node.inputs[0],
node.inputs[1],
node.inputs[2],
node.inputs[3],
);
let qn = graph.node(q);
let kn = graph.node(k);
let vn = graph.node(v);
let (qp, q_axis, q_start, q_len) = match narrow_parent(qn) {
Some(p) => p,
None => continue,
};
let (kp, k_axis, k_start, k_len) = match narrow_parent(kn) {
Some(p) => p,
None => continue,
};
let (vp, v_axis, v_start, v_len) = match narrow_parent(vn) {
Some(p) => p,
None => continue,
};
if qp != kp || kp != vp {
continue;
}
let h = num_heads * head_dim;
let parent_rank = graph.node(qp).shape.rank();
let last_ax = parent_rank.saturating_sub(1);
if q_axis != last_ax || k_axis != last_ax || v_axis != last_ax {
continue;
}
if q_len != h || k_len != h || v_len != h {
continue;
}
if q_start != 0 || k_start != h || v_start != 2 * h {
continue;
}
if graph.use_count(q) != 1
|| graph.use_count(k) != 1
|| graph.use_count(v) != 1
|| is_output.contains_key(&q)
|| is_output.contains_key(&k)
|| is_output.contains_key(&v)
{
continue;
}
let qkv_mm_node = graph.node(qp);
let (hidden_id, qkv_w, qkv_b) = match fused_mm_bias_none(qkv_mm_node) {
Some(t) => t,
None => continue,
};
if graph.use_count(qp) != 3 || is_output.contains_key(&qp) {
continue;
}
if graph.use_count(node.id) != 1 || is_output.contains_key(&node.id) {
continue;
}
let out_consumer_id = match graph
.nodes()
.iter()
.find(|n| n.inputs.contains(&node.id))
.map(|n| n.id)
{
Some(id) => id,
None => continue,
};
let out_mm_node = graph.node(out_consumer_id);
let (out_in, out_w, out_b) = match fused_mm_bias_none(out_mm_node) {
Some(t) if t.0 == node.id => t,
_ => continue,
};
let _ = out_in;
matches.push(Match {
attn_id: node.id,
qkv_mm_id: qp,
out_mm_id: out_consumer_id,
narrows: [q, k, v],
hidden_id,
qkv_w,
qkv_b,
out_w,
out_b,
mask,
num_heads: *num_heads,
head_dim: *head_dim,
out_shape: out_mm_node.shape.clone(),
});
fused_away.insert(qp, ());
fused_away.insert(q, ());
fused_away.insert(k, ());
fused_away.insert(v, ());
fused_away.insert(node.id, ());
fused_away.insert(out_consumer_id, ());
}
if matches.is_empty() {
return graph;
}
let mut by_out: HashMap<NodeId, &Match> = HashMap::new();
for m in &matches {
by_out.insert(m.out_mm_id, m);
}
let mut rw = Rewriter::new(&graph.name);
for node in graph.nodes() {
if fused_away.contains_key(&node.id) {
if let Some(m) = by_out.get(&node.id) {
rw.ensure_mapped(
&graph,
&[m.hidden_id, m.qkv_w, m.out_w, m.mask, m.qkv_b, m.out_b],
);
let fused_id = rw.add_fused(
Op::FusedAttentionBlock {
num_heads: m.num_heads,
head_dim: m.head_dim,
has_bias: true,
has_rope: false,
},
&[m.hidden_id, m.qkv_w, m.out_w, m.mask, m.qkv_b, m.out_b],
m.out_shape.clone(),
);
rw.replace(m.qkv_mm_id, fused_id);
rw.replace(m.narrows[0], fused_id);
rw.replace(m.narrows[1], fused_id);
rw.replace(m.narrows[2], fused_id);
rw.replace(m.attn_id, fused_id);
rw.replace(node.id, fused_id);
}
continue;
}
rw.copy_node(node);
}
rw.finish(&graph.outputs)
}
}
pub struct FuseTransformerLayer;
impl FuseTransformerLayer {
fn should_fuse(graph: &Graph) -> bool {
FuseAttentionBlock::should_fuse(graph)
}
}
fn fused_residual_ln_no_bias(node: &Node) -> Option<(NodeId, NodeId, NodeId, NodeId, f32)> {
if let Op::FusedResidualLN {
has_bias: false,
eps,
} = &node.op
&& node.inputs.len() == 4
{
return Some((
node.inputs[0],
node.inputs[1],
node.inputs[2],
node.inputs[3],
*eps,
));
}
None
}
fn fused_mm_bias_act(node: &Node) -> Option<(NodeId, NodeId, NodeId, Activation)> {
if let Op::FusedMatMulBiasAct {
activation: Some(a),
} = &node.op
&& node.inputs.len() == 3
{
return Some((node.inputs[0], node.inputs[1], node.inputs[2], *a));
}
None
}
fn fused_attn_block_bert(
node: &Node,
) -> Option<(usize, usize, NodeId, NodeId, NodeId, NodeId, NodeId, NodeId)> {
if let Op::FusedAttentionBlock {
num_heads,
head_dim,
has_bias: true,
has_rope: false,
} = &node.op
&& node.inputs.len() == 6
{
return Some((
*num_heads,
*head_dim,
node.inputs[0],
node.inputs[1],
node.inputs[2],
node.inputs[3],
node.inputs[4],
node.inputs[5],
));
}
None
}
impl Pass for FuseTransformerLayer {
fn name(&self) -> &str {
"fuse_transformer_layer"
}
fn run(&self, graph: Graph) -> Graph {
if !Self::should_fuse(&graph) {
return graph;
}
let mut is_output: HashMap<NodeId, ()> = HashMap::new();
for &oid in &graph.outputs {
is_output.insert(oid, ());
}
struct LayerMatch {
attn_id: NodeId,
ln1_id: NodeId,
fc1_id: NodeId,
fc2_id: NodeId,
ln2_id: NodeId,
inputs: [NodeId; 14],
num_heads: usize,
head_dim: usize,
intermediate_size: usize,
eps1: f32,
eps2: f32,
activation: Activation,
out_shape: Shape,
}
let mut matches: Vec<LayerMatch> = Vec::new();
let mut fused_away: HashMap<NodeId, ()> = HashMap::new();
for node in graph.nodes() {
let Some((num_heads, head_dim, hidden_id, qkv_w, out_w, mask, qkv_b, out_b)) =
fused_attn_block_bert(node)
else {
continue;
};
let attn_id = node.id;
if graph.use_count(attn_id) != 1 || is_output.contains_key(&attn_id) {
continue;
}
let ln1_id = match graph
.nodes()
.iter()
.find(|n| n.inputs.contains(&attn_id))
.map(|n| n.id)
{
Some(id) => id,
None => continue,
};
let ln1_node = graph.node(ln1_id);
let Some((ln1_x, ln1_res, ln1_g, ln1_b, eps1)) = fused_residual_ln_no_bias(ln1_node)
else {
continue;
};
if ln1_x != attn_id || ln1_res != hidden_id {
continue;
}
if graph.use_count(ln1_id) != 2 || is_output.contains_key(&ln1_id) {
continue;
}
let mut fc1_candidate: Option<NodeId> = None;
let mut ln2_candidate: Option<NodeId> = None;
for cn in graph.nodes() {
if !cn.inputs.contains(&ln1_id) {
continue;
}
if fused_mm_bias_act(cn).is_some() && cn.inputs[0] == ln1_id {
fc1_candidate = Some(cn.id);
} else if fused_residual_ln_no_bias(cn).is_some() && cn.inputs[1] == ln1_id {
ln2_candidate = Some(cn.id);
}
}
let (Some(fc1_id), Some(ln2_id)) = (fc1_candidate, ln2_candidate) else {
continue;
};
let fc1_node = graph.node(fc1_id);
let Some((_, fc1_w, fc1_b, activation)) = fused_mm_bias_act(fc1_node) else {
continue;
};
if graph.use_count(fc1_id) != 1 || is_output.contains_key(&fc1_id) {
continue;
}
let fc2_id = match graph
.nodes()
.iter()
.find(|n| n.inputs.contains(&fc1_id))
.map(|n| n.id)
{
Some(id) => id,
None => continue,
};
let fc2_node = graph.node(fc2_id);
let Some((fc2_in, fc2_w, fc2_b)) = fused_mm_bias_none(fc2_node) else {
continue;
};
if fc2_in != fc1_id {
continue;
}
if graph.use_count(fc2_id) != 1 || is_output.contains_key(&fc2_id) {
continue;
}
let ln2_node = graph.node(ln2_id);
let Some((ln2_x, ln2_res, ln2_g, ln2_b, eps2)) = fused_residual_ln_no_bias(ln2_node)
else {
continue;
};
if ln2_x != fc2_id || ln2_res != ln1_id {
continue;
}
let intermediate_size = {
let s = &graph.node(fc1_w).shape;
if s.rank() != 2 {
continue;
}
let d = s.dim(s.rank() - 1);
if !d.is_static() {
continue;
}
d.unwrap_static()
};
matches.push(LayerMatch {
attn_id,
ln1_id,
fc1_id,
fc2_id,
ln2_id,
inputs: [
hidden_id, qkv_w, qkv_b, out_w, out_b, ln1_g, ln1_b, fc1_w, fc1_b, fc2_w,
fc2_b, ln2_g, ln2_b, mask,
],
num_heads,
head_dim,
intermediate_size,
eps1,
eps2,
activation,
out_shape: ln2_node.shape.clone(),
});
fused_away.insert(attn_id, ());
fused_away.insert(ln1_id, ());
fused_away.insert(fc1_id, ());
fused_away.insert(fc2_id, ());
fused_away.insert(ln2_id, ());
}
if matches.is_empty() {
return graph;
}
let mut by_terminal: HashMap<NodeId, &LayerMatch> = HashMap::new();
for m in &matches {
by_terminal.insert(m.ln2_id, m);
}
let mut rw = Rewriter::new(&graph.name);
for node in graph.nodes() {
if fused_away.contains_key(&node.id) {
if let Some(m) = by_terminal.get(&node.id) {
rw.ensure_mapped(&graph, &m.inputs);
let fused_id = rw.add_fused(
Op::FusedTransformerLayer {
num_heads: m.num_heads,
head_dim: m.head_dim,
intermediate_size: m.intermediate_size,
eps1: m.eps1,
eps2: m.eps2,
activation: m.activation,
has_bias: true,
},
&m.inputs,
m.out_shape.clone(),
);
rw.replace(m.attn_id, fused_id);
rw.replace(m.ln1_id, fused_id);
rw.replace(m.fc1_id, fused_id);
rw.replace(m.fc2_id, fused_id);
rw.replace(node.id, fused_id);
}
continue;
}
rw.copy_node(node);
}
rw.finish(&graph.outputs)
}
}
pub struct MarkElementwiseRegions;
impl Pass for MarkElementwiseRegions {
fn name(&self) -> &str {
"mark_elementwise_regions"
}
fn run(&self, graph: Graph) -> Graph {
let mut consumers: HashMap<NodeId, usize> = HashMap::new();
for node in graph.nodes() {
for &input in &node.inputs {
*consumers.entry(input).or_insert(0) += 1;
}
}
for &out in &graph.outputs {
*consumers.entry(out).or_insert(0) += 1;
}
let chain_eligible = |op: &Op| -> bool {
matches!(
op,
Op::Activation(_) | Op::Cast { .. } | Op::Binary(_) | Op::Compare(_) | Op::Where
)
};
let chain_step_safe = |graph: &Graph, node: &rlx_ir::Node| -> bool {
match &node.op {
Op::Cast { to } => {
let in_dt = graph.shape(node.inputs[0]).dtype();
*to == in_dt
}
_ => true,
}
};
let mut region_of: HashMap<NodeId, NodeId> = HashMap::new();
let mut chain_step_idx: HashMap<NodeId, u32> = HashMap::new();
for node in graph.nodes() {
if !chain_eligible(&node.op) {
continue;
}
if !chain_step_safe(&graph, node) {
continue;
}
let out_shape = &node.shape;
let out_elems = out_shape.num_elements();
let shape_ok = node.inputs.iter().all(|id| {
let in_elems = graph.shape(*id).num_elements();
match (in_elems, out_elems) {
(Some(i), Some(o)) if i == o => true,
(Some(i), Some(o)) if i > 0 && o % i == 0 => true,
_ => false,
}
});
if !shape_ok {
continue;
}
let mut parent_root: Option<NodeId> = None;
let mut all_inputs_single_consumer = true;
for &input in &node.inputs {
if graph.node(input).op.is_fusion_boundary() {
parent_root = None;
all_inputs_single_consumer = false;
break;
}
if let Some(&root) = region_of.get(&input) {
if consumers.get(&input).copied() != Some(1) {
all_inputs_single_consumer = false;
break;
}
match parent_root {
None => parent_root = Some(root),
Some(r) if r == root => {}
Some(_) => {
parent_root = None;
all_inputs_single_consumer = false;
break;
}
}
}
}
if !all_inputs_single_consumer {
region_of.insert(node.id, node.id);
chain_step_idx.insert(node.id, 0);
continue;
}
let root = parent_root.unwrap_or(node.id);
let next_idx = node
.inputs
.iter()
.filter_map(|id| {
if region_of.get(id) == Some(&root) {
chain_step_idx.get(id).copied()
} else {
None
}
})
.max()
.map(|m| m + 1)
.unwrap_or(0);
let limits = crate::limits::active_fusion_limits();
if next_idx >= limits.max_elementwise_steps {
region_of.insert(node.id, node.id);
chain_step_idx.insert(node.id, 0);
continue;
}
region_of.insert(node.id, root);
chain_step_idx.insert(node.id, next_idx);
}
let mut by_region: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
for node in graph.nodes() {
if let Some(&root) = region_of.get(&node.id) {
by_region.entry(root).or_default().push(node.id);
}
}
let mut tail_of_region: HashMap<NodeId, NodeId> = HashMap::new();
for (root, members) in &by_region {
if members.len() < 2 {
continue;
}
let max_idx = members.iter().map(|id| chain_step_idx[id]).max().unwrap();
let tails: Vec<_> = members
.iter()
.filter(|id| chain_step_idx[id] == max_idx)
.collect();
if tails.len() != 1 {
continue;
}
tail_of_region.insert(*root, *tails[0]);
}
let by_region: HashMap<NodeId, Vec<NodeId>> = by_region
.into_iter()
.filter(|(root, _)| tail_of_region.contains_key(root))
.collect();
if by_region.is_empty() {
return graph;
}
let mut rw = Rewriter::new(&graph.name);
let mut emitted_region: HashMap<NodeId, NodeId> = HashMap::new();
for node in graph.nodes() {
if let Some(&root) = region_of.get(&node.id)
&& let Some(&tail) = tail_of_region.get(&root)
{
if emitted_region.contains_key(&root) {
let region_new = emitted_region[&root];
rw.replace(node.id, region_new);
continue;
}
if node.id == tail {
let members = &by_region[&root];
let mut ordered: Vec<NodeId> = members.clone();
ordered.sort_by_key(|id| chain_step_idx[id]);
let mut external_inputs: Vec<NodeId> = Vec::new();
let mut input_idx_of: HashMap<NodeId, u32> = HashMap::new();
let mut step_idx_of: HashMap<NodeId, u32> = HashMap::new();
for (i, member_id) in ordered.iter().enumerate() {
step_idx_of.insert(*member_id, i as u32);
let n = graph.node(*member_id);
for &inp in &n.inputs {
if !step_idx_of.contains_key(&inp) && !input_idx_of.contains_key(&inp) {
let idx = external_inputs.len() as u32;
input_idx_of.insert(inp, idx);
external_inputs.push(inp);
}
}
}
let limits = crate::limits::active_fusion_limits();
if external_inputs.len() as u32 > limits.max_elementwise_inputs
|| ordered.len() as u32 > limits.max_elementwise_steps
{
for &mid in &ordered {
rw.copy_node(graph.node(mid));
}
continue;
}
let resolve = |id: NodeId| -> ChainOperand {
if let Some(&i) = input_idx_of.get(&id) {
ChainOperand::Input(i)
} else {
ChainOperand::Step(step_idx_of[&id])
}
};
let mut chain: Vec<ChainStep> = Vec::with_capacity(ordered.len());
for member_id in &ordered {
let n = graph.node(*member_id);
let step = match &n.op {
Op::Activation(a) => ChainStep::Activation(*a, resolve(n.inputs[0])),
Op::Cast { to } => ChainStep::Cast(*to, resolve(n.inputs[0])),
Op::Binary(op) => {
ChainStep::Binary(*op, resolve(n.inputs[0]), resolve(n.inputs[1]))
}
Op::Compare(op) => {
ChainStep::Compare(*op, resolve(n.inputs[0]), resolve(n.inputs[1]))
}
Op::Where => ChainStep::Where(
resolve(n.inputs[0]),
resolve(n.inputs[1]),
resolve(n.inputs[2]),
),
_ => unreachable!("non-chain-eligible op in region"),
};
chain.push(step);
}
let mut scalar_input_mask: u32 = 0;
let mut input_modulus = [0u32; 16];
let region_shape_elems = graph.node(tail).shape.num_elements();
for (i, &ext) in external_inputs.iter().enumerate() {
if i >= 16 {
break;
}
let in_elems = graph.shape(ext).num_elements();
match (in_elems, region_shape_elems) {
(Some(1), Some(o)) if o != 1 => {
scalar_input_mask |= 1u32 << i;
input_modulus[i] = 1;
}
(Some(i_n), Some(o)) if i_n != o && i_n > 0 => {
input_modulus[i] = i_n as u32;
}
_ => { }
}
}
let region_new = rw.add_fused(
Op::ElementwiseRegion {
chain,
num_inputs: external_inputs.len() as u32,
scalar_input_mask,
input_modulus,
prologue: RegionPrologue::None,
prologue_input: 0,
},
&external_inputs,
graph.node(tail).shape.clone(),
);
emitted_region.insert(root, region_new);
rw.replace(node.id, region_new);
continue;
} else {
rw.replace(node.id, NodeId(u32::MAX)); continue;
}
}
rw.copy_node(node);
}
rw.finish(&graph.outputs)
}
}
pub struct UnfuseElementwiseRegions {
pub unfuse_prologue: bool,
}
impl UnfuseElementwiseRegions {
pub const FOR_GPU: UnfuseElementwiseRegions = UnfuseElementwiseRegions {
unfuse_prologue: false,
};
pub const FOR_CPU: UnfuseElementwiseRegions = UnfuseElementwiseRegions {
unfuse_prologue: true,
};
}
impl Pass for UnfuseElementwiseRegions {
fn name(&self) -> &str {
"unfuse_elementwise_regions"
}
fn run(&self, graph: Graph) -> Graph {
let any_region = graph
.nodes()
.iter()
.any(|n| matches!(n.op, Op::ElementwiseRegion { .. }));
if !any_region {
return graph;
}
let mut rw = Rewriter::new(&graph.name);
for node in graph.nodes() {
if let Op::ElementwiseRegion {
chain,
num_inputs: _,
scalar_input_mask: _,
input_modulus: _,
prologue,
prologue_input: _,
} = &node.op
{
if *prologue != RegionPrologue::None && !self.unfuse_prologue {
rw.copy_node(node);
continue;
}
let mut region_inputs: Vec<NodeId> =
node.inputs.iter().map(|id| rw.map(*id)).collect();
if *prologue == RegionPrologue::ResizeNearest2x {
let in_shape = rw.new_graph.node(region_inputs[0]).shape.clone();
let out_shape = if in_shape.rank() == 4 {
Shape::new(
&[
in_shape.dim(0).unwrap_static(),
in_shape.dim(1).unwrap_static(),
in_shape.dim(2).unwrap_static() * 2,
in_shape.dim(3).unwrap_static() * 2,
],
in_shape.dtype(),
)
} else {
node.shape.clone()
};
region_inputs[0] = rw.new_graph.add_node(
Op::ResizeNearest2x,
vec![region_inputs[0]],
out_shape,
);
}
let mut step_ids: Vec<NodeId> = Vec::with_capacity(chain.len());
let region_shape = node.shape.clone();
let region_dims: Vec<_> = region_shape.dims().to_vec();
let mut step_dtypes: Vec<rlx_ir::DType> = Vec::with_capacity(chain.len());
let region_dtype = region_shape.dtype();
let dtype_of = |op: &ChainOperand,
ins: &[NodeId],
step_dt: &[rlx_ir::DType],
rw: &Rewriter|
-> rlx_ir::DType {
match *op {
ChainOperand::Input(i) => rw.new_graph.node(ins[i as usize]).shape.dtype(),
ChainOperand::Step(i) => step_dt[i as usize],
}
};
let shape_of = |op: &ChainOperand,
ins: &[NodeId],
step_ids: &[NodeId],
rw: &Rewriter|
-> Shape {
match *op {
ChainOperand::Input(i) => rw.new_graph.node(ins[i as usize]).shape.clone(),
ChainOperand::Step(i) => {
rw.new_graph.node(step_ids[i as usize]).shape.clone()
}
}
};
for step in chain {
let resolve = |op: &ChainOperand| -> NodeId {
match *op {
ChainOperand::Input(i) => region_inputs[i as usize],
ChainOperand::Step(i) => step_ids[i as usize],
}
};
let (new_id, dt) = match step {
ChainStep::Activation(a, src) => {
let s = resolve(src);
let dt = dtype_of(src, ®ion_inputs, &step_dtypes, &rw);
let src_shape = shape_of(src, ®ion_inputs, &step_ids, &rw);
let dims: Vec<_> = src_shape.dims().to_vec();
let shape = Shape::from_dims(&dims, dt);
(
rw.new_graph.add_node(Op::Activation(*a), vec![s], shape),
dt,
)
}
ChainStep::Cast(to, src) => {
let s = resolve(src);
let src_shape = shape_of(src, ®ion_inputs, &step_ids, &rw);
let dims: Vec<_> = src_shape.dims().to_vec();
let shape = Shape::from_dims(&dims, *to);
(
rw.new_graph.add_node(Op::Cast { to: *to }, vec![s], shape),
*to,
)
}
ChainStep::Binary(op, lhs, rhs) => {
let l = resolve(lhs);
let r = resolve(rhs);
let dt = dtype_of(lhs, ®ion_inputs, &step_dtypes, &rw);
let lhs_shape = shape_of(lhs, ®ion_inputs, &step_ids, &rw);
let rhs_shape = shape_of(rhs, ®ion_inputs, &step_ids, &rw);
let bcast = rlx_ir::shape::broadcast(&lhs_shape, &rhs_shape)
.unwrap_or_else(|e| {
panic!(
"unfuse_elementwise_regions: cannot broadcast \
{lhs_shape:?} ⊗ {rhs_shape:?} for Binary({op:?}): {e}"
)
});
let dims: Vec<_> = bcast.dims().to_vec();
let shape = Shape::from_dims(&dims, dt);
(
rw.new_graph.add_node(Op::Binary(*op), vec![l, r], shape),
dt,
)
}
ChainStep::Compare(op, lhs, rhs) => {
let l = resolve(lhs);
let r = resolve(rhs);
let lhs_shape = shape_of(lhs, ®ion_inputs, &step_ids, &rw);
let rhs_shape = shape_of(rhs, ®ion_inputs, &step_ids, &rw);
let bcast = rlx_ir::shape::broadcast(&lhs_shape, &rhs_shape)
.unwrap_or_else(|e| {
panic!(
"unfuse_elementwise_regions: cannot broadcast \
{lhs_shape:?} ⊗ {rhs_shape:?} for Compare({op:?}): {e}"
)
});
let dims: Vec<_> = bcast.dims().to_vec();
let shape = Shape::from_dims(&dims, rlx_ir::DType::Bool);
(
rw.new_graph.add_node(Op::Compare(*op), vec![l, r], shape),
rlx_ir::DType::Bool,
)
}
ChainStep::Where(c, x, y) => {
let cn = resolve(c);
let xn = resolve(x);
let yn = resolve(y);
let dt = dtype_of(x, ®ion_inputs, &step_dtypes, &rw);
let c_shape = shape_of(c, ®ion_inputs, &step_ids, &rw);
let x_shape = shape_of(x, ®ion_inputs, &step_ids, &rw);
let y_shape = shape_of(y, ®ion_inputs, &step_ids, &rw);
let bcast_xy = rlx_ir::shape::broadcast(&x_shape, &y_shape)
.unwrap_or_else(|e| {
panic!(
"unfuse_elementwise_regions: cannot broadcast \
then/else {x_shape:?} ⊗ {y_shape:?} for Where: {e}"
)
});
let bcast = rlx_ir::shape::broadcast(&c_shape, &bcast_xy)
.unwrap_or_else(|e| {
panic!(
"unfuse_elementwise_regions: cannot broadcast cond \
{c_shape:?} ⊗ {bcast_xy:?} for Where: {e}"
)
});
let dims: Vec<_> = bcast.dims().to_vec();
let shape = Shape::from_dims(&dims, dt);
(
rw.new_graph.add_node(Op::Where, vec![cn, xn, yn], shape),
dt,
)
}
};
step_ids.push(new_id);
step_dtypes.push(dt);
}
let _ = region_dtype;
let _ = region_dims;
let last = *step_ids.last().expect("chain non-empty per pass invariant");
rw.replace(node.id, last);
continue;
}
rw.copy_node(node);
}
rw.finish(&graph.outputs)
}
}
pub fn clip_elementwise_regions(graph: Graph, limits: crate::limits::FusionLimits) -> Graph {
let oversize = |n: &rlx_ir::Node| -> bool {
matches!(
&n.op,
Op::ElementwiseRegion {
chain,
num_inputs,
..
} if *num_inputs > limits.max_elementwise_inputs
|| chain.len() as u32 > limits.max_elementwise_steps
)
};
if !graph.nodes().iter().any(oversize) {
return graph;
}
let mut rw = Rewriter::new(&graph.name);
for node in graph.nodes() {
if !oversize(node) {
rw.copy_node(node);
continue;
}
let Op::ElementwiseRegion {
chain,
num_inputs: _,
scalar_input_mask: _,
input_modulus: _,
prologue: _,
prologue_input: _,
} = &node.op
else {
unreachable!();
};
let region_inputs: Vec<NodeId> = node.inputs.iter().map(|id| rw.map(*id)).collect();
let mut step_ids: Vec<NodeId> = Vec::with_capacity(chain.len());
let region_shape = node.shape.clone();
let region_dims: Vec<_> = region_shape.dims().to_vec();
let mut step_dtypes: Vec<rlx_ir::DType> = Vec::with_capacity(chain.len());
let region_dtype = region_shape.dtype();
let dtype_of = |op: &ChainOperand,
ins: &[NodeId],
step_dt: &[rlx_ir::DType],
rw: &Rewriter|
-> rlx_ir::DType {
match *op {
ChainOperand::Input(i) => rw.new_graph.node(ins[i as usize]).shape.dtype(),
ChainOperand::Step(i) => step_dt[i as usize],
}
};
let shape_of =
|op: &ChainOperand, ins: &[NodeId], step_ids: &[NodeId], rw: &Rewriter| -> Shape {
match *op {
ChainOperand::Input(i) => rw.new_graph.node(ins[i as usize]).shape.clone(),
ChainOperand::Step(i) => rw.new_graph.node(step_ids[i as usize]).shape.clone(),
}
};
for step in chain {
let resolve = |op: &ChainOperand| -> NodeId {
match *op {
ChainOperand::Input(i) => region_inputs[i as usize],
ChainOperand::Step(i) => step_ids[i as usize],
}
};
let (new_id, dt) = match step {
ChainStep::Activation(a, src) => {
let s = resolve(src);
let dt = dtype_of(src, ®ion_inputs, &step_dtypes, &rw);
let src_shape = shape_of(src, ®ion_inputs, &step_ids, &rw);
let dims: Vec<_> = src_shape.dims().to_vec();
let shape = Shape::from_dims(&dims, dt);
(
rw.new_graph.add_node(Op::Activation(*a), vec![s], shape),
dt,
)
}
ChainStep::Cast(to, src) => {
let s = resolve(src);
let src_shape = shape_of(src, ®ion_inputs, &step_ids, &rw);
let dims: Vec<_> = src_shape.dims().to_vec();
let shape = Shape::from_dims(&dims, *to);
(
rw.new_graph.add_node(Op::Cast { to: *to }, vec![s], shape),
*to,
)
}
ChainStep::Binary(op, lhs, rhs) => {
let l = resolve(lhs);
let r = resolve(rhs);
let dt = dtype_of(lhs, ®ion_inputs, &step_dtypes, &rw);
let l_shape = shape_of(lhs, ®ion_inputs, &step_ids, &rw);
let r_shape = shape_of(rhs, ®ion_inputs, &step_ids, &rw);
let bcast = l_shape
.broadcast_with(&r_shape)
.unwrap_or_else(|e| panic!("clip_elementwise_regions: {e}"));
let dims: Vec<_> = bcast.dims().to_vec();
let shape = Shape::from_dims(&dims, dt);
(
rw.new_graph.add_node(Op::Binary(*op), vec![l, r], shape),
dt,
)
}
ChainStep::Compare(op, lhs, rhs) => {
let l = resolve(lhs);
let r = resolve(rhs);
let l_shape = shape_of(lhs, ®ion_inputs, &step_ids, &rw);
let r_shape = shape_of(rhs, ®ion_inputs, &step_ids, &rw);
let bcast = l_shape
.broadcast_with(&r_shape)
.unwrap_or_else(|e| panic!("clip_elementwise_regions: {e}"));
let dims: Vec<_> = bcast.dims().to_vec();
let shape = Shape::from_dims(&dims, rlx_ir::DType::U8);
(
rw.new_graph.add_node(Op::Compare(*op), vec![l, r], shape),
rlx_ir::DType::U8,
)
}
ChainStep::Where(cond, x, y) => {
let cn = resolve(cond);
let xn = resolve(x);
let yn = resolve(y);
let dt = dtype_of(x, ®ion_inputs, &step_dtypes, &rw);
let x_shape = shape_of(x, ®ion_inputs, &step_ids, &rw);
let y_shape = shape_of(y, ®ion_inputs, &step_ids, &rw);
let c_shape = shape_of(cond, ®ion_inputs, &step_ids, &rw);
let bcast_xy = x_shape
.broadcast_with(&y_shape)
.unwrap_or_else(|e| panic!("clip_elementwise_regions: {e}"));
let bcast = c_shape.broadcast_with(&bcast_xy).unwrap_or_else(|e| {
panic!("clip_elementwise_regions: cannot broadcast cond {c_shape:?} ⊗ {bcast_xy:?} for Where: {e}")
});
let dims: Vec<_> = bcast.dims().to_vec();
let shape = Shape::from_dims(&dims, dt);
(
rw.new_graph.add_node(Op::Where, vec![cn, xn, yn], shape),
dt,
)
}
};
step_ids.push(new_id);
step_dtypes.push(dt);
}
let _ = (region_dtype, region_dims);
let last = *step_ids
.last()
.expect("oversize region has non-empty chain");
rw.replace(node.id, last);
}
rw.finish(&graph.outputs)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::limits::FusionLimits;
use crate::pass::run_passes;
fn f32_shape(dims: &[usize]) -> Shape {
Shape::new(dims, DType::F32)
}
#[test]
fn fuse_matmul_bias_gelu() {
let mut g = Graph::new("test");
let x = g.input("x", f32_shape(&[4, 15, 384]));
let w = g.param("w", f32_shape(&[384, 1536]));
let b = g.param("b", f32_shape(&[1536]));
let mm = g.matmul(x, w, f32_shape(&[4, 15, 1536]));
let add = g.binary(BinaryOp::Add, mm, b, f32_shape(&[4, 15, 1536]));
let out = g.activation(Activation::Gelu, add, f32_shape(&[4, 15, 1536]));
g.set_outputs(vec![out]);
assert_eq!(g.len(), 6);
let fused = FuseMatMulBiasAct.run(g);
println!("{fused}");
assert_eq!(fused.len(), 4);
let out_node = fused.node(fused.outputs[0]);
assert!(matches!(
out_node.op,
Op::FusedMatMulBiasAct {
activation: Some(Activation::Gelu)
}
));
}
#[test]
fn fuse_matmul_bias_no_act() {
let mut g = Graph::new("test");
let x = g.input("x", f32_shape(&[4, 15, 384]));
let w = g.param("w", f32_shape(&[384, 384]));
let b = g.param("b", f32_shape(&[384]));
let mm = g.matmul(x, w, f32_shape(&[4, 15, 384]));
let add = g.binary(BinaryOp::Add, mm, b, f32_shape(&[4, 15, 384]));
g.set_outputs(vec![add]);
let fused = FuseMatMulBiasAct.run(g);
assert_eq!(fused.len(), 4);
let out_node = fused.node(fused.outputs[0]);
assert!(matches!(
out_node.op,
Op::FusedMatMulBiasAct { activation: None }
));
}
#[test]
fn fuse_matmul_bias_skips_unsupported_activation_epilogue() {
let mut g = Graph::new("test");
let x = g.input("x", f32_shape(&[8, 1024]));
let w = g.param("w", f32_shape(&[1024, 16]));
let b = g.param("b", f32_shape(&[16]));
let mm = g.matmul(x, w, f32_shape(&[8, 16]));
let add = g.binary(BinaryOp::Add, mm, b, f32_shape(&[8, 16]));
let exp = g.activation(Activation::Exp, add, f32_shape(&[8, 16]));
g.set_outputs(vec![exp]);
let fused = FuseMatMulBiasAct.run(g);
assert_eq!(fused.len(), 5);
let out_node = fused.node(fused.outputs[0]);
assert!(matches!(out_node.op, Op::Activation(Activation::Exp)));
let add_node = fused.node(out_node.inputs[0]);
assert!(matches!(
add_node.op,
Op::FusedMatMulBiasAct { activation: None }
));
}
#[test]
fn fuse_matmul_bias_act_with_late_bias_param() {
use rlx_ir::infer::GraphExt;
let mut g = Graph::new("late_bias");
let x = g.input("x", f32_shape(&[8, 16]));
let w = g.param("w", f32_shape(&[16, 32]));
let out = {
let mm = g.mm(x, w);
let b = g.param("b", f32_shape(&[32]));
let biased = g.add(mm, b);
g.gelu(biased)
};
g.set_outputs(vec![out]);
let fused = FuseMatMulBiasAct.run(g);
assert!(
fused
.nodes()
.iter()
.any(|n| matches!(n.op, Op::FusedMatMulBiasAct { .. })),
"bias param declared after matmul must still fuse:\n{fused}"
);
}
#[test]
fn swiglu_ffn_builder_fuses_end_to_end() {
let mut g = Graph::new("swiglu_block");
let x = g.input("x", f32_shape(&[4, 768]));
let up_w = g.param("up", f32_shape(&[768, 2048]));
let gate_w = g.param("gate", f32_shape(&[768, 2048]));
let down_w = g.param("down", f32_shape(&[2048, 768]));
let out = g.swiglu_ffn(x, up_w, gate_w, down_w);
g.set_outputs(vec![out]);
let g = FuseSharedInputMatMul.run(g);
let g = FuseSwiGLU.run(g);
assert!(
g.nodes()
.iter()
.any(|n| matches!(n.op, Op::FusedSwiGLU { .. })),
"swiglu_ffn builder should match FuseSwiGLU:\n{g}"
);
}
#[test]
fn fuse_swiglu_dual_matmul_gate_first() {
use rlx_ir::infer::GraphExt;
let mut g = Graph::new("qwen3_ffn");
let x = g.input("x", f32_shape(&[4, 768]));
let gate_w = g.param("gate", f32_shape(&[768, 2048]));
let up_w = g.param("up", f32_shape(&[768, 2048]));
let gate = g.mm(x, gate_w);
let up = g.mm(x, up_w);
let gate_act = g.silu(gate);
let out = g.mul(gate_act, up);
g.set_outputs(vec![out]);
let fused = FuseSwiGLUDualMatmul.run(g);
assert!(
fused
.nodes()
.iter()
.any(|n| matches!(n.op, Op::FusedSwiGLU { .. })),
"gate-first dual matmul should fuse:\n{fused}"
);
assert!(
fused.len() <= 6,
"dual fusion should collapse to x + weights + concat + mm + fused_swiglu, got {} nodes",
fused.len()
);
}
#[test]
fn fuse_shared_input_matmul_three_way_qkv() {
let mut g = Graph::new("qkv");
let x = g.input("x", f32_shape(&[8, 512]));
let wq = g.param("wq", f32_shape(&[512, 128]));
let wk = g.param("wk", f32_shape(&[512, 128]));
let wv = g.param("wv", f32_shape(&[512, 128]));
let q = g.matmul(x, wq, f32_shape(&[8, 128]));
let k = g.matmul(x, wk, f32_shape(&[8, 128]));
let v = g.matmul(x, wv, f32_shape(&[8, 128]));
g.set_outputs(vec![q, k, v]);
let fused = FuseSharedInputMatMul.run(g);
assert_eq!(
fused.len(),
9,
"x + 3 weights + concat + mm + 3 narrows = 9"
);
for &out in &fused.outputs {
assert!(matches!(fused.node(out).op, Op::Narrow { .. }));
}
}
#[test]
fn fuse_residual_layer_norm() {
let mut g = Graph::new("test");
let x = g.input("x", f32_shape(&[4, 15, 384]));
let residual = g.input("residual", f32_shape(&[4, 15, 384]));
let gamma = g.param("gamma", f32_shape(&[384]));
let beta = g.param("beta", f32_shape(&[384]));
let add = g.binary(BinaryOp::Add, x, residual, f32_shape(&[4, 15, 384]));
let ln = g.layer_norm(add, gamma, beta, -1, 1e-12, f32_shape(&[4, 15, 384]));
g.set_outputs(vec![ln]);
assert_eq!(g.len(), 6);
let fused = FuseResidualLN.run(g);
println!("{fused}");
assert_eq!(fused.len(), 5);
let out_node = fused.node(fused.outputs[0]);
assert!(matches!(
out_node.op,
Op::FusedResidualLN {
has_bias: false,
..
}
));
}
#[test]
fn fuse_residual_rms_norm() {
let mut g = Graph::new("test");
let x = g.input("x", f32_shape(&[4, 15, 384]));
let residual = g.input("residual", f32_shape(&[4, 15, 384]));
let gamma = g.param("gamma", f32_shape(&[384]));
let beta = g.param("beta", f32_shape(&[384]));
let add = g.binary(BinaryOp::Add, x, residual, f32_shape(&[4, 15, 384]));
let rn = g.add_node(
Op::RmsNorm {
axis: -1,
eps: 1e-6,
},
vec![add, gamma, beta],
f32_shape(&[4, 15, 384]),
);
g.set_outputs(vec![rn]);
assert_eq!(g.len(), 6);
let fused = FuseResidualRmsNorm.run(g);
assert_eq!(fused.len(), 5);
let out_node = fused.node(fused.outputs[0]);
assert!(matches!(
out_node.op,
Op::FusedResidualRmsNorm {
has_bias: false,
..
}
));
}
#[test]
fn fuse_rms_norm_reshape() {
let mut g = Graph::new("test");
let x = g.input("x", f32_shape(&[1, 8, 512]));
let gamma = g.param("gamma", f32_shape(&[512]));
let beta = g.param("beta", f32_shape(&[512]));
let rn = g.add_node(
Op::RmsNorm {
axis: -1,
eps: 1e-6,
},
vec![x, gamma, beta],
f32_shape(&[1, 8, 512]),
);
let flat = g.add_node(
Op::Reshape {
new_shape: vec![8, 512],
},
vec![rn],
f32_shape(&[8, 512]),
);
let w = g.param("w", f32_shape(&[512, 128]));
let mm = g.matmul(flat, w, f32_shape(&[8, 128]));
g.set_outputs(vec![mm]);
let fused = FuseRmsNormReshape.run(g);
assert_eq!(fused.len(), 6);
let rn_node = fused.node(fused.node(fused.outputs[0]).inputs[0]);
assert!(matches!(rn_node.op, Op::RmsNorm { .. }));
assert_eq!(rn_node.shape.dim(0).unwrap_static(), 8);
assert_eq!(rn_node.shape.dim(1).unwrap_static(), 512);
}
#[test]
fn fuse_shared_input_matmul() {
let mut g = Graph::new("swiglu");
let x = g.input("x", f32_shape(&[60, 768]));
let w1 = g.param("fc11", f32_shape(&[768, 2048]));
let w2 = g.param("fc12", f32_shape(&[768, 2048]));
let mm1 = g.matmul(x, w1, f32_shape(&[60, 2048]));
let mm2 = g.matmul(x, w2, f32_shape(&[60, 2048]));
g.set_outputs(vec![mm1, mm2]);
assert_eq!(g.len(), 5);
let fused = FuseSharedInputMatMul.run(g);
println!("{fused}");
assert!(fused.len() <= 7);
for &out in &fused.outputs {
assert!(matches!(fused.node(out).op, Op::Narrow { .. }));
}
}
#[test]
fn fuse_shared_input_matmul_with_late_w2_param() {
let mut g = Graph::new("late_w2");
let x = g.input("x", f32_shape(&[8, 16]));
let w1 = g.param("w1", f32_shape(&[16, 8]));
let mm1 = g.matmul(x, w1, f32_shape(&[8, 8]));
let w2 = g.param("w2", f32_shape(&[16, 8]));
let mm2 = g.matmul(x, w2, f32_shape(&[8, 8]));
g.set_outputs(vec![mm1, mm2]);
let fused = FuseSharedInputMatMul.run(g);
for &out in &fused.outputs {
assert!(
matches!(fused.node(out).op, Op::Narrow { .. }),
"late w2 should still fuse via ensure_mapped, got {:?}",
fused.node(out).op
);
}
}
#[test]
fn fuse_shared_input_matmul_moe_ffn_pattern() {
let mut g = Graph::new("moe_ffn");
let rows = 4usize;
let n_embd = 16usize;
let n_expert = 4usize;
let n_ff = 16usize;
let h_in = g.input("h", f32_shape(&[1, rows, n_embd]));
let h_2d = g.reshape_(h_in, vec![rows as i64, n_embd as i64]);
let router_w = g.param("router_w", f32_shape(&[n_embd, n_expert]));
let router_logits = g.matmul(h_2d, router_w, f32_shape(&[rows, n_expert]));
let shared_router_w = g.param("shared_router_w", f32_shape(&[n_embd, 1]));
let shared_logits = g.matmul(h_2d, shared_router_w, f32_shape(&[rows, 1]));
let shared_gate = g.activation(Activation::Sigmoid, shared_logits, f32_shape(&[rows, 1]));
let s_gate_w = g.param("s_gate_w", f32_shape(&[n_embd, n_ff]));
let s_up_w = g.param("s_up_w", f32_shape(&[n_embd, n_ff]));
let s_gate = g.matmul(h_2d, s_gate_w, f32_shape(&[rows, n_ff]));
let s_up = g.matmul(h_2d, s_up_w, f32_shape(&[rows, n_ff]));
let s_gate_silu = g.silu(s_gate);
let s_swiglu = g.mul(s_gate_silu, s_up);
g.set_outputs(vec![router_logits, shared_gate, s_swiglu]);
let fused = FuseSharedInputMatMul.run(g);
let narrow_count = fused
.nodes()
.iter()
.filter(|n| matches!(n.op, Op::Narrow { .. }))
.count();
assert!(
narrow_count >= 4,
"expected four narrow slices from fused h_2d matmuls, got {narrow_count}"
);
}
#[test]
fn full_bert_ffn_fusion() {
let mut g = Graph::new("bert_ffn");
let f = DType::F32;
let x = g.input("hidden", Shape::new(&[4, 15, 384], f));
let residual = g.input("residual", Shape::new(&[4, 15, 384], f));
let out_w = g.param("out.w", Shape::new(&[384, 384], f));
let out_b = g.param("out.b", Shape::new(&[384], f));
let out_mm = g.matmul(x, out_w, Shape::new(&[4, 15, 384], f));
let out_add = g.binary(BinaryOp::Add, out_mm, out_b, Shape::new(&[4, 15, 384], f));
let res_add = g.binary(
BinaryOp::Add,
out_add,
residual,
Shape::new(&[4, 15, 384], f),
);
let gamma = g.param("ln.g", Shape::new(&[384], f));
let beta = g.param("ln.b", Shape::new(&[384], f));
let ln = g.layer_norm(
res_add,
gamma,
beta,
-1,
1e-12,
Shape::new(&[4, 15, 384], f),
);
let int_w = g.param("int.w", Shape::new(&[384, 1536], f));
let int_b = g.param("int.b", Shape::new(&[1536], f));
let int_mm = g.matmul(ln, int_w, Shape::new(&[4, 15, 1536], f));
let int_add = g.binary(BinaryOp::Add, int_mm, int_b, Shape::new(&[4, 15, 1536], f));
let gelu = g.activation(Activation::Gelu, int_add, Shape::new(&[4, 15, 1536], f));
let out2_w = g.param("out2.w", Shape::new(&[1536, 384], f));
let out2_b = g.param("out2.b", Shape::new(&[384], f));
let out2_mm = g.matmul(gelu, out2_w, Shape::new(&[4, 15, 384], f));
let out2_add = g.binary(BinaryOp::Add, out2_mm, out2_b, Shape::new(&[4, 15, 384], f));
g.set_outputs(vec![out2_add]);
let before = g.len();
println!("=== BEFORE fusion ({before} nodes) ===\n{g}");
let passes: Vec<&dyn Pass> = vec![&FuseMatMulBiasAct, &FuseResidualLN];
let optimized = run_passes(g, &passes, false);
let after = optimized.len();
println!("=== AFTER fusion ({after} nodes) ===\n{optimized}");
assert!(
after < before,
"fusion should reduce node count: {before} → {after}"
);
let ops: Vec<String> = optimized
.nodes()
.iter()
.map(|n| format!("{}", n.op))
.collect();
let has_fused_mm = ops.iter().any(|s| s.contains("fused_mm_bias"));
assert!(has_fused_mm, "should have fused_mm_bias_act: {ops:?}");
}
#[test]
fn fuse_swiglu_canonical() {
let mut g = Graph::new("nomic_ffn");
let f = DType::F32;
let cat = g.input("cat", Shape::new(&[60, 4096], f));
let up = g.add_node(
Op::Narrow {
axis: 1,
start: 0,
len: 2048,
},
vec![cat],
Shape::new(&[60, 2048], f),
);
let gate = g.add_node(
Op::Narrow {
axis: 1,
start: 2048,
len: 2048,
},
vec![cat],
Shape::new(&[60, 2048], f),
);
let silu = g.activation(Activation::Silu, gate, Shape::new(&[60, 2048], f));
let out = g.binary(BinaryOp::Mul, up, silu, Shape::new(&[60, 2048], f));
g.set_outputs(vec![out]);
let before = g.len();
let fused = FuseSwiGLU.run(g);
let after = fused.len();
assert_eq!(
after,
before - 3,
"should remove narrows+silu+mul, add FusedSwiGLU"
);
let out_node = fused.node(fused.outputs[0]);
assert!(
matches!(
out_node.op,
Op::FusedSwiGLU {
cast_to: None,
gate_first: false
}
),
"output should be FusedSwiGLU, got {}",
out_node.op
);
let in_id = out_node.inputs[0];
assert!(matches!(fused.node(in_id).op, Op::Input { .. }));
}
#[test]
fn fuse_swiglu_skips_when_narrow_has_extra_user() {
let mut g = Graph::new("contended");
let f = DType::F32;
let cat = g.input("cat", Shape::new(&[60, 4096], f));
let up = g.add_node(
Op::Narrow {
axis: 1,
start: 0,
len: 2048,
},
vec![cat],
Shape::new(&[60, 2048], f),
);
let gate = g.add_node(
Op::Narrow {
axis: 1,
start: 2048,
len: 2048,
},
vec![cat],
Shape::new(&[60, 2048], f),
);
let silu = g.activation(Activation::Silu, gate, Shape::new(&[60, 2048], f));
let out = g.binary(BinaryOp::Mul, up, silu, Shape::new(&[60, 2048], f));
let extra = g.activation(Activation::Relu, up, Shape::new(&[60, 2048], f));
g.set_outputs(vec![out, extra]);
let before = g.len();
let fused = FuseSwiGLU.run(g);
assert_eq!(fused.len(), before);
let any_fused = fused
.nodes()
.iter()
.any(|n| matches!(n.op, Op::FusedSwiGLU { .. }));
assert!(!any_fused, "should not fuse when narrow has extra user");
}
#[test]
fn region_collapses_add_mul_relu_chain() {
let f = DType::F32;
let mut g = Graph::new("ew");
let a = g.input("a", Shape::new(&[8], f));
let b = g.input("b", Shape::new(&[8], f));
let c = g.input("c", Shape::new(&[8], f));
let s = Shape::new(&[8], f);
let add = g.binary(BinaryOp::Add, a, b, s.clone());
let mul = g.binary(BinaryOp::Mul, add, c, s.clone());
let relu = g.activation(Activation::Relu, mul, s.clone());
g.set_outputs(vec![relu]);
let before = g.len();
let fused = MarkElementwiseRegions.run(g);
let regions: Vec<_> = fused
.nodes()
.iter()
.filter(|n| matches!(n.op, Op::ElementwiseRegion { .. }))
.collect();
assert_eq!(regions.len(), 1, "expected one ElementwiseRegion");
let region = regions[0];
assert_eq!(
region.inputs.len(),
3,
"region has 3 external inputs (a, b, c)"
);
if let Op::ElementwiseRegion {
chain, num_inputs, ..
} = ®ion.op
{
assert_eq!(*num_inputs, 3);
assert_eq!(chain.len(), 3);
match &chain[0] {
ChainStep::Binary(
BinaryOp::Add,
ChainOperand::Input(0),
ChainOperand::Input(1),
) => {}
other => panic!("step 0 unexpected: {other:?}"),
}
match &chain[1] {
ChainStep::Binary(BinaryOp::Mul, ChainOperand::Step(0), ChainOperand::Input(2)) => {
}
other => panic!("step 1 unexpected: {other:?}"),
}
match &chain[2] {
ChainStep::Activation(Activation::Relu, ChainOperand::Step(1)) => {}
other => panic!("step 2 unexpected: {other:?}"),
}
} else {
unreachable!();
}
assert!(fused.len() < before);
}
#[test]
fn region_does_not_fuse_when_intermediate_has_multiple_consumers() {
let f = DType::F32;
let mut g = Graph::new("ew");
let a = g.input("a", Shape::new(&[4], f));
let b = g.input("b", Shape::new(&[4], f));
let s = Shape::new(&[4], f);
let add = g.binary(BinaryOp::Add, a, b, s.clone());
let relu = g.activation(Activation::Relu, add, s.clone());
let extra = g.activation(Activation::Sigmoid, add, s.clone());
g.set_outputs(vec![relu, extra]);
let before = g.len();
let fused = MarkElementwiseRegions.run(g);
let regions: Vec<_> = fused
.nodes()
.iter()
.filter(|n| matches!(n.op, Op::ElementwiseRegion { .. }))
.collect();
assert_eq!(regions.len(), 0);
assert_eq!(fused.len(), before);
}
#[test]
fn region_skips_chains_of_length_one() {
let f = DType::F32;
let mut g = Graph::new("ew");
let a = g.input("a", Shape::new(&[4], f));
let r = g.activation(Activation::Relu, a, Shape::new(&[4], f));
g.set_outputs(vec![r]);
let fused = MarkElementwiseRegions.run(g);
let any_region = fused
.nodes()
.iter()
.any(|n| matches!(n.op, Op::ElementwiseRegion { .. }));
assert!(!any_region);
}
#[test]
fn unfuse_decomposes_region_back_to_atomic_ops() {
let f = DType::F32;
let mut g = Graph::new("ew_unfuse");
let a = g.input("a", Shape::new(&[8], f));
let b = g.input("b", Shape::new(&[8], f));
let c = g.input("c", Shape::new(&[8], f));
let s = Shape::new(&[8], f);
let add = g.binary(BinaryOp::Add, a, b, s.clone());
let mul = g.binary(BinaryOp::Mul, add, c, s.clone());
let relu = g.activation(Activation::Relu, mul, s);
g.set_outputs(vec![relu]);
let fused = MarkElementwiseRegions.run(g);
assert!(
fused
.nodes()
.iter()
.any(|n| matches!(n.op, Op::ElementwiseRegion { .. }))
);
let unfused = UnfuseElementwiseRegions::FOR_CPU.run(fused);
assert!(
!unfused
.nodes()
.iter()
.any(|n| matches!(n.op, Op::ElementwiseRegion { .. }))
);
let bin_count = unfused
.nodes()
.iter()
.filter(|n| matches!(n.op, Op::Binary(_)))
.count();
let act_count = unfused
.nodes()
.iter()
.filter(|n| matches!(n.op, Op::Activation(_)))
.count();
assert_eq!(bin_count, 2, "Add + Mul restored");
assert_eq!(act_count, 1, "Relu restored");
}
#[test]
fn clip_unfuses_region_over_step_cap() {
use rlx_ir::op::{Activation, ChainOperand, ChainStep};
let mut g = Graph::new("clip");
let x = g.input("x", f32_shape(&[4]));
let mut chain: Vec<ChainStep> = Vec::new();
let mut prev = ChainOperand::Input(0);
for _ in 0..40 {
chain.push(ChainStep::Activation(Activation::Relu, prev));
prev = ChainOperand::Step(chain.len() as u32 - 1);
}
let y = g.add_node(
Op::ElementwiseRegion {
chain,
num_inputs: 1,
scalar_input_mask: 0,
input_modulus: [0; 16],
prologue: RegionPrologue::None,
prologue_input: 0,
},
vec![x],
f32_shape(&[4]),
);
g.set_outputs(vec![y]);
let clipped = clip_elementwise_regions(g, FusionLimits::GPU_NATIVE);
assert!(
!clipped
.nodes()
.iter()
.any(|n| matches!(n.op, Op::ElementwiseRegion { .. })),
"oversized region should be decomposed"
);
assert!(clipped.len() > 5);
}
#[test]
fn unfuse_is_noop_when_no_region_present() {
let f = DType::F32;
let mut g = Graph::new("noop");
let a = g.input("a", Shape::new(&[4], f));
let r = g.activation(Activation::Relu, a, Shape::new(&[4], f));
g.set_outputs(vec![r]);
let n_before = g.len();
let result = UnfuseElementwiseRegions::FOR_CPU.run(g);
assert_eq!(result.len(), n_before);
}
#[test]
fn region_includes_where_step() {
let f = DType::F32;
let mut g = Graph::new("region_where");
let a = g.input("a", Shape::new(&[4], f));
let b = g.input("b", Shape::new(&[4], f));
let s = Shape::new(&[4], f);
let cmp = g.add_node(Op::Compare(CmpOp::Gt), vec![a, b], s.clone());
let sel = g.add_node(Op::Where, vec![cmp, a, b], s.clone());
let add = g.binary(BinaryOp::Add, sel, a, s.clone());
g.set_outputs(vec![add]);
let fused = MarkElementwiseRegions.run(g);
let regions: Vec<_> = fused
.nodes()
.iter()
.filter(|n| matches!(n.op, Op::ElementwiseRegion { .. }))
.collect();
assert_eq!(regions.len(), 1, "expected one ElementwiseRegion");
if let Op::ElementwiseRegion { chain, .. } = ®ions[0].op {
assert_eq!(chain.len(), 3);
assert!(
matches!(chain[1], ChainStep::Where(_, _, _)),
"step 1 should be Where, got {:?}",
chain[1]
);
} else {
unreachable!();
}
}
#[test]
fn unfuse_decomposes_where_step_back_to_op_where() {
let f = DType::F32;
let mut g = Graph::new("unfuse_where");
let a = g.input("a", Shape::new(&[4], f));
let b = g.input("b", Shape::new(&[4], f));
let s = Shape::new(&[4], f);
let cmp = g.add_node(Op::Compare(CmpOp::Gt), vec![a, b], s.clone());
let sel = g.add_node(Op::Where, vec![cmp, a, b], s.clone());
let add = g.binary(BinaryOp::Add, sel, a, s.clone());
g.set_outputs(vec![add]);
let fused = MarkElementwiseRegions.run(g);
let unfused = UnfuseElementwiseRegions::FOR_CPU.run(fused);
let where_count = unfused
.nodes()
.iter()
.filter(|n| matches!(n.op, Op::Where))
.count();
assert_eq!(
where_count, 1,
"decomposer should re-emit one Op::Where for the chain step"
);
}
#[test]
fn fuse_attention_block_collapses_qkv_attn_outproj() {
let nh: usize = 4;
let dh: usize = 8;
let h: usize = nh * dh; let b: usize = 1;
let s: usize = 4;
let mut g = Graph::new("attn-block");
let hidden = g.input("hidden", f32_shape(&[b, s, h]));
let mask = g.input("attention_mask", f32_shape(&[b, s]));
let qkv_w = g.param("qkv_w", f32_shape(&[h, 3 * h]));
let qkv_b = g.param("qkv_b", f32_shape(&[3 * h]));
let qkv_mm = g.matmul(hidden, qkv_w, f32_shape(&[b, s, 3 * h]));
let qkv = g.binary(BinaryOp::Add, qkv_mm, qkv_b, f32_shape(&[b, s, 3 * h]));
let q = g.add_node(
Op::Narrow {
axis: 2,
start: 0,
len: h,
},
vec![qkv],
f32_shape(&[b, s, h]),
);
let k = g.add_node(
Op::Narrow {
axis: 2,
start: h,
len: h,
},
vec![qkv],
f32_shape(&[b, s, h]),
);
let v = g.add_node(
Op::Narrow {
axis: 2,
start: 2 * h,
len: h,
},
vec![qkv],
f32_shape(&[b, s, h]),
);
let attn = g.attention(q, k, v, mask, nh, dh, f32_shape(&[b, s, h]));
let out_w = g.param("out_w", f32_shape(&[h, h]));
let out_b = g.param("out_b", f32_shape(&[h]));
let out_mm = g.matmul(attn, out_w, f32_shape(&[b, s, h]));
let out = g.binary(BinaryOp::Add, out_mm, out_b, f32_shape(&[b, s, h]));
g.set_outputs(vec![out]);
let fused1 = FuseMatMulBiasAct.run(g);
let mm_bias_count = fused1
.nodes()
.iter()
.filter(|n| matches!(n.op, Op::FusedMatMulBiasAct { activation: None }))
.count();
assert_eq!(mm_bias_count, 2, "QKV + OutProj should each fuse");
let fused2 = FuseAttentionBlock.run(fused1);
let fab_count = fused2
.nodes()
.iter()
.filter(|n| {
matches!(
n.op,
Op::FusedAttentionBlock {
has_bias: true,
has_rope: false,
..
}
)
})
.count();
assert_eq!(
fab_count, 1,
"should produce exactly one FusedAttentionBlock"
);
let narrow_count = fused2
.nodes()
.iter()
.filter(|n| matches!(n.op, Op::Narrow { .. }))
.count();
let attention_count = fused2
.nodes()
.iter()
.filter(|n| matches!(n.op, Op::Attention { .. }))
.count();
let mm_bias_remaining = fused2
.nodes()
.iter()
.filter(|n| matches!(n.op, Op::FusedMatMulBiasAct { .. }))
.count();
assert_eq!(narrow_count, 0, "QKV narrows absorbed");
assert_eq!(attention_count, 0, "Attention absorbed");
assert_eq!(mm_bias_remaining, 0, "both projections absorbed");
let out_node = fused2.node(fused2.outputs[0]);
assert!(matches!(out_node.op, Op::FusedAttentionBlock { .. }));
}
#[test]
fn fuse_transformer_layer_collapses_full_bert_block() {
let nh: usize = 4;
let dh: usize = 8;
let h: usize = nh * dh;
let inter = 4 * h;
let eps1: f32 = 1e-12;
let eps2: f32 = 1e-12;
let b: usize = 1;
let s: usize = 4;
let mut g = Graph::new("bert-layer");
let hidden = g.input("hidden", f32_shape(&[b, s, h]));
let mask = g.input("attention_mask", f32_shape(&[b, s]));
let qkv_w = g.param("qkv_w", f32_shape(&[h, 3 * h]));
let qkv_b = g.param("qkv_b", f32_shape(&[3 * h]));
let qkv_mm = g.matmul(hidden, qkv_w, f32_shape(&[b, s, 3 * h]));
let qkv = g.binary(BinaryOp::Add, qkv_mm, qkv_b, f32_shape(&[b, s, 3 * h]));
let q = g.add_node(
Op::Narrow {
axis: 2,
start: 0,
len: h,
},
vec![qkv],
f32_shape(&[b, s, h]),
);
let k = g.add_node(
Op::Narrow {
axis: 2,
start: h,
len: h,
},
vec![qkv],
f32_shape(&[b, s, h]),
);
let v = g.add_node(
Op::Narrow {
axis: 2,
start: 2 * h,
len: h,
},
vec![qkv],
f32_shape(&[b, s, h]),
);
let attn = g.attention(q, k, v, mask, nh, dh, f32_shape(&[b, s, h]));
let out_w = g.param("out_w", f32_shape(&[h, h]));
let out_b = g.param("out_b", f32_shape(&[h]));
let out_mm = g.matmul(attn, out_w, f32_shape(&[b, s, h]));
let attn_out = g.binary(BinaryOp::Add, out_mm, out_b, f32_shape(&[b, s, h]));
let res1 = g.binary(BinaryOp::Add, attn_out, hidden, f32_shape(&[b, s, h]));
let ln1_g = g.param("ln1_g", f32_shape(&[h]));
let ln1_b = g.param("ln1_b", f32_shape(&[h]));
let h1 = g.add_node(
Op::LayerNorm {
axis: -1,
eps: eps1,
},
vec![res1, ln1_g, ln1_b],
f32_shape(&[b, s, h]),
);
let fc1_w = g.param("fc1_w", f32_shape(&[h, inter]));
let fc1_b = g.param("fc1_b", f32_shape(&[inter]));
let fc1_mm = g.matmul(h1, fc1_w, f32_shape(&[b, s, inter]));
let fc1_add = g.binary(BinaryOp::Add, fc1_mm, fc1_b, f32_shape(&[b, s, inter]));
let fc1_act = g.activation(Activation::Gelu, fc1_add, f32_shape(&[b, s, inter]));
let fc2_w = g.param("fc2_w", f32_shape(&[inter, h]));
let fc2_b = g.param("fc2_b", f32_shape(&[h]));
let fc2_mm = g.matmul(fc1_act, fc2_w, f32_shape(&[b, s, h]));
let ffn_out = g.binary(BinaryOp::Add, fc2_mm, fc2_b, f32_shape(&[b, s, h]));
let res2 = g.binary(BinaryOp::Add, ffn_out, h1, f32_shape(&[b, s, h]));
let ln2_g = g.param("ln2_g", f32_shape(&[h]));
let ln2_b = g.param("ln2_b", f32_shape(&[h]));
let out = g.add_node(
Op::LayerNorm {
axis: -1,
eps: eps2,
},
vec![res2, ln2_g, ln2_b],
f32_shape(&[b, s, h]),
);
g.set_outputs(vec![out]);
let g = FuseMatMulBiasAct.run(g);
let g = FuseResidualLN.run(g);
let g = FuseAttentionBlock.run(g);
let g = FuseTransformerLayer.run(g);
let ftl_count = g
.nodes()
.iter()
.filter(|n| matches!(n.op, Op::FusedTransformerLayer { .. }))
.count();
assert_eq!(
ftl_count, 1,
"single layer should collapse to one FusedTransformerLayer"
);
let leftover_fab = g
.nodes()
.iter()
.filter(|n| matches!(n.op, Op::FusedAttentionBlock { .. }))
.count();
let leftover_frln = g
.nodes()
.iter()
.filter(|n| matches!(n.op, Op::FusedResidualLN { .. }))
.count();
let leftover_fmba = g
.nodes()
.iter()
.filter(|n| matches!(n.op, Op::FusedMatMulBiasAct { .. }))
.count();
assert_eq!(leftover_fab, 0, "attn block absorbed into layer");
assert_eq!(leftover_frln, 0, "both residual+LNs absorbed");
assert_eq!(leftover_fmba, 0, "FFN matmuls absorbed");
let out_node = g.node(g.outputs[0]);
assert!(matches!(
out_node.op,
Op::FusedTransformerLayer {
num_heads: 4,
head_dim: 8,
intermediate_size: 128,
has_bias: true,
..
}
));
assert_eq!(out_node.inputs.len(), 14);
}
#[test]
fn fuse_attention_block_skips_large_inputs() {
let nh: usize = 4;
let dh: usize = 8;
let h: usize = nh * dh;
let b: usize = 16;
let s: usize = 128;
let mut g = Graph::new("attn-block-large");
let hidden = g.input("hidden", f32_shape(&[b, s, h]));
let mask = g.input("attention_mask", f32_shape(&[b, s]));
let qkv_w = g.param("qkv_w", f32_shape(&[h, 3 * h]));
let qkv_b = g.param("qkv_b", f32_shape(&[3 * h]));
let qkv_mm = g.matmul(hidden, qkv_w, f32_shape(&[b, s, 3 * h]));
let qkv = g.binary(BinaryOp::Add, qkv_mm, qkv_b, f32_shape(&[b, s, 3 * h]));
let q = g.add_node(
Op::Narrow {
axis: 2,
start: 0,
len: h,
},
vec![qkv],
f32_shape(&[b, s, h]),
);
let k = g.add_node(
Op::Narrow {
axis: 2,
start: h,
len: h,
},
vec![qkv],
f32_shape(&[b, s, h]),
);
let v = g.add_node(
Op::Narrow {
axis: 2,
start: 2 * h,
len: h,
},
vec![qkv],
f32_shape(&[b, s, h]),
);
let attn = g.attention(q, k, v, mask, nh, dh, f32_shape(&[b, s, h]));
let out_w = g.param("out_w", f32_shape(&[h, h]));
let out_b = g.param("out_b", f32_shape(&[h]));
let out_mm = g.matmul(attn, out_w, f32_shape(&[b, s, h]));
let out = g.binary(BinaryOp::Add, out_mm, out_b, f32_shape(&[b, s, h]));
g.set_outputs(vec![out]);
let fused1 = FuseMatMulBiasAct.run(g);
let fused2 = FuseAttentionBlock.run(fused1);
let fab_count = fused2
.nodes()
.iter()
.filter(|n| matches!(n.op, Op::FusedAttentionBlock { .. }))
.count();
assert_eq!(fab_count, 0, "block-fusion must skip large batches");
}
}