1use crate::pass::Pass;
23use rlx_ir::op::*;
24use rlx_ir::*;
25use std::collections::HashMap;
26
27struct Rewriter {
31 new_graph: Graph,
32 id_map: HashMap<NodeId, NodeId>,
33}
34
35impl Rewriter {
36 fn new(name: &str) -> Self {
37 Self {
38 new_graph: Graph::new(name),
39 id_map: HashMap::new(),
40 }
41 }
42
43 fn map(&self, old: NodeId) -> NodeId {
45 self.id_map[&old]
46 }
47
48 fn map_inputs(&self, old_inputs: &[NodeId]) -> Vec<NodeId> {
50 old_inputs.iter().map(|id| self.map(*id)).collect()
51 }
52
53 #[allow(dead_code)]
56 fn all_mapped(&self, ids: &[NodeId]) -> bool {
57 ids.iter().all(|id| self.id_map.contains_key(id))
58 }
59
60 fn ensure_mapped(&mut self, old: &Graph, ids: &[NodeId]) {
65 for &id in ids {
66 if self.id_map.contains_key(&id) {
67 continue;
68 }
69 let node = old.node(id);
70 if !node.inputs.is_empty() {
71 self.ensure_mapped(old, &node.inputs);
72 }
73 self.copy_node(node);
74 }
75 }
76
77 fn copy_node(&mut self, node: &Node) -> NodeId {
79 let new_inputs = self.map_inputs(&node.inputs);
80 let new_id = self
81 .new_graph
82 .add_node(node.op.clone(), new_inputs, node.shape.clone());
83 let new_node = self.new_graph.node_mut(new_id);
84 new_node.name = node.name.clone();
85 new_node.origin = node.origin.clone();
86 self.id_map.insert(node.id, new_id);
87 new_id
88 }
89
90 fn add_fused(&mut self, op: Op, old_inputs: &[NodeId], shape: Shape) -> NodeId {
92 let new_inputs: Vec<NodeId> = old_inputs.iter().map(|id| self.map(*id)).collect();
93 self.new_graph.add_node(op, new_inputs, shape)
94 }
95
96 fn replace(&mut self, old_id: NodeId, new_id: NodeId) {
98 self.id_map.insert(old_id, new_id);
99 }
100
101 fn finish(mut self, old_outputs: &[NodeId]) -> Graph {
102 let new_outputs = old_outputs.iter().map(|id| self.map(*id)).collect();
103 self.new_graph.set_outputs(new_outputs);
104 self.new_graph
105 }
106}
107
108pub struct FuseMatMulBiasAct;
123
124fn fusible_mm_bias_epilogue_activation(act: Activation) -> bool {
126 matches!(act, Activation::Gelu | Activation::Silu)
127}
128
129impl Pass for FuseMatMulBiasAct {
130 fn name(&self) -> &str {
131 "fuse_matmul_bias_act"
132 }
133
134 fn run(&self, graph: Graph) -> Graph {
135 let mut rw = Rewriter::new(&graph.name);
136 let mut fused_away: HashMap<NodeId, ()> = HashMap::new();
138
139 for node in graph.nodes() {
141 if fused_away.contains_key(&node.id) {
142 continue;
143 }
144
145 if matches!(node.op, Op::MatMul) {
148 let mm_id = node.id;
149 let mm_users: Vec<_> = graph.users(mm_id);
150
151 if mm_users.len() == 1 {
153 let add_node = graph.node(mm_users[0]);
154 if let Op::Binary(BinaryOp::Add) = &add_node.op {
155 let (bias_id, _mm_input) = if add_node.inputs[0] == mm_id {
157 (add_node.inputs[1], add_node.inputs[0])
158 } else {
159 (add_node.inputs[0], add_node.inputs[1])
160 };
161
162 let bias_shape = graph.shape(bias_id);
164 if bias_shape.rank() <= 1 {
165 let add_id = add_node.id;
166 let add_users = graph.users(add_id);
167
168 let mut activation = None;
170 let mut act_id = None;
171 if add_users.len() == 1 {
172 let act_node = graph.node(add_users[0]);
173 if let Op::Activation(a) = &act_node.op
174 && fusible_mm_bias_epilogue_activation(*a)
175 {
176 activation = Some(*a);
177 act_id = Some(act_node.id);
178 }
179 }
180
181 let out_shape = if let Some(aid) = act_id {
185 graph.shape(aid).clone()
186 } else {
187 add_node.shape.clone()
188 };
189
190 rw.ensure_mapped(&graph, &[node.inputs[0], node.inputs[1], bias_id]);
191 let fused_id = rw.add_fused(
192 Op::FusedMatMulBiasAct { activation },
193 &[node.inputs[0], node.inputs[1], bias_id],
194 out_shape,
195 );
196
197 rw.replace(mm_id, fused_id);
199 rw.replace(add_id, fused_id);
200 fused_away.insert(add_id, ());
201 if let Some(aid) = act_id {
202 rw.replace(aid, fused_id);
203 fused_away.insert(aid, ());
204 }
205 continue;
206 }
207 }
208 }
209 }
210
211 rw.copy_node(node);
213 }
214
215 rw.finish(&graph.outputs)
216 }
217}
218
219pub struct FuseResidualLN;
226
227impl Pass for FuseResidualLN {
228 fn name(&self) -> &str {
229 "fuse_residual_ln"
230 }
231
232 fn run(&self, graph: Graph) -> Graph {
233 let mut is_output: HashMap<NodeId, ()> = HashMap::new();
241 for &oid in &graph.outputs {
242 is_output.insert(oid, ());
243 }
244 let mut fused_away: HashMap<NodeId, ()> = HashMap::new();
246 for node in graph.nodes() {
247 if let Op::LayerNorm { .. } = &node.op {
248 let ln_input_id = node.inputs[0];
249 let ln_input = graph.node(ln_input_id);
250 if matches!(ln_input.op, Op::Binary(BinaryOp::Add))
251 && graph.use_count(ln_input_id) == 1
252 && !is_output.contains_key(&ln_input_id)
253 {
254 fused_away.insert(ln_input_id, ());
255 }
256 }
257 }
258
259 let mut rw = Rewriter::new(&graph.name);
260
261 for node in graph.nodes() {
262 if fused_away.contains_key(&node.id) {
263 continue;
264 }
265
266 if let Op::LayerNorm { eps, .. } = &node.op {
267 let ln_input_id = node.inputs[0];
268 let ln_input = graph.node(ln_input_id);
269
270 if matches!(ln_input.op, Op::Binary(BinaryOp::Add))
271 && fused_away.contains_key(&ln_input_id)
272 {
273 let (x_id, residual_id) = (ln_input.inputs[0], ln_input.inputs[1]);
274 let gamma_id = node.inputs[1];
275 let beta_id = node.inputs[2];
276
277 let fused_id = rw.add_fused(
278 Op::FusedResidualLN {
279 has_bias: false,
280 eps: *eps,
281 },
282 &[x_id, residual_id, gamma_id, beta_id],
283 node.shape.clone(),
284 );
285
286 rw.replace(ln_input_id, fused_id);
287 rw.replace(node.id, fused_id);
288 continue;
289 }
290 }
291
292 rw.copy_node(node);
293 }
294
295 rw.finish(&graph.outputs)
296 }
297}
298
299pub struct FuseResidualRmsNorm;
303
304impl Pass for FuseResidualRmsNorm {
305 fn name(&self) -> &str {
306 "fuse_residual_rms_norm"
307 }
308
309 fn run(&self, graph: Graph) -> Graph {
310 let mut is_output: HashMap<NodeId, ()> = HashMap::new();
311 for &oid in &graph.outputs {
312 is_output.insert(oid, ());
313 }
314 let mut fused_away: HashMap<NodeId, ()> = HashMap::new();
315 for node in graph.nodes() {
316 if let Op::RmsNorm { .. } = &node.op {
317 let rn_input_id = node.inputs[0];
318 let rn_input = graph.node(rn_input_id);
319 if matches!(rn_input.op, Op::Binary(BinaryOp::Add))
320 && graph.use_count(rn_input_id) == 1
321 && !is_output.contains_key(&rn_input_id)
322 {
323 fused_away.insert(rn_input_id, ());
324 }
325 }
326 }
327
328 let mut rw = Rewriter::new(&graph.name);
329
330 for node in graph.nodes() {
331 if fused_away.contains_key(&node.id) {
332 continue;
333 }
334
335 if let Op::RmsNorm { eps, .. } = &node.op {
336 let rn_input_id = node.inputs[0];
337 let rn_input = graph.node(rn_input_id);
338
339 if matches!(rn_input.op, Op::Binary(BinaryOp::Add))
340 && fused_away.contains_key(&rn_input_id)
341 {
342 let (x_id, residual_id) = (rn_input.inputs[0], rn_input.inputs[1]);
343 let gamma_id = node.inputs[1];
344 let beta_id = node.inputs[2];
345
346 let fused_id = rw.add_fused(
347 Op::FusedResidualRmsNorm {
348 has_bias: false,
349 eps: *eps,
350 },
351 &[x_id, residual_id, gamma_id, beta_id],
352 node.shape.clone(),
353 );
354
355 rw.replace(rn_input_id, fused_id);
356 rw.replace(node.id, fused_id);
357 continue;
358 }
359 }
360
361 rw.copy_node(node);
362 }
363
364 rw.finish(&graph.outputs)
365 }
366}
367
368pub struct FuseRmsNormReshape;
376
377fn leading_flatten_shape(in_shape: &Shape, new_shape: &[i64]) -> Option<Shape> {
378 rlx_ir::shape::leading_flatten_shape(in_shape, new_shape)
379}
380
381fn sole_consumer(graph: &Graph, id: NodeId) -> Option<NodeId> {
382 graph
383 .nodes()
384 .iter()
385 .find(|n| n.inputs.contains(&id))
386 .map(|n| n.id)
387}
388
389impl Pass for FuseRmsNormReshape {
390 fn name(&self) -> &str {
391 "fuse_rms_norm_reshape"
392 }
393
394 fn run(&self, graph: Graph) -> Graph {
395 let mut is_output: HashMap<NodeId, ()> = HashMap::new();
396 for &oid in &graph.outputs {
397 is_output.insert(oid, ());
398 }
399
400 let mut flat_shape: HashMap<NodeId, Shape> = HashMap::new();
401 let mut fused_away: HashMap<NodeId, ()> = HashMap::new();
402 for node in graph.nodes() {
403 if let Op::RmsNorm { .. } = &node.op {
404 if graph.use_count(node.id) != 1 || is_output.contains_key(&node.id) {
405 continue;
406 }
407 let Some(reshape_id) = sole_consumer(&graph, node.id) else {
408 continue;
409 };
410 if is_output.contains_key(&reshape_id) {
411 continue;
412 }
413 let reshape = graph.node(reshape_id);
414 if let Op::Reshape { new_shape } = &reshape.op {
415 if let Some(flat) = leading_flatten_shape(&node.shape, new_shape) {
416 flat_shape.insert(node.id, flat);
417 fused_away.insert(reshape_id, ());
418 }
419 }
420 }
421 }
422
423 let mut rw = Rewriter::new(&graph.name);
424
425 for node in graph.nodes() {
426 if fused_away.contains_key(&node.id) {
427 continue;
428 }
429
430 if let Op::RmsNorm { axis, eps, .. } = &node.op {
431 if let Some(flat) = flat_shape.get(&node.id) {
432 let Some(reshape_id) = sole_consumer(&graph, node.id) else {
433 rw.copy_node(node);
434 continue;
435 };
436 let fused_id = rw.add_fused(
437 Op::RmsNorm {
438 axis: *axis,
439 eps: *eps,
440 },
441 &node.inputs,
442 flat.clone(),
443 );
444 rw.replace(node.id, fused_id);
445 rw.replace(reshape_id, fused_id);
446 continue;
447 }
448 }
449
450 rw.copy_node(node);
451 }
452
453 rw.finish(&graph.outputs)
454 }
455}
456
457pub struct FuseSwiGLUDualMatmul;
469
470impl FuseSwiGLUDualMatmul {
471 fn match_dual_swiglu(
472 graph: &Graph,
473 mul_node: &Node,
474 ) -> Option<(NodeId, NodeId, NodeId, NodeId, NodeId)> {
475 if !matches!(mul_node.op, Op::Binary(BinaryOp::Mul)) {
476 return None;
477 }
478 let lhs = graph.node(mul_node.inputs[0]);
479 let rhs = graph.node(mul_node.inputs[1]);
480 let (up_mm, silu_id, silu_node) = if matches!(rhs.op, Op::Activation(Activation::Silu)) {
481 (lhs, mul_node.inputs[1], rhs)
482 } else if matches!(lhs.op, Op::Activation(Activation::Silu)) {
483 (rhs, mul_node.inputs[0], lhs)
484 } else {
485 return None;
486 };
487 if !matches!(up_mm.op, Op::MatMul) {
488 return None;
489 }
490 let gate_mm = graph.node(silu_node.inputs[0]);
491 if !matches!(gate_mm.op, Op::MatMul) {
492 return None;
493 }
494 if up_mm.inputs[0] != gate_mm.inputs[0] {
495 return None;
496 }
497 if graph.use_count(silu_id) != 1 {
498 return None;
499 }
500 Some((mul_node.id, gate_mm.id, up_mm.id, up_mm.inputs[0], silu_id))
501 }
502}
503
504impl Pass for FuseSwiGLUDualMatmul {
505 fn name(&self) -> &str {
506 "fuse_swiglu_dual_matmul"
507 }
508
509 fn run(&self, graph: Graph) -> Graph {
510 let mut matches: Vec<(NodeId, NodeId, NodeId, NodeId, NodeId)> = Vec::new();
511 let mut consumed: HashMap<NodeId, ()> = HashMap::new();
512
513 for node in graph.nodes() {
514 if let Some((mul_id, gate_mm, up_mm, _, silu_id)) =
515 Self::match_dual_swiglu(&graph, node)
516 {
517 matches.push((mul_id, gate_mm, up_mm, graph.node(up_mm).inputs[0], silu_id));
518 consumed.insert(gate_mm, ());
519 consumed.insert(up_mm, ());
520 consumed.insert(silu_id, ());
521 }
522 }
523
524 if matches.is_empty() {
525 return graph;
526 }
527
528 let match_by_mul: HashMap<NodeId, (NodeId, NodeId, NodeId)> = matches
529 .into_iter()
530 .map(|(mul, gate, up, input, _silu)| (mul, (gate, up, input)))
531 .collect();
532
533 let mut rw = Rewriter::new(&graph.name);
534 for node in graph.nodes() {
535 if consumed.contains_key(&node.id) {
536 continue;
537 }
538 if let Some(&(gate_mm, up_mm, input_id)) = match_by_mul.get(&node.id) {
539 let gate = graph.node(gate_mm);
540 let up = graph.node(up_mm);
541 let wg = gate.inputs[1];
542 let wu = up.inputs[1];
543 rw.ensure_mapped(&graph, &[input_id, wg, wu]);
544
545 let wu_shape = graph.shape(wu);
546 let wg_shape = graph.shape(wg);
547 let k = wu_shape.dim(0).unwrap_static();
548 let n_up = wu_shape.dim(1).unwrap_static();
549 let n_gate = wg_shape.dim(1).unwrap_static();
550 debug_assert_eq!(wu_shape.dim(0), wg_shape.dim(0));
551
552 let concat_shape = Shape::new(&[k, n_up + n_gate], wu_shape.dtype());
554 let concat_w = rw.add_fused(Op::Concat { axis: 1 }, &[wu, wg], concat_shape);
555
556 let out_rank = up.shape.rank();
557 let mut mm_dims: Vec<Dim> = (0..out_rank).map(|i| up.shape.dim(i)).collect();
558 mm_dims[out_rank - 1] = Dim::Static(n_up + n_gate);
559 let cat_shape = Shape::from_dims(&mm_dims, up.shape.dtype());
560 let cat_id =
561 rw.new_graph
562 .add_node(Op::MatMul, vec![rw.map(input_id), concat_w], cat_shape);
563
564 let fused_id = rw.new_graph.add_node(
565 Op::FusedSwiGLU {
566 cast_to: None,
567 gate_first: false,
568 },
569 vec![cat_id],
570 node.shape.clone(),
571 );
572 rw.replace(node.id, fused_id);
573 continue;
574 }
575 rw.copy_node(node);
576 }
577 rw.finish(&graph.outputs)
578 }
579}
580
581pub struct FuseSharedInputMatMul;
597
598impl Pass for FuseSharedInputMatMul {
599 fn name(&self) -> &str {
600 "fuse_shared_input_matmul"
601 }
602
603 fn run(&self, graph: Graph) -> Graph {
604 struct FuseGroup {
605 input_id: NodeId,
606 matmul_ids: Vec<NodeId>,
607 }
608
609 let mut input_to_matmuls: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
610 for node in graph.nodes() {
611 if matches!(node.op, Op::MatMul) {
612 input_to_matmuls
613 .entry(node.inputs[0])
614 .or_default()
615 .push(node.id);
616 }
617 }
618
619 let mut groups: Vec<FuseGroup> = Vec::new();
620 for (input_id, matmul_ids) in input_to_matmuls {
621 if matmul_ids.len() < 2 {
622 continue;
623 }
624 let first = graph.node(matmul_ids[0]);
625 let w0 = graph.shape(first.inputs[1]);
626 if w0.rank() != 2 {
627 continue;
628 }
629 let compatible = matmul_ids.iter().all(|&id| {
630 let m = graph.node(id);
631 matches!(m.op, Op::MatMul)
632 && graph.shape(m.inputs[1]).rank() == 2
633 && graph.shape(m.inputs[1]).dim(0) == w0.dim(0)
634 });
635 if compatible {
636 groups.push(FuseGroup {
637 input_id,
638 matmul_ids,
639 });
640 }
641 }
642
643 if groups.is_empty() {
644 return graph;
645 }
646
647 let group_by_first: HashMap<NodeId, &FuseGroup> =
648 groups.iter().map(|g| (g.matmul_ids[0], g)).collect();
649
650 let mut fused_away: HashMap<NodeId, ()> = HashMap::new();
651 for g in &groups {
652 for &id in &g.matmul_ids[1..] {
653 fused_away.insert(id, ());
654 }
655 }
656
657 let mut rw = Rewriter::new(&graph.name);
658 for node in graph.nodes() {
659 if fused_away.contains_key(&node.id) {
660 continue;
661 }
662
663 if let Some(group) = group_by_first.get(&node.id) {
664 let matmuls: Vec<_> = group.matmul_ids.iter().map(|&id| graph.node(id)).collect();
665 let weight_ids: Vec<NodeId> = matmuls.iter().map(|m| m.inputs[1]).collect();
666 rw.ensure_mapped(&graph, std::slice::from_ref(&group.input_id));
667 rw.ensure_mapped(&graph, &weight_ids);
668
669 let w0_shape = graph.shape(weight_ids[0]);
670 let k = w0_shape.dim(0).unwrap_static();
671 let ns: Vec<usize> = weight_ids
672 .iter()
673 .map(|&w| graph.shape(w).dim(1).unwrap_static())
674 .collect();
675 let combined_n: usize = ns.iter().sum();
676
677 let concat_shape = Shape::new(&[k, combined_n], w0_shape.dtype());
678 let concat_id = rw.add_fused(Op::Concat { axis: 1 }, &weight_ids, concat_shape);
679
680 let out_rank = matmuls[0].shape.rank();
681 let mut mm_dims: Vec<Dim> =
682 (0..out_rank).map(|i| matmuls[0].shape.dim(i)).collect();
683 mm_dims[out_rank - 1] = Dim::Static(combined_n);
684 let mm_shape = Shape::from_dims(&mm_dims, matmuls[0].shape.dtype());
685 let mm_id = rw.new_graph.add_node(
686 Op::MatMul,
687 vec![rw.map(group.input_id), concat_id],
688 mm_shape,
689 );
690
691 let mut start = 0usize;
692 for (mm, &n) in matmuls.iter().zip(&ns) {
693 let narrow = rw.new_graph.add_node(
694 Op::Narrow {
695 axis: out_rank - 1,
696 start,
697 len: n,
698 },
699 vec![mm_id],
700 mm.shape.clone(),
701 );
702 rw.replace(mm.id, narrow);
703 start += n;
704 }
705 continue;
706 }
707
708 rw.copy_node(node);
709 }
710
711 rw.finish(&graph.outputs)
712 }
713}
714
715pub struct FuseSwiGLU;
736
737impl Pass for FuseSwiGLU {
738 fn name(&self) -> &str {
739 "fuse_swiglu"
740 }
741
742 fn run(&self, graph: Graph) -> Graph {
743 #[allow(dead_code)]
749 struct Match {
750 mul_id: NodeId,
751 up_narrow_id: NodeId,
752 silu_id: NodeId,
753 gate_narrow_id: NodeId,
754 cat_id: NodeId,
755 out_n: usize,
756 gate_first: bool,
757 }
758
759 let mut matches: Vec<Match> = Vec::new();
760 let mut consumed: HashMap<NodeId, ()> = HashMap::new();
761
762 for node in graph.nodes() {
763 if !matches!(node.op, Op::Binary(BinaryOp::Mul)) {
766 continue;
767 }
768 let lhs_id = node.inputs[0];
769 let rhs_id = node.inputs[1];
770 let lhs = graph.node(lhs_id);
771 let rhs = graph.node(rhs_id);
772
773 let (up_narrow, silu_id, silu_node) =
775 if matches!(rhs.op, Op::Activation(Activation::Silu)) {
776 (lhs, rhs_id, rhs)
777 } else if matches!(lhs.op, Op::Activation(Activation::Silu)) {
778 (rhs, lhs_id, lhs)
779 } else {
780 continue;
781 };
782
783 let (up_axis, up_start, up_len) = match &up_narrow.op {
785 Op::Narrow { axis, start, len } => (*axis, *start, *len),
786 _ => continue,
787 };
788 let gate_narrow_id = silu_node.inputs[0];
790 let gate_narrow = graph.node(gate_narrow_id);
791 let (g_axis, g_start, g_len) = match &gate_narrow.op {
792 Op::Narrow { axis, start, len } => (*axis, *start, *len),
793 _ => continue,
794 };
795
796 if up_narrow.inputs[0] != gate_narrow.inputs[0] {
799 continue;
800 }
801 if up_axis != g_axis {
802 continue;
803 }
804 if up_len != g_len {
805 continue;
806 }
807 let n = up_len;
808 let gate_first = up_start == n && g_start == 0;
810 if !(gate_first || (up_start == 0 && g_start == n)) {
811 continue;
812 }
813
814 if graph.use_count(up_narrow.id) != 1 {
817 continue;
818 }
819 if graph.use_count(gate_narrow_id) != 1 {
820 continue;
821 }
822 if graph.use_count(silu_id) != 1 {
823 continue;
824 }
825
826 matches.push(Match {
827 mul_id: node.id,
828 up_narrow_id: up_narrow.id,
829 silu_id,
830 gate_narrow_id,
831 cat_id: up_narrow.inputs[0],
832 out_n: n,
833 gate_first,
834 });
835 consumed.insert(up_narrow.id, ());
836 consumed.insert(gate_narrow_id, ());
837 consumed.insert(silu_id, ());
838 }
839
840 if matches.is_empty() {
841 return graph;
842 }
843
844 let mut rw = Rewriter::new(&graph.name);
846 let match_by_mul: HashMap<NodeId, &Match> = matches.iter().map(|m| (m.mul_id, m)).collect();
847
848 for node in graph.nodes() {
849 if consumed.contains_key(&node.id) {
850 continue;
851 }
852
853 if let Some(m) = match_by_mul.get(&node.id) {
854 let out_shape = node.shape.clone();
856 debug_assert_eq!(
857 out_shape.dim(out_shape.rank() - 1).unwrap_static(),
858 m.out_n,
859 "FuseSwiGLU: output last dim should be N"
860 );
861 let fused_id = rw.add_fused(
862 Op::FusedSwiGLU {
863 cast_to: None,
864 gate_first: m.gate_first,
865 },
866 &[m.cat_id],
867 out_shape,
868 );
869 rw.replace(node.id, fused_id);
870 continue;
871 }
872
873 rw.copy_node(node);
874 }
875
876 rw.finish(&graph.outputs)
877 }
878}
879
880pub struct FuseAttentionBlock;
891
892impl FuseAttentionBlock {
893 #[allow(dead_code)]
898 fn should_fuse(graph: &Graph) -> bool {
899 let threshold: usize = rlx_ir::env::var("RLX_FUSE_ATTN_THRESHOLD")
900 .and_then(|v| v.parse().ok())
901 .unwrap_or(64);
902 for node in graph.nodes() {
903 if let Op::Input { .. } = &node.op
904 && node.shape.rank() >= 2
905 {
906 let d0 = node.shape.dim(0);
907 let d1 = node.shape.dim(1);
908 if d0.is_static() && d1.is_static() {
909 let b = d0.unwrap_static();
910 let s = d1.unwrap_static();
911 if b * s <= threshold {
912 return true;
913 }
914 }
915 }
916 }
917 false
918 }
919}
920
921impl Pass for FuseAttentionBlock {
922 fn name(&self) -> &str {
923 "fuse_attention_block"
924 }
925
926 fn run(&self, graph: Graph) -> Graph {
927 graph
931 }
932}
933
934pub struct MarkElementwiseRegions;
957
958impl Pass for MarkElementwiseRegions {
959 fn name(&self) -> &str {
960 "mark_elementwise_regions"
961 }
962
963 fn run(&self, graph: Graph) -> Graph {
964 let mut consumers: HashMap<NodeId, usize> = HashMap::new();
966 for node in graph.nodes() {
967 for &input in &node.inputs {
968 *consumers.entry(input).or_insert(0) += 1;
969 }
970 }
971 for &out in &graph.outputs {
972 *consumers.entry(out).or_insert(0) += 1;
973 }
974
975 let chain_eligible = |op: &Op| -> bool {
977 matches!(
978 op,
979 Op::Activation(_) | Op::Cast { .. } | Op::Binary(_) | Op::Compare(_) | Op::Where
980 )
981 };
982
983 let chain_step_safe = |graph: &Graph, node: &rlx_ir::Node| -> bool {
992 match &node.op {
993 Op::Cast { to } => {
994 let in_dt = graph.shape(node.inputs[0]).dtype();
995 *to == in_dt
996 }
997 _ => true,
998 }
999 };
1000
1001 let mut region_of: HashMap<NodeId, NodeId> = HashMap::new();
1010 let mut chain_step_idx: HashMap<NodeId, u32> = HashMap::new();
1011
1012 for node in graph.nodes() {
1013 if !chain_eligible(&node.op) {
1014 continue;
1015 }
1016 if !chain_step_safe(&graph, node) {
1017 continue;
1018 }
1019 let out_shape = &node.shape;
1026 let out_elems = out_shape.num_elements();
1027 let shape_ok = node.inputs.iter().all(|id| {
1028 let in_elems = graph.shape(*id).num_elements();
1029 match (in_elems, out_elems) {
1030 (Some(i), Some(o)) if i == o => true,
1031 (Some(i), Some(o)) if i > 0 && o % i == 0 => true,
1032 _ => false,
1033 }
1034 });
1035 if !shape_ok {
1036 continue;
1037 }
1038 let mut parent_root: Option<NodeId> = None;
1043 let mut all_inputs_single_consumer = true;
1044 for &input in &node.inputs {
1045 if graph.node(input).op.is_fusion_boundary() {
1047 parent_root = None;
1048 all_inputs_single_consumer = false;
1049 break;
1050 }
1051 if let Some(&root) = region_of.get(&input) {
1052 if consumers.get(&input).copied() != Some(1) {
1053 all_inputs_single_consumer = false;
1054 break;
1055 }
1056 match parent_root {
1057 None => parent_root = Some(root),
1058 Some(r) if r == root => {}
1059 Some(_) => {
1060 parent_root = None;
1061 all_inputs_single_consumer = false;
1062 break;
1063 }
1064 }
1065 }
1066 }
1067 if !all_inputs_single_consumer {
1068 region_of.insert(node.id, node.id);
1070 chain_step_idx.insert(node.id, 0);
1071 continue;
1072 }
1073 let root = parent_root.unwrap_or(node.id);
1074 let next_idx = node
1076 .inputs
1077 .iter()
1078 .filter_map(|id| {
1079 if region_of.get(id) == Some(&root) {
1080 chain_step_idx.get(id).copied()
1081 } else {
1082 None
1083 }
1084 })
1085 .max()
1086 .map(|m| m + 1)
1087 .unwrap_or(0);
1088 let limits = crate::limits::active_fusion_limits();
1089 if next_idx >= limits.max_elementwise_steps {
1090 region_of.insert(node.id, node.id);
1091 chain_step_idx.insert(node.id, 0);
1092 continue;
1093 }
1094 region_of.insert(node.id, root);
1095 chain_step_idx.insert(node.id, next_idx);
1096 }
1097
1098 let mut by_region: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
1101 for node in graph.nodes() {
1102 if let Some(&root) = region_of.get(&node.id) {
1103 by_region.entry(root).or_default().push(node.id);
1104 }
1105 }
1106
1107 let mut tail_of_region: HashMap<NodeId, NodeId> = HashMap::new();
1113 for (root, members) in &by_region {
1114 if members.len() < 2 {
1115 continue;
1116 }
1117 let max_idx = members.iter().map(|id| chain_step_idx[id]).max().unwrap();
1118 let tails: Vec<_> = members
1119 .iter()
1120 .filter(|id| chain_step_idx[id] == max_idx)
1121 .collect();
1122 if tails.len() != 1 {
1123 continue;
1124 }
1125 tail_of_region.insert(*root, *tails[0]);
1126 }
1127
1128 let by_region: HashMap<NodeId, Vec<NodeId>> = by_region
1130 .into_iter()
1131 .filter(|(root, _)| tail_of_region.contains_key(root))
1132 .collect();
1133
1134 if by_region.is_empty() {
1135 return graph;
1136 }
1137
1138 let mut rw = Rewriter::new(&graph.name);
1142 let mut emitted_region: HashMap<NodeId, NodeId> = HashMap::new();
1144
1145 for node in graph.nodes() {
1146 if let Some(&root) = region_of.get(&node.id)
1147 && let Some(&tail) = tail_of_region.get(&root)
1148 {
1149 if emitted_region.contains_key(&root) {
1150 let region_new = emitted_region[&root];
1156 rw.replace(node.id, region_new);
1157 continue;
1158 }
1159 if node.id == tail {
1160 let members = &by_region[&root];
1162 let mut ordered: Vec<NodeId> = members.clone();
1163 ordered.sort_by_key(|id| chain_step_idx[id]);
1164
1165 let mut external_inputs: Vec<NodeId> = Vec::new();
1169 let mut input_idx_of: HashMap<NodeId, u32> = HashMap::new();
1170 let mut step_idx_of: HashMap<NodeId, u32> = HashMap::new();
1171 for (i, member_id) in ordered.iter().enumerate() {
1172 step_idx_of.insert(*member_id, i as u32);
1173 let n = graph.node(*member_id);
1174 for &inp in &n.inputs {
1175 if !step_idx_of.contains_key(&inp) && !input_idx_of.contains_key(&inp) {
1176 let idx = external_inputs.len() as u32;
1177 input_idx_of.insert(inp, idx);
1178 external_inputs.push(inp);
1179 }
1180 }
1181 }
1182
1183 let limits = crate::limits::active_fusion_limits();
1184 if external_inputs.len() as u32 > limits.max_elementwise_inputs
1185 || ordered.len() as u32 > limits.max_elementwise_steps
1186 {
1187 for &mid in &ordered {
1188 rw.copy_node(graph.node(mid));
1189 }
1190 continue;
1191 }
1192
1193 let resolve = |id: NodeId| -> ChainOperand {
1194 if let Some(&i) = input_idx_of.get(&id) {
1195 ChainOperand::Input(i)
1196 } else {
1197 ChainOperand::Step(step_idx_of[&id])
1198 }
1199 };
1200 let mut chain: Vec<ChainStep> = Vec::with_capacity(ordered.len());
1201 for member_id in &ordered {
1202 let n = graph.node(*member_id);
1203 let step = match &n.op {
1204 Op::Activation(a) => ChainStep::Activation(*a, resolve(n.inputs[0])),
1205 Op::Cast { to } => ChainStep::Cast(*to, resolve(n.inputs[0])),
1206 Op::Binary(op) => {
1207 ChainStep::Binary(*op, resolve(n.inputs[0]), resolve(n.inputs[1]))
1208 }
1209 Op::Compare(op) => {
1210 ChainStep::Compare(*op, resolve(n.inputs[0]), resolve(n.inputs[1]))
1211 }
1212 Op::Where => ChainStep::Where(
1213 resolve(n.inputs[0]),
1214 resolve(n.inputs[1]),
1215 resolve(n.inputs[2]),
1216 ),
1217 _ => unreachable!("non-chain-eligible op in region"),
1218 };
1219 chain.push(step);
1220 }
1221
1222 let mut scalar_input_mask: u32 = 0;
1231 let mut input_modulus = [0u32; 16];
1232 let region_shape_elems = graph.node(tail).shape.num_elements();
1233 for (i, &ext) in external_inputs.iter().enumerate() {
1234 if i >= 16 {
1235 break;
1236 }
1237 let in_elems = graph.shape(ext).num_elements();
1238 match (in_elems, region_shape_elems) {
1239 (Some(1), Some(o)) if o != 1 => {
1240 scalar_input_mask |= 1u32 << i;
1241 input_modulus[i] = 1;
1242 }
1243 (Some(i_n), Some(o)) if i_n != o && i_n > 0 => {
1244 input_modulus[i] = i_n as u32;
1245 }
1246 _ => { }
1247 }
1248 }
1249 let region_new = rw.add_fused(
1250 Op::ElementwiseRegion {
1251 chain,
1252 num_inputs: external_inputs.len() as u32,
1253 scalar_input_mask,
1254 input_modulus,
1255 },
1256 &external_inputs,
1257 graph.node(tail).shape.clone(),
1258 );
1259 emitted_region.insert(root, region_new);
1260 rw.replace(node.id, region_new);
1261 continue;
1262 } else {
1263 rw.replace(node.id, NodeId(u32::MAX)); continue;
1267 }
1268 }
1269 rw.copy_node(node);
1270 }
1271
1272 rw.finish(&graph.outputs)
1289 }
1290}
1291
1292pub struct UnfuseElementwiseRegions;
1305
1306impl Pass for UnfuseElementwiseRegions {
1307 fn name(&self) -> &str {
1308 "unfuse_elementwise_regions"
1309 }
1310
1311 fn run(&self, graph: Graph) -> Graph {
1312 let any_region = graph
1313 .nodes()
1314 .iter()
1315 .any(|n| matches!(n.op, Op::ElementwiseRegion { .. }));
1316 if !any_region {
1317 return graph;
1318 }
1319
1320 let mut rw = Rewriter::new(&graph.name);
1321 for node in graph.nodes() {
1322 if let Op::ElementwiseRegion {
1323 chain,
1324 num_inputs: _,
1325 scalar_input_mask: _,
1326 input_modulus: _,
1327 } = &node.op
1328 {
1329 let region_inputs: Vec<NodeId> = node.inputs.iter().map(|id| rw.map(*id)).collect();
1332 let mut step_ids: Vec<NodeId> = Vec::with_capacity(chain.len());
1333 let region_shape = node.shape.clone();
1334 let region_dims: Vec<_> = region_shape.dims().to_vec();
1335 let mut step_dtypes: Vec<rlx_ir::DType> = Vec::with_capacity(chain.len());
1341 let region_dtype = region_shape.dtype();
1342 let dtype_of = |op: &ChainOperand,
1343 ins: &[NodeId],
1344 step_dt: &[rlx_ir::DType],
1345 rw: &Rewriter|
1346 -> rlx_ir::DType {
1347 match *op {
1348 ChainOperand::Input(i) => rw.new_graph.node(ins[i as usize]).shape.dtype(),
1349 ChainOperand::Step(i) => step_dt[i as usize],
1350 }
1351 };
1352 let shape_of = |op: &ChainOperand,
1363 ins: &[NodeId],
1364 step_ids: &[NodeId],
1365 rw: &Rewriter|
1366 -> Shape {
1367 match *op {
1368 ChainOperand::Input(i) => rw.new_graph.node(ins[i as usize]).shape.clone(),
1369 ChainOperand::Step(i) => {
1370 rw.new_graph.node(step_ids[i as usize]).shape.clone()
1371 }
1372 }
1373 };
1374 for step in chain {
1375 let resolve = |op: &ChainOperand| -> NodeId {
1376 match *op {
1377 ChainOperand::Input(i) => region_inputs[i as usize],
1378 ChainOperand::Step(i) => step_ids[i as usize],
1379 }
1380 };
1381 let (new_id, dt) = match step {
1382 ChainStep::Activation(a, src) => {
1383 let s = resolve(src);
1384 let dt = dtype_of(src, ®ion_inputs, &step_dtypes, &rw);
1385 let src_shape = shape_of(src, ®ion_inputs, &step_ids, &rw);
1389 let dims: Vec<_> = src_shape.dims().to_vec();
1390 let shape = Shape::from_dims(&dims, dt);
1391 (
1392 rw.new_graph.add_node(Op::Activation(*a), vec![s], shape),
1393 dt,
1394 )
1395 }
1396 ChainStep::Cast(to, src) => {
1397 let s = resolve(src);
1398 let src_shape = shape_of(src, ®ion_inputs, &step_ids, &rw);
1399 let dims: Vec<_> = src_shape.dims().to_vec();
1400 let shape = Shape::from_dims(&dims, *to);
1401 (
1402 rw.new_graph.add_node(Op::Cast { to: *to }, vec![s], shape),
1403 *to,
1404 )
1405 }
1406 ChainStep::Binary(op, lhs, rhs) => {
1407 let l = resolve(lhs);
1408 let r = resolve(rhs);
1409 let dt = dtype_of(lhs, ®ion_inputs, &step_dtypes, &rw);
1410 let lhs_shape = shape_of(lhs, ®ion_inputs, &step_ids, &rw);
1412 let rhs_shape = shape_of(rhs, ®ion_inputs, &step_ids, &rw);
1413 let bcast = rlx_ir::shape::broadcast(&lhs_shape, &rhs_shape)
1414 .unwrap_or_else(|e| {
1415 panic!(
1416 "unfuse_elementwise_regions: cannot broadcast \
1417 {lhs_shape:?} ⊗ {rhs_shape:?} for Binary({op:?}): {e}"
1418 )
1419 });
1420 let dims: Vec<_> = bcast.dims().to_vec();
1421 let shape = Shape::from_dims(&dims, dt);
1422 (
1423 rw.new_graph.add_node(Op::Binary(*op), vec![l, r], shape),
1424 dt,
1425 )
1426 }
1427 ChainStep::Compare(op, lhs, rhs) => {
1428 let l = resolve(lhs);
1429 let r = resolve(rhs);
1430 let lhs_shape = shape_of(lhs, ®ion_inputs, &step_ids, &rw);
1431 let rhs_shape = shape_of(rhs, ®ion_inputs, &step_ids, &rw);
1432 let bcast = rlx_ir::shape::broadcast(&lhs_shape, &rhs_shape)
1433 .unwrap_or_else(|e| {
1434 panic!(
1435 "unfuse_elementwise_regions: cannot broadcast \
1436 {lhs_shape:?} ⊗ {rhs_shape:?} for Compare({op:?}): {e}"
1437 )
1438 });
1439 let dims: Vec<_> = bcast.dims().to_vec();
1440 let shape = Shape::from_dims(&dims, rlx_ir::DType::Bool);
1441 (
1442 rw.new_graph.add_node(Op::Compare(*op), vec![l, r], shape),
1443 rlx_ir::DType::Bool,
1444 )
1445 }
1446 ChainStep::Where(c, x, y) => {
1447 let cn = resolve(c);
1448 let xn = resolve(x);
1449 let yn = resolve(y);
1450 let dt = dtype_of(x, ®ion_inputs, &step_dtypes, &rw);
1451 let c_shape = shape_of(c, ®ion_inputs, &step_ids, &rw);
1453 let x_shape = shape_of(x, ®ion_inputs, &step_ids, &rw);
1454 let y_shape = shape_of(y, ®ion_inputs, &step_ids, &rw);
1455 let bcast_xy = rlx_ir::shape::broadcast(&x_shape, &y_shape)
1456 .unwrap_or_else(|e| {
1457 panic!(
1458 "unfuse_elementwise_regions: cannot broadcast \
1459 then/else {x_shape:?} ⊗ {y_shape:?} for Where: {e}"
1460 )
1461 });
1462 let bcast = rlx_ir::shape::broadcast(&c_shape, &bcast_xy)
1463 .unwrap_or_else(|e| {
1464 panic!(
1465 "unfuse_elementwise_regions: cannot broadcast cond \
1466 {c_shape:?} ⊗ {bcast_xy:?} for Where: {e}"
1467 )
1468 });
1469 let dims: Vec<_> = bcast.dims().to_vec();
1470 let shape = Shape::from_dims(&dims, dt);
1471 (
1472 rw.new_graph.add_node(Op::Where, vec![cn, xn, yn], shape),
1473 dt,
1474 )
1475 }
1476 };
1477 step_ids.push(new_id);
1478 step_dtypes.push(dt);
1479 }
1480 let _ = region_dtype;
1481 let _ = region_dims;
1482 let last = *step_ids.last().expect("chain non-empty per pass invariant");
1485 rw.replace(node.id, last);
1486 continue;
1487 }
1488 rw.copy_node(node);
1489 }
1490 rw.finish(&graph.outputs)
1491 }
1492}
1493
1494pub fn clip_elementwise_regions(graph: Graph, limits: crate::limits::FusionLimits) -> Graph {
1499 let oversize = |n: &rlx_ir::Node| -> bool {
1500 matches!(
1501 &n.op,
1502 Op::ElementwiseRegion {
1503 chain,
1504 num_inputs,
1505 ..
1506 } if *num_inputs > limits.max_elementwise_inputs
1507 || chain.len() as u32 > limits.max_elementwise_steps
1508 )
1509 };
1510 if !graph.nodes().iter().any(oversize) {
1511 return graph;
1512 }
1513
1514 let mut rw = Rewriter::new(&graph.name);
1515 for node in graph.nodes() {
1516 if !oversize(node) {
1517 rw.copy_node(node);
1518 continue;
1519 }
1520
1521 let Op::ElementwiseRegion {
1522 chain,
1523 num_inputs: _,
1524 scalar_input_mask: _,
1525 input_modulus: _,
1526 } = &node.op
1527 else {
1528 unreachable!();
1529 };
1530
1531 let region_inputs: Vec<NodeId> = node.inputs.iter().map(|id| rw.map(*id)).collect();
1532 let mut step_ids: Vec<NodeId> = Vec::with_capacity(chain.len());
1533 let region_shape = node.shape.clone();
1534 let region_dims: Vec<_> = region_shape.dims().to_vec();
1535 let mut step_dtypes: Vec<rlx_ir::DType> = Vec::with_capacity(chain.len());
1536 let region_dtype = region_shape.dtype();
1537 let dtype_of = |op: &ChainOperand,
1538 ins: &[NodeId],
1539 step_dt: &[rlx_ir::DType],
1540 rw: &Rewriter|
1541 -> rlx_ir::DType {
1542 match *op {
1543 ChainOperand::Input(i) => rw.new_graph.node(ins[i as usize]).shape.dtype(),
1544 ChainOperand::Step(i) => step_dt[i as usize],
1545 }
1546 };
1547 let shape_of =
1548 |op: &ChainOperand, ins: &[NodeId], step_ids: &[NodeId], rw: &Rewriter| -> Shape {
1549 match *op {
1550 ChainOperand::Input(i) => rw.new_graph.node(ins[i as usize]).shape.clone(),
1551 ChainOperand::Step(i) => rw.new_graph.node(step_ids[i as usize]).shape.clone(),
1552 }
1553 };
1554 for step in chain {
1555 let resolve = |op: &ChainOperand| -> NodeId {
1556 match *op {
1557 ChainOperand::Input(i) => region_inputs[i as usize],
1558 ChainOperand::Step(i) => step_ids[i as usize],
1559 }
1560 };
1561 let (new_id, dt) = match step {
1562 ChainStep::Activation(a, src) => {
1563 let s = resolve(src);
1564 let dt = dtype_of(src, ®ion_inputs, &step_dtypes, &rw);
1565 let src_shape = shape_of(src, ®ion_inputs, &step_ids, &rw);
1566 let dims: Vec<_> = src_shape.dims().to_vec();
1567 let shape = Shape::from_dims(&dims, dt);
1568 (
1569 rw.new_graph.add_node(Op::Activation(*a), vec![s], shape),
1570 dt,
1571 )
1572 }
1573 ChainStep::Cast(to, src) => {
1574 let s = resolve(src);
1575 let src_shape = shape_of(src, ®ion_inputs, &step_ids, &rw);
1576 let dims: Vec<_> = src_shape.dims().to_vec();
1577 let shape = Shape::from_dims(&dims, *to);
1578 (
1579 rw.new_graph.add_node(Op::Cast { to: *to }, vec![s], shape),
1580 *to,
1581 )
1582 }
1583 ChainStep::Binary(op, lhs, rhs) => {
1584 let l = resolve(lhs);
1585 let r = resolve(rhs);
1586 let dt = dtype_of(lhs, ®ion_inputs, &step_dtypes, &rw);
1587 let l_shape = shape_of(lhs, ®ion_inputs, &step_ids, &rw);
1588 let r_shape = shape_of(rhs, ®ion_inputs, &step_ids, &rw);
1589 let bcast = l_shape
1590 .broadcast_with(&r_shape)
1591 .unwrap_or_else(|e| panic!("clip_elementwise_regions: {e}"));
1592 let dims: Vec<_> = bcast.dims().to_vec();
1593 let shape = Shape::from_dims(&dims, dt);
1594 (
1595 rw.new_graph.add_node(Op::Binary(*op), vec![l, r], shape),
1596 dt,
1597 )
1598 }
1599 ChainStep::Compare(op, lhs, rhs) => {
1600 let l = resolve(lhs);
1601 let r = resolve(rhs);
1602 let l_shape = shape_of(lhs, ®ion_inputs, &step_ids, &rw);
1603 let r_shape = shape_of(rhs, ®ion_inputs, &step_ids, &rw);
1604 let bcast = l_shape
1605 .broadcast_with(&r_shape)
1606 .unwrap_or_else(|e| panic!("clip_elementwise_regions: {e}"));
1607 let dims: Vec<_> = bcast.dims().to_vec();
1608 let shape = Shape::from_dims(&dims, rlx_ir::DType::U8);
1609 (
1610 rw.new_graph.add_node(Op::Compare(*op), vec![l, r], shape),
1611 rlx_ir::DType::U8,
1612 )
1613 }
1614 ChainStep::Where(cond, x, y) => {
1615 let cn = resolve(cond);
1616 let xn = resolve(x);
1617 let yn = resolve(y);
1618 let dt = dtype_of(x, ®ion_inputs, &step_dtypes, &rw);
1619 let x_shape = shape_of(x, ®ion_inputs, &step_ids, &rw);
1620 let y_shape = shape_of(y, ®ion_inputs, &step_ids, &rw);
1621 let c_shape = shape_of(cond, ®ion_inputs, &step_ids, &rw);
1622 let bcast_xy = x_shape
1623 .broadcast_with(&y_shape)
1624 .unwrap_or_else(|e| panic!("clip_elementwise_regions: {e}"));
1625 let bcast = c_shape.broadcast_with(&bcast_xy).unwrap_or_else(|e| {
1626 panic!("clip_elementwise_regions: cannot broadcast cond {c_shape:?} ⊗ {bcast_xy:?} for Where: {e}")
1627 });
1628 let dims: Vec<_> = bcast.dims().to_vec();
1629 let shape = Shape::from_dims(&dims, dt);
1630 (
1631 rw.new_graph.add_node(Op::Where, vec![cn, xn, yn], shape),
1632 dt,
1633 )
1634 }
1635 };
1636 step_ids.push(new_id);
1637 step_dtypes.push(dt);
1638 }
1639 let _ = (region_dtype, region_dims);
1640 let last = *step_ids
1641 .last()
1642 .expect("oversize region has non-empty chain");
1643 rw.replace(node.id, last);
1644 }
1645 rw.finish(&graph.outputs)
1646}
1647
1648#[cfg(test)]
1649mod tests {
1650 use super::*;
1651 use crate::limits::FusionLimits;
1652 use crate::pass::run_passes;
1653
1654 fn f32_shape(dims: &[usize]) -> Shape {
1655 Shape::new(dims, DType::F32)
1656 }
1657
1658 #[test]
1659 fn fuse_matmul_bias_gelu() {
1660 let mut g = Graph::new("test");
1661 let x = g.input("x", f32_shape(&[4, 15, 384]));
1662 let w = g.param("w", f32_shape(&[384, 1536]));
1663 let b = g.param("b", f32_shape(&[1536]));
1664 let mm = g.matmul(x, w, f32_shape(&[4, 15, 1536]));
1665 let add = g.binary(BinaryOp::Add, mm, b, f32_shape(&[4, 15, 1536]));
1666 let out = g.activation(Activation::Gelu, add, f32_shape(&[4, 15, 1536]));
1667 g.set_outputs(vec![out]);
1668
1669 assert_eq!(g.len(), 6); let fused = FuseMatMulBiasAct.run(g);
1672 println!("{fused}");
1673
1674 assert_eq!(fused.len(), 4);
1676 let out_node = fused.node(fused.outputs[0]);
1677 assert!(matches!(
1678 out_node.op,
1679 Op::FusedMatMulBiasAct {
1680 activation: Some(Activation::Gelu)
1681 }
1682 ));
1683 }
1684
1685 #[test]
1686 fn fuse_matmul_bias_no_act() {
1687 let mut g = Graph::new("test");
1688 let x = g.input("x", f32_shape(&[4, 15, 384]));
1689 let w = g.param("w", f32_shape(&[384, 384]));
1690 let b = g.param("b", f32_shape(&[384]));
1691 let mm = g.matmul(x, w, f32_shape(&[4, 15, 384]));
1692 let add = g.binary(BinaryOp::Add, mm, b, f32_shape(&[4, 15, 384]));
1693 g.set_outputs(vec![add]);
1694
1695 let fused = FuseMatMulBiasAct.run(g);
1696 assert_eq!(fused.len(), 4);
1697 let out_node = fused.node(fused.outputs[0]);
1698 assert!(matches!(
1699 out_node.op,
1700 Op::FusedMatMulBiasAct { activation: None }
1701 ));
1702 }
1703
1704 #[test]
1705 fn fuse_matmul_bias_skips_unsupported_activation_epilogue() {
1706 let mut g = Graph::new("test");
1707 let x = g.input("x", f32_shape(&[8, 1024]));
1708 let w = g.param("w", f32_shape(&[1024, 16]));
1709 let b = g.param("b", f32_shape(&[16]));
1710 let mm = g.matmul(x, w, f32_shape(&[8, 16]));
1711 let add = g.binary(BinaryOp::Add, mm, b, f32_shape(&[8, 16]));
1712 let exp = g.activation(Activation::Exp, add, f32_shape(&[8, 16]));
1713 g.set_outputs(vec![exp]);
1714
1715 let fused = FuseMatMulBiasAct.run(g);
1716 assert_eq!(fused.len(), 5);
1718 let out_node = fused.node(fused.outputs[0]);
1719 assert!(matches!(out_node.op, Op::Activation(Activation::Exp)));
1720 let add_node = fused.node(out_node.inputs[0]);
1721 assert!(matches!(
1722 add_node.op,
1723 Op::FusedMatMulBiasAct { activation: None }
1724 ));
1725 }
1726
1727 #[test]
1728 fn fuse_matmul_bias_act_with_late_bias_param() {
1729 use rlx_ir::infer::GraphExt;
1730
1731 let mut g = Graph::new("late_bias");
1732 let x = g.input("x", f32_shape(&[8, 16]));
1733 let w = g.param("w", f32_shape(&[16, 32]));
1734 let out = {
1735 let mm = g.mm(x, w);
1736 let b = g.param("b", f32_shape(&[32]));
1737 let biased = g.add(mm, b);
1738 g.gelu(biased)
1739 };
1740 g.set_outputs(vec![out]);
1741
1742 let fused = FuseMatMulBiasAct.run(g);
1743 assert!(
1744 fused
1745 .nodes()
1746 .iter()
1747 .any(|n| matches!(n.op, Op::FusedMatMulBiasAct { .. })),
1748 "bias param declared after matmul must still fuse:\n{fused}"
1749 );
1750 }
1751
1752 #[test]
1753 fn swiglu_ffn_builder_fuses_end_to_end() {
1754 let mut g = Graph::new("swiglu_block");
1755 let x = g.input("x", f32_shape(&[4, 768]));
1756 let up_w = g.param("up", f32_shape(&[768, 2048]));
1757 let gate_w = g.param("gate", f32_shape(&[768, 2048]));
1758 let down_w = g.param("down", f32_shape(&[2048, 768]));
1759 let out = g.swiglu_ffn(x, up_w, gate_w, down_w);
1760 g.set_outputs(vec![out]);
1761
1762 let g = FuseSharedInputMatMul.run(g);
1763 let g = FuseSwiGLU.run(g);
1764 assert!(
1765 g.nodes()
1766 .iter()
1767 .any(|n| matches!(n.op, Op::FusedSwiGLU { .. })),
1768 "swiglu_ffn builder should match FuseSwiGLU:\n{g}"
1769 );
1770 }
1771
1772 #[test]
1773 fn fuse_swiglu_dual_matmul_gate_first() {
1774 use rlx_ir::infer::GraphExt;
1775
1776 let mut g = Graph::new("qwen3_ffn");
1777 let x = g.input("x", f32_shape(&[4, 768]));
1778 let gate_w = g.param("gate", f32_shape(&[768, 2048]));
1779 let up_w = g.param("up", f32_shape(&[768, 2048]));
1780 let gate = g.mm(x, gate_w);
1781 let up = g.mm(x, up_w);
1782 let gate_act = g.silu(gate);
1783 let out = g.mul(gate_act, up);
1784 g.set_outputs(vec![out]);
1785
1786 let fused = FuseSwiGLUDualMatmul.run(g);
1787 assert!(
1788 fused
1789 .nodes()
1790 .iter()
1791 .any(|n| matches!(n.op, Op::FusedSwiGLU { .. })),
1792 "gate-first dual matmul should fuse:\n{fused}"
1793 );
1794 assert!(
1795 fused.len() <= 6,
1796 "dual fusion should collapse to x + weights + concat + mm + fused_swiglu, got {} nodes",
1797 fused.len()
1798 );
1799 }
1800
1801 #[test]
1802 fn fuse_shared_input_matmul_three_way_qkv() {
1803 let mut g = Graph::new("qkv");
1804 let x = g.input("x", f32_shape(&[8, 512]));
1805 let wq = g.param("wq", f32_shape(&[512, 128]));
1806 let wk = g.param("wk", f32_shape(&[512, 128]));
1807 let wv = g.param("wv", f32_shape(&[512, 128]));
1808 let q = g.matmul(x, wq, f32_shape(&[8, 128]));
1809 let k = g.matmul(x, wk, f32_shape(&[8, 128]));
1810 let v = g.matmul(x, wv, f32_shape(&[8, 128]));
1811 g.set_outputs(vec![q, k, v]);
1812
1813 let fused = FuseSharedInputMatMul.run(g);
1814 assert_eq!(
1815 fused.len(),
1816 9,
1817 "x + 3 weights + concat + mm + 3 narrows = 9"
1818 );
1819 for &out in &fused.outputs {
1820 assert!(matches!(fused.node(out).op, Op::Narrow { .. }));
1821 }
1822 }
1823
1824 #[test]
1825 fn fuse_residual_layer_norm() {
1826 let mut g = Graph::new("test");
1827 let x = g.input("x", f32_shape(&[4, 15, 384]));
1828 let residual = g.input("residual", f32_shape(&[4, 15, 384]));
1829 let gamma = g.param("gamma", f32_shape(&[384]));
1830 let beta = g.param("beta", f32_shape(&[384]));
1831 let add = g.binary(BinaryOp::Add, x, residual, f32_shape(&[4, 15, 384]));
1832 let ln = g.layer_norm(add, gamma, beta, -1, 1e-12, f32_shape(&[4, 15, 384]));
1833 g.set_outputs(vec![ln]);
1834
1835 assert_eq!(g.len(), 6); let fused = FuseResidualLN.run(g);
1838 println!("{fused}");
1839
1840 assert_eq!(fused.len(), 5);
1842 let out_node = fused.node(fused.outputs[0]);
1843 assert!(matches!(
1844 out_node.op,
1845 Op::FusedResidualLN {
1846 has_bias: false,
1847 ..
1848 }
1849 ));
1850 }
1851
1852 #[test]
1853 fn fuse_residual_rms_norm() {
1854 let mut g = Graph::new("test");
1855 let x = g.input("x", f32_shape(&[4, 15, 384]));
1856 let residual = g.input("residual", f32_shape(&[4, 15, 384]));
1857 let gamma = g.param("gamma", f32_shape(&[384]));
1858 let beta = g.param("beta", f32_shape(&[384]));
1859 let add = g.binary(BinaryOp::Add, x, residual, f32_shape(&[4, 15, 384]));
1860 let rn = g.add_node(
1861 Op::RmsNorm {
1862 axis: -1,
1863 eps: 1e-6,
1864 },
1865 vec![add, gamma, beta],
1866 f32_shape(&[4, 15, 384]),
1867 );
1868 g.set_outputs(vec![rn]);
1869
1870 assert_eq!(g.len(), 6);
1871
1872 let fused = FuseResidualRmsNorm.run(g);
1873 assert_eq!(fused.len(), 5);
1874 let out_node = fused.node(fused.outputs[0]);
1875 assert!(matches!(
1876 out_node.op,
1877 Op::FusedResidualRmsNorm {
1878 has_bias: false,
1879 ..
1880 }
1881 ));
1882 }
1883
1884 #[test]
1885 fn fuse_rms_norm_reshape() {
1886 let mut g = Graph::new("test");
1887 let x = g.input("x", f32_shape(&[1, 8, 512]));
1888 let gamma = g.param("gamma", f32_shape(&[512]));
1889 let beta = g.param("beta", f32_shape(&[512]));
1890 let rn = g.add_node(
1891 Op::RmsNorm {
1892 axis: -1,
1893 eps: 1e-6,
1894 },
1895 vec![x, gamma, beta],
1896 f32_shape(&[1, 8, 512]),
1897 );
1898 let flat = g.add_node(
1899 Op::Reshape {
1900 new_shape: vec![8, 512],
1901 },
1902 vec![rn],
1903 f32_shape(&[8, 512]),
1904 );
1905 let w = g.param("w", f32_shape(&[512, 128]));
1906 let mm = g.matmul(flat, w, f32_shape(&[8, 128]));
1907 g.set_outputs(vec![mm]);
1908
1909 let fused = FuseRmsNormReshape.run(g);
1910 assert_eq!(fused.len(), 6);
1912 let rn_node = fused.node(fused.node(fused.outputs[0]).inputs[0]);
1913 assert!(matches!(rn_node.op, Op::RmsNorm { .. }));
1914 assert_eq!(rn_node.shape.dim(0).unwrap_static(), 8);
1915 assert_eq!(rn_node.shape.dim(1).unwrap_static(), 512);
1916 }
1917
1918 #[test]
1919 fn fuse_shared_input_matmul() {
1920 let mut g = Graph::new("swiglu");
1921 let x = g.input("x", f32_shape(&[60, 768]));
1922 let w1 = g.param("fc11", f32_shape(&[768, 2048]));
1923 let w2 = g.param("fc12", f32_shape(&[768, 2048]));
1924 let mm1 = g.matmul(x, w1, f32_shape(&[60, 2048]));
1925 let mm2 = g.matmul(x, w2, f32_shape(&[60, 2048]));
1926 g.set_outputs(vec![mm1, mm2]);
1927
1928 assert_eq!(g.len(), 5); let fused = FuseSharedInputMatMul.run(g);
1931 println!("{fused}");
1932
1933 assert!(fused.len() <= 7);
1935 for &out in &fused.outputs {
1937 assert!(matches!(fused.node(out).op, Op::Narrow { .. }));
1938 }
1939 }
1940
1941 #[test]
1944 fn fuse_shared_input_matmul_with_late_w2_param() {
1945 let mut g = Graph::new("late_w2");
1946 let x = g.input("x", f32_shape(&[8, 16]));
1947 let w1 = g.param("w1", f32_shape(&[16, 8]));
1948 let mm1 = g.matmul(x, w1, f32_shape(&[8, 8]));
1949 let w2 = g.param("w2", f32_shape(&[16, 8]));
1950 let mm2 = g.matmul(x, w2, f32_shape(&[8, 8]));
1951 g.set_outputs(vec![mm1, mm2]);
1952
1953 let fused = FuseSharedInputMatMul.run(g);
1954 for &out in &fused.outputs {
1955 assert!(
1956 matches!(fused.node(out).op, Op::Narrow { .. }),
1957 "late w2 should still fuse via ensure_mapped, got {:?}",
1958 fused.node(out).op
1959 );
1960 }
1961 }
1962
1963 #[test]
1966 fn fuse_shared_input_matmul_moe_ffn_pattern() {
1967 let mut g = Graph::new("moe_ffn");
1968 let rows = 4usize;
1969 let n_embd = 16usize;
1970 let n_expert = 4usize;
1971 let n_ff = 16usize;
1972
1973 let h_in = g.input("h", f32_shape(&[1, rows, n_embd]));
1974 let h_2d = g.reshape_(h_in, vec![rows as i64, n_embd as i64]);
1975
1976 let router_w = g.param("router_w", f32_shape(&[n_embd, n_expert]));
1977 let router_logits = g.matmul(h_2d, router_w, f32_shape(&[rows, n_expert]));
1978
1979 let shared_router_w = g.param("shared_router_w", f32_shape(&[n_embd, 1]));
1981 let shared_logits = g.matmul(h_2d, shared_router_w, f32_shape(&[rows, 1]));
1982 let shared_gate = g.activation(Activation::Sigmoid, shared_logits, f32_shape(&[rows, 1]));
1983
1984 let s_gate_w = g.param("s_gate_w", f32_shape(&[n_embd, n_ff]));
1985 let s_up_w = g.param("s_up_w", f32_shape(&[n_embd, n_ff]));
1986 let s_gate = g.matmul(h_2d, s_gate_w, f32_shape(&[rows, n_ff]));
1987 let s_up = g.matmul(h_2d, s_up_w, f32_shape(&[rows, n_ff]));
1988 let s_gate_silu = g.silu(s_gate);
1989 let s_swiglu = g.mul(s_gate_silu, s_up);
1990
1991 g.set_outputs(vec![router_logits, shared_gate, s_swiglu]);
1992
1993 let fused = FuseSharedInputMatMul.run(g);
1994 let narrow_count = fused
1995 .nodes()
1996 .iter()
1997 .filter(|n| matches!(n.op, Op::Narrow { .. }))
1998 .count();
1999 assert!(
2000 narrow_count >= 4,
2001 "expected four narrow slices from fused h_2d matmuls, got {narrow_count}"
2002 );
2003 }
2004
2005 #[test]
2007 fn full_bert_ffn_fusion() {
2008 let mut g = Graph::new("bert_ffn");
2009 let f = DType::F32;
2010
2011 let x = g.input("hidden", Shape::new(&[4, 15, 384], f));
2012 let residual = g.input("residual", Shape::new(&[4, 15, 384], f));
2013
2014 let out_w = g.param("out.w", Shape::new(&[384, 384], f));
2016 let out_b = g.param("out.b", Shape::new(&[384], f));
2017 let out_mm = g.matmul(x, out_w, Shape::new(&[4, 15, 384], f));
2018 let out_add = g.binary(BinaryOp::Add, out_mm, out_b, Shape::new(&[4, 15, 384], f));
2019 let res_add = g.binary(
2020 BinaryOp::Add,
2021 out_add,
2022 residual,
2023 Shape::new(&[4, 15, 384], f),
2024 );
2025 let gamma = g.param("ln.g", Shape::new(&[384], f));
2026 let beta = g.param("ln.b", Shape::new(&[384], f));
2027 let ln = g.layer_norm(
2028 res_add,
2029 gamma,
2030 beta,
2031 -1,
2032 1e-12,
2033 Shape::new(&[4, 15, 384], f),
2034 );
2035
2036 let int_w = g.param("int.w", Shape::new(&[384, 1536], f));
2038 let int_b = g.param("int.b", Shape::new(&[1536], f));
2039 let int_mm = g.matmul(ln, int_w, Shape::new(&[4, 15, 1536], f));
2040 let int_add = g.binary(BinaryOp::Add, int_mm, int_b, Shape::new(&[4, 15, 1536], f));
2041 let gelu = g.activation(Activation::Gelu, int_add, Shape::new(&[4, 15, 1536], f));
2042
2043 let out2_w = g.param("out2.w", Shape::new(&[1536, 384], f));
2045 let out2_b = g.param("out2.b", Shape::new(&[384], f));
2046 let out2_mm = g.matmul(gelu, out2_w, Shape::new(&[4, 15, 384], f));
2047 let out2_add = g.binary(BinaryOp::Add, out2_mm, out2_b, Shape::new(&[4, 15, 384], f));
2048
2049 g.set_outputs(vec![out2_add]);
2050
2051 let before = g.len();
2052 println!("=== BEFORE fusion ({before} nodes) ===\n{g}");
2053
2054 let passes: Vec<&dyn Pass> = vec![&FuseMatMulBiasAct, &FuseResidualLN];
2056 let optimized = run_passes(g, &passes, false);
2057 let after = optimized.len();
2058 println!("=== AFTER fusion ({after} nodes) ===\n{optimized}");
2059
2060 assert!(
2064 after < before,
2065 "fusion should reduce node count: {before} → {after}"
2066 );
2067
2068 let ops: Vec<String> = optimized
2070 .nodes()
2071 .iter()
2072 .map(|n| format!("{}", n.op))
2073 .collect();
2074 let has_fused_mm = ops.iter().any(|s| s.contains("fused_mm_bias"));
2075 assert!(has_fused_mm, "should have fused_mm_bias_act: {ops:?}");
2076 }
2077
2078 #[test]
2081 fn fuse_swiglu_canonical() {
2082 let mut g = Graph::new("nomic_ffn");
2083 let f = DType::F32;
2084 let cat = g.input("cat", Shape::new(&[60, 4096], f));
2086 let up = g.add_node(
2087 Op::Narrow {
2088 axis: 1,
2089 start: 0,
2090 len: 2048,
2091 },
2092 vec![cat],
2093 Shape::new(&[60, 2048], f),
2094 );
2095 let gate = g.add_node(
2096 Op::Narrow {
2097 axis: 1,
2098 start: 2048,
2099 len: 2048,
2100 },
2101 vec![cat],
2102 Shape::new(&[60, 2048], f),
2103 );
2104 let silu = g.activation(Activation::Silu, gate, Shape::new(&[60, 2048], f));
2105 let out = g.binary(BinaryOp::Mul, up, silu, Shape::new(&[60, 2048], f));
2106 g.set_outputs(vec![out]);
2107
2108 let before = g.len();
2109 let fused = FuseSwiGLU.run(g);
2110 let after = fused.len();
2111 assert_eq!(
2114 after,
2115 before - 3,
2116 "should remove narrows+silu+mul, add FusedSwiGLU"
2117 );
2118 let out_node = fused.node(fused.outputs[0]);
2119 assert!(
2120 matches!(
2121 out_node.op,
2122 Op::FusedSwiGLU {
2123 cast_to: None,
2124 gate_first: false
2125 }
2126 ),
2127 "output should be FusedSwiGLU, got {}",
2128 out_node.op
2129 );
2130 let in_id = out_node.inputs[0];
2132 assert!(matches!(fused.node(in_id).op, Op::Input { .. }));
2133 }
2134
2135 #[test]
2138 fn fuse_swiglu_skips_when_narrow_has_extra_user() {
2139 let mut g = Graph::new("contended");
2140 let f = DType::F32;
2141 let cat = g.input("cat", Shape::new(&[60, 4096], f));
2142 let up = g.add_node(
2143 Op::Narrow {
2144 axis: 1,
2145 start: 0,
2146 len: 2048,
2147 },
2148 vec![cat],
2149 Shape::new(&[60, 2048], f),
2150 );
2151 let gate = g.add_node(
2152 Op::Narrow {
2153 axis: 1,
2154 start: 2048,
2155 len: 2048,
2156 },
2157 vec![cat],
2158 Shape::new(&[60, 2048], f),
2159 );
2160 let silu = g.activation(Activation::Silu, gate, Shape::new(&[60, 2048], f));
2161 let out = g.binary(BinaryOp::Mul, up, silu, Shape::new(&[60, 2048], f));
2162 let extra = g.activation(Activation::Relu, up, Shape::new(&[60, 2048], f));
2164 g.set_outputs(vec![out, extra]);
2165
2166 let before = g.len();
2167 let fused = FuseSwiGLU.run(g);
2168 assert_eq!(fused.len(), before);
2170 let any_fused = fused
2172 .nodes()
2173 .iter()
2174 .any(|n| matches!(n.op, Op::FusedSwiGLU { .. }));
2175 assert!(!any_fused, "should not fuse when narrow has extra user");
2176 }
2177
2178 #[test]
2181 fn region_collapses_add_mul_relu_chain() {
2182 let f = DType::F32;
2185 let mut g = Graph::new("ew");
2186 let a = g.input("a", Shape::new(&[8], f));
2187 let b = g.input("b", Shape::new(&[8], f));
2188 let c = g.input("c", Shape::new(&[8], f));
2189 let s = Shape::new(&[8], f);
2190 let add = g.binary(BinaryOp::Add, a, b, s.clone());
2191 let mul = g.binary(BinaryOp::Mul, add, c, s.clone());
2192 let relu = g.activation(Activation::Relu, mul, s.clone());
2193 g.set_outputs(vec![relu]);
2194
2195 let before = g.len();
2196 let fused = MarkElementwiseRegions.run(g);
2197
2198 let regions: Vec<_> = fused
2200 .nodes()
2201 .iter()
2202 .filter(|n| matches!(n.op, Op::ElementwiseRegion { .. }))
2203 .collect();
2204 assert_eq!(regions.len(), 1, "expected one ElementwiseRegion");
2205 let region = regions[0];
2206 assert_eq!(
2207 region.inputs.len(),
2208 3,
2209 "region has 3 external inputs (a, b, c)"
2210 );
2211 if let Op::ElementwiseRegion {
2212 chain, num_inputs, ..
2213 } = ®ion.op
2214 {
2215 assert_eq!(*num_inputs, 3);
2216 assert_eq!(chain.len(), 3);
2217 match &chain[0] {
2219 ChainStep::Binary(
2220 BinaryOp::Add,
2221 ChainOperand::Input(0),
2222 ChainOperand::Input(1),
2223 ) => {}
2224 other => panic!("step 0 unexpected: {other:?}"),
2225 }
2226 match &chain[1] {
2228 ChainStep::Binary(BinaryOp::Mul, ChainOperand::Step(0), ChainOperand::Input(2)) => {
2229 }
2230 other => panic!("step 1 unexpected: {other:?}"),
2231 }
2232 match &chain[2] {
2234 ChainStep::Activation(Activation::Relu, ChainOperand::Step(1)) => {}
2235 other => panic!("step 2 unexpected: {other:?}"),
2236 }
2237 } else {
2238 unreachable!();
2239 }
2240 assert!(fused.len() < before);
2243 }
2244
2245 #[test]
2246 fn region_does_not_fuse_when_intermediate_has_multiple_consumers() {
2247 let f = DType::F32;
2250 let mut g = Graph::new("ew");
2251 let a = g.input("a", Shape::new(&[4], f));
2252 let b = g.input("b", Shape::new(&[4], f));
2253 let s = Shape::new(&[4], f);
2254 let add = g.binary(BinaryOp::Add, a, b, s.clone());
2255 let relu = g.activation(Activation::Relu, add, s.clone());
2256 let extra = g.activation(Activation::Sigmoid, add, s.clone());
2257 g.set_outputs(vec![relu, extra]);
2258
2259 let before = g.len();
2260 let fused = MarkElementwiseRegions.run(g);
2261 let regions: Vec<_> = fused
2265 .nodes()
2266 .iter()
2267 .filter(|n| matches!(n.op, Op::ElementwiseRegion { .. }))
2268 .collect();
2269 assert_eq!(regions.len(), 0);
2270 assert_eq!(fused.len(), before);
2271 }
2272
2273 #[test]
2274 fn region_skips_chains_of_length_one() {
2275 let f = DType::F32;
2277 let mut g = Graph::new("ew");
2278 let a = g.input("a", Shape::new(&[4], f));
2279 let r = g.activation(Activation::Relu, a, Shape::new(&[4], f));
2280 g.set_outputs(vec![r]);
2281
2282 let fused = MarkElementwiseRegions.run(g);
2283 let any_region = fused
2284 .nodes()
2285 .iter()
2286 .any(|n| matches!(n.op, Op::ElementwiseRegion { .. }));
2287 assert!(!any_region);
2288 }
2289
2290 #[test]
2291 fn unfuse_decomposes_region_back_to_atomic_ops() {
2292 let f = DType::F32;
2295 let mut g = Graph::new("ew_unfuse");
2296 let a = g.input("a", Shape::new(&[8], f));
2297 let b = g.input("b", Shape::new(&[8], f));
2298 let c = g.input("c", Shape::new(&[8], f));
2299 let s = Shape::new(&[8], f);
2300 let add = g.binary(BinaryOp::Add, a, b, s.clone());
2301 let mul = g.binary(BinaryOp::Mul, add, c, s.clone());
2302 let relu = g.activation(Activation::Relu, mul, s);
2303 g.set_outputs(vec![relu]);
2304
2305 let fused = MarkElementwiseRegions.run(g);
2306 assert!(
2308 fused
2309 .nodes()
2310 .iter()
2311 .any(|n| matches!(n.op, Op::ElementwiseRegion { .. }))
2312 );
2313
2314 let unfused = UnfuseElementwiseRegions.run(fused);
2315 assert!(
2317 !unfused
2318 .nodes()
2319 .iter()
2320 .any(|n| matches!(n.op, Op::ElementwiseRegion { .. }))
2321 );
2322 let bin_count = unfused
2324 .nodes()
2325 .iter()
2326 .filter(|n| matches!(n.op, Op::Binary(_)))
2327 .count();
2328 let act_count = unfused
2329 .nodes()
2330 .iter()
2331 .filter(|n| matches!(n.op, Op::Activation(_)))
2332 .count();
2333 assert_eq!(bin_count, 2, "Add + Mul restored");
2334 assert_eq!(act_count, 1, "Relu restored");
2335 }
2336
2337 #[test]
2338 fn clip_unfuses_region_over_step_cap() {
2339 use rlx_ir::op::{Activation, ChainOperand, ChainStep};
2340
2341 let mut g = Graph::new("clip");
2342 let x = g.input("x", f32_shape(&[4]));
2343 let mut chain: Vec<ChainStep> = Vec::new();
2344 let mut prev = ChainOperand::Input(0);
2345 for _ in 0..40 {
2346 chain.push(ChainStep::Activation(Activation::Relu, prev));
2347 prev = ChainOperand::Step(chain.len() as u32 - 1);
2348 }
2349 let y = g.add_node(
2350 Op::ElementwiseRegion {
2351 chain,
2352 num_inputs: 1,
2353 scalar_input_mask: 0,
2354 input_modulus: [0; 16],
2355 },
2356 vec![x],
2357 f32_shape(&[4]),
2358 );
2359 g.set_outputs(vec![y]);
2360
2361 let clipped = clip_elementwise_regions(g, FusionLimits::GPU_NATIVE);
2362 assert!(
2363 !clipped
2364 .nodes()
2365 .iter()
2366 .any(|n| matches!(n.op, Op::ElementwiseRegion { .. })),
2367 "oversized region should be decomposed"
2368 );
2369 assert!(clipped.len() > 5);
2370 }
2371
2372 #[test]
2373 fn unfuse_is_noop_when_no_region_present() {
2374 let f = DType::F32;
2375 let mut g = Graph::new("noop");
2376 let a = g.input("a", Shape::new(&[4], f));
2377 let r = g.activation(Activation::Relu, a, Shape::new(&[4], f));
2378 g.set_outputs(vec![r]);
2379 let n_before = g.len();
2380 let result = UnfuseElementwiseRegions.run(g);
2381 assert_eq!(result.len(), n_before);
2383 }
2384
2385 #[test]
2386 fn region_includes_where_step() {
2387 let f = DType::F32;
2392 let mut g = Graph::new("region_where");
2393 let a = g.input("a", Shape::new(&[4], f));
2394 let b = g.input("b", Shape::new(&[4], f));
2395 let s = Shape::new(&[4], f);
2396 let cmp = g.add_node(Op::Compare(CmpOp::Gt), vec![a, b], s.clone());
2397 let sel = g.add_node(Op::Where, vec![cmp, a, b], s.clone());
2398 let add = g.binary(BinaryOp::Add, sel, a, s.clone());
2399 g.set_outputs(vec![add]);
2400
2401 let fused = MarkElementwiseRegions.run(g);
2402 let regions: Vec<_> = fused
2403 .nodes()
2404 .iter()
2405 .filter(|n| matches!(n.op, Op::ElementwiseRegion { .. }))
2406 .collect();
2407 assert_eq!(regions.len(), 1, "expected one ElementwiseRegion");
2408 if let Op::ElementwiseRegion { chain, .. } = ®ions[0].op {
2409 assert_eq!(chain.len(), 3);
2411 assert!(
2412 matches!(chain[1], ChainStep::Where(_, _, _)),
2413 "step 1 should be Where, got {:?}",
2414 chain[1]
2415 );
2416 } else {
2417 unreachable!();
2418 }
2419 }
2420
2421 #[test]
2422 fn unfuse_decomposes_where_step_back_to_op_where() {
2423 let f = DType::F32;
2426 let mut g = Graph::new("unfuse_where");
2427 let a = g.input("a", Shape::new(&[4], f));
2428 let b = g.input("b", Shape::new(&[4], f));
2429 let s = Shape::new(&[4], f);
2430 let cmp = g.add_node(Op::Compare(CmpOp::Gt), vec![a, b], s.clone());
2431 let sel = g.add_node(Op::Where, vec![cmp, a, b], s.clone());
2432 let add = g.binary(BinaryOp::Add, sel, a, s.clone());
2433 g.set_outputs(vec![add]);
2434 let fused = MarkElementwiseRegions.run(g);
2435 let unfused = UnfuseElementwiseRegions.run(fused);
2436 let where_count = unfused
2437 .nodes()
2438 .iter()
2439 .filter(|n| matches!(n.op, Op::Where))
2440 .count();
2441 assert_eq!(
2442 where_count, 1,
2443 "decomposer should re-emit one Op::Where for the chain step"
2444 );
2445 }
2446}