1use rlx_ir::OpKind;
24
25use crate::DeadCodeElimination;
26use crate::io_output_gate::SelectPeaksOnlyOutputs;
27use rlx_fusion::control_flow::LowerControlFlow;
28use rlx_fusion::fk_fusion::{
29 DecomposeFusionRegions, FuseBatchPreprocess, FuseRegionPrologue, MarkBatchSliceRegions,
30 MarkTransformRegions,
31};
32use rlx_fusion::fusion::{
33 FuseAttentionBlock, FuseMatMulBiasAct, FuseResidualLN, FuseResidualRmsNorm, FuseRmsNormReshape,
34 FuseSharedInputMatMul, FuseSwiGLU, FuseSwiGLUDualMatmul, FuseTransformerLayer,
35 MarkElementwiseRegions, UnfuseElementwiseRegions,
36};
37use rlx_fusion::limits::{FusionLimits, with_fusion_limits};
38use rlx_fusion::lower_dot_general::LowerDotGeneral;
39use rlx_fusion::pass::{Pass, run_passes};
40use rlx_ir::Graph;
41
42use crate::fusion_target::with_fusion_target;
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46pub enum FusionTarget {
47 Cpu,
48 Metal,
49 Mlx,
50 Wgpu,
51 Cuda,
52 Rocm,
53 Tpu,
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub struct FusionOptions {
59 pub skip_fusion: bool,
61 pub unfuse_elementwise_regions: bool,
63 pub keep_elementwise_regions: bool,
65 pub decompose_fusion_regions: bool,
67 pub fk_fusion: bool,
69 pub fuse_region_prologue: bool,
71 pub fuse_batch_preprocess: bool,
73 pub native_fk_regions: bool,
75 pub fusion_limits: FusionLimits,
77}
78
79impl Default for FusionOptions {
80 fn default() -> Self {
81 Self {
82 skip_fusion: false,
83 unfuse_elementwise_regions: false,
84 keep_elementwise_regions: false,
85 decompose_fusion_regions: false,
86 fk_fusion: true,
87 fuse_region_prologue: true,
88 fuse_batch_preprocess: true,
89 native_fk_regions: false,
90 fusion_limits: FusionLimits::default(),
91 }
92 }
93}
94
95impl FusionOptions {
96 pub fn from_metal_env() -> Self {
98 Self {
99 skip_fusion: rlx_ir::env::flag("RLX_METAL_NO_FUSION"),
100 unfuse_elementwise_regions: rlx_ir::env::flag("RLX_METAL_UNFUSE_REGIONS"),
101 keep_elementwise_regions: rlx_ir::env::flag("RLX_KEEP_ELEMENTWISE_REGIONS"),
102 decompose_fusion_regions: rlx_ir::env::flag("RLX_DECOMPOSE_FUSION_REGIONS"),
103 fk_fusion: !rlx_ir::env::flag("RLX_NO_FK_FUSION"),
104 fuse_region_prologue: if rlx_ir::env::is_unset("RLX_FUSE_REGION_PROLOGUE") {
105 true
106 } else {
107 rlx_ir::env::flag("RLX_FUSE_REGION_PROLOGUE")
108 },
109 fuse_batch_preprocess: if rlx_ir::env::is_unset("RLX_FUSE_BATCH_PREPROCESS") {
110 true
111 } else {
112 rlx_ir::env::flag("RLX_FUSE_BATCH_PREPROCESS")
113 },
114 native_fk_regions: rlx_ir::env::flag("RLX_NATIVE_FK_REGIONS"),
115 ..Self::default()
116 }
117 }
118
119 pub fn merge_env(mut self) -> Self {
121 if rlx_ir::env::flag("RLX_METAL_NO_FUSION") {
122 self.skip_fusion = true;
123 }
124 if rlx_ir::env::flag("RLX_METAL_UNFUSE_REGIONS") {
125 self.unfuse_elementwise_regions = true;
126 }
127 if rlx_ir::env::flag("RLX_KEEP_ELEMENTWISE_REGIONS") {
128 self.keep_elementwise_regions = true;
129 }
130 if rlx_ir::env::flag("RLX_DECOMPOSE_FUSION_REGIONS") {
131 self.decompose_fusion_regions = true;
132 }
133 if rlx_ir::env::flag("RLX_NO_FK_FUSION") {
134 self.fk_fusion = false;
135 }
136 if !rlx_ir::env::is_unset("RLX_FUSE_REGION_PROLOGUE") {
137 self.fuse_region_prologue = rlx_ir::env::flag("RLX_FUSE_REGION_PROLOGUE");
138 }
139 if !rlx_ir::env::is_unset("RLX_FUSE_BATCH_PREPROCESS") {
140 self.fuse_batch_preprocess = rlx_ir::env::flag("RLX_FUSE_BATCH_PREPROCESS");
141 }
142 if rlx_ir::env::flag("RLX_NATIVE_FK_REGIONS") {
143 self.native_fk_regions = true;
144 }
145 if rlx_ir::env::flag("RLX_NO_NATIVE_FK_REGIONS") {
146 self.native_fk_regions = false;
147 }
148 self
149 }
150
151 pub fn apply_native_fk_defaults(mut self, target: FusionTarget) -> Self {
153 if rlx_ir::env::flag("RLX_NO_NATIVE_FK_REGIONS") {
154 self.native_fk_regions = false;
155 return self;
156 }
157 if self.native_fk_regions || rlx_ir::env::flag("RLX_NATIVE_FK_REGIONS") {
158 self.native_fk_regions = true;
159 return self;
160 }
161 if matches!(
162 target,
163 FusionTarget::Metal
164 | FusionTarget::Cuda
165 | FusionTarget::Rocm
166 | FusionTarget::Wgpu
167 | FusionTarget::Mlx
168 | FusionTarget::Tpu
169 ) {
170 self.native_fk_regions = true;
171 }
172 self
173 }
174
175 pub fn for_cpu() -> Self {
177 Self {
178 unfuse_elementwise_regions: true,
179 fusion_limits: FusionLimits::UNBOUNDED,
180 ..Self::default()
181 }
182 }
183
184 pub fn for_metal() -> Self {
187 let mut opts = Self::from_metal_env();
188 opts.unfuse_elementwise_regions = true;
189 opts
190 }
191
192 pub fn for_wgpu() -> Self {
195 let keep = rlx_ir::env::flag("RLX_KEEP_ELEMENTWISE_REGIONS");
196 Self {
197 unfuse_elementwise_regions: !keep,
198 keep_elementwise_regions: keep,
199 ..Self::default()
200 }
201 }
202}
203
204pub fn fusion_limits_for_target(target: FusionTarget) -> FusionLimits {
206 match target {
207 FusionTarget::Cpu => FusionLimits::UNBOUNDED,
208 FusionTarget::Tpu => FusionLimits {
209 max_elementwise_steps: 32,
210 max_elementwise_inputs: 16,
211 },
212 _ => FusionLimits::GPU_NATIVE,
213 }
214}
215
216#[inline]
218pub fn supports_op(supported: &[OpKind], kind: OpKind) -> bool {
219 supported.is_empty() || supported.contains(&kind)
220}
221
222pub fn fusion_passes_for_supported(
231 supported: &[OpKind],
232 opts: FusionOptions,
233 target: FusionTarget,
234) -> Vec<&'static dyn Pass> {
235 let opts = opts.apply_native_fk_defaults(target);
236 if opts.skip_fusion {
237 return vec![&LowerControlFlow, &LowerDotGeneral];
238 }
239
240 let mut passes: Vec<&'static dyn Pass> = vec![&LowerControlFlow, &LowerDotGeneral];
241
242 if supports_op(supported, OpKind::FusedMatMulBiasAct) {
248 passes.push(&FuseMatMulBiasAct);
249 }
250 if supports_op(supported, OpKind::FusedAttentionBlock) {
258 passes.push(&FuseAttentionBlock);
259 }
260 if supports_op(supported, OpKind::FusedResidualLN) {
264 passes.push(&FuseResidualLN);
265 }
266 if supports_op(supported, OpKind::FusedResidualRmsNorm) {
267 passes.push(&FuseResidualRmsNorm);
268 }
269 passes.push(&FuseRmsNormReshape);
270
271 if rlx_ir::env::flag("RLX_ENABLE_FUSE_TRANSFORMER_LAYER")
280 && supports_op(supported, OpKind::FusedTransformerLayer)
281 && supports_op(supported, OpKind::FusedAttentionBlock)
282 {
283 passes.push(&FuseTransformerLayer);
284 }
285
286 if supports_op(supported, OpKind::FusedSwiGLU) {
287 passes.push(&FuseSwiGLUDualMatmul);
288 }
289 if supports_op(supported, OpKind::MatMul) {
290 passes.push(&FuseSharedInputMatMul);
291 }
292 if supports_op(supported, OpKind::FusedSwiGLU) {
293 passes.push(&FuseSwiGLU);
294 }
295
296 let keep_regions =
300 supports_op(supported, OpKind::ElementwiseRegion) && !opts.unfuse_elementwise_regions;
301 if keep_regions {
302 passes.push(&MarkElementwiseRegions);
303 }
304 if opts.fk_fusion {
305 passes.push(&MarkBatchSliceRegions);
306 passes.push(&MarkTransformRegions);
307 if opts.fuse_region_prologue {
308 passes.push(&FuseRegionPrologue);
309 }
310 if opts.fuse_batch_preprocess {
311 passes.push(&FuseBatchPreprocess);
312 }
313 }
314 let backend_native_fk = supports_op(supported, OpKind::TransformRegion)
315 && supports_op(supported, OpKind::BatchElementwiseRegion);
316 let keep_native_fk = opts.native_fk_regions && backend_native_fk;
317 if opts.decompose_fusion_regions || !keep_native_fk {
318 passes.push(&DecomposeFusionRegions);
319 }
320 if !keep_regions {
321 let unfuse = if matches!(target, FusionTarget::Cpu) {
322 &UnfuseElementwiseRegions::FOR_CPU
323 } else {
324 &UnfuseElementwiseRegions::FOR_GPU
325 };
326 passes.push(unfuse);
327 }
328
329 if supports_op(supported, OpKind::Fft) && supports_op(supported, OpKind::WelchPeaks) {
330 passes.push(&SelectPeaksOnlyOutputs);
331 }
332
333 finish_pipeline(passes)
334}
335
336pub fn fk_passes_after_elementwise_regions(
338 supported: &[OpKind],
339 opts: FusionOptions,
340) -> Vec<&'static dyn Pass> {
341 let mut passes: Vec<&'static dyn Pass> = Vec::new();
342 if !opts.fk_fusion {
343 let backend_native_fk = supports_op(supported, OpKind::TransformRegion)
344 && supports_op(supported, OpKind::BatchElementwiseRegion);
345 let keep_native_fk = opts.native_fk_regions && backend_native_fk;
346 if opts.decompose_fusion_regions || !keep_native_fk {
347 passes.push(&DecomposeFusionRegions);
348 }
349 return finish_pipeline(passes);
350 }
351 passes.push(&MarkBatchSliceRegions);
352 passes.push(&MarkTransformRegions);
353 if opts.fuse_region_prologue {
354 passes.push(&FuseRegionPrologue);
355 }
356 if opts.fuse_batch_preprocess {
357 passes.push(&FuseBatchPreprocess);
358 }
359 let backend_native_fk = supports_op(supported, OpKind::TransformRegion)
360 && supports_op(supported, OpKind::BatchElementwiseRegion);
361 let keep_native_fk = opts.native_fk_regions && backend_native_fk;
362 if opts.decompose_fusion_regions || !keep_native_fk {
363 passes.push(&DecomposeFusionRegions);
364 }
365 finish_pipeline(passes)
366}
367
368pub fn should_fuse_with_target(
370 target: FusionTarget,
371 before: &crate::fusion_benefit::GraphIoProfile,
372 after: &crate::fusion_benefit::GraphIoProfile,
373) -> bool {
374 io_fusion_gate_for_target(target).should_fuse(before, after)
375}
376
377pub fn io_fusion_gate_for_target(target: FusionTarget) -> crate::fusion_benefit::IoFusionGate {
379 use crate::fusion_benefit::IoFusionGate;
380 match target {
381 FusionTarget::Metal | FusionTarget::Mlx => IoFusionGate {
382 dispatch_ns: 500.0,
383 roundtrip_ns: 5_000.0,
384 memory_bw: 200.0,
385 host_readback_bw: 200.0,
386 unified_memory: true,
387 host_thunk_penalty_ns: 2_000_000.0,
388 min_gain_ns: 1_000.0,
389 },
390 FusionTarget::Cuda | FusionTarget::Rocm => IoFusionGate {
391 dispatch_ns: 2_000.0,
392 roundtrip_ns: 20_000.0,
393 memory_bw: 800.0,
394 host_readback_bw: 50.0,
395 unified_memory: false,
396 host_thunk_penalty_ns: 15_000_000.0,
397 min_gain_ns: 5_000.0,
398 },
399 FusionTarget::Wgpu | FusionTarget::Tpu => IoFusionGate {
400 dispatch_ns: 3_000.0,
401 roundtrip_ns: 30_000.0,
402 memory_bw: 100.0,
403 host_readback_bw: 40.0,
404 unified_memory: false,
405 host_thunk_penalty_ns: 25_000_000.0,
406 min_gain_ns: 10_000.0,
407 },
408 FusionTarget::Cpu => IoFusionGate {
409 dispatch_ns: 50.0,
410 roundtrip_ns: 0.0,
411 memory_bw: 50.0,
412 host_readback_bw: 50.0,
413 unified_memory: true,
414 host_thunk_penalty_ns: 0.0,
415 min_gain_ns: 0.0,
416 },
417 }
418}
419
420pub fn fusion_passes(target: FusionTarget, opts: FusionOptions) -> Vec<&'static dyn Pass> {
422 let mut opts = opts;
423 if !opts.keep_elementwise_regions
427 && matches!(target, FusionTarget::Cpu | FusionTarget::Metal)
428 && !opts.unfuse_elementwise_regions
429 {
430 opts.unfuse_elementwise_regions = true;
431 }
432 if opts.fusion_limits == FusionLimits::default() {
433 opts.fusion_limits = fusion_limits_for_target(target);
434 }
435 opts = opts.apply_native_fk_defaults(target);
436 fusion_passes_for_supported(supported_for_target(target), opts, target)
437}
438
439pub fn supported_for_target(target: FusionTarget) -> &'static [OpKind] {
443 use OpKind::*;
444 match target {
445 FusionTarget::Cpu => &[
446 MatMul,
447 DotGeneral,
448 ElementwiseRegion,
449 FusedSwiGLU,
450 FusedMatMulBiasAct,
451 FusedResidualLN,
452 FusedResidualRmsNorm,
453 FusedAttentionBlock,
454 ],
455 FusionTarget::Metal => &[
456 MatMul,
457 DotGeneral,
458 ElementwiseRegion,
459 TransformRegion,
460 BatchElementwiseRegion,
461 FusedSwiGLU,
462 FusedMatMulBiasAct,
463 FusedResidualLN,
464 FusedResidualRmsNorm,
465 ],
466 FusionTarget::Mlx => &[
467 MatMul,
468 DotGeneral,
469 ElementwiseRegion,
470 TransformRegion,
471 BatchElementwiseRegion,
472 FusedSwiGLU,
473 FusedMatMulBiasAct,
474 FusedResidualLN,
475 FusedResidualRmsNorm,
476 ],
477 FusionTarget::Wgpu => &[
478 MatMul,
479 ElementwiseRegion,
480 TransformRegion,
481 BatchElementwiseRegion,
482 FusedSwiGLU,
483 FusedMatMulBiasAct,
484 FusedResidualLN,
485 FusedResidualRmsNorm,
486 FusedAttentionBlock,
487 FusedTransformerLayer,
488 ],
489 FusionTarget::Cuda | FusionTarget::Rocm => &[
490 MatMul,
491 DotGeneral,
492 ElementwiseRegion,
493 TransformRegion,
494 BatchElementwiseRegion,
495 FusedMatMulBiasAct,
496 FusedResidualLN,
497 FusedResidualRmsNorm,
498 ],
499 FusionTarget::Tpu => &[
500 MatMul,
501 ElementwiseRegion,
502 TransformRegion,
503 BatchElementwiseRegion,
504 FusedMatMulBiasAct,
505 FusedResidualLN,
506 ],
507 }
508}
509
510fn finish_pipeline(mut passes: Vec<&'static dyn Pass>) -> Vec<&'static dyn Pass> {
511 passes.push(&DeadCodeElimination);
512 passes
513}
514
515pub fn run_fusion_pipeline(
517 graph: Graph,
518 target: FusionTarget,
519 supported: &[OpKind],
520 opts: FusionOptions,
521) -> Graph {
522 let mut opts = opts.apply_native_fk_defaults(target);
523 if opts.fusion_limits == FusionLimits::default() {
524 opts.fusion_limits = fusion_limits_for_target(target);
525 }
526 let limits = opts.fusion_limits;
527 let passes = fusion_passes_for_supported(supported, opts, target);
528 with_fusion_target(target, || {
529 with_fusion_limits(limits, || run_passes(graph, &passes, false))
530 })
531}
532
533#[cfg(test)]
534mod tests {
535 use super::*;
536 use std::sync::Mutex;
537
538 static ENV_FK_TEST_LOCK: Mutex<()> = Mutex::new(());
539
540 #[test]
541 fn cpu_pipeline_includes_attention_block() {
542 let passes = fusion_passes(FusionTarget::Cpu, FusionOptions::default());
543 assert_eq!(
544 passes.len(),
545 17,
546 "CPU default supported_ops omit Fft/WelchPeaks and mark_elementwise (unfuse-only backends skip mark)"
547 );
548 assert_eq!(passes[2].name(), "fuse_matmul_bias_act");
549 assert_eq!(passes[3].name(), "fuse_attention_block");
550 assert!(
551 passes.iter().any(|p| p.name() == "fuse_region_prologue"),
552 "default CPU pipeline should run FKL prologue fusion"
553 );
554 assert!(
555 !passes
556 .iter()
557 .any(|p| p.name() == "mark_elementwise_regions"),
558 "CPU unfuse backends should not mark elementwise regions before unfusing"
559 );
560 assert_eq!(passes.last().unwrap().name(), "dead_code_elimination");
561 }
562
563 #[test]
564 fn metal_skip_fusion_only_lowers_dot() {
565 let passes = fusion_passes(
566 FusionTarget::Metal,
567 FusionOptions {
568 skip_fusion: true,
569 ..FusionOptions::default()
570 },
571 );
572 assert_eq!(passes.len(), 2);
573 assert_eq!(passes[0].name(), "LowerControlFlow");
574 assert_eq!(passes[1].name(), "lower_dot_general");
575 }
576
577 #[test]
578 fn metal_supported_ops_omit_attention_block_fusion() {
579 let passes = fusion_passes_for_supported(
580 supported_for_target(FusionTarget::Metal),
581 FusionOptions::default(),
582 FusionTarget::Metal,
583 );
584 assert!(
585 !passes.iter().any(|p| p.name() == "fuse_attention_block"),
586 "Metal should not run FuseAttentionBlock"
587 );
588 assert!(
589 passes.iter().any(|p| p.name() == "fuse_matmul_bias_act"),
590 "Metal should fuse matmul+bias+act"
591 );
592 }
593
594 #[test]
595 fn cuda_supported_ops_fuse_matmul_bias_act() {
596 let passes = fusion_passes_for_supported(
597 supported_for_target(FusionTarget::Cuda),
598 FusionOptions::default(),
599 FusionTarget::Cuda,
600 );
601 assert!(
602 passes.iter().any(|p| p.name() == "fuse_matmul_bias_act"),
603 "CUDA should fuse matmul+bias+act when claimed"
604 );
605 assert!(
606 !passes.iter().any(|p| p.name() == "fuse_swiglu"),
607 "CUDA should not fuse SwiGLU"
608 );
609 }
610
611 #[test]
612 fn cpu_unfuses_elementwise_regions() {
613 let passes = fusion_passes_for_supported(
614 supported_for_target(FusionTarget::Cpu),
615 FusionOptions::for_cpu(),
616 FusionTarget::Cpu,
617 );
618 assert!(
619 passes
620 .iter()
621 .any(|p| p.name() == "unfuse_elementwise_regions")
622 );
623 }
624
625 #[test]
626 fn metal_unfuses_elementwise_regions_by_default() {
627 let passes = fusion_passes(FusionTarget::Metal, FusionOptions::default());
628 assert!(
629 passes
630 .iter()
631 .any(|p| p.name() == "unfuse_elementwise_regions")
632 );
633 }
634
635 #[test]
636 fn metal_default_unfuse_preserves_prologue_regions() {
637 let mut g = rlx_ir::Graph::new("t");
638 let shape_in = rlx_ir::Shape::new(&[1, 3, 8, 8], rlx_ir::DType::F32);
639 let shape_out = rlx_ir::Shape::new(&[1, 3, 16, 16], rlx_ir::DType::F32);
640 let x = g.input("x", shape_in);
641 let up = g.add_node(rlx_ir::Op::ResizeNearest2x, vec![x], shape_out.clone());
642 let r = g.add_node(
643 rlx_ir::Op::Activation(rlx_ir::op::Activation::Relu),
644 vec![up],
645 shape_out,
646 );
647 g.set_outputs(vec![r]);
648
649 let passes = fusion_passes(FusionTarget::Metal, FusionOptions::default());
650 let out = rlx_fusion::pass::run_passes(g, &passes, false);
651 assert!(out.nodes().iter().any(|n| {
652 matches!(
653 n.op,
654 rlx_ir::Op::ElementwiseRegion {
655 prologue: rlx_ir::RegionPrologue::ResizeNearest2x,
656 ..
657 }
658 )
659 }));
660 }
661
662 #[test]
663 fn fk_passes_after_elementwise_includes_batch_fusion() {
664 let opts = FusionOptions::default().apply_native_fk_defaults(FusionTarget::Tpu);
665 let passes =
666 fk_passes_after_elementwise_regions(supported_for_target(FusionTarget::Tpu), opts);
667 let names: Vec<_> = passes.iter().map(|p| p.name()).collect();
668 assert!(names.contains(&"mark_batch_slice_regions"));
669 assert!(names.contains(&"fuse_batch_preprocess"));
670 assert!(
671 !names.contains(&"decompose_fusion_regions"),
672 "TPU native FK defaults should keep batch/transform regions"
673 );
674 }
675
676 #[test]
677 fn tpu_native_fk_region_pass_policy() {
678 let _lock = ENV_FK_TEST_LOCK.lock().unwrap();
679 let default_passes = fusion_passes(FusionTarget::Tpu, FusionOptions::default());
680 assert!(
681 !default_passes
682 .iter()
683 .any(|p| p.name() == "decompose_fusion_regions"),
684 "default TPU pipeline keeps batch/transform regions via native_fk_defaults"
685 );
686
687 rlx_ir::env::set("RLX_NO_NATIVE_FK_REGIONS", "1");
688 let opt_out = fusion_passes(FusionTarget::Tpu, FusionOptions::default());
689 rlx_ir::env::unset("RLX_NO_NATIVE_FK_REGIONS");
690 assert!(
691 opt_out
692 .iter()
693 .any(|p| p.name() == "decompose_fusion_regions"),
694 "RLX_NO_NATIVE_FK_REGIONS should force decompose on TPU"
695 );
696 }
697
698 #[test]
699 fn native_fk_regions_skips_decompose_on_tpu() {
700 let passes = fusion_passes(
701 FusionTarget::Tpu,
702 FusionOptions {
703 native_fk_regions: true,
704 decompose_fusion_regions: false,
705 unfuse_elementwise_regions: false,
706 ..FusionOptions::default()
707 },
708 );
709 assert!(
710 !passes
711 .iter()
712 .any(|p| p.name() == "decompose_fusion_regions"),
713 "native_fk_regions should skip decompose on TPU when batch/transform are supported"
714 );
715 }
716
717 #[test]
718 fn native_fk_regions_skips_decompose_on_metal() {
719 let passes = fusion_passes(
720 FusionTarget::Metal,
721 FusionOptions {
722 native_fk_regions: true,
723 decompose_fusion_regions: false,
724 unfuse_elementwise_regions: false,
725 ..FusionOptions::default()
726 },
727 );
728 assert!(
729 !passes
730 .iter()
731 .any(|p| p.name() == "decompose_fusion_regions"),
732 "native_fk_regions should skip decompose when backend claims batch/transform ops"
733 );
734 }
735
736 #[test]
737 fn metal_keeps_elementwise_regions_when_requested() {
738 let passes = fusion_passes(
739 FusionTarget::Metal,
740 FusionOptions {
741 keep_elementwise_regions: true,
742 unfuse_elementwise_regions: false,
743 ..FusionOptions::default()
744 },
745 );
746 assert!(
747 !passes
748 .iter()
749 .any(|p| p.name() == "unfuse_elementwise_regions"),
750 "keep_elementwise_regions should skip unfuse pass"
751 );
752 assert!(
753 passes.iter().any(|p| p.name() == "fuse_region_prologue"),
754 "FKL prologue fusion should still run"
755 );
756 }
757
758 #[test]
759 fn metal_audio_ops_pipeline_includes_peaks_output_gate() {
760 let mut supported = supported_for_target(FusionTarget::Metal).to_vec();
761 supported.push(OpKind::Fft);
762 supported.push(OpKind::WelchPeaks);
763 let passes =
764 fusion_passes_for_supported(&supported, FusionOptions::default(), FusionTarget::Metal);
765 assert!(
766 passes
767 .iter()
768 .any(|p| p.name() == "select_peaks_only_outputs"),
769 "Metal + Fft/WelchPeaks should run IO peaks-only output gate"
770 );
771 }
772
773 #[test]
774 fn should_fuse_with_target_matches_gate() {
775 use crate::fusion_benefit::GraphIoProfile;
776 let dense = GraphIoProfile {
777 kernel_launches: 3,
778 sync_points: 0,
779 host_output_bytes: 33_554_432,
780 device_traffic_bytes: 184_549_376,
781 };
782 let fused = GraphIoProfile {
783 kernel_launches: 4,
784 sync_points: 1,
785 host_output_bytes: 1_048_576,
786 device_traffic_bytes: 219_152_384,
787 };
788 assert!(should_fuse_with_target(FusionTarget::Metal, &dense, &fused));
789 assert!(!should_fuse_with_target(FusionTarget::Wgpu, &dense, &fused));
790 }
791}