1use rlx_ir::OpKind;
24
25use crate::DeadCodeElimination;
26use rlx_fusion::control_flow::LowerControlFlow;
27use rlx_fusion::fk_fusion::{
28 DecomposeFusionRegions, FuseBatchPreprocess, FuseRegionPrologue, MarkBatchSliceRegions,
29 MarkTransformRegions,
30};
31use rlx_fusion::fusion::{
32 FuseAttentionBlock, FuseMatMulBiasAct, FuseResidualLN, FuseResidualRmsNorm, FuseRmsNormReshape,
33 FuseSharedInputMatMul, FuseSwiGLU, FuseSwiGLUDualMatmul, FuseTransformerLayer,
34 MarkElementwiseRegions, UnfuseElementwiseRegions,
35};
36use rlx_fusion::limits::FusionLimits;
37use rlx_fusion::lower_dot_general::LowerDotGeneral;
38use rlx_fusion::pass::Pass;
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum FusionTarget {
43 Cpu,
44 Metal,
45 Mlx,
46 Wgpu,
47 Cuda,
48 Rocm,
49 Tpu,
50}
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub struct FusionOptions {
55 pub skip_fusion: bool,
57 pub unfuse_elementwise_regions: bool,
59 pub keep_elementwise_regions: bool,
61 pub decompose_fusion_regions: bool,
63 pub fk_fusion: bool,
65 pub fuse_region_prologue: bool,
67 pub fuse_batch_preprocess: bool,
69 pub native_fk_regions: bool,
71 pub fusion_limits: FusionLimits,
73}
74
75impl Default for FusionOptions {
76 fn default() -> Self {
77 Self {
78 skip_fusion: false,
79 unfuse_elementwise_regions: false,
80 keep_elementwise_regions: false,
81 decompose_fusion_regions: false,
82 fk_fusion: true,
83 fuse_region_prologue: true,
84 fuse_batch_preprocess: true,
85 native_fk_regions: false,
86 fusion_limits: FusionLimits::default(),
87 }
88 }
89}
90
91impl FusionOptions {
92 pub fn from_metal_env() -> Self {
94 Self {
95 skip_fusion: rlx_ir::env::flag("RLX_METAL_NO_FUSION"),
96 unfuse_elementwise_regions: rlx_ir::env::flag("RLX_METAL_UNFUSE_REGIONS"),
97 keep_elementwise_regions: rlx_ir::env::flag("RLX_KEEP_ELEMENTWISE_REGIONS"),
98 decompose_fusion_regions: rlx_ir::env::flag("RLX_DECOMPOSE_FUSION_REGIONS"),
99 fk_fusion: !rlx_ir::env::flag("RLX_NO_FK_FUSION"),
100 fuse_region_prologue: if rlx_ir::env::is_unset("RLX_FUSE_REGION_PROLOGUE") {
101 true
102 } else {
103 rlx_ir::env::flag("RLX_FUSE_REGION_PROLOGUE")
104 },
105 fuse_batch_preprocess: if rlx_ir::env::is_unset("RLX_FUSE_BATCH_PREPROCESS") {
106 true
107 } else {
108 rlx_ir::env::flag("RLX_FUSE_BATCH_PREPROCESS")
109 },
110 native_fk_regions: rlx_ir::env::flag("RLX_NATIVE_FK_REGIONS"),
111 ..Self::default()
112 }
113 }
114
115 pub fn merge_env(mut self) -> Self {
117 if rlx_ir::env::flag("RLX_METAL_NO_FUSION") {
118 self.skip_fusion = true;
119 }
120 if rlx_ir::env::flag("RLX_METAL_UNFUSE_REGIONS") {
121 self.unfuse_elementwise_regions = true;
122 }
123 if rlx_ir::env::flag("RLX_KEEP_ELEMENTWISE_REGIONS") {
124 self.keep_elementwise_regions = true;
125 }
126 if rlx_ir::env::flag("RLX_DECOMPOSE_FUSION_REGIONS") {
127 self.decompose_fusion_regions = true;
128 }
129 if rlx_ir::env::flag("RLX_NO_FK_FUSION") {
130 self.fk_fusion = false;
131 }
132 if !rlx_ir::env::is_unset("RLX_FUSE_REGION_PROLOGUE") {
133 self.fuse_region_prologue = rlx_ir::env::flag("RLX_FUSE_REGION_PROLOGUE");
134 }
135 if !rlx_ir::env::is_unset("RLX_FUSE_BATCH_PREPROCESS") {
136 self.fuse_batch_preprocess = rlx_ir::env::flag("RLX_FUSE_BATCH_PREPROCESS");
137 }
138 if rlx_ir::env::flag("RLX_NATIVE_FK_REGIONS") {
139 self.native_fk_regions = true;
140 }
141 if rlx_ir::env::flag("RLX_NO_NATIVE_FK_REGIONS") {
142 self.native_fk_regions = false;
143 }
144 self
145 }
146
147 pub fn apply_native_fk_defaults(mut self, target: FusionTarget) -> Self {
149 if rlx_ir::env::flag("RLX_NO_NATIVE_FK_REGIONS") {
150 self.native_fk_regions = false;
151 return self;
152 }
153 if self.native_fk_regions || rlx_ir::env::flag("RLX_NATIVE_FK_REGIONS") {
154 self.native_fk_regions = true;
155 return self;
156 }
157 if matches!(
158 target,
159 FusionTarget::Metal
160 | FusionTarget::Cuda
161 | FusionTarget::Rocm
162 | FusionTarget::Wgpu
163 | FusionTarget::Mlx
164 | FusionTarget::Tpu
165 ) {
166 self.native_fk_regions = true;
167 }
168 self
169 }
170
171 pub fn for_cpu() -> Self {
173 Self {
174 unfuse_elementwise_regions: true,
175 fusion_limits: FusionLimits::UNBOUNDED,
176 ..Self::default()
177 }
178 }
179
180 pub fn for_metal() -> Self {
183 let mut opts = Self::from_metal_env();
184 opts.unfuse_elementwise_regions = true;
185 opts
186 }
187
188 pub fn for_wgpu() -> Self {
191 let keep = rlx_ir::env::flag("RLX_KEEP_ELEMENTWISE_REGIONS");
192 Self {
193 unfuse_elementwise_regions: !keep,
194 keep_elementwise_regions: keep,
195 ..Self::default()
196 }
197 }
198}
199
200pub fn fusion_limits_for_target(target: FusionTarget) -> FusionLimits {
202 match target {
203 FusionTarget::Cpu => FusionLimits::UNBOUNDED,
204 FusionTarget::Tpu => FusionLimits {
205 max_elementwise_steps: 32,
206 max_elementwise_inputs: 16,
207 },
208 _ => FusionLimits::GPU_NATIVE,
209 }
210}
211
212#[inline]
214pub fn supports_op(supported: &[OpKind], kind: OpKind) -> bool {
215 supported.is_empty() || supported.contains(&kind)
216}
217
218pub fn fusion_passes_for_supported(
227 supported: &[OpKind],
228 opts: FusionOptions,
229 target: FusionTarget,
230) -> Vec<&'static dyn Pass> {
231 let opts = opts.apply_native_fk_defaults(target);
232 if opts.skip_fusion {
233 return vec![&LowerControlFlow, &LowerDotGeneral];
234 }
235
236 let mut passes: Vec<&'static dyn Pass> = vec![&LowerControlFlow, &LowerDotGeneral];
237
238 if supports_op(supported, OpKind::FusedMatMulBiasAct) {
244 passes.push(&FuseMatMulBiasAct);
245 }
246 if supports_op(supported, OpKind::FusedAttentionBlock) {
254 passes.push(&FuseAttentionBlock);
255 }
256 if supports_op(supported, OpKind::FusedResidualLN) {
260 passes.push(&FuseResidualLN);
261 }
262 if supports_op(supported, OpKind::FusedResidualRmsNorm) {
263 passes.push(&FuseResidualRmsNorm);
264 }
265 passes.push(&FuseRmsNormReshape);
266
267 if rlx_ir::env::flag("RLX_ENABLE_FUSE_TRANSFORMER_LAYER")
276 && supports_op(supported, OpKind::FusedTransformerLayer)
277 && supports_op(supported, OpKind::FusedAttentionBlock)
278 {
279 passes.push(&FuseTransformerLayer);
280 }
281
282 if supports_op(supported, OpKind::FusedSwiGLU) {
283 passes.push(&FuseSwiGLUDualMatmul);
284 }
285 if supports_op(supported, OpKind::MatMul) {
286 passes.push(&FuseSharedInputMatMul);
287 }
288 if supports_op(supported, OpKind::FusedSwiGLU) {
289 passes.push(&FuseSwiGLU);
290 }
291
292 passes.push(&MarkElementwiseRegions);
295 if opts.fk_fusion {
296 passes.push(&MarkBatchSliceRegions);
297 passes.push(&MarkTransformRegions);
298 if opts.fuse_region_prologue {
299 passes.push(&FuseRegionPrologue);
300 }
301 if opts.fuse_batch_preprocess {
302 passes.push(&FuseBatchPreprocess);
303 }
304 }
305 let backend_native_fk = supports_op(supported, OpKind::TransformRegion)
306 && supports_op(supported, OpKind::BatchElementwiseRegion);
307 let keep_native_fk = opts.native_fk_regions && backend_native_fk;
308 if opts.decompose_fusion_regions || !keep_native_fk {
309 passes.push(&DecomposeFusionRegions);
310 }
311 let keep_regions =
312 supports_op(supported, OpKind::ElementwiseRegion) && !opts.unfuse_elementwise_regions;
313 if !keep_regions {
314 let unfuse = if matches!(target, FusionTarget::Cpu) {
315 &UnfuseElementwiseRegions::FOR_CPU
316 } else {
317 &UnfuseElementwiseRegions::FOR_GPU
318 };
319 passes.push(unfuse);
320 }
321
322 finish_pipeline(passes)
323}
324
325pub fn fk_passes_after_elementwise_regions(
327 supported: &[OpKind],
328 opts: FusionOptions,
329) -> Vec<&'static dyn Pass> {
330 let mut passes: Vec<&'static dyn Pass> = Vec::new();
331 if !opts.fk_fusion {
332 let backend_native_fk = supports_op(supported, OpKind::TransformRegion)
333 && supports_op(supported, OpKind::BatchElementwiseRegion);
334 let keep_native_fk = opts.native_fk_regions && backend_native_fk;
335 if opts.decompose_fusion_regions || !keep_native_fk {
336 passes.push(&DecomposeFusionRegions);
337 }
338 return finish_pipeline(passes);
339 }
340 passes.push(&MarkBatchSliceRegions);
341 passes.push(&MarkTransformRegions);
342 if opts.fuse_region_prologue {
343 passes.push(&FuseRegionPrologue);
344 }
345 if opts.fuse_batch_preprocess {
346 passes.push(&FuseBatchPreprocess);
347 }
348 let backend_native_fk = supports_op(supported, OpKind::TransformRegion)
349 && supports_op(supported, OpKind::BatchElementwiseRegion);
350 let keep_native_fk = opts.native_fk_regions && backend_native_fk;
351 if opts.decompose_fusion_regions || !keep_native_fk {
352 passes.push(&DecomposeFusionRegions);
353 }
354 finish_pipeline(passes)
355}
356
357pub fn io_fusion_gate_for_target(target: FusionTarget) -> crate::fusion_benefit::IoFusionGate {
359 use crate::fusion_benefit::IoFusionGate;
360 match target {
361 FusionTarget::Metal | FusionTarget::Mlx => IoFusionGate {
362 dispatch_ns: 500.0,
363 roundtrip_ns: 5_000.0,
364 memory_bw: 200.0,
365 host_readback_bw: 200.0,
366 unified_memory: true,
367 min_gain_ns: 1_000.0,
368 },
369 FusionTarget::Cuda | FusionTarget::Rocm => IoFusionGate {
370 dispatch_ns: 2_000.0,
371 roundtrip_ns: 20_000.0,
372 memory_bw: 800.0,
373 host_readback_bw: 50.0,
374 unified_memory: false,
375 min_gain_ns: 5_000.0,
376 },
377 FusionTarget::Wgpu | FusionTarget::Tpu => IoFusionGate {
378 dispatch_ns: 3_000.0,
379 roundtrip_ns: 30_000.0,
380 memory_bw: 100.0,
381 host_readback_bw: 40.0,
382 unified_memory: false,
383 min_gain_ns: 10_000.0,
384 },
385 FusionTarget::Cpu => IoFusionGate {
386 dispatch_ns: 50.0,
387 roundtrip_ns: 0.0,
388 memory_bw: 50.0,
389 host_readback_bw: 50.0,
390 unified_memory: true,
391 min_gain_ns: 0.0,
392 },
393 }
394}
395
396pub fn fusion_passes(target: FusionTarget, opts: FusionOptions) -> Vec<&'static dyn Pass> {
398 let mut opts = opts;
399 if !opts.keep_elementwise_regions
403 && matches!(target, FusionTarget::Cpu | FusionTarget::Metal)
404 && !opts.unfuse_elementwise_regions
405 {
406 opts.unfuse_elementwise_regions = true;
407 }
408 if opts.fusion_limits == FusionLimits::default() {
409 opts.fusion_limits = fusion_limits_for_target(target);
410 }
411 opts = opts.apply_native_fk_defaults(target);
412 fusion_passes_for_supported(supported_for_target(target), opts, target)
413}
414
415pub fn supported_for_target(target: FusionTarget) -> &'static [OpKind] {
419 use OpKind::*;
420 match target {
421 FusionTarget::Cpu => &[
422 MatMul,
423 DotGeneral,
424 ElementwiseRegion,
425 FusedSwiGLU,
426 FusedMatMulBiasAct,
427 FusedResidualLN,
428 FusedResidualRmsNorm,
429 FusedAttentionBlock,
430 ],
431 FusionTarget::Metal => &[
432 MatMul,
433 DotGeneral,
434 ElementwiseRegion,
435 TransformRegion,
436 BatchElementwiseRegion,
437 FusedSwiGLU,
438 FusedMatMulBiasAct,
439 FusedResidualLN,
440 FusedResidualRmsNorm,
441 ],
442 FusionTarget::Mlx => &[
443 MatMul,
444 DotGeneral,
445 ElementwiseRegion,
446 TransformRegion,
447 BatchElementwiseRegion,
448 FusedSwiGLU,
449 FusedMatMulBiasAct,
450 FusedResidualLN,
451 FusedResidualRmsNorm,
452 ],
453 FusionTarget::Wgpu => &[
454 MatMul,
455 ElementwiseRegion,
456 TransformRegion,
457 BatchElementwiseRegion,
458 FusedSwiGLU,
459 FusedMatMulBiasAct,
460 FusedResidualLN,
461 FusedResidualRmsNorm,
462 FusedAttentionBlock,
463 FusedTransformerLayer,
464 ],
465 FusionTarget::Cuda | FusionTarget::Rocm => &[
466 MatMul,
467 DotGeneral,
468 ElementwiseRegion,
469 TransformRegion,
470 BatchElementwiseRegion,
471 FusedMatMulBiasAct,
472 FusedResidualLN,
473 FusedResidualRmsNorm,
474 ],
475 FusionTarget::Tpu => &[
476 MatMul,
477 ElementwiseRegion,
478 TransformRegion,
479 BatchElementwiseRegion,
480 FusedMatMulBiasAct,
481 FusedResidualLN,
482 ],
483 }
484}
485
486fn finish_pipeline(mut passes: Vec<&'static dyn Pass>) -> Vec<&'static dyn Pass> {
487 passes.push(&DeadCodeElimination);
488 passes
489}
490
491#[cfg(test)]
492mod tests {
493 use super::*;
494 use std::sync::Mutex;
495
496 static ENV_FK_TEST_LOCK: Mutex<()> = Mutex::new(());
497
498 #[test]
499 fn cpu_pipeline_includes_attention_block() {
500 let passes = fusion_passes(FusionTarget::Cpu, FusionOptions::default());
501 assert_eq!(passes.len(), 18);
502 assert_eq!(passes[2].name(), "fuse_matmul_bias_act");
503 assert_eq!(passes[3].name(), "fuse_attention_block");
504 assert!(
505 passes.iter().any(|p| p.name() == "fuse_region_prologue"),
506 "default CPU pipeline should run FKL prologue fusion"
507 );
508 assert_eq!(passes.last().unwrap().name(), "dead_code_elimination");
509 }
510
511 #[test]
512 fn metal_skip_fusion_only_lowers_dot() {
513 let passes = fusion_passes(
514 FusionTarget::Metal,
515 FusionOptions {
516 skip_fusion: true,
517 ..FusionOptions::default()
518 },
519 );
520 assert_eq!(passes.len(), 2);
521 assert_eq!(passes[0].name(), "LowerControlFlow");
522 assert_eq!(passes[1].name(), "lower_dot_general");
523 }
524
525 #[test]
526 fn metal_supported_ops_omit_attention_block_fusion() {
527 let passes = fusion_passes_for_supported(
528 supported_for_target(FusionTarget::Metal),
529 FusionOptions::default(),
530 FusionTarget::Metal,
531 );
532 assert!(
533 !passes.iter().any(|p| p.name() == "fuse_attention_block"),
534 "Metal should not run FuseAttentionBlock"
535 );
536 assert!(
537 passes.iter().any(|p| p.name() == "fuse_matmul_bias_act"),
538 "Metal should fuse matmul+bias+act"
539 );
540 }
541
542 #[test]
543 fn cuda_supported_ops_fuse_matmul_bias_act() {
544 let passes = fusion_passes_for_supported(
545 supported_for_target(FusionTarget::Cuda),
546 FusionOptions::default(),
547 FusionTarget::Cuda,
548 );
549 assert!(
550 passes.iter().any(|p| p.name() == "fuse_matmul_bias_act"),
551 "CUDA should fuse matmul+bias+act when claimed"
552 );
553 assert!(
554 !passes.iter().any(|p| p.name() == "fuse_swiglu"),
555 "CUDA should not fuse SwiGLU"
556 );
557 }
558
559 #[test]
560 fn cpu_unfuses_elementwise_regions() {
561 let passes = fusion_passes_for_supported(
562 supported_for_target(FusionTarget::Cpu),
563 FusionOptions::for_cpu(),
564 FusionTarget::Cpu,
565 );
566 assert!(
567 passes
568 .iter()
569 .any(|p| p.name() == "unfuse_elementwise_regions")
570 );
571 }
572
573 #[test]
574 fn metal_unfuses_elementwise_regions_by_default() {
575 let passes = fusion_passes(FusionTarget::Metal, FusionOptions::default());
576 assert!(
577 passes
578 .iter()
579 .any(|p| p.name() == "unfuse_elementwise_regions")
580 );
581 }
582
583 #[test]
584 fn metal_default_unfuse_preserves_prologue_regions() {
585 let mut g = rlx_ir::Graph::new("t");
586 let shape_in = rlx_ir::Shape::new(&[1, 3, 8, 8], rlx_ir::DType::F32);
587 let shape_out = rlx_ir::Shape::new(&[1, 3, 16, 16], rlx_ir::DType::F32);
588 let x = g.input("x", shape_in);
589 let up = g.add_node(rlx_ir::Op::ResizeNearest2x, vec![x], shape_out.clone());
590 let r = g.add_node(
591 rlx_ir::Op::Activation(rlx_ir::op::Activation::Relu),
592 vec![up],
593 shape_out,
594 );
595 g.set_outputs(vec![r]);
596
597 let passes = fusion_passes(FusionTarget::Metal, FusionOptions::default());
598 let out = rlx_fusion::pass::run_passes(g, &passes, false);
599 assert!(out.nodes().iter().any(|n| {
600 matches!(
601 n.op,
602 rlx_ir::Op::ElementwiseRegion {
603 prologue: rlx_ir::RegionPrologue::ResizeNearest2x,
604 ..
605 }
606 )
607 }));
608 }
609
610 #[test]
611 fn fk_passes_after_elementwise_includes_batch_fusion() {
612 let opts = FusionOptions::default().apply_native_fk_defaults(FusionTarget::Tpu);
613 let passes =
614 fk_passes_after_elementwise_regions(supported_for_target(FusionTarget::Tpu), opts);
615 let names: Vec<_> = passes.iter().map(|p| p.name()).collect();
616 assert!(names.contains(&"mark_batch_slice_regions"));
617 assert!(names.contains(&"fuse_batch_preprocess"));
618 assert!(
619 !names.contains(&"decompose_fusion_regions"),
620 "TPU native FK defaults should keep batch/transform regions"
621 );
622 }
623
624 #[test]
625 fn tpu_native_fk_region_pass_policy() {
626 let _lock = ENV_FK_TEST_LOCK.lock().unwrap();
627 let default_passes = fusion_passes(FusionTarget::Tpu, FusionOptions::default());
628 assert!(
629 !default_passes
630 .iter()
631 .any(|p| p.name() == "decompose_fusion_regions"),
632 "default TPU pipeline keeps batch/transform regions via native_fk_defaults"
633 );
634
635 rlx_ir::env::set("RLX_NO_NATIVE_FK_REGIONS", "1");
636 let opt_out = fusion_passes(FusionTarget::Tpu, FusionOptions::default());
637 rlx_ir::env::unset("RLX_NO_NATIVE_FK_REGIONS");
638 assert!(
639 opt_out
640 .iter()
641 .any(|p| p.name() == "decompose_fusion_regions"),
642 "RLX_NO_NATIVE_FK_REGIONS should force decompose on TPU"
643 );
644 }
645
646 #[test]
647 fn native_fk_regions_skips_decompose_on_tpu() {
648 let passes = fusion_passes(
649 FusionTarget::Tpu,
650 FusionOptions {
651 native_fk_regions: true,
652 decompose_fusion_regions: false,
653 unfuse_elementwise_regions: false,
654 ..FusionOptions::default()
655 },
656 );
657 assert!(
658 !passes
659 .iter()
660 .any(|p| p.name() == "decompose_fusion_regions"),
661 "native_fk_regions should skip decompose on TPU when batch/transform are supported"
662 );
663 }
664
665 #[test]
666 fn native_fk_regions_skips_decompose_on_metal() {
667 let passes = fusion_passes(
668 FusionTarget::Metal,
669 FusionOptions {
670 native_fk_regions: true,
671 decompose_fusion_regions: false,
672 unfuse_elementwise_regions: false,
673 ..FusionOptions::default()
674 },
675 );
676 assert!(
677 !passes
678 .iter()
679 .any(|p| p.name() == "decompose_fusion_regions"),
680 "native_fk_regions should skip decompose when backend claims batch/transform ops"
681 );
682 }
683
684 #[test]
685 fn metal_keeps_elementwise_regions_when_requested() {
686 let passes = fusion_passes(
687 FusionTarget::Metal,
688 FusionOptions {
689 keep_elementwise_regions: true,
690 unfuse_elementwise_regions: false,
691 ..FusionOptions::default()
692 },
693 );
694 assert!(
695 !passes
696 .iter()
697 .any(|p| p.name() == "unfuse_elementwise_regions"),
698 "keep_elementwise_regions should skip unfuse pass"
699 );
700 assert!(
701 passes.iter().any(|p| p.name() == "fuse_region_prologue"),
702 "FKL prologue fusion should still run"
703 );
704 }
705}