Skip to main content

rlx_fusion/
fusion_report.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Fusion diagnostics — what fused, what missed, and why.
17
18use rlx_ir::op::{Activation, BinaryOp, RegionPrologue};
19use rlx_ir::{Graph, NodeId, Op, node_label};
20use std::fmt;
21
22/// Why a recognizable fusion pattern was not collapsed.
23#[derive(Debug, Clone, PartialEq, Eq)]
24pub enum MissReason {
25    MultiConsumer,
26    NonAddBiasConsumer,
27    BiasRankTooHigh { rank: usize },
28    UnsupportedEpilogueActivation(Activation),
29    SharedMatmulCount { count: usize },
30    SwigluGateBeforeUp,
31    SwigluNotSharedInput,
32    NotFused,
33}
34
35/// A single fusion opportunity that remains in the graph.
36#[derive(Debug, Clone, PartialEq, Eq)]
37pub struct MissedFusion {
38    pub pattern: &'static str,
39    pub node: NodeId,
40    pub reason: MissReason,
41    /// HIR label / node name when available.
42    pub context: Option<String>,
43    /// Actionable fix hint.
44    pub hint: Option<String>,
45}
46
47/// Before/after fusion statistics and missed-pattern tally.
48#[derive(Debug, Clone, Default, PartialEq, Eq)]
49pub struct FusionReport {
50    pub nodes_before: usize,
51    pub nodes_after: usize,
52    pub matmul_before: usize,
53    pub attention: usize,
54    pub rope: usize,
55    pub narrow: usize,
56    pub matmul_after: usize,
57    pub silu: usize,
58    pub mul: usize,
59    pub fused_matmul_bias_act: usize,
60    pub fused_swiglu: usize,
61    pub fused_residual_ln: usize,
62    pub fused_residual_rms_norm: usize,
63    pub fused_attention_block: usize,
64    pub fused_transformer_layer: usize,
65    pub elementwise_region: usize,
66    pub transform_region: usize,
67    pub batch_elementwise_region: usize,
68    pub fk_prologue_region: usize,
69    pub missed: Vec<MissedFusion>,
70}
71
72impl FusionReport {
73    /// Compare an unfused graph with the post-pass result.
74    pub fn analyze(before: &Graph, after: &Graph) -> Self {
75        let before_stats = count_ops(before);
76        let after_stats = count_ops(after);
77        let missed = scan_misses(after);
78        Self {
79            nodes_before: before.len(),
80            nodes_after: after.len(),
81            matmul_before: before_stats.matmul,
82            attention: after_stats.attention,
83            rope: after_stats.rope,
84            narrow: after_stats.narrow,
85            matmul_after: after_stats.matmul,
86            silu: after_stats.silu,
87            mul: after_stats.mul,
88            fused_matmul_bias_act: after_stats.fused_matmul_bias_act,
89            fused_swiglu: after_stats.fused_swiglu,
90            fused_residual_ln: after_stats.fused_residual_ln,
91            fused_residual_rms_norm: after_stats.fused_residual_rms_norm,
92            fused_attention_block: after_stats.fused_attention_block,
93            fused_transformer_layer: after_stats.fused_transformer_layer,
94            elementwise_region: after_stats.elementwise_region,
95            transform_region: after_stats.transform_region,
96            batch_elementwise_region: after_stats.batch_elementwise_region,
97            fk_prologue_region: after_stats.fk_prologue_region,
98            missed,
99        }
100    }
101
102    /// Scan a graph (typically post-fusion) for patterns that should
103    /// have collapsed but did not.
104    pub fn scan(graph: &Graph) -> Self {
105        let stats = count_ops(graph);
106        let missed = scan_misses(graph);
107        Self {
108            nodes_before: graph.len(),
109            nodes_after: graph.len(),
110            matmul_before: stats.matmul,
111            matmul_after: stats.matmul,
112            attention: stats.attention,
113            rope: stats.rope,
114            narrow: stats.narrow,
115            silu: stats.silu,
116            mul: stats.mul,
117            fused_matmul_bias_act: stats.fused_matmul_bias_act,
118            fused_swiglu: stats.fused_swiglu,
119            fused_residual_ln: stats.fused_residual_ln,
120            fused_residual_rms_norm: stats.fused_residual_rms_norm,
121            fused_attention_block: stats.fused_attention_block,
122            fused_transformer_layer: stats.fused_transformer_layer,
123            elementwise_region: stats.elementwise_region,
124            transform_region: stats.transform_region,
125            batch_elementwise_region: stats.batch_elementwise_region,
126            fk_prologue_region: stats.fk_prologue_region,
127            missed,
128        }
129    }
130
131    pub fn missed_matmul_bias_act(&self) -> usize {
132        self.missed
133            .iter()
134            .filter(|m| m.pattern == "matmul_bias_act")
135            .count()
136    }
137
138    pub fn missed_swiglu(&self) -> usize {
139        self.missed.iter().filter(|m| m.pattern == "swiglu").count()
140    }
141
142    pub fn missed_shared_matmul(&self) -> usize {
143        self.missed
144            .iter()
145            .filter(|m| m.pattern == "shared_input_matmul")
146            .count()
147    }
148
149    /// One-line summary suitable for logs and CSV benches.
150    pub fn summary_line(&self) -> String {
151        format!(
152            "nodes={}→{} matmul={}→{} fused_mm_act={} fused_swiglu={} \
153             elementwise_region={} transform_region={} batch_region={} fk_prologue={} \
154             missed_mm_act={} missed_swiglu={} missed_shared_mm={}",
155            self.nodes_before,
156            self.nodes_after,
157            self.matmul_before,
158            self.matmul_after,
159            self.fused_matmul_bias_act,
160            self.fused_swiglu,
161            self.elementwise_region,
162            self.transform_region,
163            self.batch_elementwise_region,
164            self.fk_prologue_region,
165            self.missed_matmul_bias_act(),
166            self.missed_swiglu(),
167            self.missed_shared_matmul(),
168        )
169    }
170}
171
172impl fmt::Display for FusionReport {
173    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
174        writeln!(f, "fusion report:")?;
175        writeln!(f, "  {}", self.summary_line())?;
176        if !self.missed.is_empty() {
177            writeln!(f, "  missed patterns:")?;
178            for m in &self.missed {
179                write!(f, "    {} @ {}", m.pattern, m.node)?;
180                if let Some(ref c) = m.context {
181                    write!(f, " ({c})")?;
182                }
183                write!(f, " — {:?}", m.reason)?;
184                if let Some(ref h) = m.hint {
185                    write!(f, " → {h}")?;
186                }
187                writeln!(f)?;
188            }
189        }
190        Ok(())
191    }
192}
193
194#[derive(Default)]
195struct OpCounts {
196    matmul: usize,
197    attention: usize,
198    rope: usize,
199    narrow: usize,
200    silu: usize,
201    mul: usize,
202    fused_matmul_bias_act: usize,
203    fused_swiglu: usize,
204    fused_residual_ln: usize,
205    fused_residual_rms_norm: usize,
206    fused_attention_block: usize,
207    fused_transformer_layer: usize,
208    elementwise_region: usize,
209    transform_region: usize,
210    batch_elementwise_region: usize,
211    fk_prologue_region: usize,
212}
213
214fn count_ops(graph: &Graph) -> OpCounts {
215    let mut s = OpCounts::default();
216    for node in graph.nodes() {
217        match &node.op {
218            Op::Attention { .. } => s.attention += 1,
219            Op::Rope { .. } => s.rope += 1,
220            Op::Narrow { .. } => s.narrow += 1,
221            Op::MatMul => s.matmul += 1,
222            Op::Activation(Activation::Silu) => s.silu += 1,
223            Op::Binary(BinaryOp::Mul) => s.mul += 1,
224            Op::FusedMatMulBiasAct { .. } => s.fused_matmul_bias_act += 1,
225            Op::FusedSwiGLU { .. } => s.fused_swiglu += 1,
226            Op::FusedResidualLN { .. } => s.fused_residual_ln += 1,
227            Op::FusedResidualRmsNorm { .. } => s.fused_residual_rms_norm += 1,
228            Op::FusedAttentionBlock { .. } => s.fused_attention_block += 1,
229            Op::FusedTransformerLayer { .. } => s.fused_transformer_layer += 1,
230            Op::ElementwiseRegion { prologue, .. } => {
231                s.elementwise_region += 1;
232                if *prologue != RegionPrologue::None {
233                    s.fk_prologue_region += 1;
234                }
235            }
236            Op::TransformRegion { .. } => s.transform_region += 1,
237            Op::BatchElementwiseRegion { .. } => s.batch_elementwise_region += 1,
238            _ => {}
239        }
240    }
241    s
242}
243
244fn missed_entry(
245    graph: &Graph,
246    pattern: &'static str,
247    node: NodeId,
248    reason: MissReason,
249) -> MissedFusion {
250    MissedFusion {
251        pattern,
252        node,
253        context: Some(node_label(graph, node)),
254        hint: Some(fusion_hint(&reason)),
255        reason,
256    }
257}
258
259fn fusion_hint(reason: &MissReason) -> String {
260    match reason {
261        MissReason::MultiConsumer => {
262            "single-consumer chain required — clone input or use HirOp::LinearFused".into()
263        }
264        MissReason::NonAddBiasConsumer => "use linear+bias or HirModule::linear_fused".into(),
265        MissReason::BiasRankTooHigh { .. } => "bias must be rank-1".into(),
266        MissReason::UnsupportedEpilogueActivation(_) => {
267            "FuseMatMulBiasAct supports Gelu/Silu only".into()
268        }
269        MissReason::SharedMatmulCount { .. } => "use shared_linear_pair or HirOp::SwiGLU".into(),
270        MissReason::SwigluGateBeforeUp => "pass up_w before gate_w in swiglu_ffn".into(),
271        MissReason::SwigluNotSharedInput => "gate and up must share the same input".into(),
272        MissReason::NotFused => "check inspect_pipeline / RLX_FUSION_REPORT=1".into(),
273    }
274}
275
276fn scan_misses(graph: &Graph) -> Vec<MissedFusion> {
277    let mut missed = Vec::new();
278    missed.extend(scan_missed_matmul_bias_act(graph));
279    missed.extend(scan_missed_shared_matmul(graph));
280    missed.extend(scan_missed_swiglu(graph));
281    missed
282}
283
284fn scan_missed_matmul_bias_act(graph: &Graph) -> Vec<MissedFusion> {
285    let mut out = Vec::new();
286    for node in graph.nodes() {
287        if !matches!(node.op, Op::MatMul) {
288            continue;
289        }
290        let mm_id = node.id;
291        let users = graph.users(mm_id);
292        if users.len() != 1 {
293            if users.len() > 1 {
294                out.push(missed_entry(
295                    graph,
296                    "matmul_bias_act",
297                    mm_id,
298                    MissReason::MultiConsumer,
299                ));
300            }
301            continue;
302        }
303        let add_node = graph.node(users[0]);
304        let Op::Binary(BinaryOp::Add) = &add_node.op else {
305            out.push(missed_entry(
306                graph,
307                "matmul_bias_act",
308                mm_id,
309                MissReason::NonAddBiasConsumer,
310            ));
311            continue;
312        };
313        let bias_id = if add_node.inputs[0] == mm_id {
314            add_node.inputs[1]
315        } else {
316            add_node.inputs[0]
317        };
318        let bias_rank = graph.shape(bias_id).rank();
319        if bias_rank > 1 {
320            out.push(missed_entry(
321                graph,
322                "matmul_bias_act",
323                mm_id,
324                MissReason::BiasRankTooHigh { rank: bias_rank },
325            ));
326            continue;
327        }
328        let add_users = graph.users(add_node.id);
329        if add_users.len() == 1 {
330            if let Op::Activation(act) = &graph.node(add_users[0]).op
331                && !fusible_mm_bias_epilogue(*act)
332            {
333                out.push(missed_entry(
334                    graph,
335                    "matmul_bias_act",
336                    mm_id,
337                    MissReason::UnsupportedEpilogueActivation(*act),
338                ));
339            }
340        }
341    }
342    out
343}
344
345fn fusible_mm_bias_epilogue(act: Activation) -> bool {
346    matches!(act, Activation::Gelu | Activation::Silu)
347}
348
349fn scan_missed_shared_matmul(graph: &Graph) -> Vec<MissedFusion> {
350    let mut input_to_matmuls: std::collections::HashMap<NodeId, Vec<NodeId>> =
351        std::collections::HashMap::new();
352    for node in graph.nodes() {
353        if matches!(node.op, Op::MatMul) {
354            input_to_matmuls
355                .entry(node.inputs[0])
356                .or_default()
357                .push(node.id);
358        }
359    }
360    let mut out = Vec::new();
361    for matmuls in input_to_matmuls.values() {
362        if matmuls.len() == 2 {
363            let a = graph.node(matmuls[0]);
364            let b = graph.node(matmuls[1]);
365            let w1 = graph.shape(a.inputs[1]);
366            let w2 = graph.shape(b.inputs[1]);
367            if w1.rank() == 2 && w2.rank() == 2 && w1.dim(0) == w2.dim(0) {
368                out.push(missed_entry(
369                    graph,
370                    "shared_input_matmul",
371                    matmuls[0],
372                    MissReason::NotFused,
373                ));
374            }
375        } else if matmuls.len() > 2 {
376            out.push(missed_entry(
377                graph,
378                "shared_input_matmul",
379                matmuls[0],
380                MissReason::SharedMatmulCount {
381                    count: matmuls.len(),
382                },
383            ));
384        }
385    }
386    out
387}
388
389fn scan_missed_swiglu(graph: &Graph) -> Vec<MissedFusion> {
390    let mut out = Vec::new();
391    for node in graph.nodes() {
392        if !matches!(node.op, Op::Binary(BinaryOp::Mul)) {
393            continue;
394        }
395        let lhs = graph.node(node.inputs[0]);
396        let rhs = graph.node(node.inputs[1]);
397        let (up_side, silu_side) = if matches!(rhs.op, Op::Activation(Activation::Silu)) {
398            (lhs, rhs)
399        } else if matches!(lhs.op, Op::Activation(Activation::Silu)) {
400            (rhs, lhs)
401        } else {
402            continue;
403        };
404        if !matches!(up_side.op, Op::MatMul) {
405            continue;
406        }
407        let gate_mm = graph.node(silu_side.inputs[0]);
408        if !matches!(gate_mm.op, Op::MatMul) {
409            continue;
410        }
411        if up_side.inputs[0] != gate_mm.inputs[0] {
412            out.push(missed_entry(
413                graph,
414                "swiglu",
415                node.id,
416                MissReason::SwigluNotSharedInput,
417            ));
418            continue;
419        }
420        // Gate-before-up declaration order prevents FuseSwiGLU after shared-input concat.
421        if graph
422            .nodes()
423            .iter()
424            .position(|n| n.id == up_side.id)
425            .zip(graph.nodes().iter().position(|n| n.id == gate_mm.id))
426            .is_some_and(|(up_idx, gate_idx)| gate_idx < up_idx)
427        {
428            out.push(missed_entry(
429                graph,
430                "swiglu",
431                node.id,
432                MissReason::SwigluGateBeforeUp,
433            ));
434        }
435    }
436    out
437}
438
439#[cfg(test)]
440mod tests {
441    use super::*;
442    use rlx_ir::DType;
443    use rlx_ir::Shape;
444    use rlx_ir::infer::GraphExt;
445
446    fn f32_shape(dims: &[usize]) -> Shape {
447        Shape::new(dims, DType::F32)
448    }
449
450    #[test]
451    fn report_counts_fused_ops() {
452        use crate::fusion::{FuseSharedInputMatMul, FuseSwiGLU};
453        use crate::pass::Pass;
454
455        let mut g = Graph::new("report");
456        let x = g.input("x", f32_shape(&[4, 768]));
457        let up_w = g.param("up", f32_shape(&[768, 128]));
458        let gate_w = g.param("gate", f32_shape(&[768, 128]));
459        let down_w = g.param("down", f32_shape(&[128, 768]));
460        let out = g.swiglu_ffn(x, up_w, gate_w, down_w);
461        g.set_outputs(vec![out]);
462        let before = g.clone();
463
464        g = FuseSharedInputMatMul.run(g);
465        g = FuseSwiGLU.run(g);
466
467        let report = FusionReport::analyze(&before, &g);
468        assert_eq!(report.fused_swiglu, 1);
469        assert!(report.nodes_after < report.nodes_before);
470    }
471
472    #[test]
473    fn report_flags_gate_before_up() {
474        let mut g = Graph::new("gate_first");
475        let x = g.input("x", f32_shape(&[4, 8]));
476        let gate_w = g.param("gate", f32_shape(&[8, 16]));
477        let up_w = g.param("up", f32_shape(&[8, 16]));
478        let gate = g.mm(x, gate_w);
479        let up = g.mm(x, up_w);
480        let gate_silu = g.silu(gate);
481        let out = g.mul(gate_silu, up);
482        g.set_outputs(vec![out]);
483
484        let report = FusionReport::scan(&g);
485        assert!(report.missed_swiglu() >= 1);
486        assert!(
487            report
488                .missed
489                .iter()
490                .any(|m| m.reason == MissReason::SwigluGateBeforeUp)
491        );
492    }
493}