1use crate::pass::Pass;
23use rlx_ir::op::*;
24use rlx_ir::*;
25use std::collections::HashMap;
26
27use crate::graph_rewrite::Rewriter;
30
31pub struct FuseMatMulBiasAct;
46
47fn fusible_mm_bias_epilogue_activation(act: Activation) -> bool {
49 matches!(act, Activation::Gelu | Activation::Silu)
50}
51
52impl Pass for FuseMatMulBiasAct {
53 fn name(&self) -> &str {
54 "fuse_matmul_bias_act"
55 }
56
57 fn run(&self, graph: Graph) -> Graph {
58 let mut rw = Rewriter::new(&graph.name);
59 let mut fused_away: HashMap<NodeId, ()> = HashMap::new();
61
62 for node in graph.nodes() {
64 if fused_away.contains_key(&node.id) {
65 continue;
66 }
67
68 if matches!(node.op, Op::MatMul) {
71 let mm_id = node.id;
72 let mm_users: Vec<_> = graph.users(mm_id);
73
74 if mm_users.len() == 1 {
76 let add_node = graph.node(mm_users[0]);
77 if let Op::Binary(BinaryOp::Add) = &add_node.op {
78 let (bias_id, _mm_input) = if add_node.inputs[0] == mm_id {
80 (add_node.inputs[1], add_node.inputs[0])
81 } else {
82 (add_node.inputs[0], add_node.inputs[1])
83 };
84
85 let bias_shape = graph.shape(bias_id);
87 if bias_shape.rank() <= 1 {
88 let add_id = add_node.id;
89 let add_users = graph.users(add_id);
90
91 let mut activation = None;
93 let mut act_id = None;
94 if add_users.len() == 1 {
95 let act_node = graph.node(add_users[0]);
96 if let Op::Activation(a) = &act_node.op
97 && fusible_mm_bias_epilogue_activation(*a)
98 {
99 activation = Some(*a);
100 act_id = Some(act_node.id);
101 }
102 }
103
104 let out_shape = if let Some(aid) = act_id {
108 graph.shape(aid).clone()
109 } else {
110 add_node.shape.clone()
111 };
112
113 rw.ensure_mapped(&graph, &[node.inputs[0], node.inputs[1], bias_id]);
114 let fused_id = rw.add_fused(
115 Op::FusedMatMulBiasAct { activation },
116 &[node.inputs[0], node.inputs[1], bias_id],
117 out_shape,
118 );
119
120 rw.replace(mm_id, fused_id);
122 rw.replace(add_id, fused_id);
123 fused_away.insert(add_id, ());
124 if let Some(aid) = act_id {
125 rw.replace(aid, fused_id);
126 fused_away.insert(aid, ());
127 }
128 continue;
129 }
130 }
131 }
132 }
133
134 rw.copy_node(node);
136 }
137
138 rw.finish(&graph.outputs)
139 }
140}
141
142pub struct FuseResidualLN;
149
150impl Pass for FuseResidualLN {
151 fn name(&self) -> &str {
152 "fuse_residual_ln"
153 }
154
155 fn run(&self, graph: Graph) -> Graph {
156 let mut is_output: HashMap<NodeId, ()> = HashMap::new();
164 for &oid in &graph.outputs {
165 is_output.insert(oid, ());
166 }
167 let mut fused_away: HashMap<NodeId, ()> = HashMap::new();
169 for node in graph.nodes() {
170 if let Op::LayerNorm { .. } = &node.op {
171 let ln_input_id = node.inputs[0];
172 let ln_input = graph.node(ln_input_id);
173 if matches!(ln_input.op, Op::Binary(BinaryOp::Add))
174 && graph.use_count(ln_input_id) == 1
175 && !is_output.contains_key(&ln_input_id)
176 {
177 fused_away.insert(ln_input_id, ());
178 }
179 }
180 }
181
182 let mut rw = Rewriter::new(&graph.name);
183
184 for node in graph.nodes() {
185 if fused_away.contains_key(&node.id) {
186 continue;
187 }
188
189 if let Op::LayerNorm { eps, .. } = &node.op {
190 let ln_input_id = node.inputs[0];
191 let ln_input = graph.node(ln_input_id);
192
193 if matches!(ln_input.op, Op::Binary(BinaryOp::Add))
194 && fused_away.contains_key(&ln_input_id)
195 {
196 let (x_id, residual_id) = (ln_input.inputs[0], ln_input.inputs[1]);
197 let gamma_id = node.inputs[1];
198 let beta_id = node.inputs[2];
199
200 let fused_id = rw.add_fused(
201 Op::FusedResidualLN {
202 has_bias: false,
203 eps: *eps,
204 },
205 &[x_id, residual_id, gamma_id, beta_id],
206 node.shape.clone(),
207 );
208
209 rw.replace(ln_input_id, fused_id);
210 rw.replace(node.id, fused_id);
211 continue;
212 }
213 }
214
215 rw.copy_node(node);
216 }
217
218 rw.finish(&graph.outputs)
219 }
220}
221
222pub struct FuseResidualRmsNorm;
226
227impl Pass for FuseResidualRmsNorm {
228 fn name(&self) -> &str {
229 "fuse_residual_rms_norm"
230 }
231
232 fn run(&self, graph: Graph) -> Graph {
233 let mut is_output: HashMap<NodeId, ()> = HashMap::new();
234 for &oid in &graph.outputs {
235 is_output.insert(oid, ());
236 }
237 let mut fused_away: HashMap<NodeId, ()> = HashMap::new();
238 for node in graph.nodes() {
239 if let Op::RmsNorm { .. } = &node.op {
240 let rn_input_id = node.inputs[0];
241 let rn_input = graph.node(rn_input_id);
242 if matches!(rn_input.op, Op::Binary(BinaryOp::Add))
243 && graph.use_count(rn_input_id) == 1
244 && !is_output.contains_key(&rn_input_id)
245 {
246 fused_away.insert(rn_input_id, ());
247 }
248 }
249 }
250
251 let mut rw = Rewriter::new(&graph.name);
252
253 for node in graph.nodes() {
254 if fused_away.contains_key(&node.id) {
255 continue;
256 }
257
258 if let Op::RmsNorm { eps, .. } = &node.op {
259 let rn_input_id = node.inputs[0];
260 let rn_input = graph.node(rn_input_id);
261
262 if matches!(rn_input.op, Op::Binary(BinaryOp::Add))
263 && fused_away.contains_key(&rn_input_id)
264 {
265 let (x_id, residual_id) = (rn_input.inputs[0], rn_input.inputs[1]);
266 let gamma_id = node.inputs[1];
267 let beta_id = node.inputs[2];
268
269 let fused_id = rw.add_fused(
270 Op::FusedResidualRmsNorm {
271 has_bias: false,
272 eps: *eps,
273 },
274 &[x_id, residual_id, gamma_id, beta_id],
275 node.shape.clone(),
276 );
277
278 rw.replace(rn_input_id, fused_id);
279 rw.replace(node.id, fused_id);
280 continue;
281 }
282 }
283
284 rw.copy_node(node);
285 }
286
287 rw.finish(&graph.outputs)
288 }
289}
290
291pub struct FuseRmsNormReshape;
299
300fn leading_flatten_shape(in_shape: &Shape, new_shape: &[i64]) -> Option<Shape> {
301 rlx_ir::shape::leading_flatten_shape(in_shape, new_shape)
302}
303
304fn sole_consumer(graph: &Graph, id: NodeId) -> Option<NodeId> {
305 graph
306 .nodes()
307 .iter()
308 .find(|n| n.inputs.contains(&id))
309 .map(|n| n.id)
310}
311
312impl Pass for FuseRmsNormReshape {
313 fn name(&self) -> &str {
314 "fuse_rms_norm_reshape"
315 }
316
317 fn run(&self, graph: Graph) -> Graph {
318 let mut is_output: HashMap<NodeId, ()> = HashMap::new();
319 for &oid in &graph.outputs {
320 is_output.insert(oid, ());
321 }
322
323 let mut flat_shape: HashMap<NodeId, Shape> = HashMap::new();
324 let mut fused_away: HashMap<NodeId, ()> = HashMap::new();
325 for node in graph.nodes() {
326 if let Op::RmsNorm { .. } = &node.op {
327 if graph.use_count(node.id) != 1 || is_output.contains_key(&node.id) {
328 continue;
329 }
330 let Some(reshape_id) = sole_consumer(&graph, node.id) else {
331 continue;
332 };
333 if is_output.contains_key(&reshape_id) {
334 continue;
335 }
336 let reshape = graph.node(reshape_id);
337 if let Op::Reshape { new_shape } = &reshape.op {
338 if let Some(flat) = leading_flatten_shape(&node.shape, new_shape) {
339 flat_shape.insert(node.id, flat);
340 fused_away.insert(reshape_id, ());
341 }
342 }
343 }
344 }
345
346 let mut rw = Rewriter::new(&graph.name);
347
348 for node in graph.nodes() {
349 if fused_away.contains_key(&node.id) {
350 continue;
351 }
352
353 if let Op::RmsNorm { axis, eps, .. } = &node.op {
354 if let Some(flat) = flat_shape.get(&node.id) {
355 let Some(reshape_id) = sole_consumer(&graph, node.id) else {
356 rw.copy_node(node);
357 continue;
358 };
359 let fused_id = rw.add_fused(
360 Op::RmsNorm {
361 axis: *axis,
362 eps: *eps,
363 },
364 &node.inputs,
365 flat.clone(),
366 );
367 rw.replace(node.id, fused_id);
368 rw.replace(reshape_id, fused_id);
369 continue;
370 }
371 }
372
373 rw.copy_node(node);
374 }
375
376 rw.finish(&graph.outputs)
377 }
378}
379
380pub struct FuseSwiGLUDualMatmul;
392
393impl FuseSwiGLUDualMatmul {
394 fn match_dual_swiglu(
395 graph: &Graph,
396 mul_node: &Node,
397 ) -> Option<(NodeId, NodeId, NodeId, NodeId, NodeId)> {
398 if !matches!(mul_node.op, Op::Binary(BinaryOp::Mul)) {
399 return None;
400 }
401 let lhs = graph.node(mul_node.inputs[0]);
402 let rhs = graph.node(mul_node.inputs[1]);
403 let (up_mm, silu_id, silu_node) = if matches!(rhs.op, Op::Activation(Activation::Silu)) {
404 (lhs, mul_node.inputs[1], rhs)
405 } else if matches!(lhs.op, Op::Activation(Activation::Silu)) {
406 (rhs, mul_node.inputs[0], lhs)
407 } else {
408 return None;
409 };
410 if !matches!(up_mm.op, Op::MatMul) {
411 return None;
412 }
413 let gate_mm = graph.node(silu_node.inputs[0]);
414 if !matches!(gate_mm.op, Op::MatMul) {
415 return None;
416 }
417 if up_mm.inputs[0] != gate_mm.inputs[0] {
418 return None;
419 }
420 if graph.use_count(silu_id) != 1 {
421 return None;
422 }
423 Some((mul_node.id, gate_mm.id, up_mm.id, up_mm.inputs[0], silu_id))
424 }
425}
426
427impl Pass for FuseSwiGLUDualMatmul {
428 fn name(&self) -> &str {
429 "fuse_swiglu_dual_matmul"
430 }
431
432 fn run(&self, graph: Graph) -> Graph {
433 let mut matches: Vec<(NodeId, NodeId, NodeId, NodeId, NodeId)> = Vec::new();
434 let mut consumed: HashMap<NodeId, ()> = HashMap::new();
435
436 for node in graph.nodes() {
437 if let Some((mul_id, gate_mm, up_mm, _, silu_id)) =
438 Self::match_dual_swiglu(&graph, node)
439 {
440 matches.push((mul_id, gate_mm, up_mm, graph.node(up_mm).inputs[0], silu_id));
441 consumed.insert(gate_mm, ());
442 consumed.insert(up_mm, ());
443 consumed.insert(silu_id, ());
444 }
445 }
446
447 if matches.is_empty() {
448 return graph;
449 }
450
451 let match_by_mul: HashMap<NodeId, (NodeId, NodeId, NodeId)> = matches
452 .into_iter()
453 .map(|(mul, gate, up, input, _silu)| (mul, (gate, up, input)))
454 .collect();
455
456 let mut rw = Rewriter::new(&graph.name);
457 for node in graph.nodes() {
458 if consumed.contains_key(&node.id) {
459 continue;
460 }
461 if let Some(&(gate_mm, up_mm, input_id)) = match_by_mul.get(&node.id) {
462 let gate = graph.node(gate_mm);
463 let up = graph.node(up_mm);
464 let wg = gate.inputs[1];
465 let wu = up.inputs[1];
466 rw.ensure_mapped(&graph, &[input_id, wg, wu]);
467
468 let wu_shape = graph.shape(wu);
469 let wg_shape = graph.shape(wg);
470 let k = wu_shape.dim(0).unwrap_static();
471 let n_up = wu_shape.dim(1).unwrap_static();
472 let n_gate = wg_shape.dim(1).unwrap_static();
473 debug_assert_eq!(wu_shape.dim(0), wg_shape.dim(0));
474
475 let concat_shape = Shape::new(&[k, n_up + n_gate], wu_shape.dtype());
477 let concat_w = rw.add_fused(Op::Concat { axis: 1 }, &[wu, wg], concat_shape);
478
479 let out_rank = up.shape.rank();
480 let mut mm_dims: Vec<Dim> = (0..out_rank).map(|i| up.shape.dim(i)).collect();
481 mm_dims[out_rank - 1] = Dim::Static(n_up + n_gate);
482 let cat_shape = Shape::from_dims(&mm_dims, up.shape.dtype());
483 let cat_id =
484 rw.new_graph
485 .add_node(Op::MatMul, vec![rw.map(input_id), concat_w], cat_shape);
486
487 let fused_id = rw.new_graph.add_node(
488 Op::FusedSwiGLU {
489 cast_to: None,
490 gate_first: false,
491 },
492 vec![cat_id],
493 node.shape.clone(),
494 );
495 rw.replace(node.id, fused_id);
496 continue;
497 }
498 rw.copy_node(node);
499 }
500 rw.finish(&graph.outputs)
501 }
502}
503
504pub struct FuseSharedInputMatMul;
520
521impl Pass for FuseSharedInputMatMul {
522 fn name(&self) -> &str {
523 "fuse_shared_input_matmul"
524 }
525
526 fn run(&self, graph: Graph) -> Graph {
527 struct FuseGroup {
528 input_id: NodeId,
529 matmul_ids: Vec<NodeId>,
530 }
531
532 let mut input_to_matmuls: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
533 for node in graph.nodes() {
534 if matches!(node.op, Op::MatMul) {
535 input_to_matmuls
536 .entry(node.inputs[0])
537 .or_default()
538 .push(node.id);
539 }
540 }
541
542 let mut groups: Vec<FuseGroup> = Vec::new();
543 for (input_id, matmul_ids) in input_to_matmuls {
544 if matmul_ids.len() < 2 {
545 continue;
546 }
547 let first = graph.node(matmul_ids[0]);
548 let w0 = graph.shape(first.inputs[1]);
549 if w0.rank() != 2 {
550 continue;
551 }
552 let compatible = matmul_ids.iter().all(|&id| {
553 let m = graph.node(id);
554 matches!(m.op, Op::MatMul)
555 && graph.shape(m.inputs[1]).rank() == 2
556 && graph.shape(m.inputs[1]).dim(0) == w0.dim(0)
557 });
558 if compatible {
559 groups.push(FuseGroup {
560 input_id,
561 matmul_ids,
562 });
563 }
564 }
565
566 if groups.is_empty() {
567 return graph;
568 }
569
570 let group_by_first: HashMap<NodeId, &FuseGroup> =
571 groups.iter().map(|g| (g.matmul_ids[0], g)).collect();
572
573 let mut fused_away: HashMap<NodeId, ()> = HashMap::new();
574 for g in &groups {
575 for &id in &g.matmul_ids[1..] {
576 fused_away.insert(id, ());
577 }
578 }
579
580 let mut rw = Rewriter::new(&graph.name);
581 for node in graph.nodes() {
582 if fused_away.contains_key(&node.id) {
583 continue;
584 }
585
586 if let Some(group) = group_by_first.get(&node.id) {
587 let matmuls: Vec<_> = group.matmul_ids.iter().map(|&id| graph.node(id)).collect();
588 let weight_ids: Vec<NodeId> = matmuls.iter().map(|m| m.inputs[1]).collect();
589 rw.ensure_mapped(&graph, std::slice::from_ref(&group.input_id));
590 rw.ensure_mapped(&graph, &weight_ids);
591
592 let w0_shape = graph.shape(weight_ids[0]);
593 let k = w0_shape.dim(0).unwrap_static();
594 let ns: Vec<usize> = weight_ids
595 .iter()
596 .map(|&w| graph.shape(w).dim(1).unwrap_static())
597 .collect();
598 let combined_n: usize = ns.iter().sum();
599
600 let concat_shape = Shape::new(&[k, combined_n], w0_shape.dtype());
601 let concat_id = rw.add_fused(Op::Concat { axis: 1 }, &weight_ids, concat_shape);
602
603 let out_rank = matmuls[0].shape.rank();
604 let mut mm_dims: Vec<Dim> =
605 (0..out_rank).map(|i| matmuls[0].shape.dim(i)).collect();
606 mm_dims[out_rank - 1] = Dim::Static(combined_n);
607 let mm_shape = Shape::from_dims(&mm_dims, matmuls[0].shape.dtype());
608 let mm_id = rw.new_graph.add_node(
609 Op::MatMul,
610 vec![rw.map(group.input_id), concat_id],
611 mm_shape,
612 );
613
614 let mut start = 0usize;
615 for (mm, &n) in matmuls.iter().zip(&ns) {
616 let narrow = rw.new_graph.add_node(
617 Op::Narrow {
618 axis: out_rank - 1,
619 start,
620 len: n,
621 },
622 vec![mm_id],
623 mm.shape.clone(),
624 );
625 rw.replace(mm.id, narrow);
626 start += n;
627 }
628 continue;
629 }
630
631 rw.copy_node(node);
632 }
633
634 rw.finish(&graph.outputs)
635 }
636}
637
638pub struct FuseSwiGLU;
659
660impl Pass for FuseSwiGLU {
661 fn name(&self) -> &str {
662 "fuse_swiglu"
663 }
664
665 fn run(&self, graph: Graph) -> Graph {
666 #[allow(dead_code)]
672 struct Match {
673 mul_id: NodeId,
674 up_narrow_id: NodeId,
675 silu_id: NodeId,
676 gate_narrow_id: NodeId,
677 cat_id: NodeId,
678 out_n: usize,
679 gate_first: bool,
680 }
681
682 let mut matches: Vec<Match> = Vec::new();
683 let mut consumed: HashMap<NodeId, ()> = HashMap::new();
684
685 for node in graph.nodes() {
686 if !matches!(node.op, Op::Binary(BinaryOp::Mul)) {
689 continue;
690 }
691 let lhs_id = node.inputs[0];
692 let rhs_id = node.inputs[1];
693 let lhs = graph.node(lhs_id);
694 let rhs = graph.node(rhs_id);
695
696 let (up_narrow, silu_id, silu_node) =
698 if matches!(rhs.op, Op::Activation(Activation::Silu)) {
699 (lhs, rhs_id, rhs)
700 } else if matches!(lhs.op, Op::Activation(Activation::Silu)) {
701 (rhs, lhs_id, lhs)
702 } else {
703 continue;
704 };
705
706 let (up_axis, up_start, up_len) = match &up_narrow.op {
708 Op::Narrow { axis, start, len } => (*axis, *start, *len),
709 _ => continue,
710 };
711 let gate_narrow_id = silu_node.inputs[0];
713 let gate_narrow = graph.node(gate_narrow_id);
714 let (g_axis, g_start, g_len) = match &gate_narrow.op {
715 Op::Narrow { axis, start, len } => (*axis, *start, *len),
716 _ => continue,
717 };
718
719 if up_narrow.inputs[0] != gate_narrow.inputs[0] {
722 continue;
723 }
724 if up_axis != g_axis {
725 continue;
726 }
727 if up_len != g_len {
728 continue;
729 }
730 let n = up_len;
731 let gate_first = up_start == n && g_start == 0;
733 if !(gate_first || (up_start == 0 && g_start == n)) {
734 continue;
735 }
736
737 if graph.use_count(up_narrow.id) != 1 {
740 continue;
741 }
742 if graph.use_count(gate_narrow_id) != 1 {
743 continue;
744 }
745 if graph.use_count(silu_id) != 1 {
746 continue;
747 }
748
749 matches.push(Match {
750 mul_id: node.id,
751 up_narrow_id: up_narrow.id,
752 silu_id,
753 gate_narrow_id,
754 cat_id: up_narrow.inputs[0],
755 out_n: n,
756 gate_first,
757 });
758 consumed.insert(up_narrow.id, ());
759 consumed.insert(gate_narrow_id, ());
760 consumed.insert(silu_id, ());
761 }
762
763 if matches.is_empty() {
764 return graph;
765 }
766
767 let mut rw = Rewriter::new(&graph.name);
769 let match_by_mul: HashMap<NodeId, &Match> = matches.iter().map(|m| (m.mul_id, m)).collect();
770
771 for node in graph.nodes() {
772 if consumed.contains_key(&node.id) {
773 continue;
774 }
775
776 if let Some(m) = match_by_mul.get(&node.id) {
777 let out_shape = node.shape.clone();
779 debug_assert_eq!(
780 out_shape.dim(out_shape.rank() - 1).unwrap_static(),
781 m.out_n,
782 "FuseSwiGLU: output last dim should be N"
783 );
784 let fused_id = rw.add_fused(
785 Op::FusedSwiGLU {
786 cast_to: None,
787 gate_first: m.gate_first,
788 },
789 &[m.cat_id],
790 out_shape,
791 );
792 rw.replace(node.id, fused_id);
793 continue;
794 }
795
796 rw.copy_node(node);
797 }
798
799 rw.finish(&graph.outputs)
800 }
801}
802
803pub struct FuseAttentionBlock;
814
815impl FuseAttentionBlock {
816 fn should_fuse(graph: &Graph) -> bool {
824 let threshold: usize = rlx_ir::env::var("RLX_FUSE_ATTN_THRESHOLD")
825 .and_then(|v| v.parse().ok())
826 .unwrap_or(64);
827 for node in graph.nodes() {
828 if let Op::Input { .. } = &node.op
829 && node.shape.rank() >= 2
830 {
831 let d0 = node.shape.dim(0);
832 let d1 = node.shape.dim(1);
833 if d0.is_static() && d1.is_static() {
834 let b = d0.unwrap_static();
835 let s = d1.unwrap_static();
836 if b * s <= threshold {
837 return true;
838 }
839 }
840 }
841 }
842 false
843 }
844}
845
846fn narrow_parent(node: &Node) -> Option<(NodeId, usize, usize, usize)> {
848 match &node.op {
849 Op::Narrow { axis, start, len } => Some((node.inputs[0], *axis, *start, *len)),
850 _ => None,
851 }
852}
853
854fn fused_mm_bias_none(node: &Node) -> Option<(NodeId, NodeId, NodeId)> {
856 if let Op::FusedMatMulBiasAct { activation: None } = &node.op
857 && node.inputs.len() == 3
858 {
859 return Some((node.inputs[0], node.inputs[1], node.inputs[2]));
860 }
861 None
862}
863
864impl Pass for FuseAttentionBlock {
865 fn name(&self) -> &str {
866 "fuse_attention_block"
867 }
868
869 fn run(&self, graph: Graph) -> Graph {
870 if !Self::should_fuse(&graph) {
873 return graph;
874 }
875
876 let mut is_output: HashMap<NodeId, ()> = HashMap::new();
892 for &oid in &graph.outputs {
893 is_output.insert(oid, ());
894 }
895
896 struct Match {
899 attn_id: NodeId,
900 qkv_mm_id: NodeId,
901 out_mm_id: NodeId,
902 narrows: [NodeId; 3],
903 hidden_id: NodeId,
904 qkv_w: NodeId,
905 qkv_b: NodeId,
906 out_w: NodeId,
907 out_b: NodeId,
908 mask: NodeId,
909 num_heads: usize,
910 head_dim: usize,
911 out_shape: Shape,
912 }
913 let mut matches: Vec<Match> = Vec::new();
914 let mut fused_away: HashMap<NodeId, ()> = HashMap::new();
915
916 for node in graph.nodes() {
917 let Op::Attention {
918 num_heads,
919 head_dim,
920 mask_kind,
921 score_scale,
922 attn_logit_softcap,
923 } = &node.op
924 else {
925 continue;
926 };
927 if !matches!(mask_kind, MaskKind::Custom)
930 || score_scale.is_some()
931 || attn_logit_softcap.is_some()
932 || node.inputs.len() != 4
933 {
934 continue;
935 }
936 let (q, k, v, mask) = (
937 node.inputs[0],
938 node.inputs[1],
939 node.inputs[2],
940 node.inputs[3],
941 );
942
943 let qn = graph.node(q);
946 let kn = graph.node(k);
947 let vn = graph.node(v);
948 let (qp, q_axis, q_start, q_len) = match narrow_parent(qn) {
949 Some(p) => p,
950 None => continue,
951 };
952 let (kp, k_axis, k_start, k_len) = match narrow_parent(kn) {
953 Some(p) => p,
954 None => continue,
955 };
956 let (vp, v_axis, v_start, v_len) = match narrow_parent(vn) {
957 Some(p) => p,
958 None => continue,
959 };
960 if qp != kp || kp != vp {
961 continue;
962 }
963 let h = num_heads * head_dim;
964 let parent_rank = graph.node(qp).shape.rank();
965 let last_ax = parent_rank.saturating_sub(1);
966 if q_axis != last_ax || k_axis != last_ax || v_axis != last_ax {
967 continue;
968 }
969 if q_len != h || k_len != h || v_len != h {
970 continue;
971 }
972 if q_start != 0 || k_start != h || v_start != 2 * h {
973 continue;
974 }
975 if graph.use_count(q) != 1
977 || graph.use_count(k) != 1
978 || graph.use_count(v) != 1
979 || is_output.contains_key(&q)
980 || is_output.contains_key(&k)
981 || is_output.contains_key(&v)
982 {
983 continue;
984 }
985
986 let qkv_mm_node = graph.node(qp);
988 let (hidden_id, qkv_w, qkv_b) = match fused_mm_bias_none(qkv_mm_node) {
989 Some(t) => t,
990 None => continue,
991 };
992 if graph.use_count(qp) != 3 || is_output.contains_key(&qp) {
995 continue;
996 }
997
998 if graph.use_count(node.id) != 1 || is_output.contains_key(&node.id) {
1000 continue;
1001 }
1002 let out_consumer_id = match graph
1003 .nodes()
1004 .iter()
1005 .find(|n| n.inputs.contains(&node.id))
1006 .map(|n| n.id)
1007 {
1008 Some(id) => id,
1009 None => continue,
1010 };
1011 let out_mm_node = graph.node(out_consumer_id);
1012 let (out_in, out_w, out_b) = match fused_mm_bias_none(out_mm_node) {
1013 Some(t) if t.0 == node.id => t,
1014 _ => continue,
1015 };
1016 let _ = out_in;
1017
1018 matches.push(Match {
1020 attn_id: node.id,
1021 qkv_mm_id: qp,
1022 out_mm_id: out_consumer_id,
1023 narrows: [q, k, v],
1024 hidden_id,
1025 qkv_w,
1026 qkv_b,
1027 out_w,
1028 out_b,
1029 mask,
1030 num_heads: *num_heads,
1031 head_dim: *head_dim,
1032 out_shape: out_mm_node.shape.clone(),
1033 });
1034 fused_away.insert(qp, ());
1035 fused_away.insert(q, ());
1036 fused_away.insert(k, ());
1037 fused_away.insert(v, ());
1038 fused_away.insert(node.id, ());
1039 fused_away.insert(out_consumer_id, ());
1040 }
1041
1042 if matches.is_empty() {
1043 return graph;
1044 }
1045
1046 let mut by_out: HashMap<NodeId, &Match> = HashMap::new();
1048 for m in &matches {
1049 by_out.insert(m.out_mm_id, m);
1050 }
1051
1052 let mut rw = Rewriter::new(&graph.name);
1053 for node in graph.nodes() {
1054 if fused_away.contains_key(&node.id) {
1055 if let Some(m) = by_out.get(&node.id) {
1056 rw.ensure_mapped(
1058 &graph,
1059 &[m.hidden_id, m.qkv_w, m.out_w, m.mask, m.qkv_b, m.out_b],
1060 );
1061 let fused_id = rw.add_fused(
1062 Op::FusedAttentionBlock {
1063 num_heads: m.num_heads,
1064 head_dim: m.head_dim,
1065 has_bias: true,
1066 has_rope: false,
1067 },
1068 &[m.hidden_id, m.qkv_w, m.out_w, m.mask, m.qkv_b, m.out_b],
1069 m.out_shape.clone(),
1070 );
1071 rw.replace(m.qkv_mm_id, fused_id);
1074 rw.replace(m.narrows[0], fused_id);
1075 rw.replace(m.narrows[1], fused_id);
1076 rw.replace(m.narrows[2], fused_id);
1077 rw.replace(m.attn_id, fused_id);
1078 rw.replace(node.id, fused_id);
1079 }
1080 continue;
1081 }
1082 rw.copy_node(node);
1083 }
1084 rw.finish(&graph.outputs)
1085 }
1086}
1087
1088pub struct FuseTransformerLayer;
1116
1117impl FuseTransformerLayer {
1118 fn should_fuse(graph: &Graph) -> bool {
1119 FuseAttentionBlock::should_fuse(graph)
1122 }
1123}
1124
1125fn fused_residual_ln_no_bias(node: &Node) -> Option<(NodeId, NodeId, NodeId, NodeId, f32)> {
1127 if let Op::FusedResidualLN {
1128 has_bias: false,
1129 eps,
1130 } = &node.op
1131 && node.inputs.len() == 4
1132 {
1133 return Some((
1134 node.inputs[0],
1135 node.inputs[1],
1136 node.inputs[2],
1137 node.inputs[3],
1138 *eps,
1139 ));
1140 }
1141 None
1142}
1143
1144fn fused_mm_bias_act(node: &Node) -> Option<(NodeId, NodeId, NodeId, Activation)> {
1146 if let Op::FusedMatMulBiasAct {
1147 activation: Some(a),
1148 } = &node.op
1149 && node.inputs.len() == 3
1150 {
1151 return Some((node.inputs[0], node.inputs[1], node.inputs[2], *a));
1152 }
1153 None
1154}
1155
1156fn fused_attn_block_bert(
1158 node: &Node,
1159) -> Option<(usize, usize, NodeId, NodeId, NodeId, NodeId, NodeId, NodeId)> {
1160 if let Op::FusedAttentionBlock {
1161 num_heads,
1162 head_dim,
1163 has_bias: true,
1164 has_rope: false,
1165 } = &node.op
1166 && node.inputs.len() == 6
1167 {
1168 return Some((
1170 *num_heads,
1171 *head_dim,
1172 node.inputs[0],
1173 node.inputs[1],
1174 node.inputs[2],
1175 node.inputs[3],
1176 node.inputs[4],
1177 node.inputs[5],
1178 ));
1179 }
1180 None
1181}
1182
1183impl Pass for FuseTransformerLayer {
1184 fn name(&self) -> &str {
1185 "fuse_transformer_layer"
1186 }
1187
1188 fn run(&self, graph: Graph) -> Graph {
1189 if !Self::should_fuse(&graph) {
1190 return graph;
1191 }
1192
1193 let mut is_output: HashMap<NodeId, ()> = HashMap::new();
1197 for &oid in &graph.outputs {
1198 is_output.insert(oid, ());
1199 }
1200
1201 struct LayerMatch {
1202 attn_id: NodeId,
1203 ln1_id: NodeId,
1204 fc1_id: NodeId,
1205 fc2_id: NodeId,
1206 ln2_id: NodeId,
1207 inputs: [NodeId; 14],
1208 num_heads: usize,
1209 head_dim: usize,
1210 intermediate_size: usize,
1211 eps1: f32,
1212 eps2: f32,
1213 activation: Activation,
1214 out_shape: Shape,
1215 }
1216
1217 let mut matches: Vec<LayerMatch> = Vec::new();
1218 let mut fused_away: HashMap<NodeId, ()> = HashMap::new();
1219
1220 for node in graph.nodes() {
1221 let Some((num_heads, head_dim, hidden_id, qkv_w, out_w, mask, qkv_b, out_b)) =
1223 fused_attn_block_bert(node)
1224 else {
1225 continue;
1226 };
1227 let attn_id = node.id;
1228 if graph.use_count(attn_id) != 1 || is_output.contains_key(&attn_id) {
1230 continue;
1231 }
1232 let ln1_id = match graph
1233 .nodes()
1234 .iter()
1235 .find(|n| n.inputs.contains(&attn_id))
1236 .map(|n| n.id)
1237 {
1238 Some(id) => id,
1239 None => continue,
1240 };
1241 let ln1_node = graph.node(ln1_id);
1242 let Some((ln1_x, ln1_res, ln1_g, ln1_b, eps1)) = fused_residual_ln_no_bias(ln1_node)
1243 else {
1244 continue;
1245 };
1246 if ln1_x != attn_id || ln1_res != hidden_id {
1248 continue;
1249 }
1250 if graph.use_count(ln1_id) != 2 || is_output.contains_key(&ln1_id) {
1252 continue;
1253 }
1254
1255 let mut fc1_candidate: Option<NodeId> = None;
1257 let mut ln2_candidate: Option<NodeId> = None;
1258 for cn in graph.nodes() {
1259 if !cn.inputs.contains(&ln1_id) {
1260 continue;
1261 }
1262 if fused_mm_bias_act(cn).is_some() && cn.inputs[0] == ln1_id {
1263 fc1_candidate = Some(cn.id);
1264 } else if fused_residual_ln_no_bias(cn).is_some() && cn.inputs[1] == ln1_id {
1265 ln2_candidate = Some(cn.id);
1266 }
1267 }
1268 let (Some(fc1_id), Some(ln2_id)) = (fc1_candidate, ln2_candidate) else {
1269 continue;
1270 };
1271 let fc1_node = graph.node(fc1_id);
1272 let Some((_, fc1_w, fc1_b, activation)) = fused_mm_bias_act(fc1_node) else {
1273 continue;
1274 };
1275 if graph.use_count(fc1_id) != 1 || is_output.contains_key(&fc1_id) {
1277 continue;
1278 }
1279 let fc2_id = match graph
1280 .nodes()
1281 .iter()
1282 .find(|n| n.inputs.contains(&fc1_id))
1283 .map(|n| n.id)
1284 {
1285 Some(id) => id,
1286 None => continue,
1287 };
1288 let fc2_node = graph.node(fc2_id);
1289 let Some((fc2_in, fc2_w, fc2_b)) = fused_mm_bias_none(fc2_node) else {
1291 continue;
1292 };
1293 if fc2_in != fc1_id {
1294 continue;
1295 }
1296 if graph.use_count(fc2_id) != 1 || is_output.contains_key(&fc2_id) {
1297 continue;
1298 }
1299 let ln2_node = graph.node(ln2_id);
1301 let Some((ln2_x, ln2_res, ln2_g, ln2_b, eps2)) = fused_residual_ln_no_bias(ln2_node)
1302 else {
1303 continue;
1304 };
1305 if ln2_x != fc2_id || ln2_res != ln1_id {
1306 continue;
1307 }
1308 let intermediate_size = {
1310 let s = &graph.node(fc1_w).shape;
1311 if s.rank() != 2 {
1312 continue;
1313 }
1314 let d = s.dim(s.rank() - 1);
1315 if !d.is_static() {
1316 continue;
1317 }
1318 d.unwrap_static()
1319 };
1320
1321 matches.push(LayerMatch {
1322 attn_id,
1323 ln1_id,
1324 fc1_id,
1325 fc2_id,
1326 ln2_id,
1327 inputs: [
1328 hidden_id, qkv_w, qkv_b, out_w, out_b, ln1_g, ln1_b, fc1_w, fc1_b, fc2_w,
1329 fc2_b, ln2_g, ln2_b, mask,
1330 ],
1331 num_heads,
1332 head_dim,
1333 intermediate_size,
1334 eps1,
1335 eps2,
1336 activation,
1337 out_shape: ln2_node.shape.clone(),
1338 });
1339 fused_away.insert(attn_id, ());
1340 fused_away.insert(ln1_id, ());
1341 fused_away.insert(fc1_id, ());
1342 fused_away.insert(fc2_id, ());
1343 fused_away.insert(ln2_id, ());
1344 }
1345
1346 if matches.is_empty() {
1347 return graph;
1348 }
1349
1350 let mut by_terminal: HashMap<NodeId, &LayerMatch> = HashMap::new();
1352 for m in &matches {
1353 by_terminal.insert(m.ln2_id, m);
1354 }
1355
1356 let mut rw = Rewriter::new(&graph.name);
1357 for node in graph.nodes() {
1358 if fused_away.contains_key(&node.id) {
1359 if let Some(m) = by_terminal.get(&node.id) {
1360 rw.ensure_mapped(&graph, &m.inputs);
1361 let fused_id = rw.add_fused(
1362 Op::FusedTransformerLayer {
1363 num_heads: m.num_heads,
1364 head_dim: m.head_dim,
1365 intermediate_size: m.intermediate_size,
1366 eps1: m.eps1,
1367 eps2: m.eps2,
1368 activation: m.activation,
1369 has_bias: true,
1370 },
1371 &m.inputs,
1372 m.out_shape.clone(),
1373 );
1374 rw.replace(m.attn_id, fused_id);
1375 rw.replace(m.ln1_id, fused_id);
1376 rw.replace(m.fc1_id, fused_id);
1377 rw.replace(m.fc2_id, fused_id);
1378 rw.replace(node.id, fused_id);
1379 }
1380 continue;
1381 }
1382 rw.copy_node(node);
1383 }
1384 rw.finish(&graph.outputs)
1385 }
1386}
1387
1388pub struct MarkElementwiseRegions;
1411
1412impl Pass for MarkElementwiseRegions {
1413 fn name(&self) -> &str {
1414 "mark_elementwise_regions"
1415 }
1416
1417 fn run(&self, graph: Graph) -> Graph {
1418 let mut consumers: HashMap<NodeId, usize> = HashMap::new();
1420 for node in graph.nodes() {
1421 for &input in &node.inputs {
1422 *consumers.entry(input).or_insert(0) += 1;
1423 }
1424 }
1425 for &out in &graph.outputs {
1426 *consumers.entry(out).or_insert(0) += 1;
1427 }
1428
1429 let chain_eligible = |op: &Op| -> bool {
1431 matches!(
1432 op,
1433 Op::Activation(_) | Op::Cast { .. } | Op::Binary(_) | Op::Compare(_) | Op::Where
1434 )
1435 };
1436
1437 let chain_step_safe = |graph: &Graph, node: &rlx_ir::Node| -> bool {
1446 match &node.op {
1447 Op::Cast { to } => {
1448 let in_dt = graph.shape(node.inputs[0]).dtype();
1449 *to == in_dt
1450 }
1451 _ => true,
1452 }
1453 };
1454
1455 let mut region_of: HashMap<NodeId, NodeId> = HashMap::new();
1464 let mut chain_step_idx: HashMap<NodeId, u32> = HashMap::new();
1465
1466 for node in graph.nodes() {
1467 if !chain_eligible(&node.op) {
1468 continue;
1469 }
1470 if !chain_step_safe(&graph, node) {
1471 continue;
1472 }
1473 let out_shape = &node.shape;
1480 let out_elems = out_shape.num_elements();
1481 let shape_ok = node.inputs.iter().all(|id| {
1482 let in_elems = graph.shape(*id).num_elements();
1483 match (in_elems, out_elems) {
1484 (Some(i), Some(o)) if i == o => true,
1485 (Some(i), Some(o)) if i > 0 && o % i == 0 => true,
1486 _ => false,
1487 }
1488 });
1489 if !shape_ok {
1490 continue;
1491 }
1492 let mut parent_root: Option<NodeId> = None;
1497 let mut all_inputs_single_consumer = true;
1498 for &input in &node.inputs {
1499 if graph.node(input).op.is_fusion_boundary() {
1501 parent_root = None;
1502 all_inputs_single_consumer = false;
1503 break;
1504 }
1505 if let Some(&root) = region_of.get(&input) {
1506 if consumers.get(&input).copied() != Some(1) {
1507 all_inputs_single_consumer = false;
1508 break;
1509 }
1510 match parent_root {
1511 None => parent_root = Some(root),
1512 Some(r) if r == root => {}
1513 Some(_) => {
1514 parent_root = None;
1515 all_inputs_single_consumer = false;
1516 break;
1517 }
1518 }
1519 }
1520 }
1521 if !all_inputs_single_consumer {
1522 region_of.insert(node.id, node.id);
1524 chain_step_idx.insert(node.id, 0);
1525 continue;
1526 }
1527 let root = parent_root.unwrap_or(node.id);
1528 let next_idx = node
1530 .inputs
1531 .iter()
1532 .filter_map(|id| {
1533 if region_of.get(id) == Some(&root) {
1534 chain_step_idx.get(id).copied()
1535 } else {
1536 None
1537 }
1538 })
1539 .max()
1540 .map(|m| m + 1)
1541 .unwrap_or(0);
1542 let limits = crate::limits::active_fusion_limits();
1543 if next_idx >= limits.max_elementwise_steps {
1544 region_of.insert(node.id, node.id);
1545 chain_step_idx.insert(node.id, 0);
1546 continue;
1547 }
1548 region_of.insert(node.id, root);
1549 chain_step_idx.insert(node.id, next_idx);
1550 }
1551
1552 let mut by_region: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
1555 for node in graph.nodes() {
1556 if let Some(&root) = region_of.get(&node.id) {
1557 by_region.entry(root).or_default().push(node.id);
1558 }
1559 }
1560
1561 let mut tail_of_region: HashMap<NodeId, NodeId> = HashMap::new();
1567 for (root, members) in &by_region {
1568 if members.len() < 2 {
1569 continue;
1570 }
1571 let max_idx = members.iter().map(|id| chain_step_idx[id]).max().unwrap();
1572 let tails: Vec<_> = members
1573 .iter()
1574 .filter(|id| chain_step_idx[id] == max_idx)
1575 .collect();
1576 if tails.len() != 1 {
1577 continue;
1578 }
1579 tail_of_region.insert(*root, *tails[0]);
1580 }
1581
1582 let by_region: HashMap<NodeId, Vec<NodeId>> = by_region
1584 .into_iter()
1585 .filter(|(root, _)| tail_of_region.contains_key(root))
1586 .collect();
1587
1588 if by_region.is_empty() {
1589 return graph;
1590 }
1591
1592 let mut rw = Rewriter::new(&graph.name);
1596 let mut emitted_region: HashMap<NodeId, NodeId> = HashMap::new();
1598
1599 for node in graph.nodes() {
1600 if let Some(&root) = region_of.get(&node.id)
1601 && let Some(&tail) = tail_of_region.get(&root)
1602 {
1603 if emitted_region.contains_key(&root) {
1604 let region_new = emitted_region[&root];
1610 rw.replace(node.id, region_new);
1611 continue;
1612 }
1613 if node.id == tail {
1614 let members = &by_region[&root];
1616 let mut ordered: Vec<NodeId> = members.clone();
1617 ordered.sort_by_key(|id| chain_step_idx[id]);
1618
1619 let mut external_inputs: Vec<NodeId> = Vec::new();
1623 let mut input_idx_of: HashMap<NodeId, u32> = HashMap::new();
1624 let mut step_idx_of: HashMap<NodeId, u32> = HashMap::new();
1625 for (i, member_id) in ordered.iter().enumerate() {
1626 step_idx_of.insert(*member_id, i as u32);
1627 let n = graph.node(*member_id);
1628 for &inp in &n.inputs {
1629 if !step_idx_of.contains_key(&inp) && !input_idx_of.contains_key(&inp) {
1630 let idx = external_inputs.len() as u32;
1631 input_idx_of.insert(inp, idx);
1632 external_inputs.push(inp);
1633 }
1634 }
1635 }
1636
1637 let limits = crate::limits::active_fusion_limits();
1638 if external_inputs.len() as u32 > limits.max_elementwise_inputs
1639 || ordered.len() as u32 > limits.max_elementwise_steps
1640 {
1641 for &mid in &ordered {
1642 rw.copy_node(graph.node(mid));
1643 }
1644 continue;
1645 }
1646
1647 let resolve = |id: NodeId| -> ChainOperand {
1648 if let Some(&i) = input_idx_of.get(&id) {
1649 ChainOperand::Input(i)
1650 } else {
1651 ChainOperand::Step(step_idx_of[&id])
1652 }
1653 };
1654 let mut chain: Vec<ChainStep> = Vec::with_capacity(ordered.len());
1655 for member_id in &ordered {
1656 let n = graph.node(*member_id);
1657 let step = match &n.op {
1658 Op::Activation(a) => ChainStep::Activation(*a, resolve(n.inputs[0])),
1659 Op::Cast { to } => ChainStep::Cast(*to, resolve(n.inputs[0])),
1660 Op::Binary(op) => {
1661 ChainStep::Binary(*op, resolve(n.inputs[0]), resolve(n.inputs[1]))
1662 }
1663 Op::Compare(op) => {
1664 ChainStep::Compare(*op, resolve(n.inputs[0]), resolve(n.inputs[1]))
1665 }
1666 Op::Where => ChainStep::Where(
1667 resolve(n.inputs[0]),
1668 resolve(n.inputs[1]),
1669 resolve(n.inputs[2]),
1670 ),
1671 _ => unreachable!("non-chain-eligible op in region"),
1672 };
1673 chain.push(step);
1674 }
1675
1676 let mut scalar_input_mask: u32 = 0;
1685 let mut input_modulus = [0u32; 16];
1686 let region_shape_elems = graph.node(tail).shape.num_elements();
1687 for (i, &ext) in external_inputs.iter().enumerate() {
1688 if i >= 16 {
1689 break;
1690 }
1691 let in_elems = graph.shape(ext).num_elements();
1692 match (in_elems, region_shape_elems) {
1693 (Some(1), Some(o)) if o != 1 => {
1694 scalar_input_mask |= 1u32 << i;
1695 input_modulus[i] = 1;
1696 }
1697 (Some(i_n), Some(o)) if i_n != o && i_n > 0 => {
1698 input_modulus[i] = i_n as u32;
1699 }
1700 _ => { }
1701 }
1702 }
1703 let region_new = rw.add_fused(
1704 Op::ElementwiseRegion {
1705 chain,
1706 num_inputs: external_inputs.len() as u32,
1707 scalar_input_mask,
1708 input_modulus,
1709 prologue: RegionPrologue::None,
1710 prologue_input: 0,
1711 },
1712 &external_inputs,
1713 graph.node(tail).shape.clone(),
1714 );
1715 emitted_region.insert(root, region_new);
1716 rw.replace(node.id, region_new);
1717 continue;
1718 } else {
1719 rw.replace(node.id, NodeId(u32::MAX)); continue;
1723 }
1724 }
1725 rw.copy_node(node);
1726 }
1727
1728 rw.finish(&graph.outputs)
1745 }
1746}
1747
1748pub struct UnfuseElementwiseRegions {
1761 pub unfuse_prologue: bool,
1764}
1765
1766impl UnfuseElementwiseRegions {
1767 pub const FOR_GPU: UnfuseElementwiseRegions = UnfuseElementwiseRegions {
1769 unfuse_prologue: false,
1770 };
1771 pub const FOR_CPU: UnfuseElementwiseRegions = UnfuseElementwiseRegions {
1773 unfuse_prologue: true,
1774 };
1775}
1776
1777impl Pass for UnfuseElementwiseRegions {
1778 fn name(&self) -> &str {
1779 "unfuse_elementwise_regions"
1780 }
1781
1782 fn run(&self, graph: Graph) -> Graph {
1783 let any_region = graph
1784 .nodes()
1785 .iter()
1786 .any(|n| matches!(n.op, Op::ElementwiseRegion { .. }));
1787 if !any_region {
1788 return graph;
1789 }
1790
1791 let mut rw = Rewriter::new(&graph.name);
1792 for node in graph.nodes() {
1793 if let Op::ElementwiseRegion {
1794 chain,
1795 num_inputs: _,
1796 scalar_input_mask: _,
1797 input_modulus: _,
1798 prologue,
1799 prologue_input: _,
1800 } = &node.op
1801 {
1802 if *prologue != RegionPrologue::None && !self.unfuse_prologue {
1803 rw.copy_node(node);
1804 continue;
1805 }
1806 let mut region_inputs: Vec<NodeId> =
1807 node.inputs.iter().map(|id| rw.map(*id)).collect();
1808 if *prologue == RegionPrologue::ResizeNearest2x {
1809 let in_shape = rw.new_graph.node(region_inputs[0]).shape.clone();
1810 let out_shape = if in_shape.rank() == 4 {
1811 Shape::new(
1812 &[
1813 in_shape.dim(0).unwrap_static(),
1814 in_shape.dim(1).unwrap_static(),
1815 in_shape.dim(2).unwrap_static() * 2,
1816 in_shape.dim(3).unwrap_static() * 2,
1817 ],
1818 in_shape.dtype(),
1819 )
1820 } else {
1821 node.shape.clone()
1822 };
1823 region_inputs[0] = rw.new_graph.add_node(
1824 Op::ResizeNearest2x,
1825 vec![region_inputs[0]],
1826 out_shape,
1827 );
1828 }
1829 let mut step_ids: Vec<NodeId> = Vec::with_capacity(chain.len());
1830 let region_shape = node.shape.clone();
1831 let region_dims: Vec<_> = region_shape.dims().to_vec();
1832 let mut step_dtypes: Vec<rlx_ir::DType> = Vec::with_capacity(chain.len());
1838 let region_dtype = region_shape.dtype();
1839 let dtype_of = |op: &ChainOperand,
1840 ins: &[NodeId],
1841 step_dt: &[rlx_ir::DType],
1842 rw: &Rewriter|
1843 -> rlx_ir::DType {
1844 match *op {
1845 ChainOperand::Input(i) => rw.new_graph.node(ins[i as usize]).shape.dtype(),
1846 ChainOperand::Step(i) => step_dt[i as usize],
1847 }
1848 };
1849 let shape_of = |op: &ChainOperand,
1860 ins: &[NodeId],
1861 step_ids: &[NodeId],
1862 rw: &Rewriter|
1863 -> Shape {
1864 match *op {
1865 ChainOperand::Input(i) => rw.new_graph.node(ins[i as usize]).shape.clone(),
1866 ChainOperand::Step(i) => {
1867 rw.new_graph.node(step_ids[i as usize]).shape.clone()
1868 }
1869 }
1870 };
1871 for step in chain {
1872 let resolve = |op: &ChainOperand| -> NodeId {
1873 match *op {
1874 ChainOperand::Input(i) => region_inputs[i as usize],
1875 ChainOperand::Step(i) => step_ids[i as usize],
1876 }
1877 };
1878 let (new_id, dt) = match step {
1879 ChainStep::Activation(a, src) => {
1880 let s = resolve(src);
1881 let dt = dtype_of(src, ®ion_inputs, &step_dtypes, &rw);
1882 let src_shape = shape_of(src, ®ion_inputs, &step_ids, &rw);
1886 let dims: Vec<_> = src_shape.dims().to_vec();
1887 let shape = Shape::from_dims(&dims, dt);
1888 (
1889 rw.new_graph.add_node(Op::Activation(*a), vec![s], shape),
1890 dt,
1891 )
1892 }
1893 ChainStep::Cast(to, src) => {
1894 let s = resolve(src);
1895 let src_shape = shape_of(src, ®ion_inputs, &step_ids, &rw);
1896 let dims: Vec<_> = src_shape.dims().to_vec();
1897 let shape = Shape::from_dims(&dims, *to);
1898 (
1899 rw.new_graph.add_node(Op::Cast { to: *to }, vec![s], shape),
1900 *to,
1901 )
1902 }
1903 ChainStep::Binary(op, lhs, rhs) => {
1904 let l = resolve(lhs);
1905 let r = resolve(rhs);
1906 let dt = dtype_of(lhs, ®ion_inputs, &step_dtypes, &rw);
1907 let lhs_shape = shape_of(lhs, ®ion_inputs, &step_ids, &rw);
1909 let rhs_shape = shape_of(rhs, ®ion_inputs, &step_ids, &rw);
1910 let bcast = rlx_ir::shape::broadcast(&lhs_shape, &rhs_shape)
1911 .unwrap_or_else(|e| {
1912 panic!(
1913 "unfuse_elementwise_regions: cannot broadcast \
1914 {lhs_shape:?} ⊗ {rhs_shape:?} for Binary({op:?}): {e}"
1915 )
1916 });
1917 let dims: Vec<_> = bcast.dims().to_vec();
1918 let shape = Shape::from_dims(&dims, dt);
1919 (
1920 rw.new_graph.add_node(Op::Binary(*op), vec![l, r], shape),
1921 dt,
1922 )
1923 }
1924 ChainStep::Compare(op, lhs, rhs) => {
1925 let l = resolve(lhs);
1926 let r = resolve(rhs);
1927 let lhs_shape = shape_of(lhs, ®ion_inputs, &step_ids, &rw);
1928 let rhs_shape = shape_of(rhs, ®ion_inputs, &step_ids, &rw);
1929 let bcast = rlx_ir::shape::broadcast(&lhs_shape, &rhs_shape)
1930 .unwrap_or_else(|e| {
1931 panic!(
1932 "unfuse_elementwise_regions: cannot broadcast \
1933 {lhs_shape:?} ⊗ {rhs_shape:?} for Compare({op:?}): {e}"
1934 )
1935 });
1936 let dims: Vec<_> = bcast.dims().to_vec();
1937 let shape = Shape::from_dims(&dims, rlx_ir::DType::Bool);
1938 (
1939 rw.new_graph.add_node(Op::Compare(*op), vec![l, r], shape),
1940 rlx_ir::DType::Bool,
1941 )
1942 }
1943 ChainStep::Where(c, x, y) => {
1944 let cn = resolve(c);
1945 let xn = resolve(x);
1946 let yn = resolve(y);
1947 let dt = dtype_of(x, ®ion_inputs, &step_dtypes, &rw);
1948 let c_shape = shape_of(c, ®ion_inputs, &step_ids, &rw);
1950 let x_shape = shape_of(x, ®ion_inputs, &step_ids, &rw);
1951 let y_shape = shape_of(y, ®ion_inputs, &step_ids, &rw);
1952 let bcast_xy = rlx_ir::shape::broadcast(&x_shape, &y_shape)
1953 .unwrap_or_else(|e| {
1954 panic!(
1955 "unfuse_elementwise_regions: cannot broadcast \
1956 then/else {x_shape:?} ⊗ {y_shape:?} for Where: {e}"
1957 )
1958 });
1959 let bcast = rlx_ir::shape::broadcast(&c_shape, &bcast_xy)
1960 .unwrap_or_else(|e| {
1961 panic!(
1962 "unfuse_elementwise_regions: cannot broadcast cond \
1963 {c_shape:?} ⊗ {bcast_xy:?} for Where: {e}"
1964 )
1965 });
1966 let dims: Vec<_> = bcast.dims().to_vec();
1967 let shape = Shape::from_dims(&dims, dt);
1968 (
1969 rw.new_graph.add_node(Op::Where, vec![cn, xn, yn], shape),
1970 dt,
1971 )
1972 }
1973 };
1974 step_ids.push(new_id);
1975 step_dtypes.push(dt);
1976 }
1977 let _ = region_dtype;
1978 let _ = region_dims;
1979 let last = *step_ids.last().expect("chain non-empty per pass invariant");
1982 rw.replace(node.id, last);
1983 continue;
1984 }
1985 rw.copy_node(node);
1986 }
1987 rw.finish(&graph.outputs)
1988 }
1989}
1990
1991pub fn clip_elementwise_regions(graph: Graph, limits: crate::limits::FusionLimits) -> Graph {
1996 let oversize = |n: &rlx_ir::Node| -> bool {
1997 matches!(
1998 &n.op,
1999 Op::ElementwiseRegion {
2000 chain,
2001 num_inputs,
2002 ..
2003 } if *num_inputs > limits.max_elementwise_inputs
2004 || chain.len() as u32 > limits.max_elementwise_steps
2005 )
2006 };
2007 if !graph.nodes().iter().any(oversize) {
2008 return graph;
2009 }
2010
2011 let mut rw = Rewriter::new(&graph.name);
2012 for node in graph.nodes() {
2013 if !oversize(node) {
2014 rw.copy_node(node);
2015 continue;
2016 }
2017
2018 let Op::ElementwiseRegion {
2019 chain,
2020 num_inputs: _,
2021 scalar_input_mask: _,
2022 input_modulus: _,
2023 prologue: _,
2024 prologue_input: _,
2025 } = &node.op
2026 else {
2027 unreachable!();
2028 };
2029
2030 let region_inputs: Vec<NodeId> = node.inputs.iter().map(|id| rw.map(*id)).collect();
2031 let mut step_ids: Vec<NodeId> = Vec::with_capacity(chain.len());
2032 let region_shape = node.shape.clone();
2033 let region_dims: Vec<_> = region_shape.dims().to_vec();
2034 let mut step_dtypes: Vec<rlx_ir::DType> = Vec::with_capacity(chain.len());
2035 let region_dtype = region_shape.dtype();
2036 let dtype_of = |op: &ChainOperand,
2037 ins: &[NodeId],
2038 step_dt: &[rlx_ir::DType],
2039 rw: &Rewriter|
2040 -> rlx_ir::DType {
2041 match *op {
2042 ChainOperand::Input(i) => rw.new_graph.node(ins[i as usize]).shape.dtype(),
2043 ChainOperand::Step(i) => step_dt[i as usize],
2044 }
2045 };
2046 let shape_of =
2047 |op: &ChainOperand, ins: &[NodeId], step_ids: &[NodeId], rw: &Rewriter| -> Shape {
2048 match *op {
2049 ChainOperand::Input(i) => rw.new_graph.node(ins[i as usize]).shape.clone(),
2050 ChainOperand::Step(i) => rw.new_graph.node(step_ids[i as usize]).shape.clone(),
2051 }
2052 };
2053 for step in chain {
2054 let resolve = |op: &ChainOperand| -> NodeId {
2055 match *op {
2056 ChainOperand::Input(i) => region_inputs[i as usize],
2057 ChainOperand::Step(i) => step_ids[i as usize],
2058 }
2059 };
2060 let (new_id, dt) = match step {
2061 ChainStep::Activation(a, src) => {
2062 let s = resolve(src);
2063 let dt = dtype_of(src, ®ion_inputs, &step_dtypes, &rw);
2064 let src_shape = shape_of(src, ®ion_inputs, &step_ids, &rw);
2065 let dims: Vec<_> = src_shape.dims().to_vec();
2066 let shape = Shape::from_dims(&dims, dt);
2067 (
2068 rw.new_graph.add_node(Op::Activation(*a), vec![s], shape),
2069 dt,
2070 )
2071 }
2072 ChainStep::Cast(to, src) => {
2073 let s = resolve(src);
2074 let src_shape = shape_of(src, ®ion_inputs, &step_ids, &rw);
2075 let dims: Vec<_> = src_shape.dims().to_vec();
2076 let shape = Shape::from_dims(&dims, *to);
2077 (
2078 rw.new_graph.add_node(Op::Cast { to: *to }, vec![s], shape),
2079 *to,
2080 )
2081 }
2082 ChainStep::Binary(op, lhs, rhs) => {
2083 let l = resolve(lhs);
2084 let r = resolve(rhs);
2085 let dt = dtype_of(lhs, ®ion_inputs, &step_dtypes, &rw);
2086 let l_shape = shape_of(lhs, ®ion_inputs, &step_ids, &rw);
2087 let r_shape = shape_of(rhs, ®ion_inputs, &step_ids, &rw);
2088 let bcast = l_shape
2089 .broadcast_with(&r_shape)
2090 .unwrap_or_else(|e| panic!("clip_elementwise_regions: {e}"));
2091 let dims: Vec<_> = bcast.dims().to_vec();
2092 let shape = Shape::from_dims(&dims, dt);
2093 (
2094 rw.new_graph.add_node(Op::Binary(*op), vec![l, r], shape),
2095 dt,
2096 )
2097 }
2098 ChainStep::Compare(op, lhs, rhs) => {
2099 let l = resolve(lhs);
2100 let r = resolve(rhs);
2101 let l_shape = shape_of(lhs, ®ion_inputs, &step_ids, &rw);
2102 let r_shape = shape_of(rhs, ®ion_inputs, &step_ids, &rw);
2103 let bcast = l_shape
2104 .broadcast_with(&r_shape)
2105 .unwrap_or_else(|e| panic!("clip_elementwise_regions: {e}"));
2106 let dims: Vec<_> = bcast.dims().to_vec();
2107 let shape = Shape::from_dims(&dims, rlx_ir::DType::U8);
2108 (
2109 rw.new_graph.add_node(Op::Compare(*op), vec![l, r], shape),
2110 rlx_ir::DType::U8,
2111 )
2112 }
2113 ChainStep::Where(cond, x, y) => {
2114 let cn = resolve(cond);
2115 let xn = resolve(x);
2116 let yn = resolve(y);
2117 let dt = dtype_of(x, ®ion_inputs, &step_dtypes, &rw);
2118 let x_shape = shape_of(x, ®ion_inputs, &step_ids, &rw);
2119 let y_shape = shape_of(y, ®ion_inputs, &step_ids, &rw);
2120 let c_shape = shape_of(cond, ®ion_inputs, &step_ids, &rw);
2121 let bcast_xy = x_shape
2122 .broadcast_with(&y_shape)
2123 .unwrap_or_else(|e| panic!("clip_elementwise_regions: {e}"));
2124 let bcast = c_shape.broadcast_with(&bcast_xy).unwrap_or_else(|e| {
2125 panic!("clip_elementwise_regions: cannot broadcast cond {c_shape:?} ⊗ {bcast_xy:?} for Where: {e}")
2126 });
2127 let dims: Vec<_> = bcast.dims().to_vec();
2128 let shape = Shape::from_dims(&dims, dt);
2129 (
2130 rw.new_graph.add_node(Op::Where, vec![cn, xn, yn], shape),
2131 dt,
2132 )
2133 }
2134 };
2135 step_ids.push(new_id);
2136 step_dtypes.push(dt);
2137 }
2138 let _ = (region_dtype, region_dims);
2139 let last = *step_ids
2140 .last()
2141 .expect("oversize region has non-empty chain");
2142 rw.replace(node.id, last);
2143 }
2144 rw.finish(&graph.outputs)
2145}
2146
2147#[cfg(test)]
2148mod tests {
2149 use super::*;
2150 use crate::limits::FusionLimits;
2151 use crate::pass::run_passes;
2152
2153 fn f32_shape(dims: &[usize]) -> Shape {
2154 Shape::new(dims, DType::F32)
2155 }
2156
2157 #[test]
2158 fn fuse_matmul_bias_gelu() {
2159 let mut g = Graph::new("test");
2160 let x = g.input("x", f32_shape(&[4, 15, 384]));
2161 let w = g.param("w", f32_shape(&[384, 1536]));
2162 let b = g.param("b", f32_shape(&[1536]));
2163 let mm = g.matmul(x, w, f32_shape(&[4, 15, 1536]));
2164 let add = g.binary(BinaryOp::Add, mm, b, f32_shape(&[4, 15, 1536]));
2165 let out = g.activation(Activation::Gelu, add, f32_shape(&[4, 15, 1536]));
2166 g.set_outputs(vec![out]);
2167
2168 assert_eq!(g.len(), 6); let fused = FuseMatMulBiasAct.run(g);
2171 println!("{fused}");
2172
2173 assert_eq!(fused.len(), 4);
2175 let out_node = fused.node(fused.outputs[0]);
2176 assert!(matches!(
2177 out_node.op,
2178 Op::FusedMatMulBiasAct {
2179 activation: Some(Activation::Gelu)
2180 }
2181 ));
2182 }
2183
2184 #[test]
2185 fn fuse_matmul_bias_no_act() {
2186 let mut g = Graph::new("test");
2187 let x = g.input("x", f32_shape(&[4, 15, 384]));
2188 let w = g.param("w", f32_shape(&[384, 384]));
2189 let b = g.param("b", f32_shape(&[384]));
2190 let mm = g.matmul(x, w, f32_shape(&[4, 15, 384]));
2191 let add = g.binary(BinaryOp::Add, mm, b, f32_shape(&[4, 15, 384]));
2192 g.set_outputs(vec![add]);
2193
2194 let fused = FuseMatMulBiasAct.run(g);
2195 assert_eq!(fused.len(), 4);
2196 let out_node = fused.node(fused.outputs[0]);
2197 assert!(matches!(
2198 out_node.op,
2199 Op::FusedMatMulBiasAct { activation: None }
2200 ));
2201 }
2202
2203 #[test]
2204 fn fuse_matmul_bias_skips_unsupported_activation_epilogue() {
2205 let mut g = Graph::new("test");
2206 let x = g.input("x", f32_shape(&[8, 1024]));
2207 let w = g.param("w", f32_shape(&[1024, 16]));
2208 let b = g.param("b", f32_shape(&[16]));
2209 let mm = g.matmul(x, w, f32_shape(&[8, 16]));
2210 let add = g.binary(BinaryOp::Add, mm, b, f32_shape(&[8, 16]));
2211 let exp = g.activation(Activation::Exp, add, f32_shape(&[8, 16]));
2212 g.set_outputs(vec![exp]);
2213
2214 let fused = FuseMatMulBiasAct.run(g);
2215 assert_eq!(fused.len(), 5);
2217 let out_node = fused.node(fused.outputs[0]);
2218 assert!(matches!(out_node.op, Op::Activation(Activation::Exp)));
2219 let add_node = fused.node(out_node.inputs[0]);
2220 assert!(matches!(
2221 add_node.op,
2222 Op::FusedMatMulBiasAct { activation: None }
2223 ));
2224 }
2225
2226 #[test]
2227 fn fuse_matmul_bias_act_with_late_bias_param() {
2228 use rlx_ir::infer::GraphExt;
2229
2230 let mut g = Graph::new("late_bias");
2231 let x = g.input("x", f32_shape(&[8, 16]));
2232 let w = g.param("w", f32_shape(&[16, 32]));
2233 let out = {
2234 let mm = g.mm(x, w);
2235 let b = g.param("b", f32_shape(&[32]));
2236 let biased = g.add(mm, b);
2237 g.gelu(biased)
2238 };
2239 g.set_outputs(vec![out]);
2240
2241 let fused = FuseMatMulBiasAct.run(g);
2242 assert!(
2243 fused
2244 .nodes()
2245 .iter()
2246 .any(|n| matches!(n.op, Op::FusedMatMulBiasAct { .. })),
2247 "bias param declared after matmul must still fuse:\n{fused}"
2248 );
2249 }
2250
2251 #[test]
2252 fn swiglu_ffn_builder_fuses_end_to_end() {
2253 let mut g = Graph::new("swiglu_block");
2254 let x = g.input("x", f32_shape(&[4, 768]));
2255 let up_w = g.param("up", f32_shape(&[768, 2048]));
2256 let gate_w = g.param("gate", f32_shape(&[768, 2048]));
2257 let down_w = g.param("down", f32_shape(&[2048, 768]));
2258 let out = g.swiglu_ffn(x, up_w, gate_w, down_w);
2259 g.set_outputs(vec![out]);
2260
2261 let g = FuseSharedInputMatMul.run(g);
2262 let g = FuseSwiGLU.run(g);
2263 assert!(
2264 g.nodes()
2265 .iter()
2266 .any(|n| matches!(n.op, Op::FusedSwiGLU { .. })),
2267 "swiglu_ffn builder should match FuseSwiGLU:\n{g}"
2268 );
2269 }
2270
2271 #[test]
2272 fn fuse_swiglu_dual_matmul_gate_first() {
2273 use rlx_ir::infer::GraphExt;
2274
2275 let mut g = Graph::new("qwen3_ffn");
2276 let x = g.input("x", f32_shape(&[4, 768]));
2277 let gate_w = g.param("gate", f32_shape(&[768, 2048]));
2278 let up_w = g.param("up", f32_shape(&[768, 2048]));
2279 let gate = g.mm(x, gate_w);
2280 let up = g.mm(x, up_w);
2281 let gate_act = g.silu(gate);
2282 let out = g.mul(gate_act, up);
2283 g.set_outputs(vec![out]);
2284
2285 let fused = FuseSwiGLUDualMatmul.run(g);
2286 assert!(
2287 fused
2288 .nodes()
2289 .iter()
2290 .any(|n| matches!(n.op, Op::FusedSwiGLU { .. })),
2291 "gate-first dual matmul should fuse:\n{fused}"
2292 );
2293 assert!(
2294 fused.len() <= 6,
2295 "dual fusion should collapse to x + weights + concat + mm + fused_swiglu, got {} nodes",
2296 fused.len()
2297 );
2298 }
2299
2300 #[test]
2301 fn fuse_shared_input_matmul_three_way_qkv() {
2302 let mut g = Graph::new("qkv");
2303 let x = g.input("x", f32_shape(&[8, 512]));
2304 let wq = g.param("wq", f32_shape(&[512, 128]));
2305 let wk = g.param("wk", f32_shape(&[512, 128]));
2306 let wv = g.param("wv", f32_shape(&[512, 128]));
2307 let q = g.matmul(x, wq, f32_shape(&[8, 128]));
2308 let k = g.matmul(x, wk, f32_shape(&[8, 128]));
2309 let v = g.matmul(x, wv, f32_shape(&[8, 128]));
2310 g.set_outputs(vec![q, k, v]);
2311
2312 let fused = FuseSharedInputMatMul.run(g);
2313 assert_eq!(
2314 fused.len(),
2315 9,
2316 "x + 3 weights + concat + mm + 3 narrows = 9"
2317 );
2318 for &out in &fused.outputs {
2319 assert!(matches!(fused.node(out).op, Op::Narrow { .. }));
2320 }
2321 }
2322
2323 #[test]
2324 fn fuse_residual_layer_norm() {
2325 let mut g = Graph::new("test");
2326 let x = g.input("x", f32_shape(&[4, 15, 384]));
2327 let residual = g.input("residual", f32_shape(&[4, 15, 384]));
2328 let gamma = g.param("gamma", f32_shape(&[384]));
2329 let beta = g.param("beta", f32_shape(&[384]));
2330 let add = g.binary(BinaryOp::Add, x, residual, f32_shape(&[4, 15, 384]));
2331 let ln = g.layer_norm(add, gamma, beta, -1, 1e-12, f32_shape(&[4, 15, 384]));
2332 g.set_outputs(vec![ln]);
2333
2334 assert_eq!(g.len(), 6); let fused = FuseResidualLN.run(g);
2337 println!("{fused}");
2338
2339 assert_eq!(fused.len(), 5);
2341 let out_node = fused.node(fused.outputs[0]);
2342 assert!(matches!(
2343 out_node.op,
2344 Op::FusedResidualLN {
2345 has_bias: false,
2346 ..
2347 }
2348 ));
2349 }
2350
2351 #[test]
2352 fn fuse_residual_rms_norm() {
2353 let mut g = Graph::new("test");
2354 let x = g.input("x", f32_shape(&[4, 15, 384]));
2355 let residual = g.input("residual", f32_shape(&[4, 15, 384]));
2356 let gamma = g.param("gamma", f32_shape(&[384]));
2357 let beta = g.param("beta", f32_shape(&[384]));
2358 let add = g.binary(BinaryOp::Add, x, residual, f32_shape(&[4, 15, 384]));
2359 let rn = g.add_node(
2360 Op::RmsNorm {
2361 axis: -1,
2362 eps: 1e-6,
2363 },
2364 vec![add, gamma, beta],
2365 f32_shape(&[4, 15, 384]),
2366 );
2367 g.set_outputs(vec![rn]);
2368
2369 assert_eq!(g.len(), 6);
2370
2371 let fused = FuseResidualRmsNorm.run(g);
2372 assert_eq!(fused.len(), 5);
2373 let out_node = fused.node(fused.outputs[0]);
2374 assert!(matches!(
2375 out_node.op,
2376 Op::FusedResidualRmsNorm {
2377 has_bias: false,
2378 ..
2379 }
2380 ));
2381 }
2382
2383 #[test]
2384 fn fuse_rms_norm_reshape() {
2385 let mut g = Graph::new("test");
2386 let x = g.input("x", f32_shape(&[1, 8, 512]));
2387 let gamma = g.param("gamma", f32_shape(&[512]));
2388 let beta = g.param("beta", f32_shape(&[512]));
2389 let rn = g.add_node(
2390 Op::RmsNorm {
2391 axis: -1,
2392 eps: 1e-6,
2393 },
2394 vec![x, gamma, beta],
2395 f32_shape(&[1, 8, 512]),
2396 );
2397 let flat = g.add_node(
2398 Op::Reshape {
2399 new_shape: vec![8, 512],
2400 },
2401 vec![rn],
2402 f32_shape(&[8, 512]),
2403 );
2404 let w = g.param("w", f32_shape(&[512, 128]));
2405 let mm = g.matmul(flat, w, f32_shape(&[8, 128]));
2406 g.set_outputs(vec![mm]);
2407
2408 let fused = FuseRmsNormReshape.run(g);
2409 assert_eq!(fused.len(), 6);
2411 let rn_node = fused.node(fused.node(fused.outputs[0]).inputs[0]);
2412 assert!(matches!(rn_node.op, Op::RmsNorm { .. }));
2413 assert_eq!(rn_node.shape.dim(0).unwrap_static(), 8);
2414 assert_eq!(rn_node.shape.dim(1).unwrap_static(), 512);
2415 }
2416
2417 #[test]
2418 fn fuse_shared_input_matmul() {
2419 let mut g = Graph::new("swiglu");
2420 let x = g.input("x", f32_shape(&[60, 768]));
2421 let w1 = g.param("fc11", f32_shape(&[768, 2048]));
2422 let w2 = g.param("fc12", f32_shape(&[768, 2048]));
2423 let mm1 = g.matmul(x, w1, f32_shape(&[60, 2048]));
2424 let mm2 = g.matmul(x, w2, f32_shape(&[60, 2048]));
2425 g.set_outputs(vec![mm1, mm2]);
2426
2427 assert_eq!(g.len(), 5); let fused = FuseSharedInputMatMul.run(g);
2430 println!("{fused}");
2431
2432 assert!(fused.len() <= 7);
2434 for &out in &fused.outputs {
2436 assert!(matches!(fused.node(out).op, Op::Narrow { .. }));
2437 }
2438 }
2439
2440 #[test]
2443 fn fuse_shared_input_matmul_with_late_w2_param() {
2444 let mut g = Graph::new("late_w2");
2445 let x = g.input("x", f32_shape(&[8, 16]));
2446 let w1 = g.param("w1", f32_shape(&[16, 8]));
2447 let mm1 = g.matmul(x, w1, f32_shape(&[8, 8]));
2448 let w2 = g.param("w2", f32_shape(&[16, 8]));
2449 let mm2 = g.matmul(x, w2, f32_shape(&[8, 8]));
2450 g.set_outputs(vec![mm1, mm2]);
2451
2452 let fused = FuseSharedInputMatMul.run(g);
2453 for &out in &fused.outputs {
2454 assert!(
2455 matches!(fused.node(out).op, Op::Narrow { .. }),
2456 "late w2 should still fuse via ensure_mapped, got {:?}",
2457 fused.node(out).op
2458 );
2459 }
2460 }
2461
2462 #[test]
2465 fn fuse_shared_input_matmul_moe_ffn_pattern() {
2466 let mut g = Graph::new("moe_ffn");
2467 let rows = 4usize;
2468 let n_embd = 16usize;
2469 let n_expert = 4usize;
2470 let n_ff = 16usize;
2471
2472 let h_in = g.input("h", f32_shape(&[1, rows, n_embd]));
2473 let h_2d = g.reshape_(h_in, vec![rows as i64, n_embd as i64]);
2474
2475 let router_w = g.param("router_w", f32_shape(&[n_embd, n_expert]));
2476 let router_logits = g.matmul(h_2d, router_w, f32_shape(&[rows, n_expert]));
2477
2478 let shared_router_w = g.param("shared_router_w", f32_shape(&[n_embd, 1]));
2480 let shared_logits = g.matmul(h_2d, shared_router_w, f32_shape(&[rows, 1]));
2481 let shared_gate = g.activation(Activation::Sigmoid, shared_logits, f32_shape(&[rows, 1]));
2482
2483 let s_gate_w = g.param("s_gate_w", f32_shape(&[n_embd, n_ff]));
2484 let s_up_w = g.param("s_up_w", f32_shape(&[n_embd, n_ff]));
2485 let s_gate = g.matmul(h_2d, s_gate_w, f32_shape(&[rows, n_ff]));
2486 let s_up = g.matmul(h_2d, s_up_w, f32_shape(&[rows, n_ff]));
2487 let s_gate_silu = g.silu(s_gate);
2488 let s_swiglu = g.mul(s_gate_silu, s_up);
2489
2490 g.set_outputs(vec![router_logits, shared_gate, s_swiglu]);
2491
2492 let fused = FuseSharedInputMatMul.run(g);
2493 let narrow_count = fused
2494 .nodes()
2495 .iter()
2496 .filter(|n| matches!(n.op, Op::Narrow { .. }))
2497 .count();
2498 assert!(
2499 narrow_count >= 4,
2500 "expected four narrow slices from fused h_2d matmuls, got {narrow_count}"
2501 );
2502 }
2503
2504 #[test]
2506 fn full_bert_ffn_fusion() {
2507 let mut g = Graph::new("bert_ffn");
2508 let f = DType::F32;
2509
2510 let x = g.input("hidden", Shape::new(&[4, 15, 384], f));
2511 let residual = g.input("residual", Shape::new(&[4, 15, 384], f));
2512
2513 let out_w = g.param("out.w", Shape::new(&[384, 384], f));
2515 let out_b = g.param("out.b", Shape::new(&[384], f));
2516 let out_mm = g.matmul(x, out_w, Shape::new(&[4, 15, 384], f));
2517 let out_add = g.binary(BinaryOp::Add, out_mm, out_b, Shape::new(&[4, 15, 384], f));
2518 let res_add = g.binary(
2519 BinaryOp::Add,
2520 out_add,
2521 residual,
2522 Shape::new(&[4, 15, 384], f),
2523 );
2524 let gamma = g.param("ln.g", Shape::new(&[384], f));
2525 let beta = g.param("ln.b", Shape::new(&[384], f));
2526 let ln = g.layer_norm(
2527 res_add,
2528 gamma,
2529 beta,
2530 -1,
2531 1e-12,
2532 Shape::new(&[4, 15, 384], f),
2533 );
2534
2535 let int_w = g.param("int.w", Shape::new(&[384, 1536], f));
2537 let int_b = g.param("int.b", Shape::new(&[1536], f));
2538 let int_mm = g.matmul(ln, int_w, Shape::new(&[4, 15, 1536], f));
2539 let int_add = g.binary(BinaryOp::Add, int_mm, int_b, Shape::new(&[4, 15, 1536], f));
2540 let gelu = g.activation(Activation::Gelu, int_add, Shape::new(&[4, 15, 1536], f));
2541
2542 let out2_w = g.param("out2.w", Shape::new(&[1536, 384], f));
2544 let out2_b = g.param("out2.b", Shape::new(&[384], f));
2545 let out2_mm = g.matmul(gelu, out2_w, Shape::new(&[4, 15, 384], f));
2546 let out2_add = g.binary(BinaryOp::Add, out2_mm, out2_b, Shape::new(&[4, 15, 384], f));
2547
2548 g.set_outputs(vec![out2_add]);
2549
2550 let before = g.len();
2551 println!("=== BEFORE fusion ({before} nodes) ===\n{g}");
2552
2553 let passes: Vec<&dyn Pass> = vec![&FuseMatMulBiasAct, &FuseResidualLN];
2555 let optimized = run_passes(g, &passes, false);
2556 let after = optimized.len();
2557 println!("=== AFTER fusion ({after} nodes) ===\n{optimized}");
2558
2559 assert!(
2563 after < before,
2564 "fusion should reduce node count: {before} → {after}"
2565 );
2566
2567 let ops: Vec<String> = optimized
2569 .nodes()
2570 .iter()
2571 .map(|n| format!("{}", n.op))
2572 .collect();
2573 let has_fused_mm = ops.iter().any(|s| s.contains("fused_mm_bias"));
2574 assert!(has_fused_mm, "should have fused_mm_bias_act: {ops:?}");
2575 }
2576
2577 #[test]
2580 fn fuse_swiglu_canonical() {
2581 let mut g = Graph::new("nomic_ffn");
2582 let f = DType::F32;
2583 let cat = g.input("cat", Shape::new(&[60, 4096], f));
2585 let up = g.add_node(
2586 Op::Narrow {
2587 axis: 1,
2588 start: 0,
2589 len: 2048,
2590 },
2591 vec![cat],
2592 Shape::new(&[60, 2048], f),
2593 );
2594 let gate = g.add_node(
2595 Op::Narrow {
2596 axis: 1,
2597 start: 2048,
2598 len: 2048,
2599 },
2600 vec![cat],
2601 Shape::new(&[60, 2048], f),
2602 );
2603 let silu = g.activation(Activation::Silu, gate, Shape::new(&[60, 2048], f));
2604 let out = g.binary(BinaryOp::Mul, up, silu, Shape::new(&[60, 2048], f));
2605 g.set_outputs(vec![out]);
2606
2607 let before = g.len();
2608 let fused = FuseSwiGLU.run(g);
2609 let after = fused.len();
2610 assert_eq!(
2613 after,
2614 before - 3,
2615 "should remove narrows+silu+mul, add FusedSwiGLU"
2616 );
2617 let out_node = fused.node(fused.outputs[0]);
2618 assert!(
2619 matches!(
2620 out_node.op,
2621 Op::FusedSwiGLU {
2622 cast_to: None,
2623 gate_first: false
2624 }
2625 ),
2626 "output should be FusedSwiGLU, got {}",
2627 out_node.op
2628 );
2629 let in_id = out_node.inputs[0];
2631 assert!(matches!(fused.node(in_id).op, Op::Input { .. }));
2632 }
2633
2634 #[test]
2637 fn fuse_swiglu_skips_when_narrow_has_extra_user() {
2638 let mut g = Graph::new("contended");
2639 let f = DType::F32;
2640 let cat = g.input("cat", Shape::new(&[60, 4096], f));
2641 let up = g.add_node(
2642 Op::Narrow {
2643 axis: 1,
2644 start: 0,
2645 len: 2048,
2646 },
2647 vec![cat],
2648 Shape::new(&[60, 2048], f),
2649 );
2650 let gate = g.add_node(
2651 Op::Narrow {
2652 axis: 1,
2653 start: 2048,
2654 len: 2048,
2655 },
2656 vec![cat],
2657 Shape::new(&[60, 2048], f),
2658 );
2659 let silu = g.activation(Activation::Silu, gate, Shape::new(&[60, 2048], f));
2660 let out = g.binary(BinaryOp::Mul, up, silu, Shape::new(&[60, 2048], f));
2661 let extra = g.activation(Activation::Relu, up, Shape::new(&[60, 2048], f));
2663 g.set_outputs(vec![out, extra]);
2664
2665 let before = g.len();
2666 let fused = FuseSwiGLU.run(g);
2667 assert_eq!(fused.len(), before);
2669 let any_fused = fused
2671 .nodes()
2672 .iter()
2673 .any(|n| matches!(n.op, Op::FusedSwiGLU { .. }));
2674 assert!(!any_fused, "should not fuse when narrow has extra user");
2675 }
2676
2677 #[test]
2680 fn region_collapses_add_mul_relu_chain() {
2681 let f = DType::F32;
2684 let mut g = Graph::new("ew");
2685 let a = g.input("a", Shape::new(&[8], f));
2686 let b = g.input("b", Shape::new(&[8], f));
2687 let c = g.input("c", Shape::new(&[8], f));
2688 let s = Shape::new(&[8], f);
2689 let add = g.binary(BinaryOp::Add, a, b, s.clone());
2690 let mul = g.binary(BinaryOp::Mul, add, c, s.clone());
2691 let relu = g.activation(Activation::Relu, mul, s.clone());
2692 g.set_outputs(vec![relu]);
2693
2694 let before = g.len();
2695 let fused = MarkElementwiseRegions.run(g);
2696
2697 let regions: Vec<_> = fused
2699 .nodes()
2700 .iter()
2701 .filter(|n| matches!(n.op, Op::ElementwiseRegion { .. }))
2702 .collect();
2703 assert_eq!(regions.len(), 1, "expected one ElementwiseRegion");
2704 let region = regions[0];
2705 assert_eq!(
2706 region.inputs.len(),
2707 3,
2708 "region has 3 external inputs (a, b, c)"
2709 );
2710 if let Op::ElementwiseRegion {
2711 chain, num_inputs, ..
2712 } = ®ion.op
2713 {
2714 assert_eq!(*num_inputs, 3);
2715 assert_eq!(chain.len(), 3);
2716 match &chain[0] {
2718 ChainStep::Binary(
2719 BinaryOp::Add,
2720 ChainOperand::Input(0),
2721 ChainOperand::Input(1),
2722 ) => {}
2723 other => panic!("step 0 unexpected: {other:?}"),
2724 }
2725 match &chain[1] {
2727 ChainStep::Binary(BinaryOp::Mul, ChainOperand::Step(0), ChainOperand::Input(2)) => {
2728 }
2729 other => panic!("step 1 unexpected: {other:?}"),
2730 }
2731 match &chain[2] {
2733 ChainStep::Activation(Activation::Relu, ChainOperand::Step(1)) => {}
2734 other => panic!("step 2 unexpected: {other:?}"),
2735 }
2736 } else {
2737 unreachable!();
2738 }
2739 assert!(fused.len() < before);
2742 }
2743
2744 #[test]
2745 fn region_does_not_fuse_when_intermediate_has_multiple_consumers() {
2746 let f = DType::F32;
2749 let mut g = Graph::new("ew");
2750 let a = g.input("a", Shape::new(&[4], f));
2751 let b = g.input("b", Shape::new(&[4], f));
2752 let s = Shape::new(&[4], f);
2753 let add = g.binary(BinaryOp::Add, a, b, s.clone());
2754 let relu = g.activation(Activation::Relu, add, s.clone());
2755 let extra = g.activation(Activation::Sigmoid, add, s.clone());
2756 g.set_outputs(vec![relu, extra]);
2757
2758 let before = g.len();
2759 let fused = MarkElementwiseRegions.run(g);
2760 let regions: Vec<_> = fused
2764 .nodes()
2765 .iter()
2766 .filter(|n| matches!(n.op, Op::ElementwiseRegion { .. }))
2767 .collect();
2768 assert_eq!(regions.len(), 0);
2769 assert_eq!(fused.len(), before);
2770 }
2771
2772 #[test]
2773 fn region_skips_chains_of_length_one() {
2774 let f = DType::F32;
2776 let mut g = Graph::new("ew");
2777 let a = g.input("a", Shape::new(&[4], f));
2778 let r = g.activation(Activation::Relu, a, Shape::new(&[4], f));
2779 g.set_outputs(vec![r]);
2780
2781 let fused = MarkElementwiseRegions.run(g);
2782 let any_region = fused
2783 .nodes()
2784 .iter()
2785 .any(|n| matches!(n.op, Op::ElementwiseRegion { .. }));
2786 assert!(!any_region);
2787 }
2788
2789 #[test]
2790 fn unfuse_decomposes_region_back_to_atomic_ops() {
2791 let f = DType::F32;
2794 let mut g = Graph::new("ew_unfuse");
2795 let a = g.input("a", Shape::new(&[8], f));
2796 let b = g.input("b", Shape::new(&[8], f));
2797 let c = g.input("c", Shape::new(&[8], f));
2798 let s = Shape::new(&[8], f);
2799 let add = g.binary(BinaryOp::Add, a, b, s.clone());
2800 let mul = g.binary(BinaryOp::Mul, add, c, s.clone());
2801 let relu = g.activation(Activation::Relu, mul, s);
2802 g.set_outputs(vec![relu]);
2803
2804 let fused = MarkElementwiseRegions.run(g);
2805 assert!(
2807 fused
2808 .nodes()
2809 .iter()
2810 .any(|n| matches!(n.op, Op::ElementwiseRegion { .. }))
2811 );
2812
2813 let unfused = UnfuseElementwiseRegions::FOR_CPU.run(fused);
2814 assert!(
2816 !unfused
2817 .nodes()
2818 .iter()
2819 .any(|n| matches!(n.op, Op::ElementwiseRegion { .. }))
2820 );
2821 let bin_count = unfused
2823 .nodes()
2824 .iter()
2825 .filter(|n| matches!(n.op, Op::Binary(_)))
2826 .count();
2827 let act_count = unfused
2828 .nodes()
2829 .iter()
2830 .filter(|n| matches!(n.op, Op::Activation(_)))
2831 .count();
2832 assert_eq!(bin_count, 2, "Add + Mul restored");
2833 assert_eq!(act_count, 1, "Relu restored");
2834 }
2835
2836 #[test]
2837 fn clip_unfuses_region_over_step_cap() {
2838 use rlx_ir::op::{Activation, ChainOperand, ChainStep};
2839
2840 let mut g = Graph::new("clip");
2841 let x = g.input("x", f32_shape(&[4]));
2842 let mut chain: Vec<ChainStep> = Vec::new();
2843 let mut prev = ChainOperand::Input(0);
2844 for _ in 0..40 {
2845 chain.push(ChainStep::Activation(Activation::Relu, prev));
2846 prev = ChainOperand::Step(chain.len() as u32 - 1);
2847 }
2848 let y = g.add_node(
2849 Op::ElementwiseRegion {
2850 chain,
2851 num_inputs: 1,
2852 scalar_input_mask: 0,
2853 input_modulus: [0; 16],
2854 prologue: RegionPrologue::None,
2855 prologue_input: 0,
2856 },
2857 vec![x],
2858 f32_shape(&[4]),
2859 );
2860 g.set_outputs(vec![y]);
2861
2862 let clipped = clip_elementwise_regions(g, FusionLimits::GPU_NATIVE);
2863 assert!(
2864 !clipped
2865 .nodes()
2866 .iter()
2867 .any(|n| matches!(n.op, Op::ElementwiseRegion { .. })),
2868 "oversized region should be decomposed"
2869 );
2870 assert!(clipped.len() > 5);
2871 }
2872
2873 #[test]
2874 fn unfuse_is_noop_when_no_region_present() {
2875 let f = DType::F32;
2876 let mut g = Graph::new("noop");
2877 let a = g.input("a", Shape::new(&[4], f));
2878 let r = g.activation(Activation::Relu, a, Shape::new(&[4], f));
2879 g.set_outputs(vec![r]);
2880 let n_before = g.len();
2881 let result = UnfuseElementwiseRegions::FOR_CPU.run(g);
2882 assert_eq!(result.len(), n_before);
2884 }
2885
2886 #[test]
2887 fn region_includes_where_step() {
2888 let f = DType::F32;
2893 let mut g = Graph::new("region_where");
2894 let a = g.input("a", Shape::new(&[4], f));
2895 let b = g.input("b", Shape::new(&[4], f));
2896 let s = Shape::new(&[4], f);
2897 let cmp = g.add_node(Op::Compare(CmpOp::Gt), vec![a, b], s.clone());
2898 let sel = g.add_node(Op::Where, vec![cmp, a, b], s.clone());
2899 let add = g.binary(BinaryOp::Add, sel, a, s.clone());
2900 g.set_outputs(vec![add]);
2901
2902 let fused = MarkElementwiseRegions.run(g);
2903 let regions: Vec<_> = fused
2904 .nodes()
2905 .iter()
2906 .filter(|n| matches!(n.op, Op::ElementwiseRegion { .. }))
2907 .collect();
2908 assert_eq!(regions.len(), 1, "expected one ElementwiseRegion");
2909 if let Op::ElementwiseRegion { chain, .. } = ®ions[0].op {
2910 assert_eq!(chain.len(), 3);
2912 assert!(
2913 matches!(chain[1], ChainStep::Where(_, _, _)),
2914 "step 1 should be Where, got {:?}",
2915 chain[1]
2916 );
2917 } else {
2918 unreachable!();
2919 }
2920 }
2921
2922 #[test]
2923 fn unfuse_decomposes_where_step_back_to_op_where() {
2924 let f = DType::F32;
2927 let mut g = Graph::new("unfuse_where");
2928 let a = g.input("a", Shape::new(&[4], f));
2929 let b = g.input("b", Shape::new(&[4], f));
2930 let s = Shape::new(&[4], f);
2931 let cmp = g.add_node(Op::Compare(CmpOp::Gt), vec![a, b], s.clone());
2932 let sel = g.add_node(Op::Where, vec![cmp, a, b], s.clone());
2933 let add = g.binary(BinaryOp::Add, sel, a, s.clone());
2934 g.set_outputs(vec![add]);
2935 let fused = MarkElementwiseRegions.run(g);
2936 let unfused = UnfuseElementwiseRegions::FOR_CPU.run(fused);
2937 let where_count = unfused
2938 .nodes()
2939 .iter()
2940 .filter(|n| matches!(n.op, Op::Where))
2941 .count();
2942 assert_eq!(
2943 where_count, 1,
2944 "decomposer should re-emit one Op::Where for the chain step"
2945 );
2946 }
2947
2948 #[test]
2952 fn fuse_attention_block_collapses_qkv_attn_outproj() {
2953 let nh: usize = 4;
2954 let dh: usize = 8;
2955 let h: usize = nh * dh; let b: usize = 1;
2957 let s: usize = 4; let mut g = Graph::new("attn-block");
2960 let hidden = g.input("hidden", f32_shape(&[b, s, h]));
2961 let mask = g.input("attention_mask", f32_shape(&[b, s]));
2962
2963 let qkv_w = g.param("qkv_w", f32_shape(&[h, 3 * h]));
2965 let qkv_b = g.param("qkv_b", f32_shape(&[3 * h]));
2966 let qkv_mm = g.matmul(hidden, qkv_w, f32_shape(&[b, s, 3 * h]));
2967 let qkv = g.binary(BinaryOp::Add, qkv_mm, qkv_b, f32_shape(&[b, s, 3 * h]));
2968
2969 let q = g.add_node(
2971 Op::Narrow {
2972 axis: 2,
2973 start: 0,
2974 len: h,
2975 },
2976 vec![qkv],
2977 f32_shape(&[b, s, h]),
2978 );
2979 let k = g.add_node(
2980 Op::Narrow {
2981 axis: 2,
2982 start: h,
2983 len: h,
2984 },
2985 vec![qkv],
2986 f32_shape(&[b, s, h]),
2987 );
2988 let v = g.add_node(
2989 Op::Narrow {
2990 axis: 2,
2991 start: 2 * h,
2992 len: h,
2993 },
2994 vec![qkv],
2995 f32_shape(&[b, s, h]),
2996 );
2997
2998 let attn = g.attention(q, k, v, mask, nh, dh, f32_shape(&[b, s, h]));
3000
3001 let out_w = g.param("out_w", f32_shape(&[h, h]));
3003 let out_b = g.param("out_b", f32_shape(&[h]));
3004 let out_mm = g.matmul(attn, out_w, f32_shape(&[b, s, h]));
3005 let out = g.binary(BinaryOp::Add, out_mm, out_b, f32_shape(&[b, s, h]));
3006 g.set_outputs(vec![out]);
3007
3008 let fused1 = FuseMatMulBiasAct.run(g);
3010 let mm_bias_count = fused1
3011 .nodes()
3012 .iter()
3013 .filter(|n| matches!(n.op, Op::FusedMatMulBiasAct { activation: None }))
3014 .count();
3015 assert_eq!(mm_bias_count, 2, "QKV + OutProj should each fuse");
3016
3017 let fused2 = FuseAttentionBlock.run(fused1);
3020 let fab_count = fused2
3021 .nodes()
3022 .iter()
3023 .filter(|n| {
3024 matches!(
3025 n.op,
3026 Op::FusedAttentionBlock {
3027 has_bias: true,
3028 has_rope: false,
3029 ..
3030 }
3031 )
3032 })
3033 .count();
3034 assert_eq!(
3035 fab_count, 1,
3036 "should produce exactly one FusedAttentionBlock"
3037 );
3038
3039 let narrow_count = fused2
3042 .nodes()
3043 .iter()
3044 .filter(|n| matches!(n.op, Op::Narrow { .. }))
3045 .count();
3046 let attention_count = fused2
3047 .nodes()
3048 .iter()
3049 .filter(|n| matches!(n.op, Op::Attention { .. }))
3050 .count();
3051 let mm_bias_remaining = fused2
3052 .nodes()
3053 .iter()
3054 .filter(|n| matches!(n.op, Op::FusedMatMulBiasAct { .. }))
3055 .count();
3056 assert_eq!(narrow_count, 0, "QKV narrows absorbed");
3057 assert_eq!(attention_count, 0, "Attention absorbed");
3058 assert_eq!(mm_bias_remaining, 0, "both projections absorbed");
3059
3060 let out_node = fused2.node(fused2.outputs[0]);
3061 assert!(matches!(out_node.op, Op::FusedAttentionBlock { .. }));
3062 }
3063
3064 #[test]
3068 fn fuse_transformer_layer_collapses_full_bert_block() {
3069 let nh: usize = 4;
3070 let dh: usize = 8;
3071 let h: usize = nh * dh;
3072 let inter = 4 * h;
3073 let eps1: f32 = 1e-12;
3074 let eps2: f32 = 1e-12;
3075 let b: usize = 1;
3076 let s: usize = 4;
3077
3078 let mut g = Graph::new("bert-layer");
3079 let hidden = g.input("hidden", f32_shape(&[b, s, h]));
3080 let mask = g.input("attention_mask", f32_shape(&[b, s]));
3081
3082 let qkv_w = g.param("qkv_w", f32_shape(&[h, 3 * h]));
3084 let qkv_b = g.param("qkv_b", f32_shape(&[3 * h]));
3085 let qkv_mm = g.matmul(hidden, qkv_w, f32_shape(&[b, s, 3 * h]));
3086 let qkv = g.binary(BinaryOp::Add, qkv_mm, qkv_b, f32_shape(&[b, s, 3 * h]));
3087 let q = g.add_node(
3088 Op::Narrow {
3089 axis: 2,
3090 start: 0,
3091 len: h,
3092 },
3093 vec![qkv],
3094 f32_shape(&[b, s, h]),
3095 );
3096 let k = g.add_node(
3097 Op::Narrow {
3098 axis: 2,
3099 start: h,
3100 len: h,
3101 },
3102 vec![qkv],
3103 f32_shape(&[b, s, h]),
3104 );
3105 let v = g.add_node(
3106 Op::Narrow {
3107 axis: 2,
3108 start: 2 * h,
3109 len: h,
3110 },
3111 vec![qkv],
3112 f32_shape(&[b, s, h]),
3113 );
3114 let attn = g.attention(q, k, v, mask, nh, dh, f32_shape(&[b, s, h]));
3115 let out_w = g.param("out_w", f32_shape(&[h, h]));
3116 let out_b = g.param("out_b", f32_shape(&[h]));
3117 let out_mm = g.matmul(attn, out_w, f32_shape(&[b, s, h]));
3118 let attn_out = g.binary(BinaryOp::Add, out_mm, out_b, f32_shape(&[b, s, h]));
3119
3120 let res1 = g.binary(BinaryOp::Add, attn_out, hidden, f32_shape(&[b, s, h]));
3122 let ln1_g = g.param("ln1_g", f32_shape(&[h]));
3123 let ln1_b = g.param("ln1_b", f32_shape(&[h]));
3124 let h1 = g.add_node(
3125 Op::LayerNorm {
3126 axis: -1,
3127 eps: eps1,
3128 },
3129 vec![res1, ln1_g, ln1_b],
3130 f32_shape(&[b, s, h]),
3131 );
3132
3133 let fc1_w = g.param("fc1_w", f32_shape(&[h, inter]));
3135 let fc1_b = g.param("fc1_b", f32_shape(&[inter]));
3136 let fc1_mm = g.matmul(h1, fc1_w, f32_shape(&[b, s, inter]));
3137 let fc1_add = g.binary(BinaryOp::Add, fc1_mm, fc1_b, f32_shape(&[b, s, inter]));
3138 let fc1_act = g.activation(Activation::Gelu, fc1_add, f32_shape(&[b, s, inter]));
3139 let fc2_w = g.param("fc2_w", f32_shape(&[inter, h]));
3140 let fc2_b = g.param("fc2_b", f32_shape(&[h]));
3141 let fc2_mm = g.matmul(fc1_act, fc2_w, f32_shape(&[b, s, h]));
3142 let ffn_out = g.binary(BinaryOp::Add, fc2_mm, fc2_b, f32_shape(&[b, s, h]));
3143
3144 let res2 = g.binary(BinaryOp::Add, ffn_out, h1, f32_shape(&[b, s, h]));
3146 let ln2_g = g.param("ln2_g", f32_shape(&[h]));
3147 let ln2_b = g.param("ln2_b", f32_shape(&[h]));
3148 let out = g.add_node(
3149 Op::LayerNorm {
3150 axis: -1,
3151 eps: eps2,
3152 },
3153 vec![res2, ln2_g, ln2_b],
3154 f32_shape(&[b, s, h]),
3155 );
3156 g.set_outputs(vec![out]);
3157
3158 let g = FuseMatMulBiasAct.run(g);
3160 let g = FuseResidualLN.run(g);
3161 let g = FuseAttentionBlock.run(g);
3162 let g = FuseTransformerLayer.run(g);
3163
3164 let ftl_count = g
3165 .nodes()
3166 .iter()
3167 .filter(|n| matches!(n.op, Op::FusedTransformerLayer { .. }))
3168 .count();
3169 assert_eq!(
3170 ftl_count, 1,
3171 "single layer should collapse to one FusedTransformerLayer"
3172 );
3173
3174 let leftover_fab = g
3178 .nodes()
3179 .iter()
3180 .filter(|n| matches!(n.op, Op::FusedAttentionBlock { .. }))
3181 .count();
3182 let leftover_frln = g
3183 .nodes()
3184 .iter()
3185 .filter(|n| matches!(n.op, Op::FusedResidualLN { .. }))
3186 .count();
3187 let leftover_fmba = g
3188 .nodes()
3189 .iter()
3190 .filter(|n| matches!(n.op, Op::FusedMatMulBiasAct { .. }))
3191 .count();
3192 assert_eq!(leftover_fab, 0, "attn block absorbed into layer");
3193 assert_eq!(leftover_frln, 0, "both residual+LNs absorbed");
3194 assert_eq!(leftover_fmba, 0, "FFN matmuls absorbed");
3195
3196 let out_node = g.node(g.outputs[0]);
3197 assert!(matches!(
3198 out_node.op,
3199 Op::FusedTransformerLayer {
3200 num_heads: 4,
3201 head_dim: 8,
3202 intermediate_size: 128,
3203 has_bias: true,
3204 ..
3205 }
3206 ));
3207 assert_eq!(out_node.inputs.len(), 14);
3208 }
3209
3210 #[test]
3213 fn fuse_attention_block_skips_large_inputs() {
3214 let nh: usize = 4;
3215 let dh: usize = 8;
3216 let h: usize = nh * dh;
3217 let b: usize = 16;
3218 let s: usize = 128; let mut g = Graph::new("attn-block-large");
3221 let hidden = g.input("hidden", f32_shape(&[b, s, h]));
3222 let mask = g.input("attention_mask", f32_shape(&[b, s]));
3223 let qkv_w = g.param("qkv_w", f32_shape(&[h, 3 * h]));
3224 let qkv_b = g.param("qkv_b", f32_shape(&[3 * h]));
3225 let qkv_mm = g.matmul(hidden, qkv_w, f32_shape(&[b, s, 3 * h]));
3226 let qkv = g.binary(BinaryOp::Add, qkv_mm, qkv_b, f32_shape(&[b, s, 3 * h]));
3227 let q = g.add_node(
3228 Op::Narrow {
3229 axis: 2,
3230 start: 0,
3231 len: h,
3232 },
3233 vec![qkv],
3234 f32_shape(&[b, s, h]),
3235 );
3236 let k = g.add_node(
3237 Op::Narrow {
3238 axis: 2,
3239 start: h,
3240 len: h,
3241 },
3242 vec![qkv],
3243 f32_shape(&[b, s, h]),
3244 );
3245 let v = g.add_node(
3246 Op::Narrow {
3247 axis: 2,
3248 start: 2 * h,
3249 len: h,
3250 },
3251 vec![qkv],
3252 f32_shape(&[b, s, h]),
3253 );
3254 let attn = g.attention(q, k, v, mask, nh, dh, f32_shape(&[b, s, h]));
3255 let out_w = g.param("out_w", f32_shape(&[h, h]));
3256 let out_b = g.param("out_b", f32_shape(&[h]));
3257 let out_mm = g.matmul(attn, out_w, f32_shape(&[b, s, h]));
3258 let out = g.binary(BinaryOp::Add, out_mm, out_b, f32_shape(&[b, s, h]));
3259 g.set_outputs(vec![out]);
3260
3261 let fused1 = FuseMatMulBiasAct.run(g);
3262 let fused2 = FuseAttentionBlock.run(fused1);
3263 let fab_count = fused2
3264 .nodes()
3265 .iter()
3266 .filter(|n| matches!(n.op, Op::FusedAttentionBlock { .. }))
3267 .count();
3268 assert_eq!(fab_count, 0, "block-fusion must skip large batches");
3269 }
3270}