1use rlx_ir::op::{Activation, BinaryOp, RegionPrologue};
19use rlx_ir::{Graph, NodeId, Op, node_label};
20use std::fmt;
21
22#[derive(Debug, Clone, PartialEq, Eq)]
24pub enum MissReason {
25 MultiConsumer,
26 NonAddBiasConsumer,
27 BiasRankTooHigh { rank: usize },
28 UnsupportedEpilogueActivation(Activation),
29 SharedMatmulCount { count: usize },
30 SwigluGateBeforeUp,
31 SwigluNotSharedInput,
32 NotFused,
33}
34
35#[derive(Debug, Clone, PartialEq, Eq)]
37pub struct MissedFusion {
38 pub pattern: &'static str,
39 pub node: NodeId,
40 pub reason: MissReason,
41 pub context: Option<String>,
43 pub hint: Option<String>,
45}
46
47#[derive(Debug, Clone, Default, PartialEq, Eq)]
49pub struct FusionReport {
50 pub nodes_before: usize,
51 pub nodes_after: usize,
52 pub matmul_before: usize,
53 pub attention: usize,
54 pub rope: usize,
55 pub narrow: usize,
56 pub matmul_after: usize,
57 pub silu: usize,
58 pub mul: usize,
59 pub fused_matmul_bias_act: usize,
60 pub fused_swiglu: usize,
61 pub fused_residual_ln: usize,
62 pub fused_residual_rms_norm: usize,
63 pub fused_attention_block: usize,
64 pub fused_transformer_layer: usize,
65 pub elementwise_region: usize,
66 pub transform_region: usize,
67 pub batch_elementwise_region: usize,
68 pub fk_prologue_region: usize,
69 pub missed: Vec<MissedFusion>,
70}
71
72impl FusionReport {
73 pub fn analyze(before: &Graph, after: &Graph) -> Self {
75 let before_stats = count_ops(before);
76 let after_stats = count_ops(after);
77 let missed = scan_misses(after);
78 Self {
79 nodes_before: before.len(),
80 nodes_after: after.len(),
81 matmul_before: before_stats.matmul,
82 attention: after_stats.attention,
83 rope: after_stats.rope,
84 narrow: after_stats.narrow,
85 matmul_after: after_stats.matmul,
86 silu: after_stats.silu,
87 mul: after_stats.mul,
88 fused_matmul_bias_act: after_stats.fused_matmul_bias_act,
89 fused_swiglu: after_stats.fused_swiglu,
90 fused_residual_ln: after_stats.fused_residual_ln,
91 fused_residual_rms_norm: after_stats.fused_residual_rms_norm,
92 fused_attention_block: after_stats.fused_attention_block,
93 fused_transformer_layer: after_stats.fused_transformer_layer,
94 elementwise_region: after_stats.elementwise_region,
95 transform_region: after_stats.transform_region,
96 batch_elementwise_region: after_stats.batch_elementwise_region,
97 fk_prologue_region: after_stats.fk_prologue_region,
98 missed,
99 }
100 }
101
102 pub fn scan(graph: &Graph) -> Self {
105 let stats = count_ops(graph);
106 let missed = scan_misses(graph);
107 Self {
108 nodes_before: graph.len(),
109 nodes_after: graph.len(),
110 matmul_before: stats.matmul,
111 matmul_after: stats.matmul,
112 attention: stats.attention,
113 rope: stats.rope,
114 narrow: stats.narrow,
115 silu: stats.silu,
116 mul: stats.mul,
117 fused_matmul_bias_act: stats.fused_matmul_bias_act,
118 fused_swiglu: stats.fused_swiglu,
119 fused_residual_ln: stats.fused_residual_ln,
120 fused_residual_rms_norm: stats.fused_residual_rms_norm,
121 fused_attention_block: stats.fused_attention_block,
122 fused_transformer_layer: stats.fused_transformer_layer,
123 elementwise_region: stats.elementwise_region,
124 transform_region: stats.transform_region,
125 batch_elementwise_region: stats.batch_elementwise_region,
126 fk_prologue_region: stats.fk_prologue_region,
127 missed,
128 }
129 }
130
131 pub fn missed_matmul_bias_act(&self) -> usize {
132 self.missed
133 .iter()
134 .filter(|m| m.pattern == "matmul_bias_act")
135 .count()
136 }
137
138 pub fn missed_swiglu(&self) -> usize {
139 self.missed.iter().filter(|m| m.pattern == "swiglu").count()
140 }
141
142 pub fn missed_shared_matmul(&self) -> usize {
143 self.missed
144 .iter()
145 .filter(|m| m.pattern == "shared_input_matmul")
146 .count()
147 }
148
149 pub fn summary_line(&self) -> String {
151 format!(
152 "nodes={}→{} matmul={}→{} fused_mm_act={} fused_swiglu={} \
153 elementwise_region={} transform_region={} batch_region={} fk_prologue={} \
154 missed_mm_act={} missed_swiglu={} missed_shared_mm={}",
155 self.nodes_before,
156 self.nodes_after,
157 self.matmul_before,
158 self.matmul_after,
159 self.fused_matmul_bias_act,
160 self.fused_swiglu,
161 self.elementwise_region,
162 self.transform_region,
163 self.batch_elementwise_region,
164 self.fk_prologue_region,
165 self.missed_matmul_bias_act(),
166 self.missed_swiglu(),
167 self.missed_shared_matmul(),
168 )
169 }
170}
171
172impl fmt::Display for FusionReport {
173 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
174 writeln!(f, "fusion report:")?;
175 writeln!(f, " {}", self.summary_line())?;
176 if !self.missed.is_empty() {
177 writeln!(f, " missed patterns:")?;
178 for m in &self.missed {
179 write!(f, " {} @ {}", m.pattern, m.node)?;
180 if let Some(ref c) = m.context {
181 write!(f, " ({c})")?;
182 }
183 write!(f, " — {:?}", m.reason)?;
184 if let Some(ref h) = m.hint {
185 write!(f, " → {h}")?;
186 }
187 writeln!(f)?;
188 }
189 }
190 Ok(())
191 }
192}
193
194#[derive(Default)]
195struct OpCounts {
196 matmul: usize,
197 attention: usize,
198 rope: usize,
199 narrow: usize,
200 silu: usize,
201 mul: usize,
202 fused_matmul_bias_act: usize,
203 fused_swiglu: usize,
204 fused_residual_ln: usize,
205 fused_residual_rms_norm: usize,
206 fused_attention_block: usize,
207 fused_transformer_layer: usize,
208 elementwise_region: usize,
209 transform_region: usize,
210 batch_elementwise_region: usize,
211 fk_prologue_region: usize,
212}
213
214fn count_ops(graph: &Graph) -> OpCounts {
215 let mut s = OpCounts::default();
216 for node in graph.nodes() {
217 match &node.op {
218 Op::Attention { .. } => s.attention += 1,
219 Op::Rope { .. } => s.rope += 1,
220 Op::Narrow { .. } => s.narrow += 1,
221 Op::MatMul => s.matmul += 1,
222 Op::Activation(Activation::Silu) => s.silu += 1,
223 Op::Binary(BinaryOp::Mul) => s.mul += 1,
224 Op::FusedMatMulBiasAct { .. } => s.fused_matmul_bias_act += 1,
225 Op::FusedSwiGLU { .. } => s.fused_swiglu += 1,
226 Op::FusedResidualLN { .. } => s.fused_residual_ln += 1,
227 Op::FusedResidualRmsNorm { .. } => s.fused_residual_rms_norm += 1,
228 Op::FusedAttentionBlock { .. } => s.fused_attention_block += 1,
229 Op::FusedTransformerLayer { .. } => s.fused_transformer_layer += 1,
230 Op::ElementwiseRegion { prologue, .. } => {
231 s.elementwise_region += 1;
232 if *prologue != RegionPrologue::None {
233 s.fk_prologue_region += 1;
234 }
235 }
236 Op::TransformRegion { .. } => s.transform_region += 1,
237 Op::BatchElementwiseRegion { .. } => s.batch_elementwise_region += 1,
238 _ => {}
239 }
240 }
241 s
242}
243
244fn missed_entry(
245 graph: &Graph,
246 pattern: &'static str,
247 node: NodeId,
248 reason: MissReason,
249) -> MissedFusion {
250 MissedFusion {
251 pattern,
252 node,
253 context: Some(node_label(graph, node)),
254 hint: Some(fusion_hint(&reason)),
255 reason,
256 }
257}
258
259fn fusion_hint(reason: &MissReason) -> String {
260 match reason {
261 MissReason::MultiConsumer => {
262 "single-consumer chain required — clone input or use HirOp::LinearFused".into()
263 }
264 MissReason::NonAddBiasConsumer => "use linear+bias or HirModule::linear_fused".into(),
265 MissReason::BiasRankTooHigh { .. } => "bias must be rank-1".into(),
266 MissReason::UnsupportedEpilogueActivation(_) => {
267 "FuseMatMulBiasAct supports Gelu/Silu only".into()
268 }
269 MissReason::SharedMatmulCount { .. } => "use shared_linear_pair or HirOp::SwiGLU".into(),
270 MissReason::SwigluGateBeforeUp => "pass up_w before gate_w in swiglu_ffn".into(),
271 MissReason::SwigluNotSharedInput => "gate and up must share the same input".into(),
272 MissReason::NotFused => "check inspect_pipeline / RLX_FUSION_REPORT=1".into(),
273 }
274}
275
276fn scan_misses(graph: &Graph) -> Vec<MissedFusion> {
277 let mut missed = Vec::new();
278 missed.extend(scan_missed_matmul_bias_act(graph));
279 missed.extend(scan_missed_shared_matmul(graph));
280 missed.extend(scan_missed_swiglu(graph));
281 missed
282}
283
284fn scan_missed_matmul_bias_act(graph: &Graph) -> Vec<MissedFusion> {
285 let mut out = Vec::new();
286 for node in graph.nodes() {
287 if !matches!(node.op, Op::MatMul) {
288 continue;
289 }
290 let mm_id = node.id;
291 let users = graph.users(mm_id);
292 if users.len() != 1 {
293 if users.len() > 1 {
294 out.push(missed_entry(
295 graph,
296 "matmul_bias_act",
297 mm_id,
298 MissReason::MultiConsumer,
299 ));
300 }
301 continue;
302 }
303 let add_node = graph.node(users[0]);
304 let Op::Binary(BinaryOp::Add) = &add_node.op else {
305 out.push(missed_entry(
306 graph,
307 "matmul_bias_act",
308 mm_id,
309 MissReason::NonAddBiasConsumer,
310 ));
311 continue;
312 };
313 let bias_id = if add_node.inputs[0] == mm_id {
314 add_node.inputs[1]
315 } else {
316 add_node.inputs[0]
317 };
318 let bias_rank = graph.shape(bias_id).rank();
319 if bias_rank > 1 {
320 out.push(missed_entry(
321 graph,
322 "matmul_bias_act",
323 mm_id,
324 MissReason::BiasRankTooHigh { rank: bias_rank },
325 ));
326 continue;
327 }
328 let add_users = graph.users(add_node.id);
329 if add_users.len() == 1 {
330 if let Op::Activation(act) = &graph.node(add_users[0]).op
331 && !fusible_mm_bias_epilogue(*act)
332 {
333 out.push(missed_entry(
334 graph,
335 "matmul_bias_act",
336 mm_id,
337 MissReason::UnsupportedEpilogueActivation(*act),
338 ));
339 }
340 }
341 }
342 out
343}
344
345fn fusible_mm_bias_epilogue(act: Activation) -> bool {
346 matches!(act, Activation::Gelu | Activation::Silu)
347}
348
349fn scan_missed_shared_matmul(graph: &Graph) -> Vec<MissedFusion> {
350 let mut input_to_matmuls: std::collections::HashMap<NodeId, Vec<NodeId>> =
351 std::collections::HashMap::new();
352 for node in graph.nodes() {
353 if matches!(node.op, Op::MatMul) {
354 input_to_matmuls
355 .entry(node.inputs[0])
356 .or_default()
357 .push(node.id);
358 }
359 }
360 let mut out = Vec::new();
361 for matmuls in input_to_matmuls.values() {
362 if matmuls.len() == 2 {
363 let a = graph.node(matmuls[0]);
364 let b = graph.node(matmuls[1]);
365 let w1 = graph.shape(a.inputs[1]);
366 let w2 = graph.shape(b.inputs[1]);
367 if w1.rank() == 2 && w2.rank() == 2 && w1.dim(0) == w2.dim(0) {
368 out.push(missed_entry(
369 graph,
370 "shared_input_matmul",
371 matmuls[0],
372 MissReason::NotFused,
373 ));
374 }
375 } else if matmuls.len() > 2 {
376 out.push(missed_entry(
377 graph,
378 "shared_input_matmul",
379 matmuls[0],
380 MissReason::SharedMatmulCount {
381 count: matmuls.len(),
382 },
383 ));
384 }
385 }
386 out
387}
388
389fn scan_missed_swiglu(graph: &Graph) -> Vec<MissedFusion> {
390 let mut out = Vec::new();
391 for node in graph.nodes() {
392 if !matches!(node.op, Op::Binary(BinaryOp::Mul)) {
393 continue;
394 }
395 let lhs = graph.node(node.inputs[0]);
396 let rhs = graph.node(node.inputs[1]);
397 let (up_side, silu_side) = if matches!(rhs.op, Op::Activation(Activation::Silu)) {
398 (lhs, rhs)
399 } else if matches!(lhs.op, Op::Activation(Activation::Silu)) {
400 (rhs, lhs)
401 } else {
402 continue;
403 };
404 if !matches!(up_side.op, Op::MatMul) {
405 continue;
406 }
407 let gate_mm = graph.node(silu_side.inputs[0]);
408 if !matches!(gate_mm.op, Op::MatMul) {
409 continue;
410 }
411 if up_side.inputs[0] != gate_mm.inputs[0] {
412 out.push(missed_entry(
413 graph,
414 "swiglu",
415 node.id,
416 MissReason::SwigluNotSharedInput,
417 ));
418 continue;
419 }
420 if graph
422 .nodes()
423 .iter()
424 .position(|n| n.id == up_side.id)
425 .zip(graph.nodes().iter().position(|n| n.id == gate_mm.id))
426 .is_some_and(|(up_idx, gate_idx)| gate_idx < up_idx)
427 {
428 out.push(missed_entry(
429 graph,
430 "swiglu",
431 node.id,
432 MissReason::SwigluGateBeforeUp,
433 ));
434 }
435 }
436 out
437}
438
439#[cfg(test)]
440mod tests {
441 use super::*;
442 use rlx_ir::DType;
443 use rlx_ir::Shape;
444 use rlx_ir::infer::GraphExt;
445
446 fn f32_shape(dims: &[usize]) -> Shape {
447 Shape::new(dims, DType::F32)
448 }
449
450 #[test]
451 fn report_counts_fused_ops() {
452 use crate::fusion::{FuseSharedInputMatMul, FuseSwiGLU};
453 use crate::pass::Pass;
454
455 let mut g = Graph::new("report");
456 let x = g.input("x", f32_shape(&[4, 768]));
457 let up_w = g.param("up", f32_shape(&[768, 128]));
458 let gate_w = g.param("gate", f32_shape(&[768, 128]));
459 let down_w = g.param("down", f32_shape(&[128, 768]));
460 let out = g.swiglu_ffn(x, up_w, gate_w, down_w);
461 g.set_outputs(vec![out]);
462 let before = g.clone();
463
464 g = FuseSharedInputMatMul.run(g);
465 g = FuseSwiGLU.run(g);
466
467 let report = FusionReport::analyze(&before, &g);
468 assert_eq!(report.fused_swiglu, 1);
469 assert!(report.nodes_after < report.nodes_before);
470 }
471
472 #[test]
473 fn report_flags_gate_before_up() {
474 let mut g = Graph::new("gate_first");
475 let x = g.input("x", f32_shape(&[4, 8]));
476 let gate_w = g.param("gate", f32_shape(&[8, 16]));
477 let up_w = g.param("up", f32_shape(&[8, 16]));
478 let gate = g.mm(x, gate_w);
479 let up = g.mm(x, up_w);
480 let gate_silu = g.silu(gate);
481 let out = g.mul(gate_silu, up);
482 g.set_outputs(vec![out]);
483
484 let report = FusionReport::scan(&g);
485 assert!(report.missed_swiglu() >= 1);
486 assert!(
487 report
488 .missed
489 .iter()
490 .any(|m| m.reason == MissReason::SwigluGateBeforeUp)
491 );
492 }
493}