1use rlx_ir::op::{Activation, BinaryOp};
19use rlx_ir::{Graph, NodeId, Op, node_label};
20use std::fmt;
21
22#[derive(Debug, Clone, PartialEq, Eq)]
24pub enum MissReason {
25 MultiConsumer,
26 NonAddBiasConsumer,
27 BiasRankTooHigh { rank: usize },
28 UnsupportedEpilogueActivation(Activation),
29 SharedMatmulCount { count: usize },
30 SwigluGateBeforeUp,
31 SwigluNotSharedInput,
32 NotFused,
33}
34
35#[derive(Debug, Clone, PartialEq, Eq)]
37pub struct MissedFusion {
38 pub pattern: &'static str,
39 pub node: NodeId,
40 pub reason: MissReason,
41 pub context: Option<String>,
43 pub hint: Option<String>,
45}
46
47#[derive(Debug, Clone, Default, PartialEq, Eq)]
49pub struct FusionReport {
50 pub nodes_before: usize,
51 pub nodes_after: usize,
52 pub matmul_before: usize,
53 pub attention: usize,
54 pub rope: usize,
55 pub narrow: usize,
56 pub matmul_after: usize,
57 pub silu: usize,
58 pub mul: usize,
59 pub fused_matmul_bias_act: usize,
60 pub fused_swiglu: usize,
61 pub fused_residual_ln: usize,
62 pub fused_residual_rms_norm: usize,
63 pub fused_attention_block: usize,
64 pub fused_transformer_layer: usize,
65 pub elementwise_region: usize,
66 pub missed: Vec<MissedFusion>,
67}
68
69impl FusionReport {
70 pub fn analyze(before: &Graph, after: &Graph) -> Self {
72 let before_stats = count_ops(before);
73 let after_stats = count_ops(after);
74 let missed = scan_misses(after);
75 Self {
76 nodes_before: before.len(),
77 nodes_after: after.len(),
78 matmul_before: before_stats.matmul,
79 attention: after_stats.attention,
80 rope: after_stats.rope,
81 narrow: after_stats.narrow,
82 matmul_after: after_stats.matmul,
83 silu: after_stats.silu,
84 mul: after_stats.mul,
85 fused_matmul_bias_act: after_stats.fused_matmul_bias_act,
86 fused_swiglu: after_stats.fused_swiglu,
87 fused_residual_ln: after_stats.fused_residual_ln,
88 fused_residual_rms_norm: after_stats.fused_residual_rms_norm,
89 fused_attention_block: after_stats.fused_attention_block,
90 fused_transformer_layer: after_stats.fused_transformer_layer,
91 elementwise_region: after_stats.elementwise_region,
92 missed,
93 }
94 }
95
96 pub fn scan(graph: &Graph) -> Self {
99 let stats = count_ops(graph);
100 let missed = scan_misses(graph);
101 Self {
102 nodes_before: graph.len(),
103 nodes_after: graph.len(),
104 matmul_before: stats.matmul,
105 matmul_after: stats.matmul,
106 attention: stats.attention,
107 rope: stats.rope,
108 narrow: stats.narrow,
109 silu: stats.silu,
110 mul: stats.mul,
111 fused_matmul_bias_act: stats.fused_matmul_bias_act,
112 fused_swiglu: stats.fused_swiglu,
113 fused_residual_ln: stats.fused_residual_ln,
114 fused_residual_rms_norm: stats.fused_residual_rms_norm,
115 fused_attention_block: stats.fused_attention_block,
116 fused_transformer_layer: stats.fused_transformer_layer,
117 elementwise_region: stats.elementwise_region,
118 missed,
119 }
120 }
121
122 pub fn missed_matmul_bias_act(&self) -> usize {
123 self.missed
124 .iter()
125 .filter(|m| m.pattern == "matmul_bias_act")
126 .count()
127 }
128
129 pub fn missed_swiglu(&self) -> usize {
130 self.missed.iter().filter(|m| m.pattern == "swiglu").count()
131 }
132
133 pub fn missed_shared_matmul(&self) -> usize {
134 self.missed
135 .iter()
136 .filter(|m| m.pattern == "shared_input_matmul")
137 .count()
138 }
139
140 pub fn summary_line(&self) -> String {
142 format!(
143 "nodes={}→{} matmul={}→{} fused_mm_act={} fused_swiglu={} \
144 elementwise_region={} missed_mm_act={} missed_swiglu={} missed_shared_mm={}",
145 self.nodes_before,
146 self.nodes_after,
147 self.matmul_before,
148 self.matmul_after,
149 self.fused_matmul_bias_act,
150 self.fused_swiglu,
151 self.elementwise_region,
152 self.missed_matmul_bias_act(),
153 self.missed_swiglu(),
154 self.missed_shared_matmul(),
155 )
156 }
157}
158
159impl fmt::Display for FusionReport {
160 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
161 writeln!(f, "fusion report:")?;
162 writeln!(f, " {}", self.summary_line())?;
163 if !self.missed.is_empty() {
164 writeln!(f, " missed patterns:")?;
165 for m in &self.missed {
166 write!(f, " {} @ {}", m.pattern, m.node)?;
167 if let Some(ref c) = m.context {
168 write!(f, " ({c})")?;
169 }
170 write!(f, " — {:?}", m.reason)?;
171 if let Some(ref h) = m.hint {
172 write!(f, " → {h}")?;
173 }
174 writeln!(f)?;
175 }
176 }
177 Ok(())
178 }
179}
180
181#[derive(Default)]
182struct OpCounts {
183 matmul: usize,
184 attention: usize,
185 rope: usize,
186 narrow: usize,
187 silu: usize,
188 mul: usize,
189 fused_matmul_bias_act: usize,
190 fused_swiglu: usize,
191 fused_residual_ln: usize,
192 fused_residual_rms_norm: usize,
193 fused_attention_block: usize,
194 fused_transformer_layer: usize,
195 elementwise_region: usize,
196}
197
198fn count_ops(graph: &Graph) -> OpCounts {
199 let mut s = OpCounts::default();
200 for node in graph.nodes() {
201 match &node.op {
202 Op::Attention { .. } => s.attention += 1,
203 Op::Rope { .. } => s.rope += 1,
204 Op::Narrow { .. } => s.narrow += 1,
205 Op::MatMul => s.matmul += 1,
206 Op::Activation(Activation::Silu) => s.silu += 1,
207 Op::Binary(BinaryOp::Mul) => s.mul += 1,
208 Op::FusedMatMulBiasAct { .. } => s.fused_matmul_bias_act += 1,
209 Op::FusedSwiGLU { .. } => s.fused_swiglu += 1,
210 Op::FusedResidualLN { .. } => s.fused_residual_ln += 1,
211 Op::FusedResidualRmsNorm { .. } => s.fused_residual_rms_norm += 1,
212 Op::FusedAttentionBlock { .. } => s.fused_attention_block += 1,
213 Op::FusedTransformerLayer { .. } => s.fused_transformer_layer += 1,
214 Op::ElementwiseRegion { .. } => s.elementwise_region += 1,
215 _ => {}
216 }
217 }
218 s
219}
220
221fn missed_entry(
222 graph: &Graph,
223 pattern: &'static str,
224 node: NodeId,
225 reason: MissReason,
226) -> MissedFusion {
227 MissedFusion {
228 pattern,
229 node,
230 context: Some(node_label(graph, node)),
231 hint: Some(fusion_hint(&reason)),
232 reason,
233 }
234}
235
236fn fusion_hint(reason: &MissReason) -> String {
237 match reason {
238 MissReason::MultiConsumer => {
239 "single-consumer chain required — clone input or use HirOp::LinearFused".into()
240 }
241 MissReason::NonAddBiasConsumer => "use linear+bias or HirModule::linear_fused".into(),
242 MissReason::BiasRankTooHigh { .. } => "bias must be rank-1".into(),
243 MissReason::UnsupportedEpilogueActivation(_) => {
244 "FuseMatMulBiasAct supports Gelu/Silu only".into()
245 }
246 MissReason::SharedMatmulCount { .. } => "use shared_linear_pair or HirOp::SwiGLU".into(),
247 MissReason::SwigluGateBeforeUp => "pass up_w before gate_w in swiglu_ffn".into(),
248 MissReason::SwigluNotSharedInput => "gate and up must share the same input".into(),
249 MissReason::NotFused => "check inspect_pipeline / RLX_FUSION_REPORT=1".into(),
250 }
251}
252
253fn scan_misses(graph: &Graph) -> Vec<MissedFusion> {
254 let mut missed = Vec::new();
255 missed.extend(scan_missed_matmul_bias_act(graph));
256 missed.extend(scan_missed_shared_matmul(graph));
257 missed.extend(scan_missed_swiglu(graph));
258 missed
259}
260
261fn scan_missed_matmul_bias_act(graph: &Graph) -> Vec<MissedFusion> {
262 let mut out = Vec::new();
263 for node in graph.nodes() {
264 if !matches!(node.op, Op::MatMul) {
265 continue;
266 }
267 let mm_id = node.id;
268 let users = graph.users(mm_id);
269 if users.len() != 1 {
270 if users.len() > 1 {
271 out.push(missed_entry(
272 graph,
273 "matmul_bias_act",
274 mm_id,
275 MissReason::MultiConsumer,
276 ));
277 }
278 continue;
279 }
280 let add_node = graph.node(users[0]);
281 let Op::Binary(BinaryOp::Add) = &add_node.op else {
282 out.push(missed_entry(
283 graph,
284 "matmul_bias_act",
285 mm_id,
286 MissReason::NonAddBiasConsumer,
287 ));
288 continue;
289 };
290 let bias_id = if add_node.inputs[0] == mm_id {
291 add_node.inputs[1]
292 } else {
293 add_node.inputs[0]
294 };
295 let bias_rank = graph.shape(bias_id).rank();
296 if bias_rank > 1 {
297 out.push(missed_entry(
298 graph,
299 "matmul_bias_act",
300 mm_id,
301 MissReason::BiasRankTooHigh { rank: bias_rank },
302 ));
303 continue;
304 }
305 let add_users = graph.users(add_node.id);
306 if add_users.len() == 1 {
307 if let Op::Activation(act) = &graph.node(add_users[0]).op
308 && !fusible_mm_bias_epilogue(*act)
309 {
310 out.push(missed_entry(
311 graph,
312 "matmul_bias_act",
313 mm_id,
314 MissReason::UnsupportedEpilogueActivation(*act),
315 ));
316 }
317 }
318 }
319 out
320}
321
322fn fusible_mm_bias_epilogue(act: Activation) -> bool {
323 matches!(act, Activation::Gelu | Activation::Silu)
324}
325
326fn scan_missed_shared_matmul(graph: &Graph) -> Vec<MissedFusion> {
327 let mut input_to_matmuls: std::collections::HashMap<NodeId, Vec<NodeId>> =
328 std::collections::HashMap::new();
329 for node in graph.nodes() {
330 if matches!(node.op, Op::MatMul) {
331 input_to_matmuls
332 .entry(node.inputs[0])
333 .or_default()
334 .push(node.id);
335 }
336 }
337 let mut out = Vec::new();
338 for matmuls in input_to_matmuls.values() {
339 if matmuls.len() == 2 {
340 let a = graph.node(matmuls[0]);
341 let b = graph.node(matmuls[1]);
342 let w1 = graph.shape(a.inputs[1]);
343 let w2 = graph.shape(b.inputs[1]);
344 if w1.rank() == 2 && w2.rank() == 2 && w1.dim(0) == w2.dim(0) {
345 out.push(missed_entry(
346 graph,
347 "shared_input_matmul",
348 matmuls[0],
349 MissReason::NotFused,
350 ));
351 }
352 } else if matmuls.len() > 2 {
353 out.push(missed_entry(
354 graph,
355 "shared_input_matmul",
356 matmuls[0],
357 MissReason::SharedMatmulCount {
358 count: matmuls.len(),
359 },
360 ));
361 }
362 }
363 out
364}
365
366fn scan_missed_swiglu(graph: &Graph) -> Vec<MissedFusion> {
367 let mut out = Vec::new();
368 for node in graph.nodes() {
369 if !matches!(node.op, Op::Binary(BinaryOp::Mul)) {
370 continue;
371 }
372 let lhs = graph.node(node.inputs[0]);
373 let rhs = graph.node(node.inputs[1]);
374 let (up_side, silu_side) = if matches!(rhs.op, Op::Activation(Activation::Silu)) {
375 (lhs, rhs)
376 } else if matches!(lhs.op, Op::Activation(Activation::Silu)) {
377 (rhs, lhs)
378 } else {
379 continue;
380 };
381 if !matches!(up_side.op, Op::MatMul) {
382 continue;
383 }
384 let gate_mm = graph.node(silu_side.inputs[0]);
385 if !matches!(gate_mm.op, Op::MatMul) {
386 continue;
387 }
388 if up_side.inputs[0] != gate_mm.inputs[0] {
389 out.push(missed_entry(
390 graph,
391 "swiglu",
392 node.id,
393 MissReason::SwigluNotSharedInput,
394 ));
395 continue;
396 }
397 if graph
399 .nodes()
400 .iter()
401 .position(|n| n.id == up_side.id)
402 .zip(graph.nodes().iter().position(|n| n.id == gate_mm.id))
403 .is_some_and(|(up_idx, gate_idx)| gate_idx < up_idx)
404 {
405 out.push(missed_entry(
406 graph,
407 "swiglu",
408 node.id,
409 MissReason::SwigluGateBeforeUp,
410 ));
411 }
412 }
413 out
414}
415
416#[cfg(test)]
417mod tests {
418 use super::*;
419 use rlx_ir::DType;
420 use rlx_ir::Shape;
421 use rlx_ir::infer::GraphExt;
422
423 fn f32_shape(dims: &[usize]) -> Shape {
424 Shape::new(dims, DType::F32)
425 }
426
427 #[test]
428 fn report_counts_fused_ops() {
429 use crate::fusion::{FuseSharedInputMatMul, FuseSwiGLU};
430 use crate::pass::Pass;
431
432 let mut g = Graph::new("report");
433 let x = g.input("x", f32_shape(&[4, 768]));
434 let up_w = g.param("up", f32_shape(&[768, 128]));
435 let gate_w = g.param("gate", f32_shape(&[768, 128]));
436 let down_w = g.param("down", f32_shape(&[128, 768]));
437 let out = g.swiglu_ffn(x, up_w, gate_w, down_w);
438 g.set_outputs(vec![out]);
439 let before = g.clone();
440
441 g = FuseSharedInputMatMul.run(g);
442 g = FuseSwiGLU.run(g);
443
444 let report = FusionReport::analyze(&before, &g);
445 assert_eq!(report.fused_swiglu, 1);
446 assert!(report.nodes_after < report.nodes_before);
447 }
448
449 #[test]
450 fn report_flags_gate_before_up() {
451 let mut g = Graph::new("gate_first");
452 let x = g.input("x", f32_shape(&[4, 8]));
453 let gate_w = g.param("gate", f32_shape(&[8, 16]));
454 let up_w = g.param("up", f32_shape(&[8, 16]));
455 let gate = g.mm(x, gate_w);
456 let up = g.mm(x, up_w);
457 let gate_silu = g.silu(gate);
458 let out = g.mul(gate_silu, up);
459 g.set_outputs(vec![out]);
460
461 let report = FusionReport::scan(&g);
462 assert!(report.missed_swiglu() >= 1);
463 assert!(
464 report
465 .missed
466 .iter()
467 .any(|m| m.reason == MissReason::SwigluGateBeforeUp)
468 );
469 }
470}