Skip to main content

rlx_fusion/
fusion.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Fusion passes — pattern-match and replace subgraphs with fused ops.
17//!
18//! Each pass scans the graph in reverse topological order, looking for
19//! specific multi-node patterns and replacing them with single fused nodes.
20//! These are the same fusions we hand-coded in burnembed's ndarray_fused.rs.
21
22use crate::pass::Pass;
23use rlx_ir::op::*;
24use rlx_ir::*;
25use std::collections::HashMap;
26
27// ── Helper: graph rewriter ──────────────────────────────────────────────
28
29use crate::graph_rewrite::Rewriter;
30
31// ── Pass 1: MatMul + Bias + Activation → FusedMatMulBiasAct ─────────────
32
33/// Fuses `matmul → add(bias) → activation` into a single FusedMatMulBiasAct.
34///
35/// This is the single most impactful fusion — it eliminates two intermediate
36/// tensors and three memory passes (matmul write, bias read+write, act read+write)
37/// down to one (matmul write with inline bias+activation).
38///
39/// Also fuses `matmul → add(bias)` without activation.
40///
41/// Epilogue activations are fused only when every backend can apply them
42/// inline with the matmul (today: Gelu and Silu). Other activations — e.g.
43/// Exp in qwen35 softplus — stay as separate ops so Metal does not silently
44/// drop the epilogue.
45pub struct FuseMatMulBiasAct;
46
47/// Activations that may be folded into `FusedMatMulBiasAct` epilogues.
48fn fusible_mm_bias_epilogue_activation(act: Activation) -> bool {
49    matches!(act, Activation::Gelu | Activation::Silu)
50}
51
52impl Pass for FuseMatMulBiasAct {
53    fn name(&self) -> &str {
54        "fuse_matmul_bias_act"
55    }
56
57    fn run(&self, graph: Graph) -> Graph {
58        let mut rw = Rewriter::new(&graph.name);
59        // Track which nodes are consumed by fusion (skip them in copy)
60        let mut fused_away: HashMap<NodeId, ()> = HashMap::new();
61
62        // Forward pass: copy nodes, detect patterns
63        for node in graph.nodes() {
64            if fused_away.contains_key(&node.id) {
65                continue;
66            }
67
68            // Pattern: MatMul → Add(bias) → Activation
69            // or:      MatMul → Add(bias)
70            if matches!(node.op, Op::MatMul) {
71                let mm_id = node.id;
72                let mm_users: Vec<_> = graph.users(mm_id);
73
74                // Check for single-use Add(bias) consumer
75                if mm_users.len() == 1 {
76                    let add_node = graph.node(mm_users[0]);
77                    if let Op::Binary(BinaryOp::Add) = &add_node.op {
78                        // Determine which input is the bias (the non-matmul one)
79                        let (bias_id, _mm_input) = if add_node.inputs[0] == mm_id {
80                            (add_node.inputs[1], add_node.inputs[0])
81                        } else {
82                            (add_node.inputs[0], add_node.inputs[1])
83                        };
84
85                        // Check if bias is a param/const with broadcastable shape
86                        let bias_shape = graph.shape(bias_id);
87                        if bias_shape.rank() <= 1 {
88                            let add_id = add_node.id;
89                            let add_users = graph.users(add_id);
90
91                            // Check for activation consumer
92                            let mut activation = None;
93                            let mut act_id = None;
94                            if add_users.len() == 1 {
95                                let act_node = graph.node(add_users[0]);
96                                if let Op::Activation(a) = &act_node.op
97                                    && fusible_mm_bias_epilogue_activation(*a)
98                                {
99                                    activation = Some(*a);
100                                    act_id = Some(act_node.id);
101                                }
102                            }
103
104                            // Emit fused node. Bias may be declared after
105                            // the matmul in the source graph — copy it early
106                            // instead of requiring builders to order params first.
107                            let out_shape = if let Some(aid) = act_id {
108                                graph.shape(aid).clone()
109                            } else {
110                                add_node.shape.clone()
111                            };
112
113                            rw.ensure_mapped(&graph, &[node.inputs[0], node.inputs[1], bias_id]);
114                            let fused_id = rw.add_fused(
115                                Op::FusedMatMulBiasAct { activation },
116                                &[node.inputs[0], node.inputs[1], bias_id],
117                                out_shape,
118                            );
119
120                            // Map old nodes to the fused result
121                            rw.replace(mm_id, fused_id);
122                            rw.replace(add_id, fused_id);
123                            fused_away.insert(add_id, ());
124                            if let Some(aid) = act_id {
125                                rw.replace(aid, fused_id);
126                                fused_away.insert(aid, ());
127                            }
128                            continue;
129                        }
130                    }
131                }
132            }
133
134            // No fusion — copy as-is
135            rw.copy_node(node);
136        }
137
138        rw.finish(&graph.outputs)
139    }
140}
141
142// ── Pass 2: Add(residual) + LayerNorm → FusedResidualLN ─────────────────
143
144/// Fuses `add(x, residual) → layer_norm` into FusedResidualLN.
145///
146/// Also detects `add(x, residual) → add(bias) → layer_norm` for the
147/// bias variant (used in BERT's output projection).
148pub struct FuseResidualLN;
149
150impl Pass for FuseResidualLN {
151    fn name(&self) -> &str {
152        "fuse_residual_ln"
153    }
154
155    fn run(&self, graph: Graph) -> Graph {
156        // Graph outputs hold implicit references to their producing
157        // nodes that don't show up in any node's `inputs` (use_count
158        // walks node inputs only). Treat being-a-graph-output as a
159        // use so we don't fuse-away an intermediate the caller still
160        // wants to read — this used to silently corrupt multi-block
161        // encoders (e.g. SAM 2 stage outputs) by collapsing the
162        // residual add of block N into block N+1's LN.
163        let mut is_output: HashMap<NodeId, ()> = HashMap::new();
164        for &oid in &graph.outputs {
165            is_output.insert(oid, ());
166        }
167        // Pre-scan: find all Add nodes consumed by LayerNorm
168        let mut fused_away: HashMap<NodeId, ()> = HashMap::new();
169        for node in graph.nodes() {
170            if let Op::LayerNorm { .. } = &node.op {
171                let ln_input_id = node.inputs[0];
172                let ln_input = graph.node(ln_input_id);
173                if matches!(ln_input.op, Op::Binary(BinaryOp::Add))
174                    && graph.use_count(ln_input_id) == 1
175                    && !is_output.contains_key(&ln_input_id)
176                {
177                    fused_away.insert(ln_input_id, ());
178                }
179            }
180        }
181
182        let mut rw = Rewriter::new(&graph.name);
183
184        for node in graph.nodes() {
185            if fused_away.contains_key(&node.id) {
186                continue;
187            }
188
189            if let Op::LayerNorm { eps, .. } = &node.op {
190                let ln_input_id = node.inputs[0];
191                let ln_input = graph.node(ln_input_id);
192
193                if matches!(ln_input.op, Op::Binary(BinaryOp::Add))
194                    && fused_away.contains_key(&ln_input_id)
195                {
196                    let (x_id, residual_id) = (ln_input.inputs[0], ln_input.inputs[1]);
197                    let gamma_id = node.inputs[1];
198                    let beta_id = node.inputs[2];
199
200                    let fused_id = rw.add_fused(
201                        Op::FusedResidualLN {
202                            has_bias: false,
203                            eps: *eps,
204                        },
205                        &[x_id, residual_id, gamma_id, beta_id],
206                        node.shape.clone(),
207                    );
208
209                    rw.replace(ln_input_id, fused_id);
210                    rw.replace(node.id, fused_id);
211                    continue;
212                }
213            }
214
215            rw.copy_node(node);
216        }
217
218        rw.finish(&graph.outputs)
219    }
220}
221
222// ── Pass 2b: Add(residual) + RmsNorm → FusedResidualRmsNorm ─────────────
223
224/// Fuses `add(x, residual) → rms_norm` into [`Op::FusedResidualRmsNorm`].
225pub struct FuseResidualRmsNorm;
226
227impl Pass for FuseResidualRmsNorm {
228    fn name(&self) -> &str {
229        "fuse_residual_rms_norm"
230    }
231
232    fn run(&self, graph: Graph) -> Graph {
233        let mut is_output: HashMap<NodeId, ()> = HashMap::new();
234        for &oid in &graph.outputs {
235            is_output.insert(oid, ());
236        }
237        let mut fused_away: HashMap<NodeId, ()> = HashMap::new();
238        for node in graph.nodes() {
239            if let Op::RmsNorm { .. } = &node.op {
240                let rn_input_id = node.inputs[0];
241                let rn_input = graph.node(rn_input_id);
242                if matches!(rn_input.op, Op::Binary(BinaryOp::Add))
243                    && graph.use_count(rn_input_id) == 1
244                    && !is_output.contains_key(&rn_input_id)
245                {
246                    fused_away.insert(rn_input_id, ());
247                }
248            }
249        }
250
251        let mut rw = Rewriter::new(&graph.name);
252
253        for node in graph.nodes() {
254            if fused_away.contains_key(&node.id) {
255                continue;
256            }
257
258            if let Op::RmsNorm { eps, .. } = &node.op {
259                let rn_input_id = node.inputs[0];
260                let rn_input = graph.node(rn_input_id);
261
262                if matches!(rn_input.op, Op::Binary(BinaryOp::Add))
263                    && fused_away.contains_key(&rn_input_id)
264                {
265                    let (x_id, residual_id) = (rn_input.inputs[0], rn_input.inputs[1]);
266                    let gamma_id = node.inputs[1];
267                    let beta_id = node.inputs[2];
268
269                    let fused_id = rw.add_fused(
270                        Op::FusedResidualRmsNorm {
271                            has_bias: false,
272                            eps: *eps,
273                        },
274                        &[x_id, residual_id, gamma_id, beta_id],
275                        node.shape.clone(),
276                    );
277
278                    rw.replace(rn_input_id, fused_id);
279                    rw.replace(node.id, fused_id);
280                    continue;
281                }
282            }
283
284            rw.copy_node(node);
285        }
286
287        rw.finish(&graph.outputs)
288    }
289}
290
291// ── Pass 2c: RmsNorm → Reshape(leading flatten) ─────────────────────────
292
293/// Fuses `rms_norm([…, H]) → reshape([∏leading, H])` into a single
294/// `RmsNorm` with the flattened output shape, eliminating a memcpy.
295///
296/// Matches the Qwen3.5 pre-norm pattern where normalized activations
297/// are immediately reshaped to 2-D for matmul.
298pub struct FuseRmsNormReshape;
299
300fn leading_flatten_shape(in_shape: &Shape, new_shape: &[i64]) -> Option<Shape> {
301    rlx_ir::shape::leading_flatten_shape(in_shape, new_shape)
302}
303
304fn sole_consumer(graph: &Graph, id: NodeId) -> Option<NodeId> {
305    graph
306        .nodes()
307        .iter()
308        .find(|n| n.inputs.contains(&id))
309        .map(|n| n.id)
310}
311
312impl Pass for FuseRmsNormReshape {
313    fn name(&self) -> &str {
314        "fuse_rms_norm_reshape"
315    }
316
317    fn run(&self, graph: Graph) -> Graph {
318        let mut is_output: HashMap<NodeId, ()> = HashMap::new();
319        for &oid in &graph.outputs {
320            is_output.insert(oid, ());
321        }
322
323        let mut flat_shape: HashMap<NodeId, Shape> = HashMap::new();
324        let mut fused_away: HashMap<NodeId, ()> = HashMap::new();
325        for node in graph.nodes() {
326            if let Op::RmsNorm { .. } = &node.op {
327                if graph.use_count(node.id) != 1 || is_output.contains_key(&node.id) {
328                    continue;
329                }
330                let Some(reshape_id) = sole_consumer(&graph, node.id) else {
331                    continue;
332                };
333                if is_output.contains_key(&reshape_id) {
334                    continue;
335                }
336                let reshape = graph.node(reshape_id);
337                if let Op::Reshape { new_shape } = &reshape.op {
338                    if let Some(flat) = leading_flatten_shape(&node.shape, new_shape) {
339                        flat_shape.insert(node.id, flat);
340                        fused_away.insert(reshape_id, ());
341                    }
342                }
343            }
344        }
345
346        let mut rw = Rewriter::new(&graph.name);
347
348        for node in graph.nodes() {
349            if fused_away.contains_key(&node.id) {
350                continue;
351            }
352
353            if let Op::RmsNorm { axis, eps, .. } = &node.op {
354                if let Some(flat) = flat_shape.get(&node.id) {
355                    let Some(reshape_id) = sole_consumer(&graph, node.id) else {
356                        rw.copy_node(node);
357                        continue;
358                    };
359                    let fused_id = rw.add_fused(
360                        Op::RmsNorm {
361                            axis: *axis,
362                            eps: *eps,
363                        },
364                        &node.inputs,
365                        flat.clone(),
366                    );
367                    rw.replace(node.id, fused_id);
368                    rw.replace(reshape_id, fused_id);
369                    continue;
370                }
371            }
372
373            rw.copy_node(node);
374        }
375
376        rw.finish(&graph.outputs)
377    }
378}
379
380// ── Pass 3b: Dual MatMul SwiGLU (gate+up before shared-input concat) ─────
381
382/// Fuses the common LLM FFN pattern in one rewrite:
383///   gate = matmul(x, wg); up = matmul(x, wu); out = mul(silu(gate), up)
384///
385/// Becomes:
386///   cat = matmul(x, concat(wu, wg))   // up weights first for kernel layout
387///   out = fused_swiglu(cat)
388///
389/// Eliminates two `[..., N]` matmul outputs plus a silu buffer — the
390/// largest memory win on transformer FFN blocks.
391pub struct FuseSwiGLUDualMatmul;
392
393impl FuseSwiGLUDualMatmul {
394    fn match_dual_swiglu(
395        graph: &Graph,
396        mul_node: &Node,
397    ) -> Option<(NodeId, NodeId, NodeId, NodeId, NodeId)> {
398        if !matches!(mul_node.op, Op::Binary(BinaryOp::Mul)) {
399            return None;
400        }
401        let lhs = graph.node(mul_node.inputs[0]);
402        let rhs = graph.node(mul_node.inputs[1]);
403        let (up_mm, silu_id, silu_node) = if matches!(rhs.op, Op::Activation(Activation::Silu)) {
404            (lhs, mul_node.inputs[1], rhs)
405        } else if matches!(lhs.op, Op::Activation(Activation::Silu)) {
406            (rhs, mul_node.inputs[0], lhs)
407        } else {
408            return None;
409        };
410        if !matches!(up_mm.op, Op::MatMul) {
411            return None;
412        }
413        let gate_mm = graph.node(silu_node.inputs[0]);
414        if !matches!(gate_mm.op, Op::MatMul) {
415            return None;
416        }
417        if up_mm.inputs[0] != gate_mm.inputs[0] {
418            return None;
419        }
420        if graph.use_count(silu_id) != 1 {
421            return None;
422        }
423        Some((mul_node.id, gate_mm.id, up_mm.id, up_mm.inputs[0], silu_id))
424    }
425}
426
427impl Pass for FuseSwiGLUDualMatmul {
428    fn name(&self) -> &str {
429        "fuse_swiglu_dual_matmul"
430    }
431
432    fn run(&self, graph: Graph) -> Graph {
433        let mut matches: Vec<(NodeId, NodeId, NodeId, NodeId, NodeId)> = Vec::new();
434        let mut consumed: HashMap<NodeId, ()> = HashMap::new();
435
436        for node in graph.nodes() {
437            if let Some((mul_id, gate_mm, up_mm, _, silu_id)) =
438                Self::match_dual_swiglu(&graph, node)
439            {
440                matches.push((mul_id, gate_mm, up_mm, graph.node(up_mm).inputs[0], silu_id));
441                consumed.insert(gate_mm, ());
442                consumed.insert(up_mm, ());
443                consumed.insert(silu_id, ());
444            }
445        }
446
447        if matches.is_empty() {
448            return graph;
449        }
450
451        let match_by_mul: HashMap<NodeId, (NodeId, NodeId, NodeId)> = matches
452            .into_iter()
453            .map(|(mul, gate, up, input, _silu)| (mul, (gate, up, input)))
454            .collect();
455
456        let mut rw = Rewriter::new(&graph.name);
457        for node in graph.nodes() {
458            if consumed.contains_key(&node.id) {
459                continue;
460            }
461            if let Some(&(gate_mm, up_mm, input_id)) = match_by_mul.get(&node.id) {
462                let gate = graph.node(gate_mm);
463                let up = graph.node(up_mm);
464                let wg = gate.inputs[1];
465                let wu = up.inputs[1];
466                rw.ensure_mapped(&graph, &[input_id, wg, wu]);
467
468                let wu_shape = graph.shape(wu);
469                let wg_shape = graph.shape(wg);
470                let k = wu_shape.dim(0).unwrap_static();
471                let n_up = wu_shape.dim(1).unwrap_static();
472                let n_gate = wg_shape.dim(1).unwrap_static();
473                debug_assert_eq!(wu_shape.dim(0), wg_shape.dim(0));
474
475                // Up weights first → canonical FusedSwiGLU layout (gate_first=false).
476                let concat_shape = Shape::new(&[k, n_up + n_gate], wu_shape.dtype());
477                let concat_w = rw.add_fused(Op::Concat { axis: 1 }, &[wu, wg], concat_shape);
478
479                let out_rank = up.shape.rank();
480                let mut mm_dims: Vec<Dim> = (0..out_rank).map(|i| up.shape.dim(i)).collect();
481                mm_dims[out_rank - 1] = Dim::Static(n_up + n_gate);
482                let cat_shape = Shape::from_dims(&mm_dims, up.shape.dtype());
483                let cat_id =
484                    rw.new_graph
485                        .add_node(Op::MatMul, vec![rw.map(input_id), concat_w], cat_shape);
486
487                let fused_id = rw.new_graph.add_node(
488                    Op::FusedSwiGLU {
489                        cast_to: None,
490                        gate_first: false,
491                    },
492                    vec![cat_id],
493                    node.shape.clone(),
494                );
495                rw.replace(node.id, fused_id);
496                continue;
497            }
498            rw.copy_node(node);
499        }
500        rw.finish(&graph.outputs)
501    }
502}
503
504// ── Pass 3: Shared-input MatMul concat (QKV, SwiGLU fc11+fc12) ──────────
505
506/// Detects two MatMul nodes with the same input and concatenates their
507/// weight matrices into a single larger MatMul.
508///
509/// Pattern:
510///   %a = matmul(%x, %w1)
511///   %b = matmul(%x, %w2)
512/// Becomes:
513///   %ab = matmul(%x, concat(%w1, %w2))
514///   %a = narrow(%ab, ..., 0, n1)
515///   %b = narrow(%ab, ..., n1, n2)
516///
517/// This saves one full input read (the shared input is read once instead
518/// of twice). Critical for SwiGLU (fc11+fc12) and QKV fusion.
519pub struct FuseSharedInputMatMul;
520
521impl Pass for FuseSharedInputMatMul {
522    fn name(&self) -> &str {
523        "fuse_shared_input_matmul"
524    }
525
526    fn run(&self, graph: Graph) -> Graph {
527        struct FuseGroup {
528            input_id: NodeId,
529            matmul_ids: Vec<NodeId>,
530        }
531
532        let mut input_to_matmuls: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
533        for node in graph.nodes() {
534            if matches!(node.op, Op::MatMul) {
535                input_to_matmuls
536                    .entry(node.inputs[0])
537                    .or_default()
538                    .push(node.id);
539            }
540        }
541
542        let mut groups: Vec<FuseGroup> = Vec::new();
543        for (input_id, matmul_ids) in input_to_matmuls {
544            if matmul_ids.len() < 2 {
545                continue;
546            }
547            let first = graph.node(matmul_ids[0]);
548            let w0 = graph.shape(first.inputs[1]);
549            if w0.rank() != 2 {
550                continue;
551            }
552            let compatible = matmul_ids.iter().all(|&id| {
553                let m = graph.node(id);
554                matches!(m.op, Op::MatMul)
555                    && graph.shape(m.inputs[1]).rank() == 2
556                    && graph.shape(m.inputs[1]).dim(0) == w0.dim(0)
557            });
558            if compatible {
559                groups.push(FuseGroup {
560                    input_id,
561                    matmul_ids,
562                });
563            }
564        }
565
566        if groups.is_empty() {
567            return graph;
568        }
569
570        let group_by_first: HashMap<NodeId, &FuseGroup> =
571            groups.iter().map(|g| (g.matmul_ids[0], g)).collect();
572
573        let mut fused_away: HashMap<NodeId, ()> = HashMap::new();
574        for g in &groups {
575            for &id in &g.matmul_ids[1..] {
576                fused_away.insert(id, ());
577            }
578        }
579
580        let mut rw = Rewriter::new(&graph.name);
581        for node in graph.nodes() {
582            if fused_away.contains_key(&node.id) {
583                continue;
584            }
585
586            if let Some(group) = group_by_first.get(&node.id) {
587                let matmuls: Vec<_> = group.matmul_ids.iter().map(|&id| graph.node(id)).collect();
588                let weight_ids: Vec<NodeId> = matmuls.iter().map(|m| m.inputs[1]).collect();
589                rw.ensure_mapped(&graph, std::slice::from_ref(&group.input_id));
590                rw.ensure_mapped(&graph, &weight_ids);
591
592                let w0_shape = graph.shape(weight_ids[0]);
593                let k = w0_shape.dim(0).unwrap_static();
594                let ns: Vec<usize> = weight_ids
595                    .iter()
596                    .map(|&w| graph.shape(w).dim(1).unwrap_static())
597                    .collect();
598                let combined_n: usize = ns.iter().sum();
599
600                let concat_shape = Shape::new(&[k, combined_n], w0_shape.dtype());
601                let concat_id = rw.add_fused(Op::Concat { axis: 1 }, &weight_ids, concat_shape);
602
603                let out_rank = matmuls[0].shape.rank();
604                let mut mm_dims: Vec<Dim> =
605                    (0..out_rank).map(|i| matmuls[0].shape.dim(i)).collect();
606                mm_dims[out_rank - 1] = Dim::Static(combined_n);
607                let mm_shape = Shape::from_dims(&mm_dims, matmuls[0].shape.dtype());
608                let mm_id = rw.new_graph.add_node(
609                    Op::MatMul,
610                    vec![rw.map(group.input_id), concat_id],
611                    mm_shape,
612                );
613
614                let mut start = 0usize;
615                for (mm, &n) in matmuls.iter().zip(&ns) {
616                    let narrow = rw.new_graph.add_node(
617                        Op::Narrow {
618                            axis: out_rank - 1,
619                            start,
620                            len: n,
621                        },
622                        vec![mm_id],
623                        mm.shape.clone(),
624                    );
625                    rw.replace(mm.id, narrow);
626                    start += n;
627                }
628                continue;
629            }
630
631            rw.copy_node(node);
632        }
633
634        rw.finish(&graph.outputs)
635    }
636}
637
638// ── Pass 4: Detect SwiGLU pattern → FusedSwiGLU ────────────────────────
639
640/// Detects the post-`FuseSharedInputMatMul` SwiGLU pattern and replaces it
641/// with a single `Op::FusedSwiGLU` node consuming the concatenated matmul.
642///
643/// Pattern (after `FuseSharedInputMatMul` has fused fc11+fc12 into one mm):
644///   %cat   = matmul(%x, concat(%fc11_w, %fc12_w))   ; shape [..., 2N]
645///   %up    = narrow(%cat, axis=-1, 0, N)            ; shape [..., N]
646///   %gate  = narrow(%cat, axis=-1, N, N)            ; shape [..., N]
647///   %silu  = silu(%gate)
648///   %out   = mul(%up, %silu)
649///
650/// Becomes:
651///   %out   = fused_swiglu(%cat)
652///
653/// Saves three kernel launches (two narrows + silu + mul → one kernel) and
654/// keeps up/gate resident in registers.
655///
656/// Single-use guard: only fuses when each intermediate (narrow, narrow, silu)
657/// has exactly one consumer. The mul may have any number of consumers.
658pub struct FuseSwiGLU;
659
660impl Pass for FuseSwiGLU {
661    fn name(&self) -> &str {
662        "fuse_swiglu"
663    }
664
665    fn run(&self, graph: Graph) -> Graph {
666        // Scan for Mul nodes whose two inputs match the SwiGLU pattern.
667        // Collect rewrites first, then rebuild.
668        // up_narrow_id / silu_id / gate_narrow_id are kept for pattern-shape
669        // self-documentation even though only the rewrite path reads
670        // mul_id / cat_id / out_n.
671        #[allow(dead_code)]
672        struct Match {
673            mul_id: NodeId,
674            up_narrow_id: NodeId,
675            silu_id: NodeId,
676            gate_narrow_id: NodeId,
677            cat_id: NodeId,
678            out_n: usize,
679            gate_first: bool,
680        }
681
682        let mut matches: Vec<Match> = Vec::new();
683        let mut consumed: HashMap<NodeId, ()> = HashMap::new();
684
685        for node in graph.nodes() {
686            // Looking for: mul(narrow(cat, 0, n), silu(narrow(cat, n, n)))
687            //   — or symmetrically with up/gate swapped.
688            if !matches!(node.op, Op::Binary(BinaryOp::Mul)) {
689                continue;
690            }
691            let lhs_id = node.inputs[0];
692            let rhs_id = node.inputs[1];
693            let lhs = graph.node(lhs_id);
694            let rhs = graph.node(rhs_id);
695
696            // Decide which side is silu(gate) — the silu branch.
697            let (up_narrow, silu_id, silu_node) =
698                if matches!(rhs.op, Op::Activation(Activation::Silu)) {
699                    (lhs, rhs_id, rhs)
700                } else if matches!(lhs.op, Op::Activation(Activation::Silu)) {
701                    (rhs, lhs_id, lhs)
702                } else {
703                    continue;
704                };
705
706            // up side must be a Narrow.
707            let (up_axis, up_start, up_len) = match &up_narrow.op {
708                Op::Narrow { axis, start, len } => (*axis, *start, *len),
709                _ => continue,
710            };
711            // silu input must be a Narrow.
712            let gate_narrow_id = silu_node.inputs[0];
713            let gate_narrow = graph.node(gate_narrow_id);
714            let (g_axis, g_start, g_len) = match &gate_narrow.op {
715                Op::Narrow { axis, start, len } => (*axis, *start, *len),
716                _ => continue,
717            };
718
719            // Both narrows must come from the same source on the same axis,
720            // covering the two halves: (0..N) and (N..2N).
721            if up_narrow.inputs[0] != gate_narrow.inputs[0] {
722                continue;
723            }
724            if up_axis != g_axis {
725                continue;
726            }
727            if up_len != g_len {
728                continue;
729            }
730            let n = up_len;
731            // Canonical: up @ 0, gate @ N. Swapped (gate-first builders): gate @ 0, up @ N.
732            let gate_first = up_start == n && g_start == 0;
733            if !(gate_first || (up_start == 0 && g_start == n)) {
734                continue;
735            }
736
737            // Single-use checks: narrows feed only into silu+mul, silu feeds
738            // only into mul. The cat itself can have arbitrary other users.
739            if graph.use_count(up_narrow.id) != 1 {
740                continue;
741            }
742            if graph.use_count(gate_narrow_id) != 1 {
743                continue;
744            }
745            if graph.use_count(silu_id) != 1 {
746                continue;
747            }
748
749            matches.push(Match {
750                mul_id: node.id,
751                up_narrow_id: up_narrow.id,
752                silu_id,
753                gate_narrow_id,
754                cat_id: up_narrow.inputs[0],
755                out_n: n,
756                gate_first,
757            });
758            consumed.insert(up_narrow.id, ());
759            consumed.insert(gate_narrow_id, ());
760            consumed.insert(silu_id, ());
761        }
762
763        if matches.is_empty() {
764            return graph;
765        }
766
767        // Rebuild graph, replacing matched mul nodes with FusedSwiGLU.
768        let mut rw = Rewriter::new(&graph.name);
769        let match_by_mul: HashMap<NodeId, &Match> = matches.iter().map(|m| (m.mul_id, m)).collect();
770
771        for node in graph.nodes() {
772            if consumed.contains_key(&node.id) {
773                continue;
774            }
775
776            if let Some(m) = match_by_mul.get(&node.id) {
777                // Output shape = mul's output shape (= [..., N]).
778                let out_shape = node.shape.clone();
779                debug_assert_eq!(
780                    out_shape.dim(out_shape.rank() - 1).unwrap_static(),
781                    m.out_n,
782                    "FuseSwiGLU: output last dim should be N"
783                );
784                let fused_id = rw.add_fused(
785                    Op::FusedSwiGLU {
786                        cast_to: None,
787                        gate_first: m.gate_first,
788                    },
789                    &[m.cat_id],
790                    out_shape,
791                );
792                rw.replace(node.id, fused_id);
793                continue;
794            }
795
796            rw.copy_node(node);
797        }
798
799        rw.finish(&graph.outputs)
800    }
801}
802
803// ── Pass 5: Fuse Attention Block (QKV → SDPA → OutProj) ────────────────
804
805/// Fuses `matmul(QKV) → narrow(Q,K,V) → [rope] → attention → matmul(out)`
806/// into a single FusedAttentionBlock when batch*seq is small.
807///
808/// The optimizer auto-detects batch size from graph input shapes. For small
809/// inputs (batch*seq ≤ 64), intermediate tensors fit in L1 cache, making a
810/// monolithic kernel faster than separate BLAS calls.
811///
812/// Threshold is configurable via `RLX_FUSE_ATTN_THRESHOLD` (default: 64).
813pub struct FuseAttentionBlock;
814
815impl FuseAttentionBlock {
816    /// Check if the graph has small enough inputs to benefit from fusion.
817    ///
818    /// Returns `true` when any 2-D+ input has `dim(0) * dim(1) ≤ threshold`,
819    /// where `threshold` defaults to 64 (overridable via
820    /// `RLX_FUSE_ATTN_THRESHOLD`). The cutoff matches the L1-cache budget for
821    /// keeping Q/K/V resident on CPU and reflects the dispatch-overhead
822    /// crossover for small-batch BERT-family encoders on GPU backends.
823    fn should_fuse(graph: &Graph) -> bool {
824        let threshold: usize = rlx_ir::env::var("RLX_FUSE_ATTN_THRESHOLD")
825            .and_then(|v| v.parse().ok())
826            .unwrap_or(64);
827        for node in graph.nodes() {
828            if let Op::Input { .. } = &node.op
829                && node.shape.rank() >= 2
830            {
831                let d0 = node.shape.dim(0);
832                let d1 = node.shape.dim(1);
833                if d0.is_static() && d1.is_static() {
834                    let b = d0.unwrap_static();
835                    let s = d1.unwrap_static();
836                    if b * s <= threshold {
837                        return true;
838                    }
839                }
840            }
841        }
842        false
843    }
844}
845
846/// Match a single producer node id that produces a tensor consumed by `narrow`.
847fn narrow_parent(node: &Node) -> Option<(NodeId, usize, usize, usize)> {
848    match &node.op {
849        Op::Narrow { axis, start, len } => Some((node.inputs[0], *axis, *start, *len)),
850        _ => None,
851    }
852}
853
854/// Match `FusedMatMulBiasAct{activation: None}` and return its (input, weight, bias) tuple.
855fn fused_mm_bias_none(node: &Node) -> Option<(NodeId, NodeId, NodeId)> {
856    if let Op::FusedMatMulBiasAct { activation: None } = &node.op
857        && node.inputs.len() == 3
858    {
859        return Some((node.inputs[0], node.inputs[1], node.inputs[2]));
860    }
861    None
862}
863
864impl Pass for FuseAttentionBlock {
865    fn name(&self) -> &str {
866        "fuse_attention_block"
867    }
868
869    fn run(&self, graph: Graph) -> Graph {
870        // Bail when graph input shape is too large to benefit (the L1-resident
871        // / single-dispatch win disappears once Q/K/V no longer fit on-chip).
872        if !Self::should_fuse(&graph) {
873            return graph;
874        }
875
876        // We rewrite the chain
877        //   hidden ─ FusedMatMulBiasAct(qkv_w, qkv_b) ─ narrow×3 ─ Attention(mask) ─ FusedMatMulBiasAct(out_w, out_b)
878        // into a single `Op::FusedAttentionBlock { has_bias: true, has_rope: false }`.
879        //
880        // Pattern preconditions:
881        //   * QKV producer's only consumers are the three narrows (and not a graph
882        //     output) — otherwise we'd duplicate compute on un-fuse.
883        //   * Each narrow has exactly one consumer (the attention).
884        //   * The attention has `MaskKind::Custom` (caller-supplied mask tensor).
885        //   * The attention's only consumer is the OutProj `FusedMatMulBiasAct`.
886        //   * The OutProj is not a graph output of an *intermediate* block (i.e.
887        //     fusing it is safe — its result is the layer's actual output).
888        //
889        // When any precondition fails we fall back to copying the chain through.
890
891        let mut is_output: HashMap<NodeId, ()> = HashMap::new();
892        for &oid in &graph.outputs {
893            is_output.insert(oid, ());
894        }
895
896        // Pre-scan: for each Attention with Custom mask, decide whether the
897        // surrounding chain matches. If yes, record the IDs that get folded away.
898        struct Match {
899            attn_id: NodeId,
900            qkv_mm_id: NodeId,
901            out_mm_id: NodeId,
902            narrows: [NodeId; 3],
903            hidden_id: NodeId,
904            qkv_w: NodeId,
905            qkv_b: NodeId,
906            out_w: NodeId,
907            out_b: NodeId,
908            mask: NodeId,
909            num_heads: usize,
910            head_dim: usize,
911            out_shape: Shape,
912        }
913        let mut matches: Vec<Match> = Vec::new();
914        let mut fused_away: HashMap<NodeId, ()> = HashMap::new();
915
916        for node in graph.nodes() {
917            let Op::Attention {
918                num_heads,
919                head_dim,
920                mask_kind,
921                score_scale,
922                attn_logit_softcap,
923            } = &node.op
924            else {
925                continue;
926            };
927            // Only the BERT-style mask form (caller-supplied [B, S] tensor),
928            // no score scale tweaks, no soft-cap.
929            if !matches!(mask_kind, MaskKind::Custom)
930                || score_scale.is_some()
931                || attn_logit_softcap.is_some()
932                || node.inputs.len() != 4
933            {
934                continue;
935            }
936            let (q, k, v, mask) = (
937                node.inputs[0],
938                node.inputs[1],
939                node.inputs[2],
940                node.inputs[3],
941            );
942
943            // All three of Q, K, V must be Narrows on the same parent at
944            // start=0,h,2h with len=h on the last (innermost) axis.
945            let qn = graph.node(q);
946            let kn = graph.node(k);
947            let vn = graph.node(v);
948            let (qp, q_axis, q_start, q_len) = match narrow_parent(qn) {
949                Some(p) => p,
950                None => continue,
951            };
952            let (kp, k_axis, k_start, k_len) = match narrow_parent(kn) {
953                Some(p) => p,
954                None => continue,
955            };
956            let (vp, v_axis, v_start, v_len) = match narrow_parent(vn) {
957                Some(p) => p,
958                None => continue,
959            };
960            if qp != kp || kp != vp {
961                continue;
962            }
963            let h = num_heads * head_dim;
964            let parent_rank = graph.node(qp).shape.rank();
965            let last_ax = parent_rank.saturating_sub(1);
966            if q_axis != last_ax || k_axis != last_ax || v_axis != last_ax {
967                continue;
968            }
969            if q_len != h || k_len != h || v_len != h {
970                continue;
971            }
972            if q_start != 0 || k_start != h || v_start != 2 * h {
973                continue;
974            }
975            // Narrows must be single-consumer to be safely consumed.
976            if graph.use_count(q) != 1
977                || graph.use_count(k) != 1
978                || graph.use_count(v) != 1
979                || is_output.contains_key(&q)
980                || is_output.contains_key(&k)
981                || is_output.contains_key(&v)
982            {
983                continue;
984            }
985
986            // Parent must be FusedMatMulBiasAct (post-FuseMatMulBiasAct shape).
987            let qkv_mm_node = graph.node(qp);
988            let (hidden_id, qkv_w, qkv_b) = match fused_mm_bias_none(qkv_mm_node) {
989                Some(t) => t,
990                None => continue,
991            };
992            // The QKV MM must have exactly the three narrows as consumers and
993            // must not be a graph output itself.
994            if graph.use_count(qp) != 3 || is_output.contains_key(&qp) {
995                continue;
996            }
997
998            // Find the OutProj consumer of the Attention.
999            if graph.use_count(node.id) != 1 || is_output.contains_key(&node.id) {
1000                continue;
1001            }
1002            let out_consumer_id = match graph
1003                .nodes()
1004                .iter()
1005                .find(|n| n.inputs.contains(&node.id))
1006                .map(|n| n.id)
1007            {
1008                Some(id) => id,
1009                None => continue,
1010            };
1011            let out_mm_node = graph.node(out_consumer_id);
1012            let (out_in, out_w, out_b) = match fused_mm_bias_none(out_mm_node) {
1013                Some(t) if t.0 == node.id => t,
1014                _ => continue,
1015            };
1016            let _ = out_in;
1017
1018            // All checks passed — record the match.
1019            matches.push(Match {
1020                attn_id: node.id,
1021                qkv_mm_id: qp,
1022                out_mm_id: out_consumer_id,
1023                narrows: [q, k, v],
1024                hidden_id,
1025                qkv_w,
1026                qkv_b,
1027                out_w,
1028                out_b,
1029                mask,
1030                num_heads: *num_heads,
1031                head_dim: *head_dim,
1032                out_shape: out_mm_node.shape.clone(),
1033            });
1034            fused_away.insert(qp, ());
1035            fused_away.insert(q, ());
1036            fused_away.insert(k, ());
1037            fused_away.insert(v, ());
1038            fused_away.insert(node.id, ());
1039            fused_away.insert(out_consumer_id, ());
1040        }
1041
1042        if matches.is_empty() {
1043            return graph;
1044        }
1045
1046        // Index matches by the out-projection node id so we can swap it in-place.
1047        let mut by_out: HashMap<NodeId, &Match> = HashMap::new();
1048        for m in &matches {
1049            by_out.insert(m.out_mm_id, m);
1050        }
1051
1052        let mut rw = Rewriter::new(&graph.name);
1053        for node in graph.nodes() {
1054            if fused_away.contains_key(&node.id) {
1055                if let Some(m) = by_out.get(&node.id) {
1056                    // Make sure all referenced inputs are already in the new graph.
1057                    rw.ensure_mapped(
1058                        &graph,
1059                        &[m.hidden_id, m.qkv_w, m.out_w, m.mask, m.qkv_b, m.out_b],
1060                    );
1061                    let fused_id = rw.add_fused(
1062                        Op::FusedAttentionBlock {
1063                            num_heads: m.num_heads,
1064                            head_dim: m.head_dim,
1065                            has_bias: true,
1066                            has_rope: false,
1067                        },
1068                        &[m.hidden_id, m.qkv_w, m.out_w, m.mask, m.qkv_b, m.out_b],
1069                        m.out_shape.clone(),
1070                    );
1071                    // Wire every old chain node to the new fused id so any
1072                    // downstream consumer (residual add, LN, etc.) picks it up.
1073                    rw.replace(m.qkv_mm_id, fused_id);
1074                    rw.replace(m.narrows[0], fused_id);
1075                    rw.replace(m.narrows[1], fused_id);
1076                    rw.replace(m.narrows[2], fused_id);
1077                    rw.replace(m.attn_id, fused_id);
1078                    rw.replace(node.id, fused_id);
1079                }
1080                continue;
1081            }
1082            rw.copy_node(node);
1083        }
1084        rw.finish(&graph.outputs)
1085    }
1086}
1087
1088// ── Pass 5b: Full BERT layer → FusedTransformerLayer ────────────────────
1089
1090/// Fuses an entire BERT-style transformer layer (attention block + residual+LN +
1091/// FFN + residual+LN) into one [`Op::FusedTransformerLayer`] node.
1092///
1093/// Pattern (after [`FuseMatMulBiasAct`], [`FuseResidualLN`], and
1094/// [`FuseAttentionBlock`] have run — order matters):
1095///
1096/// ```text
1097///   skip ──┬─→ FusedAttentionBlock(qkv_w, out_w, mask, qkv_b, out_b) ─→ attn_out
1098///          └─→ FusedResidualLN(attn_out, skip, ln1_g, ln1_b) ─→ h1
1099///                                                                ├─→ FusedMatMulBiasAct(fc1_w, fc1_b, GeLU) ─→ ffn_int
1100///                                                                │                                              ↓
1101///                                                                │           FusedMatMulBiasAct(fc2_w, fc2_b, None) ─→ ffn_out
1102///                                                                └────────────────────→ FusedResidualLN(ffn_out, h1, ln2_g, ln2_b) ─→ out
1103/// ```
1104///
1105/// All five nodes collapse into a single `FusedTransformerLayer { num_heads,
1106/// head_dim, intermediate_size, eps1, eps2, activation, has_bias: true }`
1107/// with the 14-input layout consumed by `rlx-mlx`'s lowering at
1108/// `rlx-mlx/src/lower.rs:1528`:
1109/// `[hidden, qkv_w, qkv_b, out_w, out_b, ln1_g, ln1_b, fc1_w, fc1_b, fc2_w, fc2_b, ln2_g, ln2_b, mask]`.
1110///
1111/// Threshold is the same as [`FuseAttentionBlock`] (`RLX_FUSE_ATTN_THRESHOLD`,
1112/// default 64). Backends that don't natively support `FusedTransformerLayer`
1113/// un-fuse it back to primitives at compile time; backends that do (MLX) can
1114/// emit one monolithic kernel per layer.
1115pub struct FuseTransformerLayer;
1116
1117impl FuseTransformerLayer {
1118    fn should_fuse(graph: &Graph) -> bool {
1119        // Same gate as FuseAttentionBlock — single-source of truth for
1120        // "this graph is small enough for L1-resident block fusion".
1121        FuseAttentionBlock::should_fuse(graph)
1122    }
1123}
1124
1125/// Match `FusedResidualLN { has_bias: false }` and return `(x, residual, gamma, beta, eps)`.
1126fn fused_residual_ln_no_bias(node: &Node) -> Option<(NodeId, NodeId, NodeId, NodeId, f32)> {
1127    if let Op::FusedResidualLN {
1128        has_bias: false,
1129        eps,
1130    } = &node.op
1131        && node.inputs.len() == 4
1132    {
1133        return Some((
1134            node.inputs[0],
1135            node.inputs[1],
1136            node.inputs[2],
1137            node.inputs[3],
1138            *eps,
1139        ));
1140    }
1141    None
1142}
1143
1144/// Match `FusedMatMulBiasAct { activation: Some(a) }` and return `(input, weight, bias, activation)`.
1145fn fused_mm_bias_act(node: &Node) -> Option<(NodeId, NodeId, NodeId, Activation)> {
1146    if let Op::FusedMatMulBiasAct {
1147        activation: Some(a),
1148    } = &node.op
1149        && node.inputs.len() == 3
1150    {
1151        return Some((node.inputs[0], node.inputs[1], node.inputs[2], *a));
1152    }
1153    None
1154}
1155
1156/// Match `FusedAttentionBlock { has_bias: true, has_rope: false }` BERT shape.
1157fn fused_attn_block_bert(
1158    node: &Node,
1159) -> Option<(usize, usize, NodeId, NodeId, NodeId, NodeId, NodeId, NodeId)> {
1160    if let Op::FusedAttentionBlock {
1161        num_heads,
1162        head_dim,
1163        has_bias: true,
1164        has_rope: false,
1165    } = &node.op
1166        && node.inputs.len() == 6
1167    {
1168        // [hidden, qkv_w, out_w, mask, qkv_b, out_b]
1169        return Some((
1170            *num_heads,
1171            *head_dim,
1172            node.inputs[0],
1173            node.inputs[1],
1174            node.inputs[2],
1175            node.inputs[3],
1176            node.inputs[4],
1177            node.inputs[5],
1178        ));
1179    }
1180    None
1181}
1182
1183impl Pass for FuseTransformerLayer {
1184    fn name(&self) -> &str {
1185        "fuse_transformer_layer"
1186    }
1187
1188    fn run(&self, graph: Graph) -> Graph {
1189        if !Self::should_fuse(&graph) {
1190            return graph;
1191        }
1192
1193        // Graph-output guard: any intermediate we'd absorb must not be an
1194        // explicit output, otherwise a downstream caller would see the
1195        // collapsed result instead of the per-stage tensor it expects.
1196        let mut is_output: HashMap<NodeId, ()> = HashMap::new();
1197        for &oid in &graph.outputs {
1198            is_output.insert(oid, ());
1199        }
1200
1201        struct LayerMatch {
1202            attn_id: NodeId,
1203            ln1_id: NodeId,
1204            fc1_id: NodeId,
1205            fc2_id: NodeId,
1206            ln2_id: NodeId,
1207            inputs: [NodeId; 14],
1208            num_heads: usize,
1209            head_dim: usize,
1210            intermediate_size: usize,
1211            eps1: f32,
1212            eps2: f32,
1213            activation: Activation,
1214            out_shape: Shape,
1215        }
1216
1217        let mut matches: Vec<LayerMatch> = Vec::new();
1218        let mut fused_away: HashMap<NodeId, ()> = HashMap::new();
1219
1220        for node in graph.nodes() {
1221            // Anchor on each FusedAttentionBlock — every BERT layer starts here.
1222            let Some((num_heads, head_dim, hidden_id, qkv_w, out_w, mask, qkv_b, out_b)) =
1223                fused_attn_block_bert(node)
1224            else {
1225                continue;
1226            };
1227            let attn_id = node.id;
1228            // Attention's only consumer must be the post-attn FusedResidualLN.
1229            if graph.use_count(attn_id) != 1 || is_output.contains_key(&attn_id) {
1230                continue;
1231            }
1232            let ln1_id = match graph
1233                .nodes()
1234                .iter()
1235                .find(|n| n.inputs.contains(&attn_id))
1236                .map(|n| n.id)
1237            {
1238                Some(id) => id,
1239                None => continue,
1240            };
1241            let ln1_node = graph.node(ln1_id);
1242            let Some((ln1_x, ln1_res, ln1_g, ln1_b, eps1)) = fused_residual_ln_no_bias(ln1_node)
1243            else {
1244                continue;
1245            };
1246            // Order in the residual+LN: x = attn_out, residual = skip (= hidden).
1247            if ln1_x != attn_id || ln1_res != hidden_id {
1248                continue;
1249            }
1250            // h1 must have exactly 2 consumers (FFN.1 input AND ln2 residual).
1251            if graph.use_count(ln1_id) != 2 || is_output.contains_key(&ln1_id) {
1252                continue;
1253            }
1254
1255            // Find FFN.1: FusedMatMulBiasAct(h1, fc1_w, fc1_b) with GeLU.
1256            let mut fc1_candidate: Option<NodeId> = None;
1257            let mut ln2_candidate: Option<NodeId> = None;
1258            for cn in graph.nodes() {
1259                if !cn.inputs.contains(&ln1_id) {
1260                    continue;
1261                }
1262                if fused_mm_bias_act(cn).is_some() && cn.inputs[0] == ln1_id {
1263                    fc1_candidate = Some(cn.id);
1264                } else if fused_residual_ln_no_bias(cn).is_some() && cn.inputs[1] == ln1_id {
1265                    ln2_candidate = Some(cn.id);
1266                }
1267            }
1268            let (Some(fc1_id), Some(ln2_id)) = (fc1_candidate, ln2_candidate) else {
1269                continue;
1270            };
1271            let fc1_node = graph.node(fc1_id);
1272            let Some((_, fc1_w, fc1_b, activation)) = fused_mm_bias_act(fc1_node) else {
1273                continue;
1274            };
1275            // FFN.1 output → FFN.2 (single consumer).
1276            if graph.use_count(fc1_id) != 1 || is_output.contains_key(&fc1_id) {
1277                continue;
1278            }
1279            let fc2_id = match graph
1280                .nodes()
1281                .iter()
1282                .find(|n| n.inputs.contains(&fc1_id))
1283                .map(|n| n.id)
1284            {
1285                Some(id) => id,
1286                None => continue,
1287            };
1288            let fc2_node = graph.node(fc2_id);
1289            // FFN.2 must be FusedMatMulBiasAct with activation=None.
1290            let Some((fc2_in, fc2_w, fc2_b)) = fused_mm_bias_none(fc2_node) else {
1291                continue;
1292            };
1293            if fc2_in != fc1_id {
1294                continue;
1295            }
1296            if graph.use_count(fc2_id) != 1 || is_output.contains_key(&fc2_id) {
1297                continue;
1298            }
1299            // Final residual+LN: x = ffn_out, residual = h1, gamma/beta + eps2.
1300            let ln2_node = graph.node(ln2_id);
1301            let Some((ln2_x, ln2_res, ln2_g, ln2_b, eps2)) = fused_residual_ln_no_bias(ln2_node)
1302            else {
1303                continue;
1304            };
1305            if ln2_x != fc2_id || ln2_res != ln1_id {
1306                continue;
1307            }
1308            // intermediate_size from fc1_w (`[H, intermediate_size]`).
1309            let intermediate_size = {
1310                let s = &graph.node(fc1_w).shape;
1311                if s.rank() != 2 {
1312                    continue;
1313                }
1314                let d = s.dim(s.rank() - 1);
1315                if !d.is_static() {
1316                    continue;
1317                }
1318                d.unwrap_static()
1319            };
1320
1321            matches.push(LayerMatch {
1322                attn_id,
1323                ln1_id,
1324                fc1_id,
1325                fc2_id,
1326                ln2_id,
1327                inputs: [
1328                    hidden_id, qkv_w, qkv_b, out_w, out_b, ln1_g, ln1_b, fc1_w, fc1_b, fc2_w,
1329                    fc2_b, ln2_g, ln2_b, mask,
1330                ],
1331                num_heads,
1332                head_dim,
1333                intermediate_size,
1334                eps1,
1335                eps2,
1336                activation,
1337                out_shape: ln2_node.shape.clone(),
1338            });
1339            fused_away.insert(attn_id, ());
1340            fused_away.insert(ln1_id, ());
1341            fused_away.insert(fc1_id, ());
1342            fused_away.insert(fc2_id, ());
1343            fused_away.insert(ln2_id, ());
1344        }
1345
1346        if matches.is_empty() {
1347            return graph;
1348        }
1349
1350        // Index by ln2 (the layer's terminal node) so we know when to emit.
1351        let mut by_terminal: HashMap<NodeId, &LayerMatch> = HashMap::new();
1352        for m in &matches {
1353            by_terminal.insert(m.ln2_id, m);
1354        }
1355
1356        let mut rw = Rewriter::new(&graph.name);
1357        for node in graph.nodes() {
1358            if fused_away.contains_key(&node.id) {
1359                if let Some(m) = by_terminal.get(&node.id) {
1360                    rw.ensure_mapped(&graph, &m.inputs);
1361                    let fused_id = rw.add_fused(
1362                        Op::FusedTransformerLayer {
1363                            num_heads: m.num_heads,
1364                            head_dim: m.head_dim,
1365                            intermediate_size: m.intermediate_size,
1366                            eps1: m.eps1,
1367                            eps2: m.eps2,
1368                            activation: m.activation,
1369                            has_bias: true,
1370                        },
1371                        &m.inputs,
1372                        m.out_shape.clone(),
1373                    );
1374                    rw.replace(m.attn_id, fused_id);
1375                    rw.replace(m.ln1_id, fused_id);
1376                    rw.replace(m.fc1_id, fused_id);
1377                    rw.replace(m.fc2_id, fused_id);
1378                    rw.replace(node.id, fused_id);
1379                }
1380                continue;
1381            }
1382            rw.copy_node(node);
1383        }
1384        rw.finish(&graph.outputs)
1385    }
1386}
1387
1388// ── PLAN L2: MarkElementwiseRegions ─────────────────────────────────────
1389//
1390// Walk the graph and collapse maximal chains of element-wise ops
1391// (Activation / Cast / Binary / Compare) into a single
1392// `Op::ElementwiseRegion`. Conditions for inclusion in a chain:
1393//   - Op is element-wise per `is_elementwise()` (excluding Where which
1394//     has a 3-input mask semantic that doesn't compose into a single
1395//     scalar register chain cleanly — keep as separate op for now).
1396//   - Output shape exactly equals every input shape (no broadcast —
1397//     broadcast scalar/vector adds register-pattern complexity, defer).
1398//   - Every intermediate (chain-internal) value has exactly one
1399//     consumer in the *whole* graph. Multi-consumer values must
1400//     materialize.
1401// The chain start can read graph-level inputs / params / earlier-fused
1402// nodes; the chain end is the last single-consumer or terminal node.
1403// This is the simplest correct cut — N-ary chain fusion replaces the
1404// pairwise `fuse_elementwise_chains` pattern in each backend with one
1405// IR-level pass + a single backend kernel. See PLAN L2.
1406//
1407// Fusion boundaries: chains do not extend across inputs whose producer
1408// satisfies [`rlx_ir::Op::is_fusion_boundary`] (BLAS, Gaussian splat, …).
1409
1410pub struct MarkElementwiseRegions;
1411
1412impl Pass for MarkElementwiseRegions {
1413    fn name(&self) -> &str {
1414        "mark_elementwise_regions"
1415    }
1416
1417    fn run(&self, graph: Graph) -> Graph {
1418        // Tally consumer counts for every node id.
1419        let mut consumers: HashMap<NodeId, usize> = HashMap::new();
1420        for node in graph.nodes() {
1421            for &input in &node.inputs {
1422                *consumers.entry(input).or_insert(0) += 1;
1423            }
1424        }
1425        for &out in &graph.outputs {
1426            *consumers.entry(out).or_insert(0) += 1;
1427        }
1428
1429        // Predicate: does this op qualify for chain inclusion?
1430        let chain_eligible = |op: &Op| -> bool {
1431            matches!(
1432                op,
1433                Op::Activation(_) | Op::Cast { .. } | Op::Binary(_) | Op::Compare(_) | Op::Where
1434            )
1435        };
1436
1437        // Per-node refinement: a `Cast { to }` only qualifies when the
1438        // destination dtype matches the operand's dtype. The chain
1439        // kernel runs entirely in f32 register scratch and writes the
1440        // tail back to the output node's arena slot — which is sized
1441        // for the tail dtype. A cross-dtype Cast inside the chain would
1442        // lose precision (no actual conversion happens in scratch) AND
1443        // mis-size the final write (an F16 output slot is half the
1444        // bytes of f32). Same-dtype Casts are trivially propagated.
1445        let chain_step_safe = |graph: &Graph, node: &rlx_ir::Node| -> bool {
1446            match &node.op {
1447                Op::Cast { to } => {
1448                    let in_dt = graph.shape(node.inputs[0]).dtype();
1449                    *to == in_dt
1450                }
1451                _ => true,
1452            }
1453        };
1454
1455        // For each node, compute which "chain root" it belongs to.
1456        // A chain consists of a sequence of single-consumer chain-eligible
1457        // nodes leading to a chain "tail" (last node before a multi-consumer
1458        // or non-eligible boundary). We assign each node a `region_id`
1459        // (= the tail's NodeId) iff it's part of a region with ≥2 ops.
1460        // Walk in topological (forward) order; for each chain-eligible
1461        // node whose every input is either non-region OR a single-consumer
1462        // region member, extend its parent chain.
1463        let mut region_of: HashMap<NodeId, NodeId> = HashMap::new();
1464        let mut chain_step_idx: HashMap<NodeId, u32> = HashMap::new();
1465
1466        for node in graph.nodes() {
1467            if !chain_eligible(&node.op) {
1468                continue;
1469            }
1470            if !chain_step_safe(&graph, node) {
1471                continue;
1472            }
1473            // Each input must either match the output element count
1474            // exactly OR be a trailing-shape broadcast (its element
1475            // count divides the output's). The kernel reads
1476            // `arena[input_offs[i] + (gid % input_modulus[i])]` for
1477            // broadcast inputs; non-broadcast inputs leave the modulus
1478            // at 0 to skip the modulo.
1479            let out_shape = &node.shape;
1480            let out_elems = out_shape.num_elements();
1481            let shape_ok = node.inputs.iter().all(|id| {
1482                let in_elems = graph.shape(*id).num_elements();
1483                match (in_elems, out_elems) {
1484                    (Some(i), Some(o)) if i == o => true,
1485                    (Some(i), Some(o)) if i > 0 && o % i == 0 => true,
1486                    _ => false,
1487                }
1488            });
1489            if !shape_ok {
1490                continue;
1491            }
1492            // A chain extends an input's chain when the input is itself
1493            // chain-eligible AND has exactly one consumer (= this node).
1494            // If multiple inputs satisfy this, the chains must be the same
1495            // (= they share a chain root); pick that root.
1496            let mut parent_root: Option<NodeId> = None;
1497            let mut all_inputs_single_consumer = true;
1498            for &input in &node.inputs {
1499                // BLAS / splat render ops are explicit fusion boundaries.
1500                if graph.node(input).op.is_fusion_boundary() {
1501                    parent_root = None;
1502                    all_inputs_single_consumer = false;
1503                    break;
1504                }
1505                if let Some(&root) = region_of.get(&input) {
1506                    if consumers.get(&input).copied() != Some(1) {
1507                        all_inputs_single_consumer = false;
1508                        break;
1509                    }
1510                    match parent_root {
1511                        None => parent_root = Some(root),
1512                        Some(r) if r == root => {}
1513                        Some(_) => {
1514                            parent_root = None;
1515                            all_inputs_single_consumer = false;
1516                            break;
1517                        }
1518                    }
1519                }
1520            }
1521            if !all_inputs_single_consumer {
1522                // Start a fresh chain rooted at this node.
1523                region_of.insert(node.id, node.id);
1524                chain_step_idx.insert(node.id, 0);
1525                continue;
1526            }
1527            let root = parent_root.unwrap_or(node.id);
1528            // step idx = max(parents' idx in same chain) + 1
1529            let next_idx = node
1530                .inputs
1531                .iter()
1532                .filter_map(|id| {
1533                    if region_of.get(id) == Some(&root) {
1534                        chain_step_idx.get(id).copied()
1535                    } else {
1536                        None
1537                    }
1538                })
1539                .max()
1540                .map(|m| m + 1)
1541                .unwrap_or(0);
1542            let limits = crate::limits::active_fusion_limits();
1543            if next_idx >= limits.max_elementwise_steps {
1544                region_of.insert(node.id, node.id);
1545                chain_step_idx.insert(node.id, 0);
1546                continue;
1547            }
1548            region_of.insert(node.id, root);
1549            chain_step_idx.insert(node.id, next_idx);
1550        }
1551
1552        // Group nodes by region_id; only regions with ≥2 nodes are worth fusing.
1553        // The "region tail" (= last node) becomes the new ElementwiseRegion node.
1554        let mut by_region: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
1555        for node in graph.nodes() {
1556            if let Some(&root) = region_of.get(&node.id) {
1557                by_region.entry(root).or_default().push(node.id);
1558            }
1559        }
1560
1561        // Each region's "tail" is the node with the highest chain_step_idx.
1562        // For correctness, the tail must be the only node in the region with
1563        // a non-region or multi-consumer outflow — otherwise the region would
1564        // span past it. Skip regions where the tail isn't unique (= chain
1565        // forks internally).
1566        let mut tail_of_region: HashMap<NodeId, NodeId> = HashMap::new();
1567        for (root, members) in &by_region {
1568            if members.len() < 2 {
1569                continue;
1570            }
1571            let max_idx = members.iter().map(|id| chain_step_idx[id]).max().unwrap();
1572            let tails: Vec<_> = members
1573                .iter()
1574                .filter(|id| chain_step_idx[id] == max_idx)
1575                .collect();
1576            if tails.len() != 1 {
1577                continue;
1578            }
1579            tail_of_region.insert(*root, *tails[0]);
1580        }
1581
1582        // Drop "regions" that aren't worth fusing (size < 2 or non-unique tail).
1583        let by_region: HashMap<NodeId, Vec<NodeId>> = by_region
1584            .into_iter()
1585            .filter(|(root, _)| tail_of_region.contains_key(root))
1586            .collect();
1587
1588        if by_region.is_empty() {
1589            return graph;
1590        }
1591
1592        // Rewrite the graph: copy non-region nodes verbatim; for each region,
1593        // emit a single ElementwiseRegion at the tail's position (in topo order)
1594        // and replace each region member's NodeId in the id map with that.
1595        let mut rw = Rewriter::new(&graph.name);
1596        // Track region nodes already emitted (we emit at tail's topo position).
1597        let mut emitted_region: HashMap<NodeId, NodeId> = HashMap::new();
1598
1599        for node in graph.nodes() {
1600            if let Some(&root) = region_of.get(&node.id)
1601                && let Some(&tail) = tail_of_region.get(&root)
1602            {
1603                if emitted_region.contains_key(&root) {
1604                    // Member but tail already emitted (or not tail). Map to
1605                    // either the new region node (if tail) or to a sentinel
1606                    // we never look up directly. Internal members are not
1607                    // referenced after fusion (single-consumer guarantee),
1608                    // so we map them to the region node id for safety.
1609                    let region_new = emitted_region[&root];
1610                    rw.replace(node.id, region_new);
1611                    continue;
1612                }
1613                if node.id == tail {
1614                    // Sort region members in topological (= chain step) order.
1615                    let members = &by_region[&root];
1616                    let mut ordered: Vec<NodeId> = members.clone();
1617                    ordered.sort_by_key(|id| chain_step_idx[id]);
1618
1619                    // Collect external inputs (chain inputs that aren't members).
1620                    // SSA: each chain step refers to either an external input
1621                    // or a previous step. Build the chain.
1622                    let mut external_inputs: Vec<NodeId> = Vec::new();
1623                    let mut input_idx_of: HashMap<NodeId, u32> = HashMap::new();
1624                    let mut step_idx_of: HashMap<NodeId, u32> = HashMap::new();
1625                    for (i, member_id) in ordered.iter().enumerate() {
1626                        step_idx_of.insert(*member_id, i as u32);
1627                        let n = graph.node(*member_id);
1628                        for &inp in &n.inputs {
1629                            if !step_idx_of.contains_key(&inp) && !input_idx_of.contains_key(&inp) {
1630                                let idx = external_inputs.len() as u32;
1631                                input_idx_of.insert(inp, idx);
1632                                external_inputs.push(inp);
1633                            }
1634                        }
1635                    }
1636
1637                    let limits = crate::limits::active_fusion_limits();
1638                    if external_inputs.len() as u32 > limits.max_elementwise_inputs
1639                        || ordered.len() as u32 > limits.max_elementwise_steps
1640                    {
1641                        for &mid in &ordered {
1642                            rw.copy_node(graph.node(mid));
1643                        }
1644                        continue;
1645                    }
1646
1647                    let resolve = |id: NodeId| -> ChainOperand {
1648                        if let Some(&i) = input_idx_of.get(&id) {
1649                            ChainOperand::Input(i)
1650                        } else {
1651                            ChainOperand::Step(step_idx_of[&id])
1652                        }
1653                    };
1654                    let mut chain: Vec<ChainStep> = Vec::with_capacity(ordered.len());
1655                    for member_id in &ordered {
1656                        let n = graph.node(*member_id);
1657                        let step = match &n.op {
1658                            Op::Activation(a) => ChainStep::Activation(*a, resolve(n.inputs[0])),
1659                            Op::Cast { to } => ChainStep::Cast(*to, resolve(n.inputs[0])),
1660                            Op::Binary(op) => {
1661                                ChainStep::Binary(*op, resolve(n.inputs[0]), resolve(n.inputs[1]))
1662                            }
1663                            Op::Compare(op) => {
1664                                ChainStep::Compare(*op, resolve(n.inputs[0]), resolve(n.inputs[1]))
1665                            }
1666                            Op::Where => ChainStep::Where(
1667                                resolve(n.inputs[0]),
1668                                resolve(n.inputs[1]),
1669                                resolve(n.inputs[2]),
1670                            ),
1671                            _ => unreachable!("non-chain-eligible op in region"),
1672                        };
1673                        chain.push(step);
1674                    }
1675
1676                    // PLAN L2 quality: per-input broadcast metadata.
1677                    // `scalar_input_mask` is the fast-path bitfield
1678                    // (bit `i` set ⇒ input `i` is a single-element
1679                    // scalar). `input_modulus[i]` is the per-input
1680                    // element count: 0 means "no broadcast" (kernel
1681                    // reads gid directly), >0 means tile by modulo.
1682                    // Encoder enforces `out_elems % in_elems == 0`
1683                    // upstream so the modulo divides cleanly.
1684                    let mut scalar_input_mask: u32 = 0;
1685                    let mut input_modulus = [0u32; 16];
1686                    let region_shape_elems = graph.node(tail).shape.num_elements();
1687                    for (i, &ext) in external_inputs.iter().enumerate() {
1688                        if i >= 16 {
1689                            break;
1690                        }
1691                        let in_elems = graph.shape(ext).num_elements();
1692                        match (in_elems, region_shape_elems) {
1693                            (Some(1), Some(o)) if o != 1 => {
1694                                scalar_input_mask |= 1u32 << i;
1695                                input_modulus[i] = 1;
1696                            }
1697                            (Some(i_n), Some(o)) if i_n != o && i_n > 0 => {
1698                                input_modulus[i] = i_n as u32;
1699                            }
1700                            _ => { /* no broadcast: leave modulus 0 */ }
1701                        }
1702                    }
1703                    let region_new = rw.add_fused(
1704                        Op::ElementwiseRegion {
1705                            chain,
1706                            num_inputs: external_inputs.len() as u32,
1707                            scalar_input_mask,
1708                            input_modulus,
1709                            prologue: RegionPrologue::None,
1710                            prologue_input: 0,
1711                        },
1712                        &external_inputs,
1713                        graph.node(tail).shape.clone(),
1714                    );
1715                    emitted_region.insert(root, region_new);
1716                    rw.replace(node.id, region_new);
1717                    continue;
1718                } else {
1719                    // Region member but not tail; skip (will be replaced
1720                    // when the tail is processed).
1721                    rw.replace(node.id, NodeId(u32::MAX)); // sentinel
1722                    continue;
1723                }
1724            }
1725            rw.copy_node(node);
1726        }
1727
1728        // Final cleanup pass: any sentinel id_map entries get rewired to
1729        // their region's emitted node now that emission is done.
1730        // (Actually the order above means tails are processed in topo
1731        // order and members appear before tails in topo order, so by the
1732        // time a member's consumer is rewritten its id_map points to the
1733        // sentinel. Fix-up: walk again, rewrite sentinels.)
1734        // Simpler approach: process region members in second pass.
1735        // The current order processes tail last per region, so non-tail
1736        // members get sentinels. Their consumers are either other region
1737        // members (which we don't directly use the input from) or the
1738        // tail itself. Since the tail builds its own chain via members
1739        // directly from the original graph, the rewriter's id_map for
1740        // non-tail members is only consulted for the tail's input list —
1741        // which we resolve via `external_inputs` (already correctly
1742        // mapped via add_fused → map_inputs). So sentinels are safe.
1743
1744        rw.finish(&graph.outputs)
1745    }
1746}
1747
1748// ── PLAN L2 fallback: UnfuseElementwiseRegions ───────────────────────
1749//
1750// Decompose `Op::ElementwiseRegion` back into its constituent atomic
1751// ops (Activation / Cast / Binary / Compare). The output of the
1752// region is replaced with the result of the chain's last step;
1753// internal step results become individual nodes wired into the rest
1754// of the graph. Used by backends that don't have a native region
1755// kernel — they get the *correctness* of L2's IR-level fusion (no op
1756// missing) without needing to implement region codegen. Run BEFORE
1757// the backend's own lowering. No-op when the graph contains no
1758// ElementwiseRegion nodes.
1759
1760pub struct UnfuseElementwiseRegions {
1761    /// When false, `ElementwiseRegion` nodes with an FKL prologue are kept
1762    /// for native GPU region kernels; when true (CPU), they decompose too.
1763    pub unfuse_prologue: bool,
1764}
1765
1766impl UnfuseElementwiseRegions {
1767    /// GPU / Metal / CUDA / wgpu: unfuse plain regions, keep resize prologue.
1768    pub const FOR_GPU: UnfuseElementwiseRegions = UnfuseElementwiseRegions {
1769        unfuse_prologue: false,
1770    };
1771    /// CPU: decompose every region (no native region executor).
1772    pub const FOR_CPU: UnfuseElementwiseRegions = UnfuseElementwiseRegions {
1773        unfuse_prologue: true,
1774    };
1775}
1776
1777impl Pass for UnfuseElementwiseRegions {
1778    fn name(&self) -> &str {
1779        "unfuse_elementwise_regions"
1780    }
1781
1782    fn run(&self, graph: Graph) -> Graph {
1783        let any_region = graph
1784            .nodes()
1785            .iter()
1786            .any(|n| matches!(n.op, Op::ElementwiseRegion { .. }));
1787        if !any_region {
1788            return graph;
1789        }
1790
1791        let mut rw = Rewriter::new(&graph.name);
1792        for node in graph.nodes() {
1793            if let Op::ElementwiseRegion {
1794                chain,
1795                num_inputs: _,
1796                scalar_input_mask: _,
1797                input_modulus: _,
1798                prologue,
1799                prologue_input: _,
1800            } = &node.op
1801            {
1802                if *prologue != RegionPrologue::None && !self.unfuse_prologue {
1803                    rw.copy_node(node);
1804                    continue;
1805                }
1806                let mut region_inputs: Vec<NodeId> =
1807                    node.inputs.iter().map(|id| rw.map(*id)).collect();
1808                if *prologue == RegionPrologue::ResizeNearest2x {
1809                    let in_shape = rw.new_graph.node(region_inputs[0]).shape.clone();
1810                    let out_shape = if in_shape.rank() == 4 {
1811                        Shape::new(
1812                            &[
1813                                in_shape.dim(0).unwrap_static(),
1814                                in_shape.dim(1).unwrap_static(),
1815                                in_shape.dim(2).unwrap_static() * 2,
1816                                in_shape.dim(3).unwrap_static() * 2,
1817                            ],
1818                            in_shape.dtype(),
1819                        )
1820                    } else {
1821                        node.shape.clone()
1822                    };
1823                    region_inputs[0] = rw.new_graph.add_node(
1824                        Op::ResizeNearest2x,
1825                        vec![region_inputs[0]],
1826                        out_shape,
1827                    );
1828                }
1829                let mut step_ids: Vec<NodeId> = Vec::with_capacity(chain.len());
1830                let region_shape = node.shape.clone();
1831                let region_dims: Vec<_> = region_shape.dims().to_vec();
1832                // Per-step result dtype, indexed by step position.
1833                // The chain may pass through Cast steps that change the
1834                // dtype mid-chain; using `region_shape.dtype()` blindly
1835                // would mis-tag intermediate Activation/Binary/Where
1836                // shapes. Track the dtype propagated through each step.
1837                let mut step_dtypes: Vec<rlx_ir::DType> = Vec::with_capacity(chain.len());
1838                let region_dtype = region_shape.dtype();
1839                let dtype_of = |op: &ChainOperand,
1840                                ins: &[NodeId],
1841                                step_dt: &[rlx_ir::DType],
1842                                rw: &Rewriter|
1843                 -> rlx_ir::DType {
1844                    match *op {
1845                        ChainOperand::Input(i) => rw.new_graph.node(ins[i as usize]).shape.dtype(),
1846                        ChainOperand::Step(i) => step_dt[i as usize],
1847                    }
1848                };
1849                // Shape of an operand in the rewritten graph. Critical
1850                // for broadcast inputs: a region whose final shape is
1851                // `[8, 1]` can still have a scalar operand at some
1852                // step; tagging that step with region_dims would lie
1853                // about its element count and trip the binary/activation
1854                // kernels (which size their reads/writes off the IR
1855                // shape, not the broadcast-aware semantics the L2
1856                // region kernel would have used). Use the actual node
1857                // shape so the unfused pipeline matches what each op
1858                // semantically produces.
1859                let shape_of = |op: &ChainOperand,
1860                                ins: &[NodeId],
1861                                step_ids: &[NodeId],
1862                                rw: &Rewriter|
1863                 -> Shape {
1864                    match *op {
1865                        ChainOperand::Input(i) => rw.new_graph.node(ins[i as usize]).shape.clone(),
1866                        ChainOperand::Step(i) => {
1867                            rw.new_graph.node(step_ids[i as usize]).shape.clone()
1868                        }
1869                    }
1870                };
1871                for step in chain {
1872                    let resolve = |op: &ChainOperand| -> NodeId {
1873                        match *op {
1874                            ChainOperand::Input(i) => region_inputs[i as usize],
1875                            ChainOperand::Step(i) => step_ids[i as usize],
1876                        }
1877                    };
1878                    let (new_id, dt) = match step {
1879                        ChainStep::Activation(a, src) => {
1880                            let s = resolve(src);
1881                            let dt = dtype_of(src, &region_inputs, &step_dtypes, &rw);
1882                            // Activation is element-wise: output shape
1883                            // == input shape (preserve broadcast-source
1884                            // shapes; do NOT promote to region_dims).
1885                            let src_shape = shape_of(src, &region_inputs, &step_ids, &rw);
1886                            let dims: Vec<_> = src_shape.dims().to_vec();
1887                            let shape = Shape::from_dims(&dims, dt);
1888                            (
1889                                rw.new_graph.add_node(Op::Activation(*a), vec![s], shape),
1890                                dt,
1891                            )
1892                        }
1893                        ChainStep::Cast(to, src) => {
1894                            let s = resolve(src);
1895                            let src_shape = shape_of(src, &region_inputs, &step_ids, &rw);
1896                            let dims: Vec<_> = src_shape.dims().to_vec();
1897                            let shape = Shape::from_dims(&dims, *to);
1898                            (
1899                                rw.new_graph.add_node(Op::Cast { to: *to }, vec![s], shape),
1900                                *to,
1901                            )
1902                        }
1903                        ChainStep::Binary(op, lhs, rhs) => {
1904                            let l = resolve(lhs);
1905                            let r = resolve(rhs);
1906                            let dt = dtype_of(lhs, &region_inputs, &step_dtypes, &rw);
1907                            // Binary: NumPy-style broadcast of operands.
1908                            let lhs_shape = shape_of(lhs, &region_inputs, &step_ids, &rw);
1909                            let rhs_shape = shape_of(rhs, &region_inputs, &step_ids, &rw);
1910                            let bcast = rlx_ir::shape::broadcast(&lhs_shape, &rhs_shape)
1911                                .unwrap_or_else(|e| {
1912                                    panic!(
1913                                        "unfuse_elementwise_regions: cannot broadcast \
1914                                         {lhs_shape:?} ⊗ {rhs_shape:?} for Binary({op:?}): {e}"
1915                                    )
1916                                });
1917                            let dims: Vec<_> = bcast.dims().to_vec();
1918                            let shape = Shape::from_dims(&dims, dt);
1919                            (
1920                                rw.new_graph.add_node(Op::Binary(*op), vec![l, r], shape),
1921                                dt,
1922                            )
1923                        }
1924                        ChainStep::Compare(op, lhs, rhs) => {
1925                            let l = resolve(lhs);
1926                            let r = resolve(rhs);
1927                            let lhs_shape = shape_of(lhs, &region_inputs, &step_ids, &rw);
1928                            let rhs_shape = shape_of(rhs, &region_inputs, &step_ids, &rw);
1929                            let bcast = rlx_ir::shape::broadcast(&lhs_shape, &rhs_shape)
1930                                .unwrap_or_else(|e| {
1931                                    panic!(
1932                                        "unfuse_elementwise_regions: cannot broadcast \
1933                                         {lhs_shape:?} ⊗ {rhs_shape:?} for Compare({op:?}): {e}"
1934                                    )
1935                                });
1936                            let dims: Vec<_> = bcast.dims().to_vec();
1937                            let shape = Shape::from_dims(&dims, rlx_ir::DType::Bool);
1938                            (
1939                                rw.new_graph.add_node(Op::Compare(*op), vec![l, r], shape),
1940                                rlx_ir::DType::Bool,
1941                            )
1942                        }
1943                        ChainStep::Where(c, x, y) => {
1944                            let cn = resolve(c);
1945                            let xn = resolve(x);
1946                            let yn = resolve(y);
1947                            let dt = dtype_of(x, &region_inputs, &step_dtypes, &rw);
1948                            // Where: broadcast across (cond, then, else).
1949                            let c_shape = shape_of(c, &region_inputs, &step_ids, &rw);
1950                            let x_shape = shape_of(x, &region_inputs, &step_ids, &rw);
1951                            let y_shape = shape_of(y, &region_inputs, &step_ids, &rw);
1952                            let bcast_xy = rlx_ir::shape::broadcast(&x_shape, &y_shape)
1953                                .unwrap_or_else(|e| {
1954                                    panic!(
1955                                        "unfuse_elementwise_regions: cannot broadcast \
1956                                         then/else {x_shape:?} ⊗ {y_shape:?} for Where: {e}"
1957                                    )
1958                                });
1959                            let bcast = rlx_ir::shape::broadcast(&c_shape, &bcast_xy)
1960                                .unwrap_or_else(|e| {
1961                                    panic!(
1962                                        "unfuse_elementwise_regions: cannot broadcast cond \
1963                                         {c_shape:?} ⊗ {bcast_xy:?} for Where: {e}"
1964                                    )
1965                                });
1966                            let dims: Vec<_> = bcast.dims().to_vec();
1967                            let shape = Shape::from_dims(&dims, dt);
1968                            (
1969                                rw.new_graph.add_node(Op::Where, vec![cn, xn, yn], shape),
1970                                dt,
1971                            )
1972                        }
1973                    };
1974                    step_ids.push(new_id);
1975                    step_dtypes.push(dt);
1976                }
1977                let _ = region_dtype;
1978                let _ = region_dims;
1979                // The region's "output" (= last step) replaces the original
1980                // ElementwiseRegion node id.
1981                let last = *step_ids.last().expect("chain non-empty per pass invariant");
1982                rw.replace(node.id, last);
1983                continue;
1984            }
1985            rw.copy_node(node);
1986        }
1987        rw.finish(&graph.outputs)
1988    }
1989}
1990
1991/// Unfuse only `ElementwiseRegion` nodes that exceed [`crate::limits::FusionLimits`].
1992///
1993/// Run after [`MarkElementwiseRegions`] when marking may still produce
1994/// oversized chains (e.g. limits tightened per backend).
1995pub fn clip_elementwise_regions(graph: Graph, limits: crate::limits::FusionLimits) -> Graph {
1996    let oversize = |n: &rlx_ir::Node| -> bool {
1997        matches!(
1998            &n.op,
1999            Op::ElementwiseRegion {
2000                chain,
2001                num_inputs,
2002                ..
2003            } if *num_inputs > limits.max_elementwise_inputs
2004                || chain.len() as u32 > limits.max_elementwise_steps
2005        )
2006    };
2007    if !graph.nodes().iter().any(oversize) {
2008        return graph;
2009    }
2010
2011    let mut rw = Rewriter::new(&graph.name);
2012    for node in graph.nodes() {
2013        if !oversize(node) {
2014            rw.copy_node(node);
2015            continue;
2016        }
2017
2018        let Op::ElementwiseRegion {
2019            chain,
2020            num_inputs: _,
2021            scalar_input_mask: _,
2022            input_modulus: _,
2023            prologue: _,
2024            prologue_input: _,
2025        } = &node.op
2026        else {
2027            unreachable!();
2028        };
2029
2030        let region_inputs: Vec<NodeId> = node.inputs.iter().map(|id| rw.map(*id)).collect();
2031        let mut step_ids: Vec<NodeId> = Vec::with_capacity(chain.len());
2032        let region_shape = node.shape.clone();
2033        let region_dims: Vec<_> = region_shape.dims().to_vec();
2034        let mut step_dtypes: Vec<rlx_ir::DType> = Vec::with_capacity(chain.len());
2035        let region_dtype = region_shape.dtype();
2036        let dtype_of = |op: &ChainOperand,
2037                        ins: &[NodeId],
2038                        step_dt: &[rlx_ir::DType],
2039                        rw: &Rewriter|
2040         -> rlx_ir::DType {
2041            match *op {
2042                ChainOperand::Input(i) => rw.new_graph.node(ins[i as usize]).shape.dtype(),
2043                ChainOperand::Step(i) => step_dt[i as usize],
2044            }
2045        };
2046        let shape_of =
2047            |op: &ChainOperand, ins: &[NodeId], step_ids: &[NodeId], rw: &Rewriter| -> Shape {
2048                match *op {
2049                    ChainOperand::Input(i) => rw.new_graph.node(ins[i as usize]).shape.clone(),
2050                    ChainOperand::Step(i) => rw.new_graph.node(step_ids[i as usize]).shape.clone(),
2051                }
2052            };
2053        for step in chain {
2054            let resolve = |op: &ChainOperand| -> NodeId {
2055                match *op {
2056                    ChainOperand::Input(i) => region_inputs[i as usize],
2057                    ChainOperand::Step(i) => step_ids[i as usize],
2058                }
2059            };
2060            let (new_id, dt) = match step {
2061                ChainStep::Activation(a, src) => {
2062                    let s = resolve(src);
2063                    let dt = dtype_of(src, &region_inputs, &step_dtypes, &rw);
2064                    let src_shape = shape_of(src, &region_inputs, &step_ids, &rw);
2065                    let dims: Vec<_> = src_shape.dims().to_vec();
2066                    let shape = Shape::from_dims(&dims, dt);
2067                    (
2068                        rw.new_graph.add_node(Op::Activation(*a), vec![s], shape),
2069                        dt,
2070                    )
2071                }
2072                ChainStep::Cast(to, src) => {
2073                    let s = resolve(src);
2074                    let src_shape = shape_of(src, &region_inputs, &step_ids, &rw);
2075                    let dims: Vec<_> = src_shape.dims().to_vec();
2076                    let shape = Shape::from_dims(&dims, *to);
2077                    (
2078                        rw.new_graph.add_node(Op::Cast { to: *to }, vec![s], shape),
2079                        *to,
2080                    )
2081                }
2082                ChainStep::Binary(op, lhs, rhs) => {
2083                    let l = resolve(lhs);
2084                    let r = resolve(rhs);
2085                    let dt = dtype_of(lhs, &region_inputs, &step_dtypes, &rw);
2086                    let l_shape = shape_of(lhs, &region_inputs, &step_ids, &rw);
2087                    let r_shape = shape_of(rhs, &region_inputs, &step_ids, &rw);
2088                    let bcast = l_shape
2089                        .broadcast_with(&r_shape)
2090                        .unwrap_or_else(|e| panic!("clip_elementwise_regions: {e}"));
2091                    let dims: Vec<_> = bcast.dims().to_vec();
2092                    let shape = Shape::from_dims(&dims, dt);
2093                    (
2094                        rw.new_graph.add_node(Op::Binary(*op), vec![l, r], shape),
2095                        dt,
2096                    )
2097                }
2098                ChainStep::Compare(op, lhs, rhs) => {
2099                    let l = resolve(lhs);
2100                    let r = resolve(rhs);
2101                    let l_shape = shape_of(lhs, &region_inputs, &step_ids, &rw);
2102                    let r_shape = shape_of(rhs, &region_inputs, &step_ids, &rw);
2103                    let bcast = l_shape
2104                        .broadcast_with(&r_shape)
2105                        .unwrap_or_else(|e| panic!("clip_elementwise_regions: {e}"));
2106                    let dims: Vec<_> = bcast.dims().to_vec();
2107                    let shape = Shape::from_dims(&dims, rlx_ir::DType::U8);
2108                    (
2109                        rw.new_graph.add_node(Op::Compare(*op), vec![l, r], shape),
2110                        rlx_ir::DType::U8,
2111                    )
2112                }
2113                ChainStep::Where(cond, x, y) => {
2114                    let cn = resolve(cond);
2115                    let xn = resolve(x);
2116                    let yn = resolve(y);
2117                    let dt = dtype_of(x, &region_inputs, &step_dtypes, &rw);
2118                    let x_shape = shape_of(x, &region_inputs, &step_ids, &rw);
2119                    let y_shape = shape_of(y, &region_inputs, &step_ids, &rw);
2120                    let c_shape = shape_of(cond, &region_inputs, &step_ids, &rw);
2121                    let bcast_xy = x_shape
2122                        .broadcast_with(&y_shape)
2123                        .unwrap_or_else(|e| panic!("clip_elementwise_regions: {e}"));
2124                    let bcast = c_shape.broadcast_with(&bcast_xy).unwrap_or_else(|e| {
2125                        panic!("clip_elementwise_regions: cannot broadcast cond {c_shape:?} ⊗ {bcast_xy:?} for Where: {e}")
2126                    });
2127                    let dims: Vec<_> = bcast.dims().to_vec();
2128                    let shape = Shape::from_dims(&dims, dt);
2129                    (
2130                        rw.new_graph.add_node(Op::Where, vec![cn, xn, yn], shape),
2131                        dt,
2132                    )
2133                }
2134            };
2135            step_ids.push(new_id);
2136            step_dtypes.push(dt);
2137        }
2138        let _ = (region_dtype, region_dims);
2139        let last = *step_ids
2140            .last()
2141            .expect("oversize region has non-empty chain");
2142        rw.replace(node.id, last);
2143    }
2144    rw.finish(&graph.outputs)
2145}
2146
2147#[cfg(test)]
2148mod tests {
2149    use super::*;
2150    use crate::limits::FusionLimits;
2151    use crate::pass::run_passes;
2152
2153    fn f32_shape(dims: &[usize]) -> Shape {
2154        Shape::new(dims, DType::F32)
2155    }
2156
2157    #[test]
2158    fn fuse_matmul_bias_gelu() {
2159        let mut g = Graph::new("test");
2160        let x = g.input("x", f32_shape(&[4, 15, 384]));
2161        let w = g.param("w", f32_shape(&[384, 1536]));
2162        let b = g.param("b", f32_shape(&[1536]));
2163        let mm = g.matmul(x, w, f32_shape(&[4, 15, 1536]));
2164        let add = g.binary(BinaryOp::Add, mm, b, f32_shape(&[4, 15, 1536]));
2165        let out = g.activation(Activation::Gelu, add, f32_shape(&[4, 15, 1536]));
2166        g.set_outputs(vec![out]);
2167
2168        assert_eq!(g.len(), 6); // input, w, b, mm, add, gelu
2169
2170        let fused = FuseMatMulBiasAct.run(g);
2171        println!("{fused}");
2172
2173        // Should be: input, w, b, fused_mm_bias_gelu
2174        assert_eq!(fused.len(), 4);
2175        let out_node = fused.node(fused.outputs[0]);
2176        assert!(matches!(
2177            out_node.op,
2178            Op::FusedMatMulBiasAct {
2179                activation: Some(Activation::Gelu)
2180            }
2181        ));
2182    }
2183
2184    #[test]
2185    fn fuse_matmul_bias_no_act() {
2186        let mut g = Graph::new("test");
2187        let x = g.input("x", f32_shape(&[4, 15, 384]));
2188        let w = g.param("w", f32_shape(&[384, 384]));
2189        let b = g.param("b", f32_shape(&[384]));
2190        let mm = g.matmul(x, w, f32_shape(&[4, 15, 384]));
2191        let add = g.binary(BinaryOp::Add, mm, b, f32_shape(&[4, 15, 384]));
2192        g.set_outputs(vec![add]);
2193
2194        let fused = FuseMatMulBiasAct.run(g);
2195        assert_eq!(fused.len(), 4);
2196        let out_node = fused.node(fused.outputs[0]);
2197        assert!(matches!(
2198            out_node.op,
2199            Op::FusedMatMulBiasAct { activation: None }
2200        ));
2201    }
2202
2203    #[test]
2204    fn fuse_matmul_bias_skips_unsupported_activation_epilogue() {
2205        let mut g = Graph::new("test");
2206        let x = g.input("x", f32_shape(&[8, 1024]));
2207        let w = g.param("w", f32_shape(&[1024, 16]));
2208        let b = g.param("b", f32_shape(&[16]));
2209        let mm = g.matmul(x, w, f32_shape(&[8, 16]));
2210        let add = g.binary(BinaryOp::Add, mm, b, f32_shape(&[8, 16]));
2211        let exp = g.activation(Activation::Exp, add, f32_shape(&[8, 16]));
2212        g.set_outputs(vec![exp]);
2213
2214        let fused = FuseMatMulBiasAct.run(g);
2215        // mm + bias fuse; Exp stays separate (qwen35 softplus pattern).
2216        assert_eq!(fused.len(), 5);
2217        let out_node = fused.node(fused.outputs[0]);
2218        assert!(matches!(out_node.op, Op::Activation(Activation::Exp)));
2219        let add_node = fused.node(out_node.inputs[0]);
2220        assert!(matches!(
2221            add_node.op,
2222            Op::FusedMatMulBiasAct { activation: None }
2223        ));
2224    }
2225
2226    #[test]
2227    fn fuse_matmul_bias_act_with_late_bias_param() {
2228        use rlx_ir::infer::GraphExt;
2229
2230        let mut g = Graph::new("late_bias");
2231        let x = g.input("x", f32_shape(&[8, 16]));
2232        let w = g.param("w", f32_shape(&[16, 32]));
2233        let out = {
2234            let mm = g.mm(x, w);
2235            let b = g.param("b", f32_shape(&[32]));
2236            let biased = g.add(mm, b);
2237            g.gelu(biased)
2238        };
2239        g.set_outputs(vec![out]);
2240
2241        let fused = FuseMatMulBiasAct.run(g);
2242        assert!(
2243            fused
2244                .nodes()
2245                .iter()
2246                .any(|n| matches!(n.op, Op::FusedMatMulBiasAct { .. })),
2247            "bias param declared after matmul must still fuse:\n{fused}"
2248        );
2249    }
2250
2251    #[test]
2252    fn swiglu_ffn_builder_fuses_end_to_end() {
2253        let mut g = Graph::new("swiglu_block");
2254        let x = g.input("x", f32_shape(&[4, 768]));
2255        let up_w = g.param("up", f32_shape(&[768, 2048]));
2256        let gate_w = g.param("gate", f32_shape(&[768, 2048]));
2257        let down_w = g.param("down", f32_shape(&[2048, 768]));
2258        let out = g.swiglu_ffn(x, up_w, gate_w, down_w);
2259        g.set_outputs(vec![out]);
2260
2261        let g = FuseSharedInputMatMul.run(g);
2262        let g = FuseSwiGLU.run(g);
2263        assert!(
2264            g.nodes()
2265                .iter()
2266                .any(|n| matches!(n.op, Op::FusedSwiGLU { .. })),
2267            "swiglu_ffn builder should match FuseSwiGLU:\n{g}"
2268        );
2269    }
2270
2271    #[test]
2272    fn fuse_swiglu_dual_matmul_gate_first() {
2273        use rlx_ir::infer::GraphExt;
2274
2275        let mut g = Graph::new("qwen3_ffn");
2276        let x = g.input("x", f32_shape(&[4, 768]));
2277        let gate_w = g.param("gate", f32_shape(&[768, 2048]));
2278        let up_w = g.param("up", f32_shape(&[768, 2048]));
2279        let gate = g.mm(x, gate_w);
2280        let up = g.mm(x, up_w);
2281        let gate_act = g.silu(gate);
2282        let out = g.mul(gate_act, up);
2283        g.set_outputs(vec![out]);
2284
2285        let fused = FuseSwiGLUDualMatmul.run(g);
2286        assert!(
2287            fused
2288                .nodes()
2289                .iter()
2290                .any(|n| matches!(n.op, Op::FusedSwiGLU { .. })),
2291            "gate-first dual matmul should fuse:\n{fused}"
2292        );
2293        assert!(
2294            fused.len() <= 6,
2295            "dual fusion should collapse to x + weights + concat + mm + fused_swiglu, got {} nodes",
2296            fused.len()
2297        );
2298    }
2299
2300    #[test]
2301    fn fuse_shared_input_matmul_three_way_qkv() {
2302        let mut g = Graph::new("qkv");
2303        let x = g.input("x", f32_shape(&[8, 512]));
2304        let wq = g.param("wq", f32_shape(&[512, 128]));
2305        let wk = g.param("wk", f32_shape(&[512, 128]));
2306        let wv = g.param("wv", f32_shape(&[512, 128]));
2307        let q = g.matmul(x, wq, f32_shape(&[8, 128]));
2308        let k = g.matmul(x, wk, f32_shape(&[8, 128]));
2309        let v = g.matmul(x, wv, f32_shape(&[8, 128]));
2310        g.set_outputs(vec![q, k, v]);
2311
2312        let fused = FuseSharedInputMatMul.run(g);
2313        assert_eq!(
2314            fused.len(),
2315            9,
2316            "x + 3 weights + concat + mm + 3 narrows = 9"
2317        );
2318        for &out in &fused.outputs {
2319            assert!(matches!(fused.node(out).op, Op::Narrow { .. }));
2320        }
2321    }
2322
2323    #[test]
2324    fn fuse_residual_layer_norm() {
2325        let mut g = Graph::new("test");
2326        let x = g.input("x", f32_shape(&[4, 15, 384]));
2327        let residual = g.input("residual", f32_shape(&[4, 15, 384]));
2328        let gamma = g.param("gamma", f32_shape(&[384]));
2329        let beta = g.param("beta", f32_shape(&[384]));
2330        let add = g.binary(BinaryOp::Add, x, residual, f32_shape(&[4, 15, 384]));
2331        let ln = g.layer_norm(add, gamma, beta, -1, 1e-12, f32_shape(&[4, 15, 384]));
2332        g.set_outputs(vec![ln]);
2333
2334        assert_eq!(g.len(), 6); // x, residual, gamma, beta, add, ln
2335
2336        let fused = FuseResidualLN.run(g);
2337        println!("{fused}");
2338
2339        // Should be: x, residual, gamma, beta, fused_residual_ln
2340        assert_eq!(fused.len(), 5);
2341        let out_node = fused.node(fused.outputs[0]);
2342        assert!(matches!(
2343            out_node.op,
2344            Op::FusedResidualLN {
2345                has_bias: false,
2346                ..
2347            }
2348        ));
2349    }
2350
2351    #[test]
2352    fn fuse_residual_rms_norm() {
2353        let mut g = Graph::new("test");
2354        let x = g.input("x", f32_shape(&[4, 15, 384]));
2355        let residual = g.input("residual", f32_shape(&[4, 15, 384]));
2356        let gamma = g.param("gamma", f32_shape(&[384]));
2357        let beta = g.param("beta", f32_shape(&[384]));
2358        let add = g.binary(BinaryOp::Add, x, residual, f32_shape(&[4, 15, 384]));
2359        let rn = g.add_node(
2360            Op::RmsNorm {
2361                axis: -1,
2362                eps: 1e-6,
2363            },
2364            vec![add, gamma, beta],
2365            f32_shape(&[4, 15, 384]),
2366        );
2367        g.set_outputs(vec![rn]);
2368
2369        assert_eq!(g.len(), 6);
2370
2371        let fused = FuseResidualRmsNorm.run(g);
2372        assert_eq!(fused.len(), 5);
2373        let out_node = fused.node(fused.outputs[0]);
2374        assert!(matches!(
2375            out_node.op,
2376            Op::FusedResidualRmsNorm {
2377                has_bias: false,
2378                ..
2379            }
2380        ));
2381    }
2382
2383    #[test]
2384    fn fuse_rms_norm_reshape() {
2385        let mut g = Graph::new("test");
2386        let x = g.input("x", f32_shape(&[1, 8, 512]));
2387        let gamma = g.param("gamma", f32_shape(&[512]));
2388        let beta = g.param("beta", f32_shape(&[512]));
2389        let rn = g.add_node(
2390            Op::RmsNorm {
2391                axis: -1,
2392                eps: 1e-6,
2393            },
2394            vec![x, gamma, beta],
2395            f32_shape(&[1, 8, 512]),
2396        );
2397        let flat = g.add_node(
2398            Op::Reshape {
2399                new_shape: vec![8, 512],
2400            },
2401            vec![rn],
2402            f32_shape(&[8, 512]),
2403        );
2404        let w = g.param("w", f32_shape(&[512, 128]));
2405        let mm = g.matmul(flat, w, f32_shape(&[8, 128]));
2406        g.set_outputs(vec![mm]);
2407
2408        let fused = FuseRmsNormReshape.run(g);
2409        // x, gamma, beta, rms_norm(2d), w, matmul — no separate reshape
2410        assert_eq!(fused.len(), 6);
2411        let rn_node = fused.node(fused.node(fused.outputs[0]).inputs[0]);
2412        assert!(matches!(rn_node.op, Op::RmsNorm { .. }));
2413        assert_eq!(rn_node.shape.dim(0).unwrap_static(), 8);
2414        assert_eq!(rn_node.shape.dim(1).unwrap_static(), 512);
2415    }
2416
2417    #[test]
2418    fn fuse_shared_input_matmul() {
2419        let mut g = Graph::new("swiglu");
2420        let x = g.input("x", f32_shape(&[60, 768]));
2421        let w1 = g.param("fc11", f32_shape(&[768, 2048]));
2422        let w2 = g.param("fc12", f32_shape(&[768, 2048]));
2423        let mm1 = g.matmul(x, w1, f32_shape(&[60, 2048]));
2424        let mm2 = g.matmul(x, w2, f32_shape(&[60, 2048]));
2425        g.set_outputs(vec![mm1, mm2]);
2426
2427        assert_eq!(g.len(), 5); // x, w1, w2, mm1, mm2
2428
2429        let fused = FuseSharedInputMatMul.run(g);
2430        println!("{fused}");
2431
2432        // Should have: x, w1, w2, concat(w1,w2), combined_mm, narrow1, narrow2
2433        assert!(fused.len() <= 7);
2434        // Both outputs should be Narrow ops
2435        for &out in &fused.outputs {
2436            assert!(matches!(fused.node(out).op, Op::Narrow { .. }));
2437        }
2438    }
2439
2440    /// Regression: `FuseSharedInputMatMul` used to panic when `w2` is
2441    /// declared after `mm1`. `ensure_mapped` now copies late operands.
2442    #[test]
2443    fn fuse_shared_input_matmul_with_late_w2_param() {
2444        let mut g = Graph::new("late_w2");
2445        let x = g.input("x", f32_shape(&[8, 16]));
2446        let w1 = g.param("w1", f32_shape(&[16, 8]));
2447        let mm1 = g.matmul(x, w1, f32_shape(&[8, 8]));
2448        let w2 = g.param("w2", f32_shape(&[16, 8]));
2449        let mm2 = g.matmul(x, w2, f32_shape(&[8, 8]));
2450        g.set_outputs(vec![mm1, mm2]);
2451
2452        let fused = FuseSharedInputMatMul.run(g);
2453        for &out in &fused.outputs {
2454            assert!(
2455                matches!(fused.node(out).op, Op::Narrow { .. }),
2456                "late w2 should still fuse via ensure_mapped, got {:?}",
2457                fused.node(out).op
2458            );
2459        }
2460    }
2461
2462    /// Regression: qwen35moe FFN declares router / shared-expert matmuls on the
2463    /// same flattened hidden state with weights scattered through the block.
2464    #[test]
2465    fn fuse_shared_input_matmul_moe_ffn_pattern() {
2466        let mut g = Graph::new("moe_ffn");
2467        let rows = 4usize;
2468        let n_embd = 16usize;
2469        let n_expert = 4usize;
2470        let n_ff = 16usize;
2471
2472        let h_in = g.input("h", f32_shape(&[1, rows, n_embd]));
2473        let h_2d = g.reshape_(h_in, vec![rows as i64, n_embd as i64]);
2474
2475        let router_w = g.param("router_w", f32_shape(&[n_embd, n_expert]));
2476        let router_logits = g.matmul(h_2d, router_w, f32_shape(&[rows, n_expert]));
2477
2478        // MoE body omitted — only the shared-expert tail matters for fusion order.
2479        let shared_router_w = g.param("shared_router_w", f32_shape(&[n_embd, 1]));
2480        let shared_logits = g.matmul(h_2d, shared_router_w, f32_shape(&[rows, 1]));
2481        let shared_gate = g.activation(Activation::Sigmoid, shared_logits, f32_shape(&[rows, 1]));
2482
2483        let s_gate_w = g.param("s_gate_w", f32_shape(&[n_embd, n_ff]));
2484        let s_up_w = g.param("s_up_w", f32_shape(&[n_embd, n_ff]));
2485        let s_gate = g.matmul(h_2d, s_gate_w, f32_shape(&[rows, n_ff]));
2486        let s_up = g.matmul(h_2d, s_up_w, f32_shape(&[rows, n_ff]));
2487        let s_gate_silu = g.silu(s_gate);
2488        let s_swiglu = g.mul(s_gate_silu, s_up);
2489
2490        g.set_outputs(vec![router_logits, shared_gate, s_swiglu]);
2491
2492        let fused = FuseSharedInputMatMul.run(g);
2493        let narrow_count = fused
2494            .nodes()
2495            .iter()
2496            .filter(|n| matches!(n.op, Op::Narrow { .. }))
2497            .count();
2498        assert!(
2499            narrow_count >= 4,
2500            "expected four narrow slices from fused h_2d matmuls, got {narrow_count}"
2501        );
2502    }
2503
2504    /// Full pipeline: build a BERT FFN subgraph and run all fusion passes.
2505    #[test]
2506    fn full_bert_ffn_fusion() {
2507        let mut g = Graph::new("bert_ffn");
2508        let f = DType::F32;
2509
2510        let x = g.input("hidden", Shape::new(&[4, 15, 384], f));
2511        let residual = g.input("residual", Shape::new(&[4, 15, 384], f));
2512
2513        // Output projection result + residual + LN
2514        let out_w = g.param("out.w", Shape::new(&[384, 384], f));
2515        let out_b = g.param("out.b", Shape::new(&[384], f));
2516        let out_mm = g.matmul(x, out_w, Shape::new(&[4, 15, 384], f));
2517        let out_add = g.binary(BinaryOp::Add, out_mm, out_b, Shape::new(&[4, 15, 384], f));
2518        let res_add = g.binary(
2519            BinaryOp::Add,
2520            out_add,
2521            residual,
2522            Shape::new(&[4, 15, 384], f),
2523        );
2524        let gamma = g.param("ln.g", Shape::new(&[384], f));
2525        let beta = g.param("ln.b", Shape::new(&[384], f));
2526        let ln = g.layer_norm(
2527            res_add,
2528            gamma,
2529            beta,
2530            -1,
2531            1e-12,
2532            Shape::new(&[4, 15, 384], f),
2533        );
2534
2535        // FFN intermediate: matmul + bias + gelu
2536        let int_w = g.param("int.w", Shape::new(&[384, 1536], f));
2537        let int_b = g.param("int.b", Shape::new(&[1536], f));
2538        let int_mm = g.matmul(ln, int_w, Shape::new(&[4, 15, 1536], f));
2539        let int_add = g.binary(BinaryOp::Add, int_mm, int_b, Shape::new(&[4, 15, 1536], f));
2540        let gelu = g.activation(Activation::Gelu, int_add, Shape::new(&[4, 15, 1536], f));
2541
2542        // FFN output: matmul + bias
2543        let out2_w = g.param("out2.w", Shape::new(&[1536, 384], f));
2544        let out2_b = g.param("out2.b", Shape::new(&[384], f));
2545        let out2_mm = g.matmul(gelu, out2_w, Shape::new(&[4, 15, 384], f));
2546        let out2_add = g.binary(BinaryOp::Add, out2_mm, out2_b, Shape::new(&[4, 15, 384], f));
2547
2548        g.set_outputs(vec![out2_add]);
2549
2550        let before = g.len();
2551        println!("=== BEFORE fusion ({before} nodes) ===\n{g}");
2552
2553        // Run all passes
2554        let passes: Vec<&dyn Pass> = vec![&FuseMatMulBiasAct, &FuseResidualLN];
2555        let optimized = run_passes(g, &passes, false);
2556        let after = optimized.len();
2557        println!("=== AFTER fusion ({after} nodes) ===\n{optimized}");
2558
2559        // Should have eliminated:
2560        // - 2 Add + 1 Gelu from matmul_bias_gelu fusion (×2 matmuls)
2561        // - 1 Add from residual_ln fusion
2562        assert!(
2563            after < before,
2564            "fusion should reduce node count: {before} → {after}"
2565        );
2566
2567        // Check that fused ops exist
2568        let ops: Vec<String> = optimized
2569            .nodes()
2570            .iter()
2571            .map(|n| format!("{}", n.op))
2572            .collect();
2573        let has_fused_mm = ops.iter().any(|s| s.contains("fused_mm_bias"));
2574        assert!(has_fused_mm, "should have fused_mm_bias_act: {ops:?}");
2575    }
2576
2577    /// FuseSwiGLU fires on the canonical Nomic-style pattern produced by
2578    /// `FuseSharedInputMatMul` (concat'd matmul → narrow×2 → silu → mul).
2579    #[test]
2580    fn fuse_swiglu_canonical() {
2581        let mut g = Graph::new("nomic_ffn");
2582        let f = DType::F32;
2583        // After FuseSharedInputMatMul: cat = mm(x, concat(fc11, fc12)) → [60, 4096]
2584        let cat = g.input("cat", Shape::new(&[60, 4096], f));
2585        let up = g.add_node(
2586            Op::Narrow {
2587                axis: 1,
2588                start: 0,
2589                len: 2048,
2590            },
2591            vec![cat],
2592            Shape::new(&[60, 2048], f),
2593        );
2594        let gate = g.add_node(
2595            Op::Narrow {
2596                axis: 1,
2597                start: 2048,
2598                len: 2048,
2599            },
2600            vec![cat],
2601            Shape::new(&[60, 2048], f),
2602        );
2603        let silu = g.activation(Activation::Silu, gate, Shape::new(&[60, 2048], f));
2604        let out = g.binary(BinaryOp::Mul, up, silu, Shape::new(&[60, 2048], f));
2605        g.set_outputs(vec![out]);
2606
2607        let before = g.len();
2608        let fused = FuseSwiGLU.run(g);
2609        let after = fused.len();
2610        // Removed: up, gate, silu, mul → replaced by FusedSwiGLU.
2611        // Net: -3 nodes (4 removed, 1 added).
2612        assert_eq!(
2613            after,
2614            before - 3,
2615            "should remove narrows+silu+mul, add FusedSwiGLU"
2616        );
2617        let out_node = fused.node(fused.outputs[0]);
2618        assert!(
2619            matches!(
2620                out_node.op,
2621                Op::FusedSwiGLU {
2622                    cast_to: None,
2623                    gate_first: false
2624                }
2625            ),
2626            "output should be FusedSwiGLU, got {}",
2627            out_node.op
2628        );
2629        // FusedSwiGLU's input is the cat tensor.
2630        let in_id = out_node.inputs[0];
2631        assert!(matches!(fused.node(in_id).op, Op::Input { .. }));
2632    }
2633
2634    /// FuseSwiGLU does NOT fire when narrows are shared with another consumer
2635    /// (would corrupt the second consumer's view of the data).
2636    #[test]
2637    fn fuse_swiglu_skips_when_narrow_has_extra_user() {
2638        let mut g = Graph::new("contended");
2639        let f = DType::F32;
2640        let cat = g.input("cat", Shape::new(&[60, 4096], f));
2641        let up = g.add_node(
2642            Op::Narrow {
2643                axis: 1,
2644                start: 0,
2645                len: 2048,
2646            },
2647            vec![cat],
2648            Shape::new(&[60, 2048], f),
2649        );
2650        let gate = g.add_node(
2651            Op::Narrow {
2652                axis: 1,
2653                start: 2048,
2654                len: 2048,
2655            },
2656            vec![cat],
2657            Shape::new(&[60, 2048], f),
2658        );
2659        let silu = g.activation(Activation::Silu, gate, Shape::new(&[60, 2048], f));
2660        let out = g.binary(BinaryOp::Mul, up, silu, Shape::new(&[60, 2048], f));
2661        // Extra user of `up` — this should block fusion.
2662        let extra = g.activation(Activation::Relu, up, Shape::new(&[60, 2048], f));
2663        g.set_outputs(vec![out, extra]);
2664
2665        let before = g.len();
2666        let fused = FuseSwiGLU.run(g);
2667        // Pass should be a no-op when fusion is unsafe.
2668        assert_eq!(fused.len(), before);
2669        // No FusedSwiGLU node anywhere.
2670        let any_fused = fused
2671            .nodes()
2672            .iter()
2673            .any(|n| matches!(n.op, Op::FusedSwiGLU { .. }));
2674        assert!(!any_fused, "should not fuse when narrow has extra user");
2675    }
2676
2677    // ── MarkElementwiseRegions (PLAN L2) ────────────────────────────
2678
2679    #[test]
2680    fn region_collapses_add_mul_relu_chain() {
2681        // Build: out = relu(add(a, b) * c). All same shape, single consumer
2682        // chain. Should fuse into one ElementwiseRegion.
2683        let f = DType::F32;
2684        let mut g = Graph::new("ew");
2685        let a = g.input("a", Shape::new(&[8], f));
2686        let b = g.input("b", Shape::new(&[8], f));
2687        let c = g.input("c", Shape::new(&[8], f));
2688        let s = Shape::new(&[8], f);
2689        let add = g.binary(BinaryOp::Add, a, b, s.clone());
2690        let mul = g.binary(BinaryOp::Mul, add, c, s.clone());
2691        let relu = g.activation(Activation::Relu, mul, s.clone());
2692        g.set_outputs(vec![relu]);
2693
2694        let before = g.len();
2695        let fused = MarkElementwiseRegions.run(g);
2696
2697        // Three element-wise ops collapsed into one region node.
2698        let regions: Vec<_> = fused
2699            .nodes()
2700            .iter()
2701            .filter(|n| matches!(n.op, Op::ElementwiseRegion { .. }))
2702            .collect();
2703        assert_eq!(regions.len(), 1, "expected one ElementwiseRegion");
2704        let region = regions[0];
2705        assert_eq!(
2706            region.inputs.len(),
2707            3,
2708            "region has 3 external inputs (a, b, c)"
2709        );
2710        if let Op::ElementwiseRegion {
2711            chain, num_inputs, ..
2712        } = &region.op
2713        {
2714            assert_eq!(*num_inputs, 3);
2715            assert_eq!(chain.len(), 3);
2716            // Step 0: Add(Input(0), Input(1))
2717            match &chain[0] {
2718                ChainStep::Binary(
2719                    BinaryOp::Add,
2720                    ChainOperand::Input(0),
2721                    ChainOperand::Input(1),
2722                ) => {}
2723                other => panic!("step 0 unexpected: {other:?}"),
2724            }
2725            // Step 1: Mul(Step(0), Input(2))
2726            match &chain[1] {
2727                ChainStep::Binary(BinaryOp::Mul, ChainOperand::Step(0), ChainOperand::Input(2)) => {
2728                }
2729                other => panic!("step 1 unexpected: {other:?}"),
2730            }
2731            // Step 2: Activation(Relu, Step(1))
2732            match &chain[2] {
2733                ChainStep::Activation(Activation::Relu, ChainOperand::Step(1)) => {}
2734                other => panic!("step 2 unexpected: {other:?}"),
2735            }
2736        } else {
2737            unreachable!();
2738        }
2739        // Original chain (3 ops) replaced by 1 region; net node count is
2740        // (inputs 3) + (region 1) = 4 (vs 3 + 3 = 6 before).
2741        assert!(fused.len() < before);
2742    }
2743
2744    #[test]
2745    fn region_does_not_fuse_when_intermediate_has_multiple_consumers() {
2746        // out1 = add(a, b); out2 = relu(out1). out1 also fed to out_extra.
2747        // Multi-consumer on out1 forbids fusion.
2748        let f = DType::F32;
2749        let mut g = Graph::new("ew");
2750        let a = g.input("a", Shape::new(&[4], f));
2751        let b = g.input("b", Shape::new(&[4], f));
2752        let s = Shape::new(&[4], f);
2753        let add = g.binary(BinaryOp::Add, a, b, s.clone());
2754        let relu = g.activation(Activation::Relu, add, s.clone());
2755        let extra = g.activation(Activation::Sigmoid, add, s.clone());
2756        g.set_outputs(vec![relu, extra]);
2757
2758        let before = g.len();
2759        let fused = MarkElementwiseRegions.run(g);
2760        // No region: add has two consumers (relu and extra), so the chain
2761        // can't extend through it. Each downstream activation is alone in
2762        // its region (size 1, doesn't fuse).
2763        let regions: Vec<_> = fused
2764            .nodes()
2765            .iter()
2766            .filter(|n| matches!(n.op, Op::ElementwiseRegion { .. }))
2767            .collect();
2768        assert_eq!(regions.len(), 0);
2769        assert_eq!(fused.len(), before);
2770    }
2771
2772    #[test]
2773    fn region_skips_chains_of_length_one() {
2774        // Single relu — no fusion (size 1 = degenerate).
2775        let f = DType::F32;
2776        let mut g = Graph::new("ew");
2777        let a = g.input("a", Shape::new(&[4], f));
2778        let r = g.activation(Activation::Relu, a, Shape::new(&[4], f));
2779        g.set_outputs(vec![r]);
2780
2781        let fused = MarkElementwiseRegions.run(g);
2782        let any_region = fused
2783            .nodes()
2784            .iter()
2785            .any(|n| matches!(n.op, Op::ElementwiseRegion { .. }));
2786        assert!(!any_region);
2787    }
2788
2789    #[test]
2790    fn unfuse_decomposes_region_back_to_atomic_ops() {
2791        // Build the same chain, fuse it, then unfuse — expect the
2792        // original atomic ops back (Add, Mul, Relu).
2793        let f = DType::F32;
2794        let mut g = Graph::new("ew_unfuse");
2795        let a = g.input("a", Shape::new(&[8], f));
2796        let b = g.input("b", Shape::new(&[8], f));
2797        let c = g.input("c", Shape::new(&[8], f));
2798        let s = Shape::new(&[8], f);
2799        let add = g.binary(BinaryOp::Add, a, b, s.clone());
2800        let mul = g.binary(BinaryOp::Mul, add, c, s.clone());
2801        let relu = g.activation(Activation::Relu, mul, s);
2802        g.set_outputs(vec![relu]);
2803
2804        let fused = MarkElementwiseRegions.run(g);
2805        // Sanity: fusion happened.
2806        assert!(
2807            fused
2808                .nodes()
2809                .iter()
2810                .any(|n| matches!(n.op, Op::ElementwiseRegion { .. }))
2811        );
2812
2813        let unfused = UnfuseElementwiseRegions::FOR_CPU.run(fused);
2814        // No region nodes left.
2815        assert!(
2816            !unfused
2817                .nodes()
2818                .iter()
2819                .any(|n| matches!(n.op, Op::ElementwiseRegion { .. }))
2820        );
2821        // Original atomic ops are back: Add, Mul, Relu.
2822        let bin_count = unfused
2823            .nodes()
2824            .iter()
2825            .filter(|n| matches!(n.op, Op::Binary(_)))
2826            .count();
2827        let act_count = unfused
2828            .nodes()
2829            .iter()
2830            .filter(|n| matches!(n.op, Op::Activation(_)))
2831            .count();
2832        assert_eq!(bin_count, 2, "Add + Mul restored");
2833        assert_eq!(act_count, 1, "Relu restored");
2834    }
2835
2836    #[test]
2837    fn clip_unfuses_region_over_step_cap() {
2838        use rlx_ir::op::{Activation, ChainOperand, ChainStep};
2839
2840        let mut g = Graph::new("clip");
2841        let x = g.input("x", f32_shape(&[4]));
2842        let mut chain: Vec<ChainStep> = Vec::new();
2843        let mut prev = ChainOperand::Input(0);
2844        for _ in 0..40 {
2845            chain.push(ChainStep::Activation(Activation::Relu, prev));
2846            prev = ChainOperand::Step(chain.len() as u32 - 1);
2847        }
2848        let y = g.add_node(
2849            Op::ElementwiseRegion {
2850                chain,
2851                num_inputs: 1,
2852                scalar_input_mask: 0,
2853                input_modulus: [0; 16],
2854                prologue: RegionPrologue::None,
2855                prologue_input: 0,
2856            },
2857            vec![x],
2858            f32_shape(&[4]),
2859        );
2860        g.set_outputs(vec![y]);
2861
2862        let clipped = clip_elementwise_regions(g, FusionLimits::GPU_NATIVE);
2863        assert!(
2864            !clipped
2865                .nodes()
2866                .iter()
2867                .any(|n| matches!(n.op, Op::ElementwiseRegion { .. })),
2868            "oversized region should be decomposed"
2869        );
2870        assert!(clipped.len() > 5);
2871    }
2872
2873    #[test]
2874    fn unfuse_is_noop_when_no_region_present() {
2875        let f = DType::F32;
2876        let mut g = Graph::new("noop");
2877        let a = g.input("a", Shape::new(&[4], f));
2878        let r = g.activation(Activation::Relu, a, Shape::new(&[4], f));
2879        g.set_outputs(vec![r]);
2880        let n_before = g.len();
2881        let result = UnfuseElementwiseRegions::FOR_CPU.run(g);
2882        // Pass returns unchanged graph (early return on no-region check).
2883        assert_eq!(result.len(), n_before);
2884    }
2885
2886    #[test]
2887    fn region_includes_where_step() {
2888        // Build: cmp = a > b; sel = where(cmp, a, b); out = sel + a
2889        // The compare → where → add chain is fully element-wise; the
2890        // Where step lands inside the region thanks to the L2-quality
2891        // extension that adds `Op::Where` to the chain-eligible set.
2892        let f = DType::F32;
2893        let mut g = Graph::new("region_where");
2894        let a = g.input("a", Shape::new(&[4], f));
2895        let b = g.input("b", Shape::new(&[4], f));
2896        let s = Shape::new(&[4], f);
2897        let cmp = g.add_node(Op::Compare(CmpOp::Gt), vec![a, b], s.clone());
2898        let sel = g.add_node(Op::Where, vec![cmp, a, b], s.clone());
2899        let add = g.binary(BinaryOp::Add, sel, a, s.clone());
2900        g.set_outputs(vec![add]);
2901
2902        let fused = MarkElementwiseRegions.run(g);
2903        let regions: Vec<_> = fused
2904            .nodes()
2905            .iter()
2906            .filter(|n| matches!(n.op, Op::ElementwiseRegion { .. }))
2907            .collect();
2908        assert_eq!(regions.len(), 1, "expected one ElementwiseRegion");
2909        if let Op::ElementwiseRegion { chain, .. } = &regions[0].op {
2910            // 3 steps: Compare a > b, Where, Add
2911            assert_eq!(chain.len(), 3);
2912            assert!(
2913                matches!(chain[1], ChainStep::Where(_, _, _)),
2914                "step 1 should be Where, got {:?}",
2915                chain[1]
2916            );
2917        } else {
2918            unreachable!();
2919        }
2920    }
2921
2922    #[test]
2923    fn unfuse_decomposes_where_step_back_to_op_where() {
2924        // Round-trip: build a region with a Where step, decompose it,
2925        // verify the resulting graph contains an Op::Where node.
2926        let f = DType::F32;
2927        let mut g = Graph::new("unfuse_where");
2928        let a = g.input("a", Shape::new(&[4], f));
2929        let b = g.input("b", Shape::new(&[4], f));
2930        let s = Shape::new(&[4], f);
2931        let cmp = g.add_node(Op::Compare(CmpOp::Gt), vec![a, b], s.clone());
2932        let sel = g.add_node(Op::Where, vec![cmp, a, b], s.clone());
2933        let add = g.binary(BinaryOp::Add, sel, a, s.clone());
2934        g.set_outputs(vec![add]);
2935        let fused = MarkElementwiseRegions.run(g);
2936        let unfused = UnfuseElementwiseRegions::FOR_CPU.run(fused);
2937        let where_count = unfused
2938            .nodes()
2939            .iter()
2940            .filter(|n| matches!(n.op, Op::Where))
2941            .count();
2942        assert_eq!(
2943            where_count, 1,
2944            "decomposer should re-emit one Op::Where for the chain step"
2945        );
2946    }
2947
2948    /// Synthetic BERT attention block: input [B,S,H] → QKV proj (matmul+bias) →
2949    /// narrow×3 → Attention(mask) → OutProj (matmul+bias) → output [B,S,H].
2950    /// Runs FuseMatMulBiasAct then FuseAttentionBlock and asserts collapse.
2951    #[test]
2952    fn fuse_attention_block_collapses_qkv_attn_outproj() {
2953        let nh: usize = 4;
2954        let dh: usize = 8;
2955        let h: usize = nh * dh; // 32
2956        let b: usize = 1;
2957        let s: usize = 4; // tiny — keep b*s ≤ 64 so should_fuse fires
2958
2959        let mut g = Graph::new("attn-block");
2960        let hidden = g.input("hidden", f32_shape(&[b, s, h]));
2961        let mask = g.input("attention_mask", f32_shape(&[b, s]));
2962
2963        // QKV projection (matmul + bias).
2964        let qkv_w = g.param("qkv_w", f32_shape(&[h, 3 * h]));
2965        let qkv_b = g.param("qkv_b", f32_shape(&[3 * h]));
2966        let qkv_mm = g.matmul(hidden, qkv_w, f32_shape(&[b, s, 3 * h]));
2967        let qkv = g.binary(BinaryOp::Add, qkv_mm, qkv_b, f32_shape(&[b, s, 3 * h]));
2968
2969        // Three narrows on the innermost axis.
2970        let q = g.add_node(
2971            Op::Narrow {
2972                axis: 2,
2973                start: 0,
2974                len: h,
2975            },
2976            vec![qkv],
2977            f32_shape(&[b, s, h]),
2978        );
2979        let k = g.add_node(
2980            Op::Narrow {
2981                axis: 2,
2982                start: h,
2983                len: h,
2984            },
2985            vec![qkv],
2986            f32_shape(&[b, s, h]),
2987        );
2988        let v = g.add_node(
2989            Op::Narrow {
2990                axis: 2,
2991                start: 2 * h,
2992                len: h,
2993            },
2994            vec![qkv],
2995            f32_shape(&[b, s, h]),
2996        );
2997
2998        // Attention with custom (input) mask.
2999        let attn = g.attention(q, k, v, mask, nh, dh, f32_shape(&[b, s, h]));
3000
3001        // OutProj (matmul + bias).
3002        let out_w = g.param("out_w", f32_shape(&[h, h]));
3003        let out_b = g.param("out_b", f32_shape(&[h]));
3004        let out_mm = g.matmul(attn, out_w, f32_shape(&[b, s, h]));
3005        let out = g.binary(BinaryOp::Add, out_mm, out_b, f32_shape(&[b, s, h]));
3006        g.set_outputs(vec![out]);
3007
3008        // Step 1: FuseMatMulBiasAct collapses each matmul+bias into one node.
3009        let fused1 = FuseMatMulBiasAct.run(g);
3010        let mm_bias_count = fused1
3011            .nodes()
3012            .iter()
3013            .filter(|n| matches!(n.op, Op::FusedMatMulBiasAct { activation: None }))
3014            .count();
3015        assert_eq!(mm_bias_count, 2, "QKV + OutProj should each fuse");
3016
3017        // Step 2: FuseAttentionBlock collapses QKV-MM → narrow×3 → Attention → OutProj-MM
3018        // into one FusedAttentionBlock node.
3019        let fused2 = FuseAttentionBlock.run(fused1);
3020        let fab_count = fused2
3021            .nodes()
3022            .iter()
3023            .filter(|n| {
3024                matches!(
3025                    n.op,
3026                    Op::FusedAttentionBlock {
3027                        has_bias: true,
3028                        has_rope: false,
3029                        ..
3030                    }
3031                )
3032            })
3033            .count();
3034        assert_eq!(
3035            fab_count, 1,
3036            "should produce exactly one FusedAttentionBlock"
3037        );
3038
3039        // No stray Narrow / Attention / FusedMatMulBiasAct should remain from
3040        // the collapsed chain.
3041        let narrow_count = fused2
3042            .nodes()
3043            .iter()
3044            .filter(|n| matches!(n.op, Op::Narrow { .. }))
3045            .count();
3046        let attention_count = fused2
3047            .nodes()
3048            .iter()
3049            .filter(|n| matches!(n.op, Op::Attention { .. }))
3050            .count();
3051        let mm_bias_remaining = fused2
3052            .nodes()
3053            .iter()
3054            .filter(|n| matches!(n.op, Op::FusedMatMulBiasAct { .. }))
3055            .count();
3056        assert_eq!(narrow_count, 0, "QKV narrows absorbed");
3057        assert_eq!(attention_count, 0, "Attention absorbed");
3058        assert_eq!(mm_bias_remaining, 0, "both projections absorbed");
3059
3060        let out_node = fused2.node(fused2.outputs[0]);
3061        assert!(matches!(out_node.op, Op::FusedAttentionBlock { .. }));
3062    }
3063
3064    /// Synthetic full BERT layer (one block): hidden → FusedAttentionBlock →
3065    /// FusedResidualLN → FusedMatMulBiasAct(GeLU) → FusedMatMulBiasAct →
3066    /// FusedResidualLN. Confirm FuseTransformerLayer collapses to one node.
3067    #[test]
3068    fn fuse_transformer_layer_collapses_full_bert_block() {
3069        let nh: usize = 4;
3070        let dh: usize = 8;
3071        let h: usize = nh * dh;
3072        let inter = 4 * h;
3073        let eps1: f32 = 1e-12;
3074        let eps2: f32 = 1e-12;
3075        let b: usize = 1;
3076        let s: usize = 4;
3077
3078        let mut g = Graph::new("bert-layer");
3079        let hidden = g.input("hidden", f32_shape(&[b, s, h]));
3080        let mask = g.input("attention_mask", f32_shape(&[b, s]));
3081
3082        // === Attention block ===
3083        let qkv_w = g.param("qkv_w", f32_shape(&[h, 3 * h]));
3084        let qkv_b = g.param("qkv_b", f32_shape(&[3 * h]));
3085        let qkv_mm = g.matmul(hidden, qkv_w, f32_shape(&[b, s, 3 * h]));
3086        let qkv = g.binary(BinaryOp::Add, qkv_mm, qkv_b, f32_shape(&[b, s, 3 * h]));
3087        let q = g.add_node(
3088            Op::Narrow {
3089                axis: 2,
3090                start: 0,
3091                len: h,
3092            },
3093            vec![qkv],
3094            f32_shape(&[b, s, h]),
3095        );
3096        let k = g.add_node(
3097            Op::Narrow {
3098                axis: 2,
3099                start: h,
3100                len: h,
3101            },
3102            vec![qkv],
3103            f32_shape(&[b, s, h]),
3104        );
3105        let v = g.add_node(
3106            Op::Narrow {
3107                axis: 2,
3108                start: 2 * h,
3109                len: h,
3110            },
3111            vec![qkv],
3112            f32_shape(&[b, s, h]),
3113        );
3114        let attn = g.attention(q, k, v, mask, nh, dh, f32_shape(&[b, s, h]));
3115        let out_w = g.param("out_w", f32_shape(&[h, h]));
3116        let out_b = g.param("out_b", f32_shape(&[h]));
3117        let out_mm = g.matmul(attn, out_w, f32_shape(&[b, s, h]));
3118        let attn_out = g.binary(BinaryOp::Add, out_mm, out_b, f32_shape(&[b, s, h]));
3119
3120        // === Post-attn residual + LN ===
3121        let res1 = g.binary(BinaryOp::Add, attn_out, hidden, f32_shape(&[b, s, h]));
3122        let ln1_g = g.param("ln1_g", f32_shape(&[h]));
3123        let ln1_b = g.param("ln1_b", f32_shape(&[h]));
3124        let h1 = g.add_node(
3125            Op::LayerNorm {
3126                axis: -1,
3127                eps: eps1,
3128            },
3129            vec![res1, ln1_g, ln1_b],
3130            f32_shape(&[b, s, h]),
3131        );
3132
3133        // === FFN ===
3134        let fc1_w = g.param("fc1_w", f32_shape(&[h, inter]));
3135        let fc1_b = g.param("fc1_b", f32_shape(&[inter]));
3136        let fc1_mm = g.matmul(h1, fc1_w, f32_shape(&[b, s, inter]));
3137        let fc1_add = g.binary(BinaryOp::Add, fc1_mm, fc1_b, f32_shape(&[b, s, inter]));
3138        let fc1_act = g.activation(Activation::Gelu, fc1_add, f32_shape(&[b, s, inter]));
3139        let fc2_w = g.param("fc2_w", f32_shape(&[inter, h]));
3140        let fc2_b = g.param("fc2_b", f32_shape(&[h]));
3141        let fc2_mm = g.matmul(fc1_act, fc2_w, f32_shape(&[b, s, h]));
3142        let ffn_out = g.binary(BinaryOp::Add, fc2_mm, fc2_b, f32_shape(&[b, s, h]));
3143
3144        // === Post-FFN residual + LN ===
3145        let res2 = g.binary(BinaryOp::Add, ffn_out, h1, f32_shape(&[b, s, h]));
3146        let ln2_g = g.param("ln2_g", f32_shape(&[h]));
3147        let ln2_b = g.param("ln2_b", f32_shape(&[h]));
3148        let out = g.add_node(
3149            Op::LayerNorm {
3150                axis: -1,
3151                eps: eps2,
3152            },
3153            vec![res2, ln2_g, ln2_b],
3154            f32_shape(&[b, s, h]),
3155        );
3156        g.set_outputs(vec![out]);
3157
3158        // Run the same pipeline order the production pipeline uses.
3159        let g = FuseMatMulBiasAct.run(g);
3160        let g = FuseResidualLN.run(g);
3161        let g = FuseAttentionBlock.run(g);
3162        let g = FuseTransformerLayer.run(g);
3163
3164        let ftl_count = g
3165            .nodes()
3166            .iter()
3167            .filter(|n| matches!(n.op, Op::FusedTransformerLayer { .. }))
3168            .count();
3169        assert_eq!(
3170            ftl_count, 1,
3171            "single layer should collapse to one FusedTransformerLayer"
3172        );
3173
3174        // After the full pipeline, the layer's intermediate fused ops should
3175        // be gone — only the parameter / input nodes and the single
3176        // FusedTransformerLayer remain.
3177        let leftover_fab = g
3178            .nodes()
3179            .iter()
3180            .filter(|n| matches!(n.op, Op::FusedAttentionBlock { .. }))
3181            .count();
3182        let leftover_frln = g
3183            .nodes()
3184            .iter()
3185            .filter(|n| matches!(n.op, Op::FusedResidualLN { .. }))
3186            .count();
3187        let leftover_fmba = g
3188            .nodes()
3189            .iter()
3190            .filter(|n| matches!(n.op, Op::FusedMatMulBiasAct { .. }))
3191            .count();
3192        assert_eq!(leftover_fab, 0, "attn block absorbed into layer");
3193        assert_eq!(leftover_frln, 0, "both residual+LNs absorbed");
3194        assert_eq!(leftover_fmba, 0, "FFN matmuls absorbed");
3195
3196        let out_node = g.node(g.outputs[0]);
3197        assert!(matches!(
3198            out_node.op,
3199            Op::FusedTransformerLayer {
3200                num_heads: 4,
3201                head_dim: 8,
3202                intermediate_size: 128,
3203                has_bias: true,
3204                ..
3205            }
3206        ));
3207        assert_eq!(out_node.inputs.len(), 14);
3208    }
3209
3210    /// `should_fuse` must reject the pass when batch·seq exceeds the threshold,
3211    /// so attention block fusion stays opt-in for small inputs.
3212    #[test]
3213    fn fuse_attention_block_skips_large_inputs() {
3214        let nh: usize = 4;
3215        let dh: usize = 8;
3216        let h: usize = nh * dh;
3217        let b: usize = 16;
3218        let s: usize = 128; // b*s = 2048 ≫ 64 default threshold
3219
3220        let mut g = Graph::new("attn-block-large");
3221        let hidden = g.input("hidden", f32_shape(&[b, s, h]));
3222        let mask = g.input("attention_mask", f32_shape(&[b, s]));
3223        let qkv_w = g.param("qkv_w", f32_shape(&[h, 3 * h]));
3224        let qkv_b = g.param("qkv_b", f32_shape(&[3 * h]));
3225        let qkv_mm = g.matmul(hidden, qkv_w, f32_shape(&[b, s, 3 * h]));
3226        let qkv = g.binary(BinaryOp::Add, qkv_mm, qkv_b, f32_shape(&[b, s, 3 * h]));
3227        let q = g.add_node(
3228            Op::Narrow {
3229                axis: 2,
3230                start: 0,
3231                len: h,
3232            },
3233            vec![qkv],
3234            f32_shape(&[b, s, h]),
3235        );
3236        let k = g.add_node(
3237            Op::Narrow {
3238                axis: 2,
3239                start: h,
3240                len: h,
3241            },
3242            vec![qkv],
3243            f32_shape(&[b, s, h]),
3244        );
3245        let v = g.add_node(
3246            Op::Narrow {
3247                axis: 2,
3248                start: 2 * h,
3249                len: h,
3250            },
3251            vec![qkv],
3252            f32_shape(&[b, s, h]),
3253        );
3254        let attn = g.attention(q, k, v, mask, nh, dh, f32_shape(&[b, s, h]));
3255        let out_w = g.param("out_w", f32_shape(&[h, h]));
3256        let out_b = g.param("out_b", f32_shape(&[h]));
3257        let out_mm = g.matmul(attn, out_w, f32_shape(&[b, s, h]));
3258        let out = g.binary(BinaryOp::Add, out_mm, out_b, f32_shape(&[b, s, h]));
3259        g.set_outputs(vec![out]);
3260
3261        let fused1 = FuseMatMulBiasAct.run(g);
3262        let fused2 = FuseAttentionBlock.run(fused1);
3263        let fab_count = fused2
3264            .nodes()
3265            .iter()
3266            .filter(|n| matches!(n.op, Op::FusedAttentionBlock { .. }))
3267            .count();
3268        assert_eq!(fab_count, 0, "block-fusion must skip large batches");
3269    }
3270}