1use rlx_ir::OpKind;
24
25use crate::DeadCodeElimination;
26use rlx_fusion::control_flow::LowerControlFlow;
27use rlx_fusion::fk_fusion::{
28 DecomposeFusionRegions, FuseBatchPreprocess, FuseRegionPrologue, MarkBatchSliceRegions,
29 MarkTransformRegions,
30};
31use rlx_fusion::fusion::{
32 FuseAttentionBlock, FuseMatMulBiasAct, FuseResidualLN, FuseResidualRmsNorm, FuseRmsNormReshape,
33 FuseSharedInputMatMul, FuseSwiGLU, FuseSwiGLUDualMatmul, FuseTransformerLayer,
34 MarkElementwiseRegions, UnfuseElementwiseRegions,
35};
36use rlx_fusion::limits::FusionLimits;
37use rlx_fusion::lower_dot_general::LowerDotGeneral;
38use rlx_fusion::pass::Pass;
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum FusionTarget {
43 Cpu,
44 Metal,
45 Mlx,
46 Wgpu,
47 Cuda,
48 Rocm,
49 Tpu,
50}
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub struct FusionOptions {
55 pub skip_fusion: bool,
57 pub unfuse_elementwise_regions: bool,
59 pub keep_elementwise_regions: bool,
61 pub decompose_fusion_regions: bool,
63 pub fk_fusion: bool,
65 pub fuse_region_prologue: bool,
67 pub fuse_batch_preprocess: bool,
69 pub native_fk_regions: bool,
71 pub fusion_limits: FusionLimits,
73}
74
75impl Default for FusionOptions {
76 fn default() -> Self {
77 Self {
78 skip_fusion: false,
79 unfuse_elementwise_regions: false,
80 keep_elementwise_regions: false,
81 decompose_fusion_regions: false,
82 fk_fusion: true,
83 fuse_region_prologue: true,
84 fuse_batch_preprocess: true,
85 native_fk_regions: false,
86 fusion_limits: FusionLimits::default(),
87 }
88 }
89}
90
91impl FusionOptions {
92 pub fn from_metal_env() -> Self {
94 Self {
95 skip_fusion: rlx_ir::env::flag("RLX_METAL_NO_FUSION"),
96 unfuse_elementwise_regions: rlx_ir::env::flag("RLX_METAL_UNFUSE_REGIONS"),
97 keep_elementwise_regions: rlx_ir::env::flag("RLX_KEEP_ELEMENTWISE_REGIONS"),
98 decompose_fusion_regions: rlx_ir::env::flag("RLX_DECOMPOSE_FUSION_REGIONS"),
99 fk_fusion: !rlx_ir::env::flag("RLX_NO_FK_FUSION"),
100 fuse_region_prologue: if rlx_ir::env::is_unset("RLX_FUSE_REGION_PROLOGUE") {
101 true
102 } else {
103 rlx_ir::env::flag("RLX_FUSE_REGION_PROLOGUE")
104 },
105 fuse_batch_preprocess: if rlx_ir::env::is_unset("RLX_FUSE_BATCH_PREPROCESS") {
106 true
107 } else {
108 rlx_ir::env::flag("RLX_FUSE_BATCH_PREPROCESS")
109 },
110 native_fk_regions: rlx_ir::env::flag("RLX_NATIVE_FK_REGIONS"),
111 ..Self::default()
112 }
113 }
114
115 pub fn merge_env(mut self) -> Self {
117 if rlx_ir::env::flag("RLX_METAL_NO_FUSION") {
118 self.skip_fusion = true;
119 }
120 if rlx_ir::env::flag("RLX_METAL_UNFUSE_REGIONS") {
121 self.unfuse_elementwise_regions = true;
122 }
123 if rlx_ir::env::flag("RLX_KEEP_ELEMENTWISE_REGIONS") {
124 self.keep_elementwise_regions = true;
125 }
126 if rlx_ir::env::flag("RLX_DECOMPOSE_FUSION_REGIONS") {
127 self.decompose_fusion_regions = true;
128 }
129 if rlx_ir::env::flag("RLX_NO_FK_FUSION") {
130 self.fk_fusion = false;
131 }
132 if !rlx_ir::env::is_unset("RLX_FUSE_REGION_PROLOGUE") {
133 self.fuse_region_prologue = rlx_ir::env::flag("RLX_FUSE_REGION_PROLOGUE");
134 }
135 if !rlx_ir::env::is_unset("RLX_FUSE_BATCH_PREPROCESS") {
136 self.fuse_batch_preprocess = rlx_ir::env::flag("RLX_FUSE_BATCH_PREPROCESS");
137 }
138 if rlx_ir::env::flag("RLX_NATIVE_FK_REGIONS") {
139 self.native_fk_regions = true;
140 }
141 if rlx_ir::env::flag("RLX_NO_NATIVE_FK_REGIONS") {
142 self.native_fk_regions = false;
143 }
144 self
145 }
146
147 pub fn apply_native_fk_defaults(mut self, target: FusionTarget) -> Self {
149 if rlx_ir::env::flag("RLX_NO_NATIVE_FK_REGIONS") {
150 self.native_fk_regions = false;
151 return self;
152 }
153 if self.native_fk_regions || rlx_ir::env::flag("RLX_NATIVE_FK_REGIONS") {
154 self.native_fk_regions = true;
155 return self;
156 }
157 if matches!(
158 target,
159 FusionTarget::Metal
160 | FusionTarget::Cuda
161 | FusionTarget::Rocm
162 | FusionTarget::Wgpu
163 | FusionTarget::Mlx
164 | FusionTarget::Tpu
165 ) {
166 self.native_fk_regions = true;
167 }
168 self
169 }
170
171 pub fn for_cpu() -> Self {
173 Self {
174 unfuse_elementwise_regions: true,
175 fusion_limits: FusionLimits::UNBOUNDED,
176 ..Self::default()
177 }
178 }
179
180 pub fn for_metal() -> Self {
183 let mut opts = Self::from_metal_env();
184 opts.unfuse_elementwise_regions = true;
185 opts
186 }
187
188 pub fn for_wgpu() -> Self {
191 let keep = rlx_ir::env::flag("RLX_KEEP_ELEMENTWISE_REGIONS");
192 Self {
193 unfuse_elementwise_regions: !keep,
194 keep_elementwise_regions: keep,
195 ..Self::default()
196 }
197 }
198}
199
200pub fn fusion_limits_for_target(target: FusionTarget) -> FusionLimits {
202 match target {
203 FusionTarget::Cpu => FusionLimits::UNBOUNDED,
204 FusionTarget::Tpu => FusionLimits {
205 max_elementwise_steps: 32,
206 max_elementwise_inputs: 16,
207 },
208 _ => FusionLimits::GPU_NATIVE,
209 }
210}
211
212#[inline]
214pub fn supports_op(supported: &[OpKind], kind: OpKind) -> bool {
215 supported.is_empty() || supported.contains(&kind)
216}
217
218pub fn fusion_passes_for_supported(
227 supported: &[OpKind],
228 opts: FusionOptions,
229 target: FusionTarget,
230) -> Vec<&'static dyn Pass> {
231 let opts = opts.apply_native_fk_defaults(target);
232 if opts.skip_fusion {
233 return vec![&LowerControlFlow, &LowerDotGeneral];
234 }
235
236 let mut passes: Vec<&'static dyn Pass> = vec![&LowerControlFlow, &LowerDotGeneral];
237
238 if supports_op(supported, OpKind::FusedMatMulBiasAct) {
244 passes.push(&FuseMatMulBiasAct);
245 }
246 if supports_op(supported, OpKind::FusedAttentionBlock) {
254 passes.push(&FuseAttentionBlock);
255 }
256 if supports_op(supported, OpKind::FusedResidualLN) {
260 passes.push(&FuseResidualLN);
261 }
262 if supports_op(supported, OpKind::FusedResidualRmsNorm) {
263 passes.push(&FuseResidualRmsNorm);
264 }
265 passes.push(&FuseRmsNormReshape);
266
267 if rlx_ir::env::flag("RLX_ENABLE_FUSE_TRANSFORMER_LAYER")
276 && supports_op(supported, OpKind::FusedTransformerLayer)
277 && supports_op(supported, OpKind::FusedAttentionBlock)
278 {
279 passes.push(&FuseTransformerLayer);
280 }
281
282 if supports_op(supported, OpKind::FusedSwiGLU) {
283 passes.push(&FuseSwiGLUDualMatmul);
284 }
285 if supports_op(supported, OpKind::MatMul) {
286 passes.push(&FuseSharedInputMatMul);
287 }
288 if supports_op(supported, OpKind::FusedSwiGLU) {
289 passes.push(&FuseSwiGLU);
290 }
291
292 passes.push(&MarkElementwiseRegions);
295 if opts.fk_fusion {
296 passes.push(&MarkBatchSliceRegions);
297 passes.push(&MarkTransformRegions);
298 if opts.fuse_region_prologue {
299 passes.push(&FuseRegionPrologue);
300 }
301 if opts.fuse_batch_preprocess {
302 passes.push(&FuseBatchPreprocess);
303 }
304 }
305 let backend_native_fk = supports_op(supported, OpKind::TransformRegion)
306 && supports_op(supported, OpKind::BatchElementwiseRegion);
307 let keep_native_fk = opts.native_fk_regions && backend_native_fk;
308 if opts.decompose_fusion_regions || !keep_native_fk {
309 passes.push(&DecomposeFusionRegions);
310 }
311 let keep_regions =
312 supports_op(supported, OpKind::ElementwiseRegion) && !opts.unfuse_elementwise_regions;
313 if !keep_regions {
314 let unfuse = if matches!(target, FusionTarget::Cpu) {
315 &UnfuseElementwiseRegions::FOR_CPU
316 } else {
317 &UnfuseElementwiseRegions::FOR_GPU
318 };
319 passes.push(unfuse);
320 }
321
322 finish_pipeline(passes)
323}
324
325pub fn fk_passes_after_elementwise_regions(
327 supported: &[OpKind],
328 opts: FusionOptions,
329) -> Vec<&'static dyn Pass> {
330 let mut passes: Vec<&'static dyn Pass> = Vec::new();
331 if !opts.fk_fusion {
332 let backend_native_fk = supports_op(supported, OpKind::TransformRegion)
333 && supports_op(supported, OpKind::BatchElementwiseRegion);
334 let keep_native_fk = opts.native_fk_regions && backend_native_fk;
335 if opts.decompose_fusion_regions || !keep_native_fk {
336 passes.push(&DecomposeFusionRegions);
337 }
338 return finish_pipeline(passes);
339 }
340 passes.push(&MarkBatchSliceRegions);
341 passes.push(&MarkTransformRegions);
342 if opts.fuse_region_prologue {
343 passes.push(&FuseRegionPrologue);
344 }
345 if opts.fuse_batch_preprocess {
346 passes.push(&FuseBatchPreprocess);
347 }
348 let backend_native_fk = supports_op(supported, OpKind::TransformRegion)
349 && supports_op(supported, OpKind::BatchElementwiseRegion);
350 let keep_native_fk = opts.native_fk_regions && backend_native_fk;
351 if opts.decompose_fusion_regions || !keep_native_fk {
352 passes.push(&DecomposeFusionRegions);
353 }
354 finish_pipeline(passes)
355}
356
357pub fn fusion_passes(target: FusionTarget, opts: FusionOptions) -> Vec<&'static dyn Pass> {
359 let mut opts = opts;
360 if !opts.keep_elementwise_regions
364 && matches!(target, FusionTarget::Cpu | FusionTarget::Metal)
365 && !opts.unfuse_elementwise_regions
366 {
367 opts.unfuse_elementwise_regions = true;
368 }
369 if opts.fusion_limits == FusionLimits::default() {
370 opts.fusion_limits = fusion_limits_for_target(target);
371 }
372 opts = opts.apply_native_fk_defaults(target);
373 fusion_passes_for_supported(supported_for_target(target), opts, target)
374}
375
376pub fn supported_for_target(target: FusionTarget) -> &'static [OpKind] {
380 use OpKind::*;
381 match target {
382 FusionTarget::Cpu => &[
383 MatMul,
384 DotGeneral,
385 ElementwiseRegion,
386 FusedSwiGLU,
387 FusedMatMulBiasAct,
388 FusedResidualLN,
389 FusedResidualRmsNorm,
390 FusedAttentionBlock,
391 ],
392 FusionTarget::Metal => &[
393 MatMul,
394 DotGeneral,
395 ElementwiseRegion,
396 TransformRegion,
397 BatchElementwiseRegion,
398 FusedSwiGLU,
399 FusedMatMulBiasAct,
400 FusedResidualLN,
401 FusedResidualRmsNorm,
402 ],
403 FusionTarget::Mlx => &[
404 MatMul,
405 DotGeneral,
406 ElementwiseRegion,
407 TransformRegion,
408 BatchElementwiseRegion,
409 FusedSwiGLU,
410 FusedMatMulBiasAct,
411 FusedResidualLN,
412 FusedResidualRmsNorm,
413 ],
414 FusionTarget::Wgpu => &[
415 MatMul,
416 ElementwiseRegion,
417 TransformRegion,
418 BatchElementwiseRegion,
419 FusedSwiGLU,
420 FusedMatMulBiasAct,
421 FusedResidualLN,
422 FusedResidualRmsNorm,
423 FusedAttentionBlock,
424 FusedTransformerLayer,
425 ],
426 FusionTarget::Cuda | FusionTarget::Rocm => &[
427 MatMul,
428 DotGeneral,
429 ElementwiseRegion,
430 TransformRegion,
431 BatchElementwiseRegion,
432 FusedMatMulBiasAct,
433 FusedResidualLN,
434 FusedResidualRmsNorm,
435 ],
436 FusionTarget::Tpu => &[
437 MatMul,
438 ElementwiseRegion,
439 TransformRegion,
440 BatchElementwiseRegion,
441 FusedMatMulBiasAct,
442 FusedResidualLN,
443 ],
444 }
445}
446
447fn finish_pipeline(mut passes: Vec<&'static dyn Pass>) -> Vec<&'static dyn Pass> {
448 passes.push(&DeadCodeElimination);
449 passes
450}
451
452#[cfg(test)]
453mod tests {
454 use super::*;
455 use std::sync::Mutex;
456
457 static ENV_FK_TEST_LOCK: Mutex<()> = Mutex::new(());
458
459 #[test]
460 fn cpu_pipeline_includes_attention_block() {
461 let passes = fusion_passes(FusionTarget::Cpu, FusionOptions::default());
462 assert_eq!(passes.len(), 18);
463 assert_eq!(passes[2].name(), "fuse_matmul_bias_act");
464 assert_eq!(passes[3].name(), "fuse_attention_block");
465 assert!(
466 passes.iter().any(|p| p.name() == "fuse_region_prologue"),
467 "default CPU pipeline should run FKL prologue fusion"
468 );
469 assert_eq!(passes.last().unwrap().name(), "dead_code_elimination");
470 }
471
472 #[test]
473 fn metal_skip_fusion_only_lowers_dot() {
474 let passes = fusion_passes(
475 FusionTarget::Metal,
476 FusionOptions {
477 skip_fusion: true,
478 ..FusionOptions::default()
479 },
480 );
481 assert_eq!(passes.len(), 2);
482 assert_eq!(passes[0].name(), "LowerControlFlow");
483 assert_eq!(passes[1].name(), "lower_dot_general");
484 }
485
486 #[test]
487 fn metal_supported_ops_omit_attention_block_fusion() {
488 let passes = fusion_passes_for_supported(
489 supported_for_target(FusionTarget::Metal),
490 FusionOptions::default(),
491 FusionTarget::Metal,
492 );
493 assert!(
494 !passes.iter().any(|p| p.name() == "fuse_attention_block"),
495 "Metal should not run FuseAttentionBlock"
496 );
497 assert!(
498 passes.iter().any(|p| p.name() == "fuse_matmul_bias_act"),
499 "Metal should fuse matmul+bias+act"
500 );
501 }
502
503 #[test]
504 fn cuda_supported_ops_fuse_matmul_bias_act() {
505 let passes = fusion_passes_for_supported(
506 supported_for_target(FusionTarget::Cuda),
507 FusionOptions::default(),
508 FusionTarget::Cuda,
509 );
510 assert!(
511 passes.iter().any(|p| p.name() == "fuse_matmul_bias_act"),
512 "CUDA should fuse matmul+bias+act when claimed"
513 );
514 assert!(
515 !passes.iter().any(|p| p.name() == "fuse_swiglu"),
516 "CUDA should not fuse SwiGLU"
517 );
518 }
519
520 #[test]
521 fn cpu_unfuses_elementwise_regions() {
522 let passes = fusion_passes_for_supported(
523 supported_for_target(FusionTarget::Cpu),
524 FusionOptions::for_cpu(),
525 FusionTarget::Cpu,
526 );
527 assert!(
528 passes
529 .iter()
530 .any(|p| p.name() == "unfuse_elementwise_regions")
531 );
532 }
533
534 #[test]
535 fn metal_unfuses_elementwise_regions_by_default() {
536 let passes = fusion_passes(FusionTarget::Metal, FusionOptions::default());
537 assert!(
538 passes
539 .iter()
540 .any(|p| p.name() == "unfuse_elementwise_regions")
541 );
542 }
543
544 #[test]
545 fn metal_default_unfuse_preserves_prologue_regions() {
546 let mut g = rlx_ir::Graph::new("t");
547 let shape_in = rlx_ir::Shape::new(&[1, 3, 8, 8], rlx_ir::DType::F32);
548 let shape_out = rlx_ir::Shape::new(&[1, 3, 16, 16], rlx_ir::DType::F32);
549 let x = g.input("x", shape_in);
550 let up = g.add_node(rlx_ir::Op::ResizeNearest2x, vec![x], shape_out.clone());
551 let r = g.add_node(
552 rlx_ir::Op::Activation(rlx_ir::op::Activation::Relu),
553 vec![up],
554 shape_out,
555 );
556 g.set_outputs(vec![r]);
557
558 let passes = fusion_passes(FusionTarget::Metal, FusionOptions::default());
559 let out = rlx_fusion::pass::run_passes(g, &passes, false);
560 assert!(out.nodes().iter().any(|n| {
561 matches!(
562 n.op,
563 rlx_ir::Op::ElementwiseRegion {
564 prologue: rlx_ir::RegionPrologue::ResizeNearest2x,
565 ..
566 }
567 )
568 }));
569 }
570
571 #[test]
572 fn fk_passes_after_elementwise_includes_batch_fusion() {
573 let opts = FusionOptions::default().apply_native_fk_defaults(FusionTarget::Tpu);
574 let passes =
575 fk_passes_after_elementwise_regions(supported_for_target(FusionTarget::Tpu), opts);
576 let names: Vec<_> = passes.iter().map(|p| p.name()).collect();
577 assert!(names.contains(&"mark_batch_slice_regions"));
578 assert!(names.contains(&"fuse_batch_preprocess"));
579 assert!(
580 !names.contains(&"decompose_fusion_regions"),
581 "TPU native FK defaults should keep batch/transform regions"
582 );
583 }
584
585 #[test]
586 fn tpu_native_fk_region_pass_policy() {
587 let _lock = ENV_FK_TEST_LOCK.lock().unwrap();
588 let default_passes = fusion_passes(FusionTarget::Tpu, FusionOptions::default());
589 assert!(
590 !default_passes
591 .iter()
592 .any(|p| p.name() == "decompose_fusion_regions"),
593 "default TPU pipeline keeps batch/transform regions via native_fk_defaults"
594 );
595
596 rlx_ir::env::set("RLX_NO_NATIVE_FK_REGIONS", "1");
597 let opt_out = fusion_passes(FusionTarget::Tpu, FusionOptions::default());
598 rlx_ir::env::unset("RLX_NO_NATIVE_FK_REGIONS");
599 assert!(
600 opt_out
601 .iter()
602 .any(|p| p.name() == "decompose_fusion_regions"),
603 "RLX_NO_NATIVE_FK_REGIONS should force decompose on TPU"
604 );
605 }
606
607 #[test]
608 fn native_fk_regions_skips_decompose_on_tpu() {
609 let passes = fusion_passes(
610 FusionTarget::Tpu,
611 FusionOptions {
612 native_fk_regions: true,
613 decompose_fusion_regions: false,
614 unfuse_elementwise_regions: false,
615 ..FusionOptions::default()
616 },
617 );
618 assert!(
619 !passes
620 .iter()
621 .any(|p| p.name() == "decompose_fusion_regions"),
622 "native_fk_regions should skip decompose on TPU when batch/transform are supported"
623 );
624 }
625
626 #[test]
627 fn native_fk_regions_skips_decompose_on_metal() {
628 let passes = fusion_passes(
629 FusionTarget::Metal,
630 FusionOptions {
631 native_fk_regions: true,
632 decompose_fusion_regions: false,
633 unfuse_elementwise_regions: false,
634 ..FusionOptions::default()
635 },
636 );
637 assert!(
638 !passes
639 .iter()
640 .any(|p| p.name() == "decompose_fusion_regions"),
641 "native_fk_regions should skip decompose when backend claims batch/transform ops"
642 );
643 }
644
645 #[test]
646 fn metal_keeps_elementwise_regions_when_requested() {
647 let passes = fusion_passes(
648 FusionTarget::Metal,
649 FusionOptions {
650 keep_elementwise_regions: true,
651 unfuse_elementwise_regions: false,
652 ..FusionOptions::default()
653 },
654 );
655 assert!(
656 !passes
657 .iter()
658 .any(|p| p.name() == "unfuse_elementwise_regions"),
659 "keep_elementwise_regions should skip unfuse pass"
660 );
661 assert!(
662 passes.iter().any(|p| p.name() == "fuse_region_prologue"),
663 "FKL prologue fusion should still run"
664 );
665 }
666}