Skip to main content

rlx_fusion/
fusion_report.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Fusion diagnostics — what fused, what missed, and why.
17
18use rlx_ir::op::{Activation, BinaryOp};
19use rlx_ir::{Graph, NodeId, Op, node_label};
20use std::fmt;
21
22/// Why a recognizable fusion pattern was not collapsed.
23#[derive(Debug, Clone, PartialEq, Eq)]
24pub enum MissReason {
25    MultiConsumer,
26    NonAddBiasConsumer,
27    BiasRankTooHigh { rank: usize },
28    UnsupportedEpilogueActivation(Activation),
29    SharedMatmulCount { count: usize },
30    SwigluGateBeforeUp,
31    SwigluNotSharedInput,
32    NotFused,
33}
34
35/// A single fusion opportunity that remains in the graph.
36#[derive(Debug, Clone, PartialEq, Eq)]
37pub struct MissedFusion {
38    pub pattern: &'static str,
39    pub node: NodeId,
40    pub reason: MissReason,
41    /// HIR label / node name when available.
42    pub context: Option<String>,
43    /// Actionable fix hint.
44    pub hint: Option<String>,
45}
46
47/// Before/after fusion statistics and missed-pattern tally.
48#[derive(Debug, Clone, Default, PartialEq, Eq)]
49pub struct FusionReport {
50    pub nodes_before: usize,
51    pub nodes_after: usize,
52    pub matmul_before: usize,
53    pub attention: usize,
54    pub rope: usize,
55    pub narrow: usize,
56    pub matmul_after: usize,
57    pub silu: usize,
58    pub mul: usize,
59    pub fused_matmul_bias_act: usize,
60    pub fused_swiglu: usize,
61    pub fused_residual_ln: usize,
62    pub fused_residual_rms_norm: usize,
63    pub fused_attention_block: usize,
64    pub fused_transformer_layer: usize,
65    pub elementwise_region: usize,
66    pub missed: Vec<MissedFusion>,
67}
68
69impl FusionReport {
70    /// Compare an unfused graph with the post-pass result.
71    pub fn analyze(before: &Graph, after: &Graph) -> Self {
72        let before_stats = count_ops(before);
73        let after_stats = count_ops(after);
74        let missed = scan_misses(after);
75        Self {
76            nodes_before: before.len(),
77            nodes_after: after.len(),
78            matmul_before: before_stats.matmul,
79            attention: after_stats.attention,
80            rope: after_stats.rope,
81            narrow: after_stats.narrow,
82            matmul_after: after_stats.matmul,
83            silu: after_stats.silu,
84            mul: after_stats.mul,
85            fused_matmul_bias_act: after_stats.fused_matmul_bias_act,
86            fused_swiglu: after_stats.fused_swiglu,
87            fused_residual_ln: after_stats.fused_residual_ln,
88            fused_residual_rms_norm: after_stats.fused_residual_rms_norm,
89            fused_attention_block: after_stats.fused_attention_block,
90            fused_transformer_layer: after_stats.fused_transformer_layer,
91            elementwise_region: after_stats.elementwise_region,
92            missed,
93        }
94    }
95
96    /// Scan a graph (typically post-fusion) for patterns that should
97    /// have collapsed but did not.
98    pub fn scan(graph: &Graph) -> Self {
99        let stats = count_ops(graph);
100        let missed = scan_misses(graph);
101        Self {
102            nodes_before: graph.len(),
103            nodes_after: graph.len(),
104            matmul_before: stats.matmul,
105            matmul_after: stats.matmul,
106            attention: stats.attention,
107            rope: stats.rope,
108            narrow: stats.narrow,
109            silu: stats.silu,
110            mul: stats.mul,
111            fused_matmul_bias_act: stats.fused_matmul_bias_act,
112            fused_swiglu: stats.fused_swiglu,
113            fused_residual_ln: stats.fused_residual_ln,
114            fused_residual_rms_norm: stats.fused_residual_rms_norm,
115            fused_attention_block: stats.fused_attention_block,
116            fused_transformer_layer: stats.fused_transformer_layer,
117            elementwise_region: stats.elementwise_region,
118            missed,
119        }
120    }
121
122    pub fn missed_matmul_bias_act(&self) -> usize {
123        self.missed
124            .iter()
125            .filter(|m| m.pattern == "matmul_bias_act")
126            .count()
127    }
128
129    pub fn missed_swiglu(&self) -> usize {
130        self.missed.iter().filter(|m| m.pattern == "swiglu").count()
131    }
132
133    pub fn missed_shared_matmul(&self) -> usize {
134        self.missed
135            .iter()
136            .filter(|m| m.pattern == "shared_input_matmul")
137            .count()
138    }
139
140    /// One-line summary suitable for logs and CSV benches.
141    pub fn summary_line(&self) -> String {
142        format!(
143            "nodes={}→{} matmul={}→{} fused_mm_act={} fused_swiglu={} \
144             elementwise_region={} missed_mm_act={} missed_swiglu={} missed_shared_mm={}",
145            self.nodes_before,
146            self.nodes_after,
147            self.matmul_before,
148            self.matmul_after,
149            self.fused_matmul_bias_act,
150            self.fused_swiglu,
151            self.elementwise_region,
152            self.missed_matmul_bias_act(),
153            self.missed_swiglu(),
154            self.missed_shared_matmul(),
155        )
156    }
157}
158
159impl fmt::Display for FusionReport {
160    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
161        writeln!(f, "fusion report:")?;
162        writeln!(f, "  {}", self.summary_line())?;
163        if !self.missed.is_empty() {
164            writeln!(f, "  missed patterns:")?;
165            for m in &self.missed {
166                write!(f, "    {} @ {}", m.pattern, m.node)?;
167                if let Some(ref c) = m.context {
168                    write!(f, " ({c})")?;
169                }
170                write!(f, " — {:?}", m.reason)?;
171                if let Some(ref h) = m.hint {
172                    write!(f, " → {h}")?;
173                }
174                writeln!(f)?;
175            }
176        }
177        Ok(())
178    }
179}
180
181#[derive(Default)]
182struct OpCounts {
183    matmul: usize,
184    attention: usize,
185    rope: usize,
186    narrow: usize,
187    silu: usize,
188    mul: usize,
189    fused_matmul_bias_act: usize,
190    fused_swiglu: usize,
191    fused_residual_ln: usize,
192    fused_residual_rms_norm: usize,
193    fused_attention_block: usize,
194    fused_transformer_layer: usize,
195    elementwise_region: usize,
196}
197
198fn count_ops(graph: &Graph) -> OpCounts {
199    let mut s = OpCounts::default();
200    for node in graph.nodes() {
201        match &node.op {
202            Op::Attention { .. } => s.attention += 1,
203            Op::Rope { .. } => s.rope += 1,
204            Op::Narrow { .. } => s.narrow += 1,
205            Op::MatMul => s.matmul += 1,
206            Op::Activation(Activation::Silu) => s.silu += 1,
207            Op::Binary(BinaryOp::Mul) => s.mul += 1,
208            Op::FusedMatMulBiasAct { .. } => s.fused_matmul_bias_act += 1,
209            Op::FusedSwiGLU { .. } => s.fused_swiglu += 1,
210            Op::FusedResidualLN { .. } => s.fused_residual_ln += 1,
211            Op::FusedResidualRmsNorm { .. } => s.fused_residual_rms_norm += 1,
212            Op::FusedAttentionBlock { .. } => s.fused_attention_block += 1,
213            Op::FusedTransformerLayer { .. } => s.fused_transformer_layer += 1,
214            Op::ElementwiseRegion { .. } => s.elementwise_region += 1,
215            _ => {}
216        }
217    }
218    s
219}
220
221fn missed_entry(
222    graph: &Graph,
223    pattern: &'static str,
224    node: NodeId,
225    reason: MissReason,
226) -> MissedFusion {
227    MissedFusion {
228        pattern,
229        node,
230        context: Some(node_label(graph, node)),
231        hint: Some(fusion_hint(&reason)),
232        reason,
233    }
234}
235
236fn fusion_hint(reason: &MissReason) -> String {
237    match reason {
238        MissReason::MultiConsumer => {
239            "single-consumer chain required — clone input or use HirOp::LinearFused".into()
240        }
241        MissReason::NonAddBiasConsumer => "use linear+bias or HirModule::linear_fused".into(),
242        MissReason::BiasRankTooHigh { .. } => "bias must be rank-1".into(),
243        MissReason::UnsupportedEpilogueActivation(_) => {
244            "FuseMatMulBiasAct supports Gelu/Silu only".into()
245        }
246        MissReason::SharedMatmulCount { .. } => "use shared_linear_pair or HirOp::SwiGLU".into(),
247        MissReason::SwigluGateBeforeUp => "pass up_w before gate_w in swiglu_ffn".into(),
248        MissReason::SwigluNotSharedInput => "gate and up must share the same input".into(),
249        MissReason::NotFused => "check inspect_pipeline / RLX_FUSION_REPORT=1".into(),
250    }
251}
252
253fn scan_misses(graph: &Graph) -> Vec<MissedFusion> {
254    let mut missed = Vec::new();
255    missed.extend(scan_missed_matmul_bias_act(graph));
256    missed.extend(scan_missed_shared_matmul(graph));
257    missed.extend(scan_missed_swiglu(graph));
258    missed
259}
260
261fn scan_missed_matmul_bias_act(graph: &Graph) -> Vec<MissedFusion> {
262    let mut out = Vec::new();
263    for node in graph.nodes() {
264        if !matches!(node.op, Op::MatMul) {
265            continue;
266        }
267        let mm_id = node.id;
268        let users = graph.users(mm_id);
269        if users.len() != 1 {
270            if users.len() > 1 {
271                out.push(missed_entry(
272                    graph,
273                    "matmul_bias_act",
274                    mm_id,
275                    MissReason::MultiConsumer,
276                ));
277            }
278            continue;
279        }
280        let add_node = graph.node(users[0]);
281        let Op::Binary(BinaryOp::Add) = &add_node.op else {
282            out.push(missed_entry(
283                graph,
284                "matmul_bias_act",
285                mm_id,
286                MissReason::NonAddBiasConsumer,
287            ));
288            continue;
289        };
290        let bias_id = if add_node.inputs[0] == mm_id {
291            add_node.inputs[1]
292        } else {
293            add_node.inputs[0]
294        };
295        let bias_rank = graph.shape(bias_id).rank();
296        if bias_rank > 1 {
297            out.push(missed_entry(
298                graph,
299                "matmul_bias_act",
300                mm_id,
301                MissReason::BiasRankTooHigh { rank: bias_rank },
302            ));
303            continue;
304        }
305        let add_users = graph.users(add_node.id);
306        if add_users.len() == 1 {
307            if let Op::Activation(act) = &graph.node(add_users[0]).op
308                && !fusible_mm_bias_epilogue(*act)
309            {
310                out.push(missed_entry(
311                    graph,
312                    "matmul_bias_act",
313                    mm_id,
314                    MissReason::UnsupportedEpilogueActivation(*act),
315                ));
316            }
317        }
318    }
319    out
320}
321
322fn fusible_mm_bias_epilogue(act: Activation) -> bool {
323    matches!(act, Activation::Gelu | Activation::Silu)
324}
325
326fn scan_missed_shared_matmul(graph: &Graph) -> Vec<MissedFusion> {
327    let mut input_to_matmuls: std::collections::HashMap<NodeId, Vec<NodeId>> =
328        std::collections::HashMap::new();
329    for node in graph.nodes() {
330        if matches!(node.op, Op::MatMul) {
331            input_to_matmuls
332                .entry(node.inputs[0])
333                .or_default()
334                .push(node.id);
335        }
336    }
337    let mut out = Vec::new();
338    for matmuls in input_to_matmuls.values() {
339        if matmuls.len() == 2 {
340            let a = graph.node(matmuls[0]);
341            let b = graph.node(matmuls[1]);
342            let w1 = graph.shape(a.inputs[1]);
343            let w2 = graph.shape(b.inputs[1]);
344            if w1.rank() == 2 && w2.rank() == 2 && w1.dim(0) == w2.dim(0) {
345                out.push(missed_entry(
346                    graph,
347                    "shared_input_matmul",
348                    matmuls[0],
349                    MissReason::NotFused,
350                ));
351            }
352        } else if matmuls.len() > 2 {
353            out.push(missed_entry(
354                graph,
355                "shared_input_matmul",
356                matmuls[0],
357                MissReason::SharedMatmulCount {
358                    count: matmuls.len(),
359                },
360            ));
361        }
362    }
363    out
364}
365
366fn scan_missed_swiglu(graph: &Graph) -> Vec<MissedFusion> {
367    let mut out = Vec::new();
368    for node in graph.nodes() {
369        if !matches!(node.op, Op::Binary(BinaryOp::Mul)) {
370            continue;
371        }
372        let lhs = graph.node(node.inputs[0]);
373        let rhs = graph.node(node.inputs[1]);
374        let (up_side, silu_side) = if matches!(rhs.op, Op::Activation(Activation::Silu)) {
375            (lhs, rhs)
376        } else if matches!(lhs.op, Op::Activation(Activation::Silu)) {
377            (rhs, lhs)
378        } else {
379            continue;
380        };
381        if !matches!(up_side.op, Op::MatMul) {
382            continue;
383        }
384        let gate_mm = graph.node(silu_side.inputs[0]);
385        if !matches!(gate_mm.op, Op::MatMul) {
386            continue;
387        }
388        if up_side.inputs[0] != gate_mm.inputs[0] {
389            out.push(missed_entry(
390                graph,
391                "swiglu",
392                node.id,
393                MissReason::SwigluNotSharedInput,
394            ));
395            continue;
396        }
397        // Gate-before-up declaration order prevents FuseSwiGLU after shared-input concat.
398        if graph
399            .nodes()
400            .iter()
401            .position(|n| n.id == up_side.id)
402            .zip(graph.nodes().iter().position(|n| n.id == gate_mm.id))
403            .is_some_and(|(up_idx, gate_idx)| gate_idx < up_idx)
404        {
405            out.push(missed_entry(
406                graph,
407                "swiglu",
408                node.id,
409                MissReason::SwigluGateBeforeUp,
410            ));
411        }
412    }
413    out
414}
415
416#[cfg(test)]
417mod tests {
418    use super::*;
419    use rlx_ir::DType;
420    use rlx_ir::Shape;
421    use rlx_ir::infer::GraphExt;
422
423    fn f32_shape(dims: &[usize]) -> Shape {
424        Shape::new(dims, DType::F32)
425    }
426
427    #[test]
428    fn report_counts_fused_ops() {
429        use crate::fusion::{FuseSharedInputMatMul, FuseSwiGLU};
430        use crate::pass::Pass;
431
432        let mut g = Graph::new("report");
433        let x = g.input("x", f32_shape(&[4, 768]));
434        let up_w = g.param("up", f32_shape(&[768, 128]));
435        let gate_w = g.param("gate", f32_shape(&[768, 128]));
436        let down_w = g.param("down", f32_shape(&[128, 768]));
437        let out = g.swiglu_ffn(x, up_w, gate_w, down_w);
438        g.set_outputs(vec![out]);
439        let before = g.clone();
440
441        g = FuseSharedInputMatMul.run(g);
442        g = FuseSwiGLU.run(g);
443
444        let report = FusionReport::analyze(&before, &g);
445        assert_eq!(report.fused_swiglu, 1);
446        assert!(report.nodes_after < report.nodes_before);
447    }
448
449    #[test]
450    fn report_flags_gate_before_up() {
451        let mut g = Graph::new("gate_first");
452        let x = g.input("x", f32_shape(&[4, 8]));
453        let gate_w = g.param("gate", f32_shape(&[8, 16]));
454        let up_w = g.param("up", f32_shape(&[8, 16]));
455        let gate = g.mm(x, gate_w);
456        let up = g.mm(x, up_w);
457        let gate_silu = g.silu(gate);
458        let out = g.mul(gate_silu, up);
459        g.set_outputs(vec![out]);
460
461        let report = FusionReport::scan(&g);
462        assert!(report.missed_swiglu() >= 1);
463        assert!(
464            report
465                .missed
466                .iter()
467                .any(|m| m.reason == MissReason::SwigluGateBeforeUp)
468        );
469    }
470}