Skip to main content

rlx_fusion/
fusion.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Fusion passes — pattern-match and replace subgraphs with fused ops.
17//!
18//! Each pass scans the graph in reverse topological order, looking for
19//! specific multi-node patterns and replacing them with single fused nodes.
20//! These are the same fusions we hand-coded in burnembed's ndarray_fused.rs.
21
22use crate::pass::Pass;
23use rlx_ir::op::*;
24use rlx_ir::*;
25use std::collections::HashMap;
26
27// ── Helper: graph rewriter ──────────────────────────────────────────────
28
29/// Maps old NodeIds to new NodeIds during graph rewriting.
30struct Rewriter {
31    new_graph: Graph,
32    id_map: HashMap<NodeId, NodeId>,
33}
34
35impl Rewriter {
36    fn new(name: &str) -> Self {
37        Self {
38            new_graph: Graph::new(name),
39            id_map: HashMap::new(),
40        }
41    }
42
43    /// Map an old NodeId to its new equivalent.
44    fn map(&self, old: NodeId) -> NodeId {
45        self.id_map[&old]
46    }
47
48    /// Map a list of old NodeIds.
49    fn map_inputs(&self, old_inputs: &[NodeId]) -> Vec<NodeId> {
50        old_inputs.iter().map(|id| self.map(*id)).collect()
51    }
52
53    /// True iff every old NodeId in `ids` has already been mapped — used by fusion
54    /// patterns to gate a rewrite on its inputs being live in the new graph.
55    #[allow(dead_code)]
56    fn all_mapped(&self, ids: &[NodeId]) -> bool {
57        ids.iter().all(|id| self.id_map.contains_key(id))
58    }
59
60    /// Copy any not-yet-mapped nodes from `old` so fusion rewrites can
61    /// reference operands declared later in the source graph (e.g. a bias
62    /// param appended after its matmul consumer, or a reshape input that
63    /// has not been reached in the linear rewrite walk yet).
64    fn ensure_mapped(&mut self, old: &Graph, ids: &[NodeId]) {
65        for &id in ids {
66            if self.id_map.contains_key(&id) {
67                continue;
68            }
69            let node = old.node(id);
70            if !node.inputs.is_empty() {
71                self.ensure_mapped(old, &node.inputs);
72            }
73            self.copy_node(node);
74        }
75    }
76
77    /// Copy a node from the old graph, remapping inputs.
78    fn copy_node(&mut self, node: &Node) -> NodeId {
79        let new_inputs = self.map_inputs(&node.inputs);
80        let new_id = self
81            .new_graph
82            .add_node(node.op.clone(), new_inputs, node.shape.clone());
83        let new_node = self.new_graph.node_mut(new_id);
84        new_node.name = node.name.clone();
85        new_node.origin = node.origin.clone();
86        self.id_map.insert(node.id, new_id);
87        new_id
88    }
89
90    /// Add a new fused node (not from the old graph).
91    fn add_fused(&mut self, op: Op, old_inputs: &[NodeId], shape: Shape) -> NodeId {
92        let new_inputs: Vec<NodeId> = old_inputs.iter().map(|id| self.map(*id)).collect();
93        self.new_graph.add_node(op, new_inputs, shape)
94    }
95
96    /// Mark an old node as replaced by a new node.
97    fn replace(&mut self, old_id: NodeId, new_id: NodeId) {
98        self.id_map.insert(old_id, new_id);
99    }
100
101    fn finish(mut self, old_outputs: &[NodeId]) -> Graph {
102        let new_outputs = old_outputs.iter().map(|id| self.map(*id)).collect();
103        self.new_graph.set_outputs(new_outputs);
104        self.new_graph
105    }
106}
107
108// ── Pass 1: MatMul + Bias + Activation → FusedMatMulBiasAct ─────────────
109
110/// Fuses `matmul → add(bias) → activation` into a single FusedMatMulBiasAct.
111///
112/// This is the single most impactful fusion — it eliminates two intermediate
113/// tensors and three memory passes (matmul write, bias read+write, act read+write)
114/// down to one (matmul write with inline bias+activation).
115///
116/// Also fuses `matmul → add(bias)` without activation.
117///
118/// Epilogue activations are fused only when every backend can apply them
119/// inline with the matmul (today: Gelu and Silu). Other activations — e.g.
120/// Exp in qwen35 softplus — stay as separate ops so Metal does not silently
121/// drop the epilogue.
122pub struct FuseMatMulBiasAct;
123
124/// Activations that may be folded into `FusedMatMulBiasAct` epilogues.
125fn fusible_mm_bias_epilogue_activation(act: Activation) -> bool {
126    matches!(act, Activation::Gelu | Activation::Silu)
127}
128
129impl Pass for FuseMatMulBiasAct {
130    fn name(&self) -> &str {
131        "fuse_matmul_bias_act"
132    }
133
134    fn run(&self, graph: Graph) -> Graph {
135        let mut rw = Rewriter::new(&graph.name);
136        // Track which nodes are consumed by fusion (skip them in copy)
137        let mut fused_away: HashMap<NodeId, ()> = HashMap::new();
138
139        // Forward pass: copy nodes, detect patterns
140        for node in graph.nodes() {
141            if fused_away.contains_key(&node.id) {
142                continue;
143            }
144
145            // Pattern: MatMul → Add(bias) → Activation
146            // or:      MatMul → Add(bias)
147            if matches!(node.op, Op::MatMul) {
148                let mm_id = node.id;
149                let mm_users: Vec<_> = graph.users(mm_id);
150
151                // Check for single-use Add(bias) consumer
152                if mm_users.len() == 1 {
153                    let add_node = graph.node(mm_users[0]);
154                    if let Op::Binary(BinaryOp::Add) = &add_node.op {
155                        // Determine which input is the bias (the non-matmul one)
156                        let (bias_id, _mm_input) = if add_node.inputs[0] == mm_id {
157                            (add_node.inputs[1], add_node.inputs[0])
158                        } else {
159                            (add_node.inputs[0], add_node.inputs[1])
160                        };
161
162                        // Check if bias is a param/const with broadcastable shape
163                        let bias_shape = graph.shape(bias_id);
164                        if bias_shape.rank() <= 1 {
165                            let add_id = add_node.id;
166                            let add_users = graph.users(add_id);
167
168                            // Check for activation consumer
169                            let mut activation = None;
170                            let mut act_id = None;
171                            if add_users.len() == 1 {
172                                let act_node = graph.node(add_users[0]);
173                                if let Op::Activation(a) = &act_node.op
174                                    && fusible_mm_bias_epilogue_activation(*a)
175                                {
176                                    activation = Some(*a);
177                                    act_id = Some(act_node.id);
178                                }
179                            }
180
181                            // Emit fused node. Bias may be declared after
182                            // the matmul in the source graph — copy it early
183                            // instead of requiring builders to order params first.
184                            let out_shape = if let Some(aid) = act_id {
185                                graph.shape(aid).clone()
186                            } else {
187                                add_node.shape.clone()
188                            };
189
190                            rw.ensure_mapped(&graph, &[node.inputs[0], node.inputs[1], bias_id]);
191                            let fused_id = rw.add_fused(
192                                Op::FusedMatMulBiasAct { activation },
193                                &[node.inputs[0], node.inputs[1], bias_id],
194                                out_shape,
195                            );
196
197                            // Map old nodes to the fused result
198                            rw.replace(mm_id, fused_id);
199                            rw.replace(add_id, fused_id);
200                            fused_away.insert(add_id, ());
201                            if let Some(aid) = act_id {
202                                rw.replace(aid, fused_id);
203                                fused_away.insert(aid, ());
204                            }
205                            continue;
206                        }
207                    }
208                }
209            }
210
211            // No fusion — copy as-is
212            rw.copy_node(node);
213        }
214
215        rw.finish(&graph.outputs)
216    }
217}
218
219// ── Pass 2: Add(residual) + LayerNorm → FusedResidualLN ─────────────────
220
221/// Fuses `add(x, residual) → layer_norm` into FusedResidualLN.
222///
223/// Also detects `add(x, residual) → add(bias) → layer_norm` for the
224/// bias variant (used in BERT's output projection).
225pub struct FuseResidualLN;
226
227impl Pass for FuseResidualLN {
228    fn name(&self) -> &str {
229        "fuse_residual_ln"
230    }
231
232    fn run(&self, graph: Graph) -> Graph {
233        // Graph outputs hold implicit references to their producing
234        // nodes that don't show up in any node's `inputs` (use_count
235        // walks node inputs only). Treat being-a-graph-output as a
236        // use so we don't fuse-away an intermediate the caller still
237        // wants to read — this used to silently corrupt multi-block
238        // encoders (e.g. SAM 2 stage outputs) by collapsing the
239        // residual add of block N into block N+1's LN.
240        let mut is_output: HashMap<NodeId, ()> = HashMap::new();
241        for &oid in &graph.outputs {
242            is_output.insert(oid, ());
243        }
244        // Pre-scan: find all Add nodes consumed by LayerNorm
245        let mut fused_away: HashMap<NodeId, ()> = HashMap::new();
246        for node in graph.nodes() {
247            if let Op::LayerNorm { .. } = &node.op {
248                let ln_input_id = node.inputs[0];
249                let ln_input = graph.node(ln_input_id);
250                if matches!(ln_input.op, Op::Binary(BinaryOp::Add))
251                    && graph.use_count(ln_input_id) == 1
252                    && !is_output.contains_key(&ln_input_id)
253                {
254                    fused_away.insert(ln_input_id, ());
255                }
256            }
257        }
258
259        let mut rw = Rewriter::new(&graph.name);
260
261        for node in graph.nodes() {
262            if fused_away.contains_key(&node.id) {
263                continue;
264            }
265
266            if let Op::LayerNorm { eps, .. } = &node.op {
267                let ln_input_id = node.inputs[0];
268                let ln_input = graph.node(ln_input_id);
269
270                if matches!(ln_input.op, Op::Binary(BinaryOp::Add))
271                    && fused_away.contains_key(&ln_input_id)
272                {
273                    let (x_id, residual_id) = (ln_input.inputs[0], ln_input.inputs[1]);
274                    let gamma_id = node.inputs[1];
275                    let beta_id = node.inputs[2];
276
277                    let fused_id = rw.add_fused(
278                        Op::FusedResidualLN {
279                            has_bias: false,
280                            eps: *eps,
281                        },
282                        &[x_id, residual_id, gamma_id, beta_id],
283                        node.shape.clone(),
284                    );
285
286                    rw.replace(ln_input_id, fused_id);
287                    rw.replace(node.id, fused_id);
288                    continue;
289                }
290            }
291
292            rw.copy_node(node);
293        }
294
295        rw.finish(&graph.outputs)
296    }
297}
298
299// ── Pass 2b: Add(residual) + RmsNorm → FusedResidualRmsNorm ─────────────
300
301/// Fuses `add(x, residual) → rms_norm` into [`Op::FusedResidualRmsNorm`].
302pub struct FuseResidualRmsNorm;
303
304impl Pass for FuseResidualRmsNorm {
305    fn name(&self) -> &str {
306        "fuse_residual_rms_norm"
307    }
308
309    fn run(&self, graph: Graph) -> Graph {
310        let mut is_output: HashMap<NodeId, ()> = HashMap::new();
311        for &oid in &graph.outputs {
312            is_output.insert(oid, ());
313        }
314        let mut fused_away: HashMap<NodeId, ()> = HashMap::new();
315        for node in graph.nodes() {
316            if let Op::RmsNorm { .. } = &node.op {
317                let rn_input_id = node.inputs[0];
318                let rn_input = graph.node(rn_input_id);
319                if matches!(rn_input.op, Op::Binary(BinaryOp::Add))
320                    && graph.use_count(rn_input_id) == 1
321                    && !is_output.contains_key(&rn_input_id)
322                {
323                    fused_away.insert(rn_input_id, ());
324                }
325            }
326        }
327
328        let mut rw = Rewriter::new(&graph.name);
329
330        for node in graph.nodes() {
331            if fused_away.contains_key(&node.id) {
332                continue;
333            }
334
335            if let Op::RmsNorm { eps, .. } = &node.op {
336                let rn_input_id = node.inputs[0];
337                let rn_input = graph.node(rn_input_id);
338
339                if matches!(rn_input.op, Op::Binary(BinaryOp::Add))
340                    && fused_away.contains_key(&rn_input_id)
341                {
342                    let (x_id, residual_id) = (rn_input.inputs[0], rn_input.inputs[1]);
343                    let gamma_id = node.inputs[1];
344                    let beta_id = node.inputs[2];
345
346                    let fused_id = rw.add_fused(
347                        Op::FusedResidualRmsNorm {
348                            has_bias: false,
349                            eps: *eps,
350                        },
351                        &[x_id, residual_id, gamma_id, beta_id],
352                        node.shape.clone(),
353                    );
354
355                    rw.replace(rn_input_id, fused_id);
356                    rw.replace(node.id, fused_id);
357                    continue;
358                }
359            }
360
361            rw.copy_node(node);
362        }
363
364        rw.finish(&graph.outputs)
365    }
366}
367
368// ── Pass 2c: RmsNorm → Reshape(leading flatten) ─────────────────────────
369
370/// Fuses `rms_norm([…, H]) → reshape([∏leading, H])` into a single
371/// `RmsNorm` with the flattened output shape, eliminating a memcpy.
372///
373/// Matches the Qwen3.5 pre-norm pattern where normalized activations
374/// are immediately reshaped to 2-D for matmul.
375pub struct FuseRmsNormReshape;
376
377fn leading_flatten_shape(in_shape: &Shape, new_shape: &[i64]) -> Option<Shape> {
378    rlx_ir::shape::leading_flatten_shape(in_shape, new_shape)
379}
380
381fn sole_consumer(graph: &Graph, id: NodeId) -> Option<NodeId> {
382    graph
383        .nodes()
384        .iter()
385        .find(|n| n.inputs.contains(&id))
386        .map(|n| n.id)
387}
388
389impl Pass for FuseRmsNormReshape {
390    fn name(&self) -> &str {
391        "fuse_rms_norm_reshape"
392    }
393
394    fn run(&self, graph: Graph) -> Graph {
395        let mut is_output: HashMap<NodeId, ()> = HashMap::new();
396        for &oid in &graph.outputs {
397            is_output.insert(oid, ());
398        }
399
400        let mut flat_shape: HashMap<NodeId, Shape> = HashMap::new();
401        let mut fused_away: HashMap<NodeId, ()> = HashMap::new();
402        for node in graph.nodes() {
403            if let Op::RmsNorm { .. } = &node.op {
404                if graph.use_count(node.id) != 1 || is_output.contains_key(&node.id) {
405                    continue;
406                }
407                let Some(reshape_id) = sole_consumer(&graph, node.id) else {
408                    continue;
409                };
410                if is_output.contains_key(&reshape_id) {
411                    continue;
412                }
413                let reshape = graph.node(reshape_id);
414                if let Op::Reshape { new_shape } = &reshape.op {
415                    if let Some(flat) = leading_flatten_shape(&node.shape, new_shape) {
416                        flat_shape.insert(node.id, flat);
417                        fused_away.insert(reshape_id, ());
418                    }
419                }
420            }
421        }
422
423        let mut rw = Rewriter::new(&graph.name);
424
425        for node in graph.nodes() {
426            if fused_away.contains_key(&node.id) {
427                continue;
428            }
429
430            if let Op::RmsNorm { axis, eps, .. } = &node.op {
431                if let Some(flat) = flat_shape.get(&node.id) {
432                    let Some(reshape_id) = sole_consumer(&graph, node.id) else {
433                        rw.copy_node(node);
434                        continue;
435                    };
436                    let fused_id = rw.add_fused(
437                        Op::RmsNorm {
438                            axis: *axis,
439                            eps: *eps,
440                        },
441                        &node.inputs,
442                        flat.clone(),
443                    );
444                    rw.replace(node.id, fused_id);
445                    rw.replace(reshape_id, fused_id);
446                    continue;
447                }
448            }
449
450            rw.copy_node(node);
451        }
452
453        rw.finish(&graph.outputs)
454    }
455}
456
457// ── Pass 3b: Dual MatMul SwiGLU (gate+up before shared-input concat) ─────
458
459/// Fuses the common LLM FFN pattern in one rewrite:
460///   gate = matmul(x, wg); up = matmul(x, wu); out = mul(silu(gate), up)
461///
462/// Becomes:
463///   cat = matmul(x, concat(wu, wg))   // up weights first for kernel layout
464///   out = fused_swiglu(cat)
465///
466/// Eliminates two `[..., N]` matmul outputs plus a silu buffer — the
467/// largest memory win on transformer FFN blocks.
468pub struct FuseSwiGLUDualMatmul;
469
470impl FuseSwiGLUDualMatmul {
471    fn match_dual_swiglu(
472        graph: &Graph,
473        mul_node: &Node,
474    ) -> Option<(NodeId, NodeId, NodeId, NodeId, NodeId)> {
475        if !matches!(mul_node.op, Op::Binary(BinaryOp::Mul)) {
476            return None;
477        }
478        let lhs = graph.node(mul_node.inputs[0]);
479        let rhs = graph.node(mul_node.inputs[1]);
480        let (up_mm, silu_id, silu_node) = if matches!(rhs.op, Op::Activation(Activation::Silu)) {
481            (lhs, mul_node.inputs[1], rhs)
482        } else if matches!(lhs.op, Op::Activation(Activation::Silu)) {
483            (rhs, mul_node.inputs[0], lhs)
484        } else {
485            return None;
486        };
487        if !matches!(up_mm.op, Op::MatMul) {
488            return None;
489        }
490        let gate_mm = graph.node(silu_node.inputs[0]);
491        if !matches!(gate_mm.op, Op::MatMul) {
492            return None;
493        }
494        if up_mm.inputs[0] != gate_mm.inputs[0] {
495            return None;
496        }
497        if graph.use_count(silu_id) != 1 {
498            return None;
499        }
500        Some((mul_node.id, gate_mm.id, up_mm.id, up_mm.inputs[0], silu_id))
501    }
502}
503
504impl Pass for FuseSwiGLUDualMatmul {
505    fn name(&self) -> &str {
506        "fuse_swiglu_dual_matmul"
507    }
508
509    fn run(&self, graph: Graph) -> Graph {
510        let mut matches: Vec<(NodeId, NodeId, NodeId, NodeId, NodeId)> = Vec::new();
511        let mut consumed: HashMap<NodeId, ()> = HashMap::new();
512
513        for node in graph.nodes() {
514            if let Some((mul_id, gate_mm, up_mm, _, silu_id)) =
515                Self::match_dual_swiglu(&graph, node)
516            {
517                matches.push((mul_id, gate_mm, up_mm, graph.node(up_mm).inputs[0], silu_id));
518                consumed.insert(gate_mm, ());
519                consumed.insert(up_mm, ());
520                consumed.insert(silu_id, ());
521            }
522        }
523
524        if matches.is_empty() {
525            return graph;
526        }
527
528        let match_by_mul: HashMap<NodeId, (NodeId, NodeId, NodeId)> = matches
529            .into_iter()
530            .map(|(mul, gate, up, input, _silu)| (mul, (gate, up, input)))
531            .collect();
532
533        let mut rw = Rewriter::new(&graph.name);
534        for node in graph.nodes() {
535            if consumed.contains_key(&node.id) {
536                continue;
537            }
538            if let Some(&(gate_mm, up_mm, input_id)) = match_by_mul.get(&node.id) {
539                let gate = graph.node(gate_mm);
540                let up = graph.node(up_mm);
541                let wg = gate.inputs[1];
542                let wu = up.inputs[1];
543                rw.ensure_mapped(&graph, &[input_id, wg, wu]);
544
545                let wu_shape = graph.shape(wu);
546                let wg_shape = graph.shape(wg);
547                let k = wu_shape.dim(0).unwrap_static();
548                let n_up = wu_shape.dim(1).unwrap_static();
549                let n_gate = wg_shape.dim(1).unwrap_static();
550                debug_assert_eq!(wu_shape.dim(0), wg_shape.dim(0));
551
552                // Up weights first → canonical FusedSwiGLU layout (gate_first=false).
553                let concat_shape = Shape::new(&[k, n_up + n_gate], wu_shape.dtype());
554                let concat_w = rw.add_fused(Op::Concat { axis: 1 }, &[wu, wg], concat_shape);
555
556                let out_rank = up.shape.rank();
557                let mut mm_dims: Vec<Dim> = (0..out_rank).map(|i| up.shape.dim(i)).collect();
558                mm_dims[out_rank - 1] = Dim::Static(n_up + n_gate);
559                let cat_shape = Shape::from_dims(&mm_dims, up.shape.dtype());
560                let cat_id =
561                    rw.new_graph
562                        .add_node(Op::MatMul, vec![rw.map(input_id), concat_w], cat_shape);
563
564                let fused_id = rw.new_graph.add_node(
565                    Op::FusedSwiGLU {
566                        cast_to: None,
567                        gate_first: false,
568                    },
569                    vec![cat_id],
570                    node.shape.clone(),
571                );
572                rw.replace(node.id, fused_id);
573                continue;
574            }
575            rw.copy_node(node);
576        }
577        rw.finish(&graph.outputs)
578    }
579}
580
581// ── Pass 3: Shared-input MatMul concat (QKV, SwiGLU fc11+fc12) ──────────
582
583/// Detects two MatMul nodes with the same input and concatenates their
584/// weight matrices into a single larger MatMul.
585///
586/// Pattern:
587///   %a = matmul(%x, %w1)
588///   %b = matmul(%x, %w2)
589/// Becomes:
590///   %ab = matmul(%x, concat(%w1, %w2))
591///   %a = narrow(%ab, ..., 0, n1)
592///   %b = narrow(%ab, ..., n1, n2)
593///
594/// This saves one full input read (the shared input is read once instead
595/// of twice). Critical for SwiGLU (fc11+fc12) and QKV fusion.
596pub struct FuseSharedInputMatMul;
597
598impl Pass for FuseSharedInputMatMul {
599    fn name(&self) -> &str {
600        "fuse_shared_input_matmul"
601    }
602
603    fn run(&self, graph: Graph) -> Graph {
604        struct FuseGroup {
605            input_id: NodeId,
606            matmul_ids: Vec<NodeId>,
607        }
608
609        let mut input_to_matmuls: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
610        for node in graph.nodes() {
611            if matches!(node.op, Op::MatMul) {
612                input_to_matmuls
613                    .entry(node.inputs[0])
614                    .or_default()
615                    .push(node.id);
616            }
617        }
618
619        let mut groups: Vec<FuseGroup> = Vec::new();
620        for (input_id, matmul_ids) in input_to_matmuls {
621            if matmul_ids.len() < 2 {
622                continue;
623            }
624            let first = graph.node(matmul_ids[0]);
625            let w0 = graph.shape(first.inputs[1]);
626            if w0.rank() != 2 {
627                continue;
628            }
629            let compatible = matmul_ids.iter().all(|&id| {
630                let m = graph.node(id);
631                matches!(m.op, Op::MatMul)
632                    && graph.shape(m.inputs[1]).rank() == 2
633                    && graph.shape(m.inputs[1]).dim(0) == w0.dim(0)
634            });
635            if compatible {
636                groups.push(FuseGroup {
637                    input_id,
638                    matmul_ids,
639                });
640            }
641        }
642
643        if groups.is_empty() {
644            return graph;
645        }
646
647        let group_by_first: HashMap<NodeId, &FuseGroup> =
648            groups.iter().map(|g| (g.matmul_ids[0], g)).collect();
649
650        let mut fused_away: HashMap<NodeId, ()> = HashMap::new();
651        for g in &groups {
652            for &id in &g.matmul_ids[1..] {
653                fused_away.insert(id, ());
654            }
655        }
656
657        let mut rw = Rewriter::new(&graph.name);
658        for node in graph.nodes() {
659            if fused_away.contains_key(&node.id) {
660                continue;
661            }
662
663            if let Some(group) = group_by_first.get(&node.id) {
664                let matmuls: Vec<_> = group.matmul_ids.iter().map(|&id| graph.node(id)).collect();
665                let weight_ids: Vec<NodeId> = matmuls.iter().map(|m| m.inputs[1]).collect();
666                rw.ensure_mapped(&graph, std::slice::from_ref(&group.input_id));
667                rw.ensure_mapped(&graph, &weight_ids);
668
669                let w0_shape = graph.shape(weight_ids[0]);
670                let k = w0_shape.dim(0).unwrap_static();
671                let ns: Vec<usize> = weight_ids
672                    .iter()
673                    .map(|&w| graph.shape(w).dim(1).unwrap_static())
674                    .collect();
675                let combined_n: usize = ns.iter().sum();
676
677                let concat_shape = Shape::new(&[k, combined_n], w0_shape.dtype());
678                let concat_id = rw.add_fused(Op::Concat { axis: 1 }, &weight_ids, concat_shape);
679
680                let out_rank = matmuls[0].shape.rank();
681                let mut mm_dims: Vec<Dim> =
682                    (0..out_rank).map(|i| matmuls[0].shape.dim(i)).collect();
683                mm_dims[out_rank - 1] = Dim::Static(combined_n);
684                let mm_shape = Shape::from_dims(&mm_dims, matmuls[0].shape.dtype());
685                let mm_id = rw.new_graph.add_node(
686                    Op::MatMul,
687                    vec![rw.map(group.input_id), concat_id],
688                    mm_shape,
689                );
690
691                let mut start = 0usize;
692                for (mm, &n) in matmuls.iter().zip(&ns) {
693                    let narrow = rw.new_graph.add_node(
694                        Op::Narrow {
695                            axis: out_rank - 1,
696                            start,
697                            len: n,
698                        },
699                        vec![mm_id],
700                        mm.shape.clone(),
701                    );
702                    rw.replace(mm.id, narrow);
703                    start += n;
704                }
705                continue;
706            }
707
708            rw.copy_node(node);
709        }
710
711        rw.finish(&graph.outputs)
712    }
713}
714
715// ── Pass 4: Detect SwiGLU pattern → FusedSwiGLU ────────────────────────
716
717/// Detects the post-`FuseSharedInputMatMul` SwiGLU pattern and replaces it
718/// with a single `Op::FusedSwiGLU` node consuming the concatenated matmul.
719///
720/// Pattern (after `FuseSharedInputMatMul` has fused fc11+fc12 into one mm):
721///   %cat   = matmul(%x, concat(%fc11_w, %fc12_w))   ; shape [..., 2N]
722///   %up    = narrow(%cat, axis=-1, 0, N)            ; shape [..., N]
723///   %gate  = narrow(%cat, axis=-1, N, N)            ; shape [..., N]
724///   %silu  = silu(%gate)
725///   %out   = mul(%up, %silu)
726///
727/// Becomes:
728///   %out   = fused_swiglu(%cat)
729///
730/// Saves three kernel launches (two narrows + silu + mul → one kernel) and
731/// keeps up/gate resident in registers.
732///
733/// Single-use guard: only fuses when each intermediate (narrow, narrow, silu)
734/// has exactly one consumer. The mul may have any number of consumers.
735pub struct FuseSwiGLU;
736
737impl Pass for FuseSwiGLU {
738    fn name(&self) -> &str {
739        "fuse_swiglu"
740    }
741
742    fn run(&self, graph: Graph) -> Graph {
743        // Scan for Mul nodes whose two inputs match the SwiGLU pattern.
744        // Collect rewrites first, then rebuild.
745        // up_narrow_id / silu_id / gate_narrow_id are kept for pattern-shape
746        // self-documentation even though only the rewrite path reads
747        // mul_id / cat_id / out_n.
748        #[allow(dead_code)]
749        struct Match {
750            mul_id: NodeId,
751            up_narrow_id: NodeId,
752            silu_id: NodeId,
753            gate_narrow_id: NodeId,
754            cat_id: NodeId,
755            out_n: usize,
756            gate_first: bool,
757        }
758
759        let mut matches: Vec<Match> = Vec::new();
760        let mut consumed: HashMap<NodeId, ()> = HashMap::new();
761
762        for node in graph.nodes() {
763            // Looking for: mul(narrow(cat, 0, n), silu(narrow(cat, n, n)))
764            //   — or symmetrically with up/gate swapped.
765            if !matches!(node.op, Op::Binary(BinaryOp::Mul)) {
766                continue;
767            }
768            let lhs_id = node.inputs[0];
769            let rhs_id = node.inputs[1];
770            let lhs = graph.node(lhs_id);
771            let rhs = graph.node(rhs_id);
772
773            // Decide which side is silu(gate) — the silu branch.
774            let (up_narrow, silu_id, silu_node) =
775                if matches!(rhs.op, Op::Activation(Activation::Silu)) {
776                    (lhs, rhs_id, rhs)
777                } else if matches!(lhs.op, Op::Activation(Activation::Silu)) {
778                    (rhs, lhs_id, lhs)
779                } else {
780                    continue;
781                };
782
783            // up side must be a Narrow.
784            let (up_axis, up_start, up_len) = match &up_narrow.op {
785                Op::Narrow { axis, start, len } => (*axis, *start, *len),
786                _ => continue,
787            };
788            // silu input must be a Narrow.
789            let gate_narrow_id = silu_node.inputs[0];
790            let gate_narrow = graph.node(gate_narrow_id);
791            let (g_axis, g_start, g_len) = match &gate_narrow.op {
792                Op::Narrow { axis, start, len } => (*axis, *start, *len),
793                _ => continue,
794            };
795
796            // Both narrows must come from the same source on the same axis,
797            // covering the two halves: (0..N) and (N..2N).
798            if up_narrow.inputs[0] != gate_narrow.inputs[0] {
799                continue;
800            }
801            if up_axis != g_axis {
802                continue;
803            }
804            if up_len != g_len {
805                continue;
806            }
807            let n = up_len;
808            // Canonical: up @ 0, gate @ N. Swapped (gate-first builders): gate @ 0, up @ N.
809            let gate_first = up_start == n && g_start == 0;
810            if !(gate_first || (up_start == 0 && g_start == n)) {
811                continue;
812            }
813
814            // Single-use checks: narrows feed only into silu+mul, silu feeds
815            // only into mul. The cat itself can have arbitrary other users.
816            if graph.use_count(up_narrow.id) != 1 {
817                continue;
818            }
819            if graph.use_count(gate_narrow_id) != 1 {
820                continue;
821            }
822            if graph.use_count(silu_id) != 1 {
823                continue;
824            }
825
826            matches.push(Match {
827                mul_id: node.id,
828                up_narrow_id: up_narrow.id,
829                silu_id,
830                gate_narrow_id,
831                cat_id: up_narrow.inputs[0],
832                out_n: n,
833                gate_first,
834            });
835            consumed.insert(up_narrow.id, ());
836            consumed.insert(gate_narrow_id, ());
837            consumed.insert(silu_id, ());
838        }
839
840        if matches.is_empty() {
841            return graph;
842        }
843
844        // Rebuild graph, replacing matched mul nodes with FusedSwiGLU.
845        let mut rw = Rewriter::new(&graph.name);
846        let match_by_mul: HashMap<NodeId, &Match> = matches.iter().map(|m| (m.mul_id, m)).collect();
847
848        for node in graph.nodes() {
849            if consumed.contains_key(&node.id) {
850                continue;
851            }
852
853            if let Some(m) = match_by_mul.get(&node.id) {
854                // Output shape = mul's output shape (= [..., N]).
855                let out_shape = node.shape.clone();
856                debug_assert_eq!(
857                    out_shape.dim(out_shape.rank() - 1).unwrap_static(),
858                    m.out_n,
859                    "FuseSwiGLU: output last dim should be N"
860                );
861                let fused_id = rw.add_fused(
862                    Op::FusedSwiGLU {
863                        cast_to: None,
864                        gate_first: m.gate_first,
865                    },
866                    &[m.cat_id],
867                    out_shape,
868                );
869                rw.replace(node.id, fused_id);
870                continue;
871            }
872
873            rw.copy_node(node);
874        }
875
876        rw.finish(&graph.outputs)
877    }
878}
879
880// ── Pass 5: Fuse Attention Block (QKV → SDPA → OutProj) ────────────────
881
882/// Fuses `matmul(QKV) → narrow(Q,K,V) → [rope] → attention → matmul(out)`
883/// into a single FusedAttentionBlock when batch*seq is small.
884///
885/// The optimizer auto-detects batch size from graph input shapes. For small
886/// inputs (batch*seq ≤ 64), intermediate tensors fit in L1 cache, making a
887/// monolithic kernel faster than separate BLAS calls.
888///
889/// Threshold is configurable via `RLX_FUSE_ATTN_THRESHOLD` (default: 64).
890pub struct FuseAttentionBlock;
891
892impl FuseAttentionBlock {
893    /// Check if the graph has small enough inputs to benefit from fusion.
894    /// Currently unused — `Pass::run` is a no-op since attention fusion
895    /// happens at thunk-compile time, not graph-rewrite time. Kept here
896    /// for the planned graph-level rewrite path.
897    #[allow(dead_code)]
898    fn should_fuse(graph: &Graph) -> bool {
899        let threshold: usize = rlx_ir::env::var("RLX_FUSE_ATTN_THRESHOLD")
900            .and_then(|v| v.parse().ok())
901            .unwrap_or(64);
902        for node in graph.nodes() {
903            if let Op::Input { .. } = &node.op
904                && node.shape.rank() >= 2
905            {
906                let d0 = node.shape.dim(0);
907                let d1 = node.shape.dim(1);
908                if d0.is_static() && d1.is_static() {
909                    let b = d0.unwrap_static();
910                    let s = d1.unwrap_static();
911                    if b * s <= threshold {
912                        return true;
913                    }
914                }
915            }
916        }
917        false
918    }
919}
920
921impl Pass for FuseAttentionBlock {
922    fn name(&self) -> &str {
923        "fuse_attention_block"
924    }
925
926    fn run(&self, graph: Graph) -> Graph {
927        // Attention block fusion is done at the thunk level (compile_thunks)
928        // instead of the graph level, to avoid complex Rewriter issues.
929        // This pass is a no-op; the thunk compiler handles it directly.
930        graph
931    }
932}
933
934// ── PLAN L2: MarkElementwiseRegions ─────────────────────────────────────
935//
936// Walk the graph and collapse maximal chains of element-wise ops
937// (Activation / Cast / Binary / Compare) into a single
938// `Op::ElementwiseRegion`. Conditions for inclusion in a chain:
939//   - Op is element-wise per `is_elementwise()` (excluding Where which
940//     has a 3-input mask semantic that doesn't compose into a single
941//     scalar register chain cleanly — keep as separate op for now).
942//   - Output shape exactly equals every input shape (no broadcast —
943//     broadcast scalar/vector adds register-pattern complexity, defer).
944//   - Every intermediate (chain-internal) value has exactly one
945//     consumer in the *whole* graph. Multi-consumer values must
946//     materialize.
947// The chain start can read graph-level inputs / params / earlier-fused
948// nodes; the chain end is the last single-consumer or terminal node.
949// This is the simplest correct cut — N-ary chain fusion replaces the
950// pairwise `fuse_elementwise_chains` pattern in each backend with one
951// IR-level pass + a single backend kernel. See PLAN L2.
952//
953// Fusion boundaries: chains do not extend across inputs whose producer
954// satisfies [`rlx_ir::Op::is_fusion_boundary`] (BLAS, Gaussian splat, …).
955
956pub struct MarkElementwiseRegions;
957
958impl Pass for MarkElementwiseRegions {
959    fn name(&self) -> &str {
960        "mark_elementwise_regions"
961    }
962
963    fn run(&self, graph: Graph) -> Graph {
964        // Tally consumer counts for every node id.
965        let mut consumers: HashMap<NodeId, usize> = HashMap::new();
966        for node in graph.nodes() {
967            for &input in &node.inputs {
968                *consumers.entry(input).or_insert(0) += 1;
969            }
970        }
971        for &out in &graph.outputs {
972            *consumers.entry(out).or_insert(0) += 1;
973        }
974
975        // Predicate: does this op qualify for chain inclusion?
976        let chain_eligible = |op: &Op| -> bool {
977            matches!(
978                op,
979                Op::Activation(_) | Op::Cast { .. } | Op::Binary(_) | Op::Compare(_) | Op::Where
980            )
981        };
982
983        // Per-node refinement: a `Cast { to }` only qualifies when the
984        // destination dtype matches the operand's dtype. The chain
985        // kernel runs entirely in f32 register scratch and writes the
986        // tail back to the output node's arena slot — which is sized
987        // for the tail dtype. A cross-dtype Cast inside the chain would
988        // lose precision (no actual conversion happens in scratch) AND
989        // mis-size the final write (an F16 output slot is half the
990        // bytes of f32). Same-dtype Casts are trivially propagated.
991        let chain_step_safe = |graph: &Graph, node: &rlx_ir::Node| -> bool {
992            match &node.op {
993                Op::Cast { to } => {
994                    let in_dt = graph.shape(node.inputs[0]).dtype();
995                    *to == in_dt
996                }
997                _ => true,
998            }
999        };
1000
1001        // For each node, compute which "chain root" it belongs to.
1002        // A chain consists of a sequence of single-consumer chain-eligible
1003        // nodes leading to a chain "tail" (last node before a multi-consumer
1004        // or non-eligible boundary). We assign each node a `region_id`
1005        // (= the tail's NodeId) iff it's part of a region with ≥2 ops.
1006        // Walk in topological (forward) order; for each chain-eligible
1007        // node whose every input is either non-region OR a single-consumer
1008        // region member, extend its parent chain.
1009        let mut region_of: HashMap<NodeId, NodeId> = HashMap::new();
1010        let mut chain_step_idx: HashMap<NodeId, u32> = HashMap::new();
1011
1012        for node in graph.nodes() {
1013            if !chain_eligible(&node.op) {
1014                continue;
1015            }
1016            if !chain_step_safe(&graph, node) {
1017                continue;
1018            }
1019            // Each input must either match the output element count
1020            // exactly OR be a trailing-shape broadcast (its element
1021            // count divides the output's). The kernel reads
1022            // `arena[input_offs[i] + (gid % input_modulus[i])]` for
1023            // broadcast inputs; non-broadcast inputs leave the modulus
1024            // at 0 to skip the modulo.
1025            let out_shape = &node.shape;
1026            let out_elems = out_shape.num_elements();
1027            let shape_ok = node.inputs.iter().all(|id| {
1028                let in_elems = graph.shape(*id).num_elements();
1029                match (in_elems, out_elems) {
1030                    (Some(i), Some(o)) if i == o => true,
1031                    (Some(i), Some(o)) if i > 0 && o % i == 0 => true,
1032                    _ => false,
1033                }
1034            });
1035            if !shape_ok {
1036                continue;
1037            }
1038            // A chain extends an input's chain when the input is itself
1039            // chain-eligible AND has exactly one consumer (= this node).
1040            // If multiple inputs satisfy this, the chains must be the same
1041            // (= they share a chain root); pick that root.
1042            let mut parent_root: Option<NodeId> = None;
1043            let mut all_inputs_single_consumer = true;
1044            for &input in &node.inputs {
1045                // BLAS / splat render ops are explicit fusion boundaries.
1046                if graph.node(input).op.is_fusion_boundary() {
1047                    parent_root = None;
1048                    all_inputs_single_consumer = false;
1049                    break;
1050                }
1051                if let Some(&root) = region_of.get(&input) {
1052                    if consumers.get(&input).copied() != Some(1) {
1053                        all_inputs_single_consumer = false;
1054                        break;
1055                    }
1056                    match parent_root {
1057                        None => parent_root = Some(root),
1058                        Some(r) if r == root => {}
1059                        Some(_) => {
1060                            parent_root = None;
1061                            all_inputs_single_consumer = false;
1062                            break;
1063                        }
1064                    }
1065                }
1066            }
1067            if !all_inputs_single_consumer {
1068                // Start a fresh chain rooted at this node.
1069                region_of.insert(node.id, node.id);
1070                chain_step_idx.insert(node.id, 0);
1071                continue;
1072            }
1073            let root = parent_root.unwrap_or(node.id);
1074            // step idx = max(parents' idx in same chain) + 1
1075            let next_idx = node
1076                .inputs
1077                .iter()
1078                .filter_map(|id| {
1079                    if region_of.get(id) == Some(&root) {
1080                        chain_step_idx.get(id).copied()
1081                    } else {
1082                        None
1083                    }
1084                })
1085                .max()
1086                .map(|m| m + 1)
1087                .unwrap_or(0);
1088            let limits = crate::limits::active_fusion_limits();
1089            if next_idx >= limits.max_elementwise_steps {
1090                region_of.insert(node.id, node.id);
1091                chain_step_idx.insert(node.id, 0);
1092                continue;
1093            }
1094            region_of.insert(node.id, root);
1095            chain_step_idx.insert(node.id, next_idx);
1096        }
1097
1098        // Group nodes by region_id; only regions with ≥2 nodes are worth fusing.
1099        // The "region tail" (= last node) becomes the new ElementwiseRegion node.
1100        let mut by_region: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
1101        for node in graph.nodes() {
1102            if let Some(&root) = region_of.get(&node.id) {
1103                by_region.entry(root).or_default().push(node.id);
1104            }
1105        }
1106
1107        // Each region's "tail" is the node with the highest chain_step_idx.
1108        // For correctness, the tail must be the only node in the region with
1109        // a non-region or multi-consumer outflow — otherwise the region would
1110        // span past it. Skip regions where the tail isn't unique (= chain
1111        // forks internally).
1112        let mut tail_of_region: HashMap<NodeId, NodeId> = HashMap::new();
1113        for (root, members) in &by_region {
1114            if members.len() < 2 {
1115                continue;
1116            }
1117            let max_idx = members.iter().map(|id| chain_step_idx[id]).max().unwrap();
1118            let tails: Vec<_> = members
1119                .iter()
1120                .filter(|id| chain_step_idx[id] == max_idx)
1121                .collect();
1122            if tails.len() != 1 {
1123                continue;
1124            }
1125            tail_of_region.insert(*root, *tails[0]);
1126        }
1127
1128        // Drop "regions" that aren't worth fusing (size < 2 or non-unique tail).
1129        let by_region: HashMap<NodeId, Vec<NodeId>> = by_region
1130            .into_iter()
1131            .filter(|(root, _)| tail_of_region.contains_key(root))
1132            .collect();
1133
1134        if by_region.is_empty() {
1135            return graph;
1136        }
1137
1138        // Rewrite the graph: copy non-region nodes verbatim; for each region,
1139        // emit a single ElementwiseRegion at the tail's position (in topo order)
1140        // and replace each region member's NodeId in the id map with that.
1141        let mut rw = Rewriter::new(&graph.name);
1142        // Track region nodes already emitted (we emit at tail's topo position).
1143        let mut emitted_region: HashMap<NodeId, NodeId> = HashMap::new();
1144
1145        for node in graph.nodes() {
1146            if let Some(&root) = region_of.get(&node.id)
1147                && let Some(&tail) = tail_of_region.get(&root)
1148            {
1149                if emitted_region.contains_key(&root) {
1150                    // Member but tail already emitted (or not tail). Map to
1151                    // either the new region node (if tail) or to a sentinel
1152                    // we never look up directly. Internal members are not
1153                    // referenced after fusion (single-consumer guarantee),
1154                    // so we map them to the region node id for safety.
1155                    let region_new = emitted_region[&root];
1156                    rw.replace(node.id, region_new);
1157                    continue;
1158                }
1159                if node.id == tail {
1160                    // Sort region members in topological (= chain step) order.
1161                    let members = &by_region[&root];
1162                    let mut ordered: Vec<NodeId> = members.clone();
1163                    ordered.sort_by_key(|id| chain_step_idx[id]);
1164
1165                    // Collect external inputs (chain inputs that aren't members).
1166                    // SSA: each chain step refers to either an external input
1167                    // or a previous step. Build the chain.
1168                    let mut external_inputs: Vec<NodeId> = Vec::new();
1169                    let mut input_idx_of: HashMap<NodeId, u32> = HashMap::new();
1170                    let mut step_idx_of: HashMap<NodeId, u32> = HashMap::new();
1171                    for (i, member_id) in ordered.iter().enumerate() {
1172                        step_idx_of.insert(*member_id, i as u32);
1173                        let n = graph.node(*member_id);
1174                        for &inp in &n.inputs {
1175                            if !step_idx_of.contains_key(&inp) && !input_idx_of.contains_key(&inp) {
1176                                let idx = external_inputs.len() as u32;
1177                                input_idx_of.insert(inp, idx);
1178                                external_inputs.push(inp);
1179                            }
1180                        }
1181                    }
1182
1183                    let limits = crate::limits::active_fusion_limits();
1184                    if external_inputs.len() as u32 > limits.max_elementwise_inputs
1185                        || ordered.len() as u32 > limits.max_elementwise_steps
1186                    {
1187                        for &mid in &ordered {
1188                            rw.copy_node(graph.node(mid));
1189                        }
1190                        continue;
1191                    }
1192
1193                    let resolve = |id: NodeId| -> ChainOperand {
1194                        if let Some(&i) = input_idx_of.get(&id) {
1195                            ChainOperand::Input(i)
1196                        } else {
1197                            ChainOperand::Step(step_idx_of[&id])
1198                        }
1199                    };
1200                    let mut chain: Vec<ChainStep> = Vec::with_capacity(ordered.len());
1201                    for member_id in &ordered {
1202                        let n = graph.node(*member_id);
1203                        let step = match &n.op {
1204                            Op::Activation(a) => ChainStep::Activation(*a, resolve(n.inputs[0])),
1205                            Op::Cast { to } => ChainStep::Cast(*to, resolve(n.inputs[0])),
1206                            Op::Binary(op) => {
1207                                ChainStep::Binary(*op, resolve(n.inputs[0]), resolve(n.inputs[1]))
1208                            }
1209                            Op::Compare(op) => {
1210                                ChainStep::Compare(*op, resolve(n.inputs[0]), resolve(n.inputs[1]))
1211                            }
1212                            Op::Where => ChainStep::Where(
1213                                resolve(n.inputs[0]),
1214                                resolve(n.inputs[1]),
1215                                resolve(n.inputs[2]),
1216                            ),
1217                            _ => unreachable!("non-chain-eligible op in region"),
1218                        };
1219                        chain.push(step);
1220                    }
1221
1222                    // PLAN L2 quality: per-input broadcast metadata.
1223                    // `scalar_input_mask` is the fast-path bitfield
1224                    // (bit `i` set ⇒ input `i` is a single-element
1225                    // scalar). `input_modulus[i]` is the per-input
1226                    // element count: 0 means "no broadcast" (kernel
1227                    // reads gid directly), >0 means tile by modulo.
1228                    // Encoder enforces `out_elems % in_elems == 0`
1229                    // upstream so the modulo divides cleanly.
1230                    let mut scalar_input_mask: u32 = 0;
1231                    let mut input_modulus = [0u32; 16];
1232                    let region_shape_elems = graph.node(tail).shape.num_elements();
1233                    for (i, &ext) in external_inputs.iter().enumerate() {
1234                        if i >= 16 {
1235                            break;
1236                        }
1237                        let in_elems = graph.shape(ext).num_elements();
1238                        match (in_elems, region_shape_elems) {
1239                            (Some(1), Some(o)) if o != 1 => {
1240                                scalar_input_mask |= 1u32 << i;
1241                                input_modulus[i] = 1;
1242                            }
1243                            (Some(i_n), Some(o)) if i_n != o && i_n > 0 => {
1244                                input_modulus[i] = i_n as u32;
1245                            }
1246                            _ => { /* no broadcast: leave modulus 0 */ }
1247                        }
1248                    }
1249                    let region_new = rw.add_fused(
1250                        Op::ElementwiseRegion {
1251                            chain,
1252                            num_inputs: external_inputs.len() as u32,
1253                            scalar_input_mask,
1254                            input_modulus,
1255                        },
1256                        &external_inputs,
1257                        graph.node(tail).shape.clone(),
1258                    );
1259                    emitted_region.insert(root, region_new);
1260                    rw.replace(node.id, region_new);
1261                    continue;
1262                } else {
1263                    // Region member but not tail; skip (will be replaced
1264                    // when the tail is processed).
1265                    rw.replace(node.id, NodeId(u32::MAX)); // sentinel
1266                    continue;
1267                }
1268            }
1269            rw.copy_node(node);
1270        }
1271
1272        // Final cleanup pass: any sentinel id_map entries get rewired to
1273        // their region's emitted node now that emission is done.
1274        // (Actually the order above means tails are processed in topo
1275        // order and members appear before tails in topo order, so by the
1276        // time a member's consumer is rewritten its id_map points to the
1277        // sentinel. Fix-up: walk again, rewrite sentinels.)
1278        // Simpler approach: process region members in second pass.
1279        // The current order processes tail last per region, so non-tail
1280        // members get sentinels. Their consumers are either other region
1281        // members (which we don't directly use the input from) or the
1282        // tail itself. Since the tail builds its own chain via members
1283        // directly from the original graph, the rewriter's id_map for
1284        // non-tail members is only consulted for the tail's input list —
1285        // which we resolve via `external_inputs` (already correctly
1286        // mapped via add_fused → map_inputs). So sentinels are safe.
1287
1288        rw.finish(&graph.outputs)
1289    }
1290}
1291
1292// ── PLAN L2 fallback: UnfuseElementwiseRegions ───────────────────────
1293//
1294// Decompose `Op::ElementwiseRegion` back into its constituent atomic
1295// ops (Activation / Cast / Binary / Compare). The output of the
1296// region is replaced with the result of the chain's last step;
1297// internal step results become individual nodes wired into the rest
1298// of the graph. Used by backends that don't have a native region
1299// kernel — they get the *correctness* of L2's IR-level fusion (no op
1300// missing) without needing to implement region codegen. Run BEFORE
1301// the backend's own lowering. No-op when the graph contains no
1302// ElementwiseRegion nodes.
1303
1304pub struct UnfuseElementwiseRegions;
1305
1306impl Pass for UnfuseElementwiseRegions {
1307    fn name(&self) -> &str {
1308        "unfuse_elementwise_regions"
1309    }
1310
1311    fn run(&self, graph: Graph) -> Graph {
1312        let any_region = graph
1313            .nodes()
1314            .iter()
1315            .any(|n| matches!(n.op, Op::ElementwiseRegion { .. }));
1316        if !any_region {
1317            return graph;
1318        }
1319
1320        let mut rw = Rewriter::new(&graph.name);
1321        for node in graph.nodes() {
1322            if let Op::ElementwiseRegion {
1323                chain,
1324                num_inputs: _,
1325                scalar_input_mask: _,
1326                input_modulus: _,
1327            } = &node.op
1328            {
1329                // Region inputs (in the new graph) — the rewriter has
1330                // already mapped each old input id.
1331                let region_inputs: Vec<NodeId> = node.inputs.iter().map(|id| rw.map(*id)).collect();
1332                let mut step_ids: Vec<NodeId> = Vec::with_capacity(chain.len());
1333                let region_shape = node.shape.clone();
1334                let region_dims: Vec<_> = region_shape.dims().to_vec();
1335                // Per-step result dtype, indexed by step position.
1336                // The chain may pass through Cast steps that change the
1337                // dtype mid-chain; using `region_shape.dtype()` blindly
1338                // would mis-tag intermediate Activation/Binary/Where
1339                // shapes. Track the dtype propagated through each step.
1340                let mut step_dtypes: Vec<rlx_ir::DType> = Vec::with_capacity(chain.len());
1341                let region_dtype = region_shape.dtype();
1342                let dtype_of = |op: &ChainOperand,
1343                                ins: &[NodeId],
1344                                step_dt: &[rlx_ir::DType],
1345                                rw: &Rewriter|
1346                 -> rlx_ir::DType {
1347                    match *op {
1348                        ChainOperand::Input(i) => rw.new_graph.node(ins[i as usize]).shape.dtype(),
1349                        ChainOperand::Step(i) => step_dt[i as usize],
1350                    }
1351                };
1352                // Shape of an operand in the rewritten graph. Critical
1353                // for broadcast inputs: a region whose final shape is
1354                // `[8, 1]` can still have a scalar operand at some
1355                // step; tagging that step with region_dims would lie
1356                // about its element count and trip the binary/activation
1357                // kernels (which size their reads/writes off the IR
1358                // shape, not the broadcast-aware semantics the L2
1359                // region kernel would have used). Use the actual node
1360                // shape so the unfused pipeline matches what each op
1361                // semantically produces.
1362                let shape_of = |op: &ChainOperand,
1363                                ins: &[NodeId],
1364                                step_ids: &[NodeId],
1365                                rw: &Rewriter|
1366                 -> Shape {
1367                    match *op {
1368                        ChainOperand::Input(i) => rw.new_graph.node(ins[i as usize]).shape.clone(),
1369                        ChainOperand::Step(i) => {
1370                            rw.new_graph.node(step_ids[i as usize]).shape.clone()
1371                        }
1372                    }
1373                };
1374                for step in chain {
1375                    let resolve = |op: &ChainOperand| -> NodeId {
1376                        match *op {
1377                            ChainOperand::Input(i) => region_inputs[i as usize],
1378                            ChainOperand::Step(i) => step_ids[i as usize],
1379                        }
1380                    };
1381                    let (new_id, dt) = match step {
1382                        ChainStep::Activation(a, src) => {
1383                            let s = resolve(src);
1384                            let dt = dtype_of(src, &region_inputs, &step_dtypes, &rw);
1385                            // Activation is element-wise: output shape
1386                            // == input shape (preserve broadcast-source
1387                            // shapes; do NOT promote to region_dims).
1388                            let src_shape = shape_of(src, &region_inputs, &step_ids, &rw);
1389                            let dims: Vec<_> = src_shape.dims().to_vec();
1390                            let shape = Shape::from_dims(&dims, dt);
1391                            (
1392                                rw.new_graph.add_node(Op::Activation(*a), vec![s], shape),
1393                                dt,
1394                            )
1395                        }
1396                        ChainStep::Cast(to, src) => {
1397                            let s = resolve(src);
1398                            let src_shape = shape_of(src, &region_inputs, &step_ids, &rw);
1399                            let dims: Vec<_> = src_shape.dims().to_vec();
1400                            let shape = Shape::from_dims(&dims, *to);
1401                            (
1402                                rw.new_graph.add_node(Op::Cast { to: *to }, vec![s], shape),
1403                                *to,
1404                            )
1405                        }
1406                        ChainStep::Binary(op, lhs, rhs) => {
1407                            let l = resolve(lhs);
1408                            let r = resolve(rhs);
1409                            let dt = dtype_of(lhs, &region_inputs, &step_dtypes, &rw);
1410                            // Binary: NumPy-style broadcast of operands.
1411                            let lhs_shape = shape_of(lhs, &region_inputs, &step_ids, &rw);
1412                            let rhs_shape = shape_of(rhs, &region_inputs, &step_ids, &rw);
1413                            let bcast = rlx_ir::shape::broadcast(&lhs_shape, &rhs_shape)
1414                                .unwrap_or_else(|e| {
1415                                    panic!(
1416                                        "unfuse_elementwise_regions: cannot broadcast \
1417                                         {lhs_shape:?} ⊗ {rhs_shape:?} for Binary({op:?}): {e}"
1418                                    )
1419                                });
1420                            let dims: Vec<_> = bcast.dims().to_vec();
1421                            let shape = Shape::from_dims(&dims, dt);
1422                            (
1423                                rw.new_graph.add_node(Op::Binary(*op), vec![l, r], shape),
1424                                dt,
1425                            )
1426                        }
1427                        ChainStep::Compare(op, lhs, rhs) => {
1428                            let l = resolve(lhs);
1429                            let r = resolve(rhs);
1430                            let lhs_shape = shape_of(lhs, &region_inputs, &step_ids, &rw);
1431                            let rhs_shape = shape_of(rhs, &region_inputs, &step_ids, &rw);
1432                            let bcast = rlx_ir::shape::broadcast(&lhs_shape, &rhs_shape)
1433                                .unwrap_or_else(|e| {
1434                                    panic!(
1435                                        "unfuse_elementwise_regions: cannot broadcast \
1436                                         {lhs_shape:?} ⊗ {rhs_shape:?} for Compare({op:?}): {e}"
1437                                    )
1438                                });
1439                            let dims: Vec<_> = bcast.dims().to_vec();
1440                            let shape = Shape::from_dims(&dims, rlx_ir::DType::Bool);
1441                            (
1442                                rw.new_graph.add_node(Op::Compare(*op), vec![l, r], shape),
1443                                rlx_ir::DType::Bool,
1444                            )
1445                        }
1446                        ChainStep::Where(c, x, y) => {
1447                            let cn = resolve(c);
1448                            let xn = resolve(x);
1449                            let yn = resolve(y);
1450                            let dt = dtype_of(x, &region_inputs, &step_dtypes, &rw);
1451                            // Where: broadcast across (cond, then, else).
1452                            let c_shape = shape_of(c, &region_inputs, &step_ids, &rw);
1453                            let x_shape = shape_of(x, &region_inputs, &step_ids, &rw);
1454                            let y_shape = shape_of(y, &region_inputs, &step_ids, &rw);
1455                            let bcast_xy = rlx_ir::shape::broadcast(&x_shape, &y_shape)
1456                                .unwrap_or_else(|e| {
1457                                    panic!(
1458                                        "unfuse_elementwise_regions: cannot broadcast \
1459                                         then/else {x_shape:?} ⊗ {y_shape:?} for Where: {e}"
1460                                    )
1461                                });
1462                            let bcast = rlx_ir::shape::broadcast(&c_shape, &bcast_xy)
1463                                .unwrap_or_else(|e| {
1464                                    panic!(
1465                                        "unfuse_elementwise_regions: cannot broadcast cond \
1466                                         {c_shape:?} ⊗ {bcast_xy:?} for Where: {e}"
1467                                    )
1468                                });
1469                            let dims: Vec<_> = bcast.dims().to_vec();
1470                            let shape = Shape::from_dims(&dims, dt);
1471                            (
1472                                rw.new_graph.add_node(Op::Where, vec![cn, xn, yn], shape),
1473                                dt,
1474                            )
1475                        }
1476                    };
1477                    step_ids.push(new_id);
1478                    step_dtypes.push(dt);
1479                }
1480                let _ = region_dtype;
1481                let _ = region_dims;
1482                // The region's "output" (= last step) replaces the original
1483                // ElementwiseRegion node id.
1484                let last = *step_ids.last().expect("chain non-empty per pass invariant");
1485                rw.replace(node.id, last);
1486                continue;
1487            }
1488            rw.copy_node(node);
1489        }
1490        rw.finish(&graph.outputs)
1491    }
1492}
1493
1494/// Unfuse only `ElementwiseRegion` nodes that exceed [`crate::limits::FusionLimits`].
1495///
1496/// Run after [`MarkElementwiseRegions`] when marking may still produce
1497/// oversized chains (e.g. limits tightened per backend).
1498pub fn clip_elementwise_regions(graph: Graph, limits: crate::limits::FusionLimits) -> Graph {
1499    let oversize = |n: &rlx_ir::Node| -> bool {
1500        matches!(
1501            &n.op,
1502            Op::ElementwiseRegion {
1503                chain,
1504                num_inputs,
1505                ..
1506            } if *num_inputs > limits.max_elementwise_inputs
1507                || chain.len() as u32 > limits.max_elementwise_steps
1508        )
1509    };
1510    if !graph.nodes().iter().any(oversize) {
1511        return graph;
1512    }
1513
1514    let mut rw = Rewriter::new(&graph.name);
1515    for node in graph.nodes() {
1516        if !oversize(node) {
1517            rw.copy_node(node);
1518            continue;
1519        }
1520
1521        let Op::ElementwiseRegion {
1522            chain,
1523            num_inputs: _,
1524            scalar_input_mask: _,
1525            input_modulus: _,
1526        } = &node.op
1527        else {
1528            unreachable!();
1529        };
1530
1531        let region_inputs: Vec<NodeId> = node.inputs.iter().map(|id| rw.map(*id)).collect();
1532        let mut step_ids: Vec<NodeId> = Vec::with_capacity(chain.len());
1533        let region_shape = node.shape.clone();
1534        let region_dims: Vec<_> = region_shape.dims().to_vec();
1535        let mut step_dtypes: Vec<rlx_ir::DType> = Vec::with_capacity(chain.len());
1536        let region_dtype = region_shape.dtype();
1537        let dtype_of = |op: &ChainOperand,
1538                        ins: &[NodeId],
1539                        step_dt: &[rlx_ir::DType],
1540                        rw: &Rewriter|
1541         -> rlx_ir::DType {
1542            match *op {
1543                ChainOperand::Input(i) => rw.new_graph.node(ins[i as usize]).shape.dtype(),
1544                ChainOperand::Step(i) => step_dt[i as usize],
1545            }
1546        };
1547        let shape_of =
1548            |op: &ChainOperand, ins: &[NodeId], step_ids: &[NodeId], rw: &Rewriter| -> Shape {
1549                match *op {
1550                    ChainOperand::Input(i) => rw.new_graph.node(ins[i as usize]).shape.clone(),
1551                    ChainOperand::Step(i) => rw.new_graph.node(step_ids[i as usize]).shape.clone(),
1552                }
1553            };
1554        for step in chain {
1555            let resolve = |op: &ChainOperand| -> NodeId {
1556                match *op {
1557                    ChainOperand::Input(i) => region_inputs[i as usize],
1558                    ChainOperand::Step(i) => step_ids[i as usize],
1559                }
1560            };
1561            let (new_id, dt) = match step {
1562                ChainStep::Activation(a, src) => {
1563                    let s = resolve(src);
1564                    let dt = dtype_of(src, &region_inputs, &step_dtypes, &rw);
1565                    let src_shape = shape_of(src, &region_inputs, &step_ids, &rw);
1566                    let dims: Vec<_> = src_shape.dims().to_vec();
1567                    let shape = Shape::from_dims(&dims, dt);
1568                    (
1569                        rw.new_graph.add_node(Op::Activation(*a), vec![s], shape),
1570                        dt,
1571                    )
1572                }
1573                ChainStep::Cast(to, src) => {
1574                    let s = resolve(src);
1575                    let src_shape = shape_of(src, &region_inputs, &step_ids, &rw);
1576                    let dims: Vec<_> = src_shape.dims().to_vec();
1577                    let shape = Shape::from_dims(&dims, *to);
1578                    (
1579                        rw.new_graph.add_node(Op::Cast { to: *to }, vec![s], shape),
1580                        *to,
1581                    )
1582                }
1583                ChainStep::Binary(op, lhs, rhs) => {
1584                    let l = resolve(lhs);
1585                    let r = resolve(rhs);
1586                    let dt = dtype_of(lhs, &region_inputs, &step_dtypes, &rw);
1587                    let l_shape = shape_of(lhs, &region_inputs, &step_ids, &rw);
1588                    let r_shape = shape_of(rhs, &region_inputs, &step_ids, &rw);
1589                    let bcast = l_shape
1590                        .broadcast_with(&r_shape)
1591                        .unwrap_or_else(|e| panic!("clip_elementwise_regions: {e}"));
1592                    let dims: Vec<_> = bcast.dims().to_vec();
1593                    let shape = Shape::from_dims(&dims, dt);
1594                    (
1595                        rw.new_graph.add_node(Op::Binary(*op), vec![l, r], shape),
1596                        dt,
1597                    )
1598                }
1599                ChainStep::Compare(op, lhs, rhs) => {
1600                    let l = resolve(lhs);
1601                    let r = resolve(rhs);
1602                    let l_shape = shape_of(lhs, &region_inputs, &step_ids, &rw);
1603                    let r_shape = shape_of(rhs, &region_inputs, &step_ids, &rw);
1604                    let bcast = l_shape
1605                        .broadcast_with(&r_shape)
1606                        .unwrap_or_else(|e| panic!("clip_elementwise_regions: {e}"));
1607                    let dims: Vec<_> = bcast.dims().to_vec();
1608                    let shape = Shape::from_dims(&dims, rlx_ir::DType::U8);
1609                    (
1610                        rw.new_graph.add_node(Op::Compare(*op), vec![l, r], shape),
1611                        rlx_ir::DType::U8,
1612                    )
1613                }
1614                ChainStep::Where(cond, x, y) => {
1615                    let cn = resolve(cond);
1616                    let xn = resolve(x);
1617                    let yn = resolve(y);
1618                    let dt = dtype_of(x, &region_inputs, &step_dtypes, &rw);
1619                    let x_shape = shape_of(x, &region_inputs, &step_ids, &rw);
1620                    let y_shape = shape_of(y, &region_inputs, &step_ids, &rw);
1621                    let c_shape = shape_of(cond, &region_inputs, &step_ids, &rw);
1622                    let bcast_xy = x_shape
1623                        .broadcast_with(&y_shape)
1624                        .unwrap_or_else(|e| panic!("clip_elementwise_regions: {e}"));
1625                    let bcast = c_shape.broadcast_with(&bcast_xy).unwrap_or_else(|e| {
1626                        panic!("clip_elementwise_regions: cannot broadcast cond {c_shape:?} ⊗ {bcast_xy:?} for Where: {e}")
1627                    });
1628                    let dims: Vec<_> = bcast.dims().to_vec();
1629                    let shape = Shape::from_dims(&dims, dt);
1630                    (
1631                        rw.new_graph.add_node(Op::Where, vec![cn, xn, yn], shape),
1632                        dt,
1633                    )
1634                }
1635            };
1636            step_ids.push(new_id);
1637            step_dtypes.push(dt);
1638        }
1639        let _ = (region_dtype, region_dims);
1640        let last = *step_ids
1641            .last()
1642            .expect("oversize region has non-empty chain");
1643        rw.replace(node.id, last);
1644    }
1645    rw.finish(&graph.outputs)
1646}
1647
1648#[cfg(test)]
1649mod tests {
1650    use super::*;
1651    use crate::limits::FusionLimits;
1652    use crate::pass::run_passes;
1653
1654    fn f32_shape(dims: &[usize]) -> Shape {
1655        Shape::new(dims, DType::F32)
1656    }
1657
1658    #[test]
1659    fn fuse_matmul_bias_gelu() {
1660        let mut g = Graph::new("test");
1661        let x = g.input("x", f32_shape(&[4, 15, 384]));
1662        let w = g.param("w", f32_shape(&[384, 1536]));
1663        let b = g.param("b", f32_shape(&[1536]));
1664        let mm = g.matmul(x, w, f32_shape(&[4, 15, 1536]));
1665        let add = g.binary(BinaryOp::Add, mm, b, f32_shape(&[4, 15, 1536]));
1666        let out = g.activation(Activation::Gelu, add, f32_shape(&[4, 15, 1536]));
1667        g.set_outputs(vec![out]);
1668
1669        assert_eq!(g.len(), 6); // input, w, b, mm, add, gelu
1670
1671        let fused = FuseMatMulBiasAct.run(g);
1672        println!("{fused}");
1673
1674        // Should be: input, w, b, fused_mm_bias_gelu
1675        assert_eq!(fused.len(), 4);
1676        let out_node = fused.node(fused.outputs[0]);
1677        assert!(matches!(
1678            out_node.op,
1679            Op::FusedMatMulBiasAct {
1680                activation: Some(Activation::Gelu)
1681            }
1682        ));
1683    }
1684
1685    #[test]
1686    fn fuse_matmul_bias_no_act() {
1687        let mut g = Graph::new("test");
1688        let x = g.input("x", f32_shape(&[4, 15, 384]));
1689        let w = g.param("w", f32_shape(&[384, 384]));
1690        let b = g.param("b", f32_shape(&[384]));
1691        let mm = g.matmul(x, w, f32_shape(&[4, 15, 384]));
1692        let add = g.binary(BinaryOp::Add, mm, b, f32_shape(&[4, 15, 384]));
1693        g.set_outputs(vec![add]);
1694
1695        let fused = FuseMatMulBiasAct.run(g);
1696        assert_eq!(fused.len(), 4);
1697        let out_node = fused.node(fused.outputs[0]);
1698        assert!(matches!(
1699            out_node.op,
1700            Op::FusedMatMulBiasAct { activation: None }
1701        ));
1702    }
1703
1704    #[test]
1705    fn fuse_matmul_bias_skips_unsupported_activation_epilogue() {
1706        let mut g = Graph::new("test");
1707        let x = g.input("x", f32_shape(&[8, 1024]));
1708        let w = g.param("w", f32_shape(&[1024, 16]));
1709        let b = g.param("b", f32_shape(&[16]));
1710        let mm = g.matmul(x, w, f32_shape(&[8, 16]));
1711        let add = g.binary(BinaryOp::Add, mm, b, f32_shape(&[8, 16]));
1712        let exp = g.activation(Activation::Exp, add, f32_shape(&[8, 16]));
1713        g.set_outputs(vec![exp]);
1714
1715        let fused = FuseMatMulBiasAct.run(g);
1716        // mm + bias fuse; Exp stays separate (qwen35 softplus pattern).
1717        assert_eq!(fused.len(), 5);
1718        let out_node = fused.node(fused.outputs[0]);
1719        assert!(matches!(out_node.op, Op::Activation(Activation::Exp)));
1720        let add_node = fused.node(out_node.inputs[0]);
1721        assert!(matches!(
1722            add_node.op,
1723            Op::FusedMatMulBiasAct { activation: None }
1724        ));
1725    }
1726
1727    #[test]
1728    fn fuse_matmul_bias_act_with_late_bias_param() {
1729        use rlx_ir::infer::GraphExt;
1730
1731        let mut g = Graph::new("late_bias");
1732        let x = g.input("x", f32_shape(&[8, 16]));
1733        let w = g.param("w", f32_shape(&[16, 32]));
1734        let out = {
1735            let mm = g.mm(x, w);
1736            let b = g.param("b", f32_shape(&[32]));
1737            let biased = g.add(mm, b);
1738            g.gelu(biased)
1739        };
1740        g.set_outputs(vec![out]);
1741
1742        let fused = FuseMatMulBiasAct.run(g);
1743        assert!(
1744            fused
1745                .nodes()
1746                .iter()
1747                .any(|n| matches!(n.op, Op::FusedMatMulBiasAct { .. })),
1748            "bias param declared after matmul must still fuse:\n{fused}"
1749        );
1750    }
1751
1752    #[test]
1753    fn swiglu_ffn_builder_fuses_end_to_end() {
1754        let mut g = Graph::new("swiglu_block");
1755        let x = g.input("x", f32_shape(&[4, 768]));
1756        let up_w = g.param("up", f32_shape(&[768, 2048]));
1757        let gate_w = g.param("gate", f32_shape(&[768, 2048]));
1758        let down_w = g.param("down", f32_shape(&[2048, 768]));
1759        let out = g.swiglu_ffn(x, up_w, gate_w, down_w);
1760        g.set_outputs(vec![out]);
1761
1762        let g = FuseSharedInputMatMul.run(g);
1763        let g = FuseSwiGLU.run(g);
1764        assert!(
1765            g.nodes()
1766                .iter()
1767                .any(|n| matches!(n.op, Op::FusedSwiGLU { .. })),
1768            "swiglu_ffn builder should match FuseSwiGLU:\n{g}"
1769        );
1770    }
1771
1772    #[test]
1773    fn fuse_swiglu_dual_matmul_gate_first() {
1774        use rlx_ir::infer::GraphExt;
1775
1776        let mut g = Graph::new("qwen3_ffn");
1777        let x = g.input("x", f32_shape(&[4, 768]));
1778        let gate_w = g.param("gate", f32_shape(&[768, 2048]));
1779        let up_w = g.param("up", f32_shape(&[768, 2048]));
1780        let gate = g.mm(x, gate_w);
1781        let up = g.mm(x, up_w);
1782        let gate_act = g.silu(gate);
1783        let out = g.mul(gate_act, up);
1784        g.set_outputs(vec![out]);
1785
1786        let fused = FuseSwiGLUDualMatmul.run(g);
1787        assert!(
1788            fused
1789                .nodes()
1790                .iter()
1791                .any(|n| matches!(n.op, Op::FusedSwiGLU { .. })),
1792            "gate-first dual matmul should fuse:\n{fused}"
1793        );
1794        assert!(
1795            fused.len() <= 6,
1796            "dual fusion should collapse to x + weights + concat + mm + fused_swiglu, got {} nodes",
1797            fused.len()
1798        );
1799    }
1800
1801    #[test]
1802    fn fuse_shared_input_matmul_three_way_qkv() {
1803        let mut g = Graph::new("qkv");
1804        let x = g.input("x", f32_shape(&[8, 512]));
1805        let wq = g.param("wq", f32_shape(&[512, 128]));
1806        let wk = g.param("wk", f32_shape(&[512, 128]));
1807        let wv = g.param("wv", f32_shape(&[512, 128]));
1808        let q = g.matmul(x, wq, f32_shape(&[8, 128]));
1809        let k = g.matmul(x, wk, f32_shape(&[8, 128]));
1810        let v = g.matmul(x, wv, f32_shape(&[8, 128]));
1811        g.set_outputs(vec![q, k, v]);
1812
1813        let fused = FuseSharedInputMatMul.run(g);
1814        assert_eq!(
1815            fused.len(),
1816            9,
1817            "x + 3 weights + concat + mm + 3 narrows = 9"
1818        );
1819        for &out in &fused.outputs {
1820            assert!(matches!(fused.node(out).op, Op::Narrow { .. }));
1821        }
1822    }
1823
1824    #[test]
1825    fn fuse_residual_layer_norm() {
1826        let mut g = Graph::new("test");
1827        let x = g.input("x", f32_shape(&[4, 15, 384]));
1828        let residual = g.input("residual", f32_shape(&[4, 15, 384]));
1829        let gamma = g.param("gamma", f32_shape(&[384]));
1830        let beta = g.param("beta", f32_shape(&[384]));
1831        let add = g.binary(BinaryOp::Add, x, residual, f32_shape(&[4, 15, 384]));
1832        let ln = g.layer_norm(add, gamma, beta, -1, 1e-12, f32_shape(&[4, 15, 384]));
1833        g.set_outputs(vec![ln]);
1834
1835        assert_eq!(g.len(), 6); // x, residual, gamma, beta, add, ln
1836
1837        let fused = FuseResidualLN.run(g);
1838        println!("{fused}");
1839
1840        // Should be: x, residual, gamma, beta, fused_residual_ln
1841        assert_eq!(fused.len(), 5);
1842        let out_node = fused.node(fused.outputs[0]);
1843        assert!(matches!(
1844            out_node.op,
1845            Op::FusedResidualLN {
1846                has_bias: false,
1847                ..
1848            }
1849        ));
1850    }
1851
1852    #[test]
1853    fn fuse_residual_rms_norm() {
1854        let mut g = Graph::new("test");
1855        let x = g.input("x", f32_shape(&[4, 15, 384]));
1856        let residual = g.input("residual", f32_shape(&[4, 15, 384]));
1857        let gamma = g.param("gamma", f32_shape(&[384]));
1858        let beta = g.param("beta", f32_shape(&[384]));
1859        let add = g.binary(BinaryOp::Add, x, residual, f32_shape(&[4, 15, 384]));
1860        let rn = g.add_node(
1861            Op::RmsNorm {
1862                axis: -1,
1863                eps: 1e-6,
1864            },
1865            vec![add, gamma, beta],
1866            f32_shape(&[4, 15, 384]),
1867        );
1868        g.set_outputs(vec![rn]);
1869
1870        assert_eq!(g.len(), 6);
1871
1872        let fused = FuseResidualRmsNorm.run(g);
1873        assert_eq!(fused.len(), 5);
1874        let out_node = fused.node(fused.outputs[0]);
1875        assert!(matches!(
1876            out_node.op,
1877            Op::FusedResidualRmsNorm {
1878                has_bias: false,
1879                ..
1880            }
1881        ));
1882    }
1883
1884    #[test]
1885    fn fuse_rms_norm_reshape() {
1886        let mut g = Graph::new("test");
1887        let x = g.input("x", f32_shape(&[1, 8, 512]));
1888        let gamma = g.param("gamma", f32_shape(&[512]));
1889        let beta = g.param("beta", f32_shape(&[512]));
1890        let rn = g.add_node(
1891            Op::RmsNorm {
1892                axis: -1,
1893                eps: 1e-6,
1894            },
1895            vec![x, gamma, beta],
1896            f32_shape(&[1, 8, 512]),
1897        );
1898        let flat = g.add_node(
1899            Op::Reshape {
1900                new_shape: vec![8, 512],
1901            },
1902            vec![rn],
1903            f32_shape(&[8, 512]),
1904        );
1905        let w = g.param("w", f32_shape(&[512, 128]));
1906        let mm = g.matmul(flat, w, f32_shape(&[8, 128]));
1907        g.set_outputs(vec![mm]);
1908
1909        let fused = FuseRmsNormReshape.run(g);
1910        // x, gamma, beta, rms_norm(2d), w, matmul — no separate reshape
1911        assert_eq!(fused.len(), 6);
1912        let rn_node = fused.node(fused.node(fused.outputs[0]).inputs[0]);
1913        assert!(matches!(rn_node.op, Op::RmsNorm { .. }));
1914        assert_eq!(rn_node.shape.dim(0).unwrap_static(), 8);
1915        assert_eq!(rn_node.shape.dim(1).unwrap_static(), 512);
1916    }
1917
1918    #[test]
1919    fn fuse_shared_input_matmul() {
1920        let mut g = Graph::new("swiglu");
1921        let x = g.input("x", f32_shape(&[60, 768]));
1922        let w1 = g.param("fc11", f32_shape(&[768, 2048]));
1923        let w2 = g.param("fc12", f32_shape(&[768, 2048]));
1924        let mm1 = g.matmul(x, w1, f32_shape(&[60, 2048]));
1925        let mm2 = g.matmul(x, w2, f32_shape(&[60, 2048]));
1926        g.set_outputs(vec![mm1, mm2]);
1927
1928        assert_eq!(g.len(), 5); // x, w1, w2, mm1, mm2
1929
1930        let fused = FuseSharedInputMatMul.run(g);
1931        println!("{fused}");
1932
1933        // Should have: x, w1, w2, concat(w1,w2), combined_mm, narrow1, narrow2
1934        assert!(fused.len() <= 7);
1935        // Both outputs should be Narrow ops
1936        for &out in &fused.outputs {
1937            assert!(matches!(fused.node(out).op, Op::Narrow { .. }));
1938        }
1939    }
1940
1941    /// Regression: `FuseSharedInputMatMul` used to panic when `w2` is
1942    /// declared after `mm1`. `ensure_mapped` now copies late operands.
1943    #[test]
1944    fn fuse_shared_input_matmul_with_late_w2_param() {
1945        let mut g = Graph::new("late_w2");
1946        let x = g.input("x", f32_shape(&[8, 16]));
1947        let w1 = g.param("w1", f32_shape(&[16, 8]));
1948        let mm1 = g.matmul(x, w1, f32_shape(&[8, 8]));
1949        let w2 = g.param("w2", f32_shape(&[16, 8]));
1950        let mm2 = g.matmul(x, w2, f32_shape(&[8, 8]));
1951        g.set_outputs(vec![mm1, mm2]);
1952
1953        let fused = FuseSharedInputMatMul.run(g);
1954        for &out in &fused.outputs {
1955            assert!(
1956                matches!(fused.node(out).op, Op::Narrow { .. }),
1957                "late w2 should still fuse via ensure_mapped, got {:?}",
1958                fused.node(out).op
1959            );
1960        }
1961    }
1962
1963    /// Regression: qwen35moe FFN declares router / shared-expert matmuls on the
1964    /// same flattened hidden state with weights scattered through the block.
1965    #[test]
1966    fn fuse_shared_input_matmul_moe_ffn_pattern() {
1967        let mut g = Graph::new("moe_ffn");
1968        let rows = 4usize;
1969        let n_embd = 16usize;
1970        let n_expert = 4usize;
1971        let n_ff = 16usize;
1972
1973        let h_in = g.input("h", f32_shape(&[1, rows, n_embd]));
1974        let h_2d = g.reshape_(h_in, vec![rows as i64, n_embd as i64]);
1975
1976        let router_w = g.param("router_w", f32_shape(&[n_embd, n_expert]));
1977        let router_logits = g.matmul(h_2d, router_w, f32_shape(&[rows, n_expert]));
1978
1979        // MoE body omitted — only the shared-expert tail matters for fusion order.
1980        let shared_router_w = g.param("shared_router_w", f32_shape(&[n_embd, 1]));
1981        let shared_logits = g.matmul(h_2d, shared_router_w, f32_shape(&[rows, 1]));
1982        let shared_gate = g.activation(Activation::Sigmoid, shared_logits, f32_shape(&[rows, 1]));
1983
1984        let s_gate_w = g.param("s_gate_w", f32_shape(&[n_embd, n_ff]));
1985        let s_up_w = g.param("s_up_w", f32_shape(&[n_embd, n_ff]));
1986        let s_gate = g.matmul(h_2d, s_gate_w, f32_shape(&[rows, n_ff]));
1987        let s_up = g.matmul(h_2d, s_up_w, f32_shape(&[rows, n_ff]));
1988        let s_gate_silu = g.silu(s_gate);
1989        let s_swiglu = g.mul(s_gate_silu, s_up);
1990
1991        g.set_outputs(vec![router_logits, shared_gate, s_swiglu]);
1992
1993        let fused = FuseSharedInputMatMul.run(g);
1994        let narrow_count = fused
1995            .nodes()
1996            .iter()
1997            .filter(|n| matches!(n.op, Op::Narrow { .. }))
1998            .count();
1999        assert!(
2000            narrow_count >= 4,
2001            "expected four narrow slices from fused h_2d matmuls, got {narrow_count}"
2002        );
2003    }
2004
2005    /// Full pipeline: build a BERT FFN subgraph and run all fusion passes.
2006    #[test]
2007    fn full_bert_ffn_fusion() {
2008        let mut g = Graph::new("bert_ffn");
2009        let f = DType::F32;
2010
2011        let x = g.input("hidden", Shape::new(&[4, 15, 384], f));
2012        let residual = g.input("residual", Shape::new(&[4, 15, 384], f));
2013
2014        // Output projection result + residual + LN
2015        let out_w = g.param("out.w", Shape::new(&[384, 384], f));
2016        let out_b = g.param("out.b", Shape::new(&[384], f));
2017        let out_mm = g.matmul(x, out_w, Shape::new(&[4, 15, 384], f));
2018        let out_add = g.binary(BinaryOp::Add, out_mm, out_b, Shape::new(&[4, 15, 384], f));
2019        let res_add = g.binary(
2020            BinaryOp::Add,
2021            out_add,
2022            residual,
2023            Shape::new(&[4, 15, 384], f),
2024        );
2025        let gamma = g.param("ln.g", Shape::new(&[384], f));
2026        let beta = g.param("ln.b", Shape::new(&[384], f));
2027        let ln = g.layer_norm(
2028            res_add,
2029            gamma,
2030            beta,
2031            -1,
2032            1e-12,
2033            Shape::new(&[4, 15, 384], f),
2034        );
2035
2036        // FFN intermediate: matmul + bias + gelu
2037        let int_w = g.param("int.w", Shape::new(&[384, 1536], f));
2038        let int_b = g.param("int.b", Shape::new(&[1536], f));
2039        let int_mm = g.matmul(ln, int_w, Shape::new(&[4, 15, 1536], f));
2040        let int_add = g.binary(BinaryOp::Add, int_mm, int_b, Shape::new(&[4, 15, 1536], f));
2041        let gelu = g.activation(Activation::Gelu, int_add, Shape::new(&[4, 15, 1536], f));
2042
2043        // FFN output: matmul + bias
2044        let out2_w = g.param("out2.w", Shape::new(&[1536, 384], f));
2045        let out2_b = g.param("out2.b", Shape::new(&[384], f));
2046        let out2_mm = g.matmul(gelu, out2_w, Shape::new(&[4, 15, 384], f));
2047        let out2_add = g.binary(BinaryOp::Add, out2_mm, out2_b, Shape::new(&[4, 15, 384], f));
2048
2049        g.set_outputs(vec![out2_add]);
2050
2051        let before = g.len();
2052        println!("=== BEFORE fusion ({before} nodes) ===\n{g}");
2053
2054        // Run all passes
2055        let passes: Vec<&dyn Pass> = vec![&FuseMatMulBiasAct, &FuseResidualLN];
2056        let optimized = run_passes(g, &passes, false);
2057        let after = optimized.len();
2058        println!("=== AFTER fusion ({after} nodes) ===\n{optimized}");
2059
2060        // Should have eliminated:
2061        // - 2 Add + 1 Gelu from matmul_bias_gelu fusion (×2 matmuls)
2062        // - 1 Add from residual_ln fusion
2063        assert!(
2064            after < before,
2065            "fusion should reduce node count: {before} → {after}"
2066        );
2067
2068        // Check that fused ops exist
2069        let ops: Vec<String> = optimized
2070            .nodes()
2071            .iter()
2072            .map(|n| format!("{}", n.op))
2073            .collect();
2074        let has_fused_mm = ops.iter().any(|s| s.contains("fused_mm_bias"));
2075        assert!(has_fused_mm, "should have fused_mm_bias_act: {ops:?}");
2076    }
2077
2078    /// FuseSwiGLU fires on the canonical Nomic-style pattern produced by
2079    /// `FuseSharedInputMatMul` (concat'd matmul → narrow×2 → silu → mul).
2080    #[test]
2081    fn fuse_swiglu_canonical() {
2082        let mut g = Graph::new("nomic_ffn");
2083        let f = DType::F32;
2084        // After FuseSharedInputMatMul: cat = mm(x, concat(fc11, fc12)) → [60, 4096]
2085        let cat = g.input("cat", Shape::new(&[60, 4096], f));
2086        let up = g.add_node(
2087            Op::Narrow {
2088                axis: 1,
2089                start: 0,
2090                len: 2048,
2091            },
2092            vec![cat],
2093            Shape::new(&[60, 2048], f),
2094        );
2095        let gate = g.add_node(
2096            Op::Narrow {
2097                axis: 1,
2098                start: 2048,
2099                len: 2048,
2100            },
2101            vec![cat],
2102            Shape::new(&[60, 2048], f),
2103        );
2104        let silu = g.activation(Activation::Silu, gate, Shape::new(&[60, 2048], f));
2105        let out = g.binary(BinaryOp::Mul, up, silu, Shape::new(&[60, 2048], f));
2106        g.set_outputs(vec![out]);
2107
2108        let before = g.len();
2109        let fused = FuseSwiGLU.run(g);
2110        let after = fused.len();
2111        // Removed: up, gate, silu, mul → replaced by FusedSwiGLU.
2112        // Net: -3 nodes (4 removed, 1 added).
2113        assert_eq!(
2114            after,
2115            before - 3,
2116            "should remove narrows+silu+mul, add FusedSwiGLU"
2117        );
2118        let out_node = fused.node(fused.outputs[0]);
2119        assert!(
2120            matches!(
2121                out_node.op,
2122                Op::FusedSwiGLU {
2123                    cast_to: None,
2124                    gate_first: false
2125                }
2126            ),
2127            "output should be FusedSwiGLU, got {}",
2128            out_node.op
2129        );
2130        // FusedSwiGLU's input is the cat tensor.
2131        let in_id = out_node.inputs[0];
2132        assert!(matches!(fused.node(in_id).op, Op::Input { .. }));
2133    }
2134
2135    /// FuseSwiGLU does NOT fire when narrows are shared with another consumer
2136    /// (would corrupt the second consumer's view of the data).
2137    #[test]
2138    fn fuse_swiglu_skips_when_narrow_has_extra_user() {
2139        let mut g = Graph::new("contended");
2140        let f = DType::F32;
2141        let cat = g.input("cat", Shape::new(&[60, 4096], f));
2142        let up = g.add_node(
2143            Op::Narrow {
2144                axis: 1,
2145                start: 0,
2146                len: 2048,
2147            },
2148            vec![cat],
2149            Shape::new(&[60, 2048], f),
2150        );
2151        let gate = g.add_node(
2152            Op::Narrow {
2153                axis: 1,
2154                start: 2048,
2155                len: 2048,
2156            },
2157            vec![cat],
2158            Shape::new(&[60, 2048], f),
2159        );
2160        let silu = g.activation(Activation::Silu, gate, Shape::new(&[60, 2048], f));
2161        let out = g.binary(BinaryOp::Mul, up, silu, Shape::new(&[60, 2048], f));
2162        // Extra user of `up` — this should block fusion.
2163        let extra = g.activation(Activation::Relu, up, Shape::new(&[60, 2048], f));
2164        g.set_outputs(vec![out, extra]);
2165
2166        let before = g.len();
2167        let fused = FuseSwiGLU.run(g);
2168        // Pass should be a no-op when fusion is unsafe.
2169        assert_eq!(fused.len(), before);
2170        // No FusedSwiGLU node anywhere.
2171        let any_fused = fused
2172            .nodes()
2173            .iter()
2174            .any(|n| matches!(n.op, Op::FusedSwiGLU { .. }));
2175        assert!(!any_fused, "should not fuse when narrow has extra user");
2176    }
2177
2178    // ── MarkElementwiseRegions (PLAN L2) ────────────────────────────
2179
2180    #[test]
2181    fn region_collapses_add_mul_relu_chain() {
2182        // Build: out = relu(add(a, b) * c). All same shape, single consumer
2183        // chain. Should fuse into one ElementwiseRegion.
2184        let f = DType::F32;
2185        let mut g = Graph::new("ew");
2186        let a = g.input("a", Shape::new(&[8], f));
2187        let b = g.input("b", Shape::new(&[8], f));
2188        let c = g.input("c", Shape::new(&[8], f));
2189        let s = Shape::new(&[8], f);
2190        let add = g.binary(BinaryOp::Add, a, b, s.clone());
2191        let mul = g.binary(BinaryOp::Mul, add, c, s.clone());
2192        let relu = g.activation(Activation::Relu, mul, s.clone());
2193        g.set_outputs(vec![relu]);
2194
2195        let before = g.len();
2196        let fused = MarkElementwiseRegions.run(g);
2197
2198        // Three element-wise ops collapsed into one region node.
2199        let regions: Vec<_> = fused
2200            .nodes()
2201            .iter()
2202            .filter(|n| matches!(n.op, Op::ElementwiseRegion { .. }))
2203            .collect();
2204        assert_eq!(regions.len(), 1, "expected one ElementwiseRegion");
2205        let region = regions[0];
2206        assert_eq!(
2207            region.inputs.len(),
2208            3,
2209            "region has 3 external inputs (a, b, c)"
2210        );
2211        if let Op::ElementwiseRegion {
2212            chain, num_inputs, ..
2213        } = &region.op
2214        {
2215            assert_eq!(*num_inputs, 3);
2216            assert_eq!(chain.len(), 3);
2217            // Step 0: Add(Input(0), Input(1))
2218            match &chain[0] {
2219                ChainStep::Binary(
2220                    BinaryOp::Add,
2221                    ChainOperand::Input(0),
2222                    ChainOperand::Input(1),
2223                ) => {}
2224                other => panic!("step 0 unexpected: {other:?}"),
2225            }
2226            // Step 1: Mul(Step(0), Input(2))
2227            match &chain[1] {
2228                ChainStep::Binary(BinaryOp::Mul, ChainOperand::Step(0), ChainOperand::Input(2)) => {
2229                }
2230                other => panic!("step 1 unexpected: {other:?}"),
2231            }
2232            // Step 2: Activation(Relu, Step(1))
2233            match &chain[2] {
2234                ChainStep::Activation(Activation::Relu, ChainOperand::Step(1)) => {}
2235                other => panic!("step 2 unexpected: {other:?}"),
2236            }
2237        } else {
2238            unreachable!();
2239        }
2240        // Original chain (3 ops) replaced by 1 region; net node count is
2241        // (inputs 3) + (region 1) = 4 (vs 3 + 3 = 6 before).
2242        assert!(fused.len() < before);
2243    }
2244
2245    #[test]
2246    fn region_does_not_fuse_when_intermediate_has_multiple_consumers() {
2247        // out1 = add(a, b); out2 = relu(out1). out1 also fed to out_extra.
2248        // Multi-consumer on out1 forbids fusion.
2249        let f = DType::F32;
2250        let mut g = Graph::new("ew");
2251        let a = g.input("a", Shape::new(&[4], f));
2252        let b = g.input("b", Shape::new(&[4], f));
2253        let s = Shape::new(&[4], f);
2254        let add = g.binary(BinaryOp::Add, a, b, s.clone());
2255        let relu = g.activation(Activation::Relu, add, s.clone());
2256        let extra = g.activation(Activation::Sigmoid, add, s.clone());
2257        g.set_outputs(vec![relu, extra]);
2258
2259        let before = g.len();
2260        let fused = MarkElementwiseRegions.run(g);
2261        // No region: add has two consumers (relu and extra), so the chain
2262        // can't extend through it. Each downstream activation is alone in
2263        // its region (size 1, doesn't fuse).
2264        let regions: Vec<_> = fused
2265            .nodes()
2266            .iter()
2267            .filter(|n| matches!(n.op, Op::ElementwiseRegion { .. }))
2268            .collect();
2269        assert_eq!(regions.len(), 0);
2270        assert_eq!(fused.len(), before);
2271    }
2272
2273    #[test]
2274    fn region_skips_chains_of_length_one() {
2275        // Single relu — no fusion (size 1 = degenerate).
2276        let f = DType::F32;
2277        let mut g = Graph::new("ew");
2278        let a = g.input("a", Shape::new(&[4], f));
2279        let r = g.activation(Activation::Relu, a, Shape::new(&[4], f));
2280        g.set_outputs(vec![r]);
2281
2282        let fused = MarkElementwiseRegions.run(g);
2283        let any_region = fused
2284            .nodes()
2285            .iter()
2286            .any(|n| matches!(n.op, Op::ElementwiseRegion { .. }));
2287        assert!(!any_region);
2288    }
2289
2290    #[test]
2291    fn unfuse_decomposes_region_back_to_atomic_ops() {
2292        // Build the same chain, fuse it, then unfuse — expect the
2293        // original atomic ops back (Add, Mul, Relu).
2294        let f = DType::F32;
2295        let mut g = Graph::new("ew_unfuse");
2296        let a = g.input("a", Shape::new(&[8], f));
2297        let b = g.input("b", Shape::new(&[8], f));
2298        let c = g.input("c", Shape::new(&[8], f));
2299        let s = Shape::new(&[8], f);
2300        let add = g.binary(BinaryOp::Add, a, b, s.clone());
2301        let mul = g.binary(BinaryOp::Mul, add, c, s.clone());
2302        let relu = g.activation(Activation::Relu, mul, s);
2303        g.set_outputs(vec![relu]);
2304
2305        let fused = MarkElementwiseRegions.run(g);
2306        // Sanity: fusion happened.
2307        assert!(
2308            fused
2309                .nodes()
2310                .iter()
2311                .any(|n| matches!(n.op, Op::ElementwiseRegion { .. }))
2312        );
2313
2314        let unfused = UnfuseElementwiseRegions.run(fused);
2315        // No region nodes left.
2316        assert!(
2317            !unfused
2318                .nodes()
2319                .iter()
2320                .any(|n| matches!(n.op, Op::ElementwiseRegion { .. }))
2321        );
2322        // Original atomic ops are back: Add, Mul, Relu.
2323        let bin_count = unfused
2324            .nodes()
2325            .iter()
2326            .filter(|n| matches!(n.op, Op::Binary(_)))
2327            .count();
2328        let act_count = unfused
2329            .nodes()
2330            .iter()
2331            .filter(|n| matches!(n.op, Op::Activation(_)))
2332            .count();
2333        assert_eq!(bin_count, 2, "Add + Mul restored");
2334        assert_eq!(act_count, 1, "Relu restored");
2335    }
2336
2337    #[test]
2338    fn clip_unfuses_region_over_step_cap() {
2339        use rlx_ir::op::{Activation, ChainOperand, ChainStep};
2340
2341        let mut g = Graph::new("clip");
2342        let x = g.input("x", f32_shape(&[4]));
2343        let mut chain: Vec<ChainStep> = Vec::new();
2344        let mut prev = ChainOperand::Input(0);
2345        for _ in 0..40 {
2346            chain.push(ChainStep::Activation(Activation::Relu, prev));
2347            prev = ChainOperand::Step(chain.len() as u32 - 1);
2348        }
2349        let y = g.add_node(
2350            Op::ElementwiseRegion {
2351                chain,
2352                num_inputs: 1,
2353                scalar_input_mask: 0,
2354                input_modulus: [0; 16],
2355            },
2356            vec![x],
2357            f32_shape(&[4]),
2358        );
2359        g.set_outputs(vec![y]);
2360
2361        let clipped = clip_elementwise_regions(g, FusionLimits::GPU_NATIVE);
2362        assert!(
2363            !clipped
2364                .nodes()
2365                .iter()
2366                .any(|n| matches!(n.op, Op::ElementwiseRegion { .. })),
2367            "oversized region should be decomposed"
2368        );
2369        assert!(clipped.len() > 5);
2370    }
2371
2372    #[test]
2373    fn unfuse_is_noop_when_no_region_present() {
2374        let f = DType::F32;
2375        let mut g = Graph::new("noop");
2376        let a = g.input("a", Shape::new(&[4], f));
2377        let r = g.activation(Activation::Relu, a, Shape::new(&[4], f));
2378        g.set_outputs(vec![r]);
2379        let n_before = g.len();
2380        let result = UnfuseElementwiseRegions.run(g);
2381        // Pass returns unchanged graph (early return on no-region check).
2382        assert_eq!(result.len(), n_before);
2383    }
2384
2385    #[test]
2386    fn region_includes_where_step() {
2387        // Build: cmp = a > b; sel = where(cmp, a, b); out = sel + a
2388        // The compare → where → add chain is fully element-wise; the
2389        // Where step lands inside the region thanks to the L2-quality
2390        // extension that adds `Op::Where` to the chain-eligible set.
2391        let f = DType::F32;
2392        let mut g = Graph::new("region_where");
2393        let a = g.input("a", Shape::new(&[4], f));
2394        let b = g.input("b", Shape::new(&[4], f));
2395        let s = Shape::new(&[4], f);
2396        let cmp = g.add_node(Op::Compare(CmpOp::Gt), vec![a, b], s.clone());
2397        let sel = g.add_node(Op::Where, vec![cmp, a, b], s.clone());
2398        let add = g.binary(BinaryOp::Add, sel, a, s.clone());
2399        g.set_outputs(vec![add]);
2400
2401        let fused = MarkElementwiseRegions.run(g);
2402        let regions: Vec<_> = fused
2403            .nodes()
2404            .iter()
2405            .filter(|n| matches!(n.op, Op::ElementwiseRegion { .. }))
2406            .collect();
2407        assert_eq!(regions.len(), 1, "expected one ElementwiseRegion");
2408        if let Op::ElementwiseRegion { chain, .. } = &regions[0].op {
2409            // 3 steps: Compare a > b, Where, Add
2410            assert_eq!(chain.len(), 3);
2411            assert!(
2412                matches!(chain[1], ChainStep::Where(_, _, _)),
2413                "step 1 should be Where, got {:?}",
2414                chain[1]
2415            );
2416        } else {
2417            unreachable!();
2418        }
2419    }
2420
2421    #[test]
2422    fn unfuse_decomposes_where_step_back_to_op_where() {
2423        // Round-trip: build a region with a Where step, decompose it,
2424        // verify the resulting graph contains an Op::Where node.
2425        let f = DType::F32;
2426        let mut g = Graph::new("unfuse_where");
2427        let a = g.input("a", Shape::new(&[4], f));
2428        let b = g.input("b", Shape::new(&[4], f));
2429        let s = Shape::new(&[4], f);
2430        let cmp = g.add_node(Op::Compare(CmpOp::Gt), vec![a, b], s.clone());
2431        let sel = g.add_node(Op::Where, vec![cmp, a, b], s.clone());
2432        let add = g.binary(BinaryOp::Add, sel, a, s.clone());
2433        g.set_outputs(vec![add]);
2434        let fused = MarkElementwiseRegions.run(g);
2435        let unfused = UnfuseElementwiseRegions.run(fused);
2436        let where_count = unfused
2437            .nodes()
2438            .iter()
2439            .filter(|n| matches!(n.op, Op::Where))
2440            .count();
2441        assert_eq!(
2442            where_count, 1,
2443            "decomposer should re-emit one Op::Where for the chain step"
2444        );
2445    }
2446}