1use rlx_ir::OpKind;
24
25use crate::DeadCodeElimination;
26use crate::io_output_gate::SelectPeaksOnlyOutputs;
27use rlx_fusion::control_flow::LowerControlFlow;
28use rlx_fusion::fk_fusion::{
29 DecomposeFusionRegions, FuseBatchPreprocess, FuseRegionPrologue, MarkBatchSliceRegions,
30 MarkTransformRegions,
31};
32use rlx_fusion::fusion::{
33 FuseAttentionBlock, FuseMatMulBiasAct, FuseResidualLN, FuseResidualRmsNorm, FuseRmsNormReshape,
34 FuseSharedInputMatMul, FuseSwiGLU, FuseSwiGLUDualMatmul, FuseTransformerLayer,
35 MarkElementwiseRegions, UnfuseElementwiseRegions,
36};
37use rlx_fusion::limits::{FusionLimits, with_fusion_limits};
38use rlx_fusion::lower_dot_general::LowerDotGeneral;
39use rlx_fusion::pass::{Pass, run_passes};
40use rlx_ir::Graph;
41
42use crate::fusion_target::with_fusion_target;
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46pub enum FusionTarget {
47 Cpu,
48 Metal,
49 Mlx,
50 Wgpu,
51 Cuda,
52 Rocm,
53 Tpu,
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub struct FusionOptions {
59 pub skip_fusion: bool,
61 pub unfuse_elementwise_regions: bool,
63 pub keep_elementwise_regions: bool,
65 pub decompose_fusion_regions: bool,
67 pub fk_fusion: bool,
69 pub fuse_region_prologue: bool,
71 pub fuse_batch_preprocess: bool,
73 pub native_fk_regions: bool,
75 pub fusion_limits: FusionLimits,
77}
78
79impl Default for FusionOptions {
80 fn default() -> Self {
81 Self {
82 skip_fusion: false,
83 unfuse_elementwise_regions: false,
84 keep_elementwise_regions: false,
85 decompose_fusion_regions: false,
86 fk_fusion: true,
87 fuse_region_prologue: true,
88 fuse_batch_preprocess: true,
89 native_fk_regions: false,
90 fusion_limits: FusionLimits::default(),
91 }
92 }
93}
94
95impl FusionOptions {
96 pub fn from_metal_env() -> Self {
98 Self {
99 skip_fusion: rlx_ir::env::flag("RLX_METAL_NO_FUSION"),
100 unfuse_elementwise_regions: rlx_ir::env::flag("RLX_METAL_UNFUSE_REGIONS"),
101 keep_elementwise_regions: rlx_ir::env::flag("RLX_KEEP_ELEMENTWISE_REGIONS"),
102 decompose_fusion_regions: rlx_ir::env::flag("RLX_DECOMPOSE_FUSION_REGIONS"),
103 fk_fusion: !rlx_ir::env::flag("RLX_NO_FK_FUSION"),
104 fuse_region_prologue: if rlx_ir::env::is_unset("RLX_FUSE_REGION_PROLOGUE") {
105 true
106 } else {
107 rlx_ir::env::flag("RLX_FUSE_REGION_PROLOGUE")
108 },
109 fuse_batch_preprocess: if rlx_ir::env::is_unset("RLX_FUSE_BATCH_PREPROCESS") {
110 true
111 } else {
112 rlx_ir::env::flag("RLX_FUSE_BATCH_PREPROCESS")
113 },
114 native_fk_regions: rlx_ir::env::flag("RLX_NATIVE_FK_REGIONS"),
115 ..Self::default()
116 }
117 }
118
119 pub fn merge_env(mut self) -> Self {
121 if rlx_ir::env::flag("RLX_METAL_NO_FUSION") {
122 self.skip_fusion = true;
123 }
124 if rlx_ir::env::flag("RLX_METAL_UNFUSE_REGIONS") {
125 self.unfuse_elementwise_regions = true;
126 }
127 if rlx_ir::env::flag("RLX_KEEP_ELEMENTWISE_REGIONS") {
128 self.keep_elementwise_regions = true;
129 }
130 if rlx_ir::env::flag("RLX_DECOMPOSE_FUSION_REGIONS") {
131 self.decompose_fusion_regions = true;
132 }
133 if rlx_ir::env::flag("RLX_NO_FK_FUSION") {
134 self.fk_fusion = false;
135 }
136 if !rlx_ir::env::is_unset("RLX_FUSE_REGION_PROLOGUE") {
137 self.fuse_region_prologue = rlx_ir::env::flag("RLX_FUSE_REGION_PROLOGUE");
138 }
139 if !rlx_ir::env::is_unset("RLX_FUSE_BATCH_PREPROCESS") {
140 self.fuse_batch_preprocess = rlx_ir::env::flag("RLX_FUSE_BATCH_PREPROCESS");
141 }
142 if rlx_ir::env::flag("RLX_NATIVE_FK_REGIONS") {
143 self.native_fk_regions = true;
144 }
145 if rlx_ir::env::flag("RLX_NO_NATIVE_FK_REGIONS") {
146 self.native_fk_regions = false;
147 }
148 self
149 }
150
151 pub fn apply_native_fk_defaults(mut self, target: FusionTarget) -> Self {
153 if rlx_ir::env::flag("RLX_NO_NATIVE_FK_REGIONS") {
154 self.native_fk_regions = false;
155 return self;
156 }
157 if self.native_fk_regions || rlx_ir::env::flag("RLX_NATIVE_FK_REGIONS") {
158 self.native_fk_regions = true;
159 return self;
160 }
161 if matches!(
162 target,
163 FusionTarget::Metal
164 | FusionTarget::Cuda
165 | FusionTarget::Rocm
166 | FusionTarget::Wgpu
167 | FusionTarget::Mlx
168 | FusionTarget::Tpu
169 ) {
170 self.native_fk_regions = true;
171 }
172 self
173 }
174
175 pub fn for_cpu() -> Self {
177 Self {
178 unfuse_elementwise_regions: true,
179 fusion_limits: FusionLimits::UNBOUNDED,
180 ..Self::default()
181 }
182 }
183
184 pub fn for_metal() -> Self {
187 let mut opts = Self::from_metal_env();
188 opts.unfuse_elementwise_regions = true;
189 opts
190 }
191
192 pub fn for_wgpu() -> Self {
195 let keep = rlx_ir::env::flag("RLX_KEEP_ELEMENTWISE_REGIONS");
196 Self {
197 unfuse_elementwise_regions: !keep,
198 keep_elementwise_regions: keep,
199 ..Self::default()
200 }
201 }
202}
203
204pub fn fusion_limits_for_target(target: FusionTarget) -> FusionLimits {
206 match target {
207 FusionTarget::Cpu => FusionLimits::UNBOUNDED,
208 FusionTarget::Tpu => FusionLimits {
209 max_elementwise_steps: 32,
210 max_elementwise_inputs: 16,
211 },
212 _ => FusionLimits::GPU_NATIVE,
213 }
214}
215
216#[inline]
218pub fn supports_op(supported: &[OpKind], kind: OpKind) -> bool {
219 supported.is_empty() || supported.contains(&kind)
220}
221
222pub fn fusion_passes_for_supported(
231 supported: &[OpKind],
232 opts: FusionOptions,
233 target: FusionTarget,
234) -> Vec<&'static dyn Pass> {
235 let opts = opts.apply_native_fk_defaults(target);
236 if opts.skip_fusion {
237 return vec![&LowerControlFlow, &LowerDotGeneral];
238 }
239
240 let mut passes: Vec<&'static dyn Pass> = vec![&LowerControlFlow, &LowerDotGeneral];
241
242 if supports_op(supported, OpKind::FusedMatMulBiasAct) {
248 passes.push(&FuseMatMulBiasAct);
249 }
250 if supports_op(supported, OpKind::FusedAttentionBlock) {
258 passes.push(&FuseAttentionBlock);
259 }
260 if supports_op(supported, OpKind::FusedResidualLN) {
264 passes.push(&FuseResidualLN);
265 }
266 if supports_op(supported, OpKind::FusedResidualRmsNorm) {
267 passes.push(&FuseResidualRmsNorm);
268 }
269 passes.push(&FuseRmsNormReshape);
270
271 if rlx_ir::env::flag("RLX_ENABLE_FUSE_TRANSFORMER_LAYER")
280 && supports_op(supported, OpKind::FusedTransformerLayer)
281 && supports_op(supported, OpKind::FusedAttentionBlock)
282 {
283 passes.push(&FuseTransformerLayer);
284 }
285
286 if supports_op(supported, OpKind::FusedSwiGLU) {
287 passes.push(&FuseSwiGLUDualMatmul);
288 }
289 if supports_op(supported, OpKind::MatMul) {
290 passes.push(&FuseSharedInputMatMul);
291 }
292 if supports_op(supported, OpKind::FusedSwiGLU) {
293 passes.push(&FuseSwiGLU);
294 }
295
296 passes.push(&MarkElementwiseRegions);
299 if opts.fk_fusion {
300 passes.push(&MarkBatchSliceRegions);
301 passes.push(&MarkTransformRegions);
302 if opts.fuse_region_prologue {
303 passes.push(&FuseRegionPrologue);
304 }
305 if opts.fuse_batch_preprocess {
306 passes.push(&FuseBatchPreprocess);
307 }
308 }
309 let backend_native_fk = supports_op(supported, OpKind::TransformRegion)
310 && supports_op(supported, OpKind::BatchElementwiseRegion);
311 let keep_native_fk = opts.native_fk_regions && backend_native_fk;
312 if opts.decompose_fusion_regions || !keep_native_fk {
313 passes.push(&DecomposeFusionRegions);
314 }
315 let keep_regions =
316 supports_op(supported, OpKind::ElementwiseRegion) && !opts.unfuse_elementwise_regions;
317 if !keep_regions {
318 let unfuse = if matches!(target, FusionTarget::Cpu) {
319 &UnfuseElementwiseRegions::FOR_CPU
320 } else {
321 &UnfuseElementwiseRegions::FOR_GPU
322 };
323 passes.push(unfuse);
324 }
325
326 if supports_op(supported, OpKind::Fft) && supports_op(supported, OpKind::WelchPeaks) {
327 passes.push(&SelectPeaksOnlyOutputs);
328 }
329
330 finish_pipeline(passes)
331}
332
333pub fn fk_passes_after_elementwise_regions(
335 supported: &[OpKind],
336 opts: FusionOptions,
337) -> Vec<&'static dyn Pass> {
338 let mut passes: Vec<&'static dyn Pass> = Vec::new();
339 if !opts.fk_fusion {
340 let backend_native_fk = supports_op(supported, OpKind::TransformRegion)
341 && supports_op(supported, OpKind::BatchElementwiseRegion);
342 let keep_native_fk = opts.native_fk_regions && backend_native_fk;
343 if opts.decompose_fusion_regions || !keep_native_fk {
344 passes.push(&DecomposeFusionRegions);
345 }
346 return finish_pipeline(passes);
347 }
348 passes.push(&MarkBatchSliceRegions);
349 passes.push(&MarkTransformRegions);
350 if opts.fuse_region_prologue {
351 passes.push(&FuseRegionPrologue);
352 }
353 if opts.fuse_batch_preprocess {
354 passes.push(&FuseBatchPreprocess);
355 }
356 let backend_native_fk = supports_op(supported, OpKind::TransformRegion)
357 && supports_op(supported, OpKind::BatchElementwiseRegion);
358 let keep_native_fk = opts.native_fk_regions && backend_native_fk;
359 if opts.decompose_fusion_regions || !keep_native_fk {
360 passes.push(&DecomposeFusionRegions);
361 }
362 finish_pipeline(passes)
363}
364
365pub fn should_fuse_with_target(
367 target: FusionTarget,
368 before: &crate::fusion_benefit::GraphIoProfile,
369 after: &crate::fusion_benefit::GraphIoProfile,
370) -> bool {
371 io_fusion_gate_for_target(target).should_fuse(before, after)
372}
373
374pub fn io_fusion_gate_for_target(target: FusionTarget) -> crate::fusion_benefit::IoFusionGate {
376 use crate::fusion_benefit::IoFusionGate;
377 match target {
378 FusionTarget::Metal | FusionTarget::Mlx => IoFusionGate {
379 dispatch_ns: 500.0,
380 roundtrip_ns: 5_000.0,
381 memory_bw: 200.0,
382 host_readback_bw: 200.0,
383 unified_memory: true,
384 host_thunk_penalty_ns: 2_000_000.0,
385 min_gain_ns: 1_000.0,
386 },
387 FusionTarget::Cuda | FusionTarget::Rocm => IoFusionGate {
388 dispatch_ns: 2_000.0,
389 roundtrip_ns: 20_000.0,
390 memory_bw: 800.0,
391 host_readback_bw: 50.0,
392 unified_memory: false,
393 host_thunk_penalty_ns: 15_000_000.0,
394 min_gain_ns: 5_000.0,
395 },
396 FusionTarget::Wgpu | FusionTarget::Tpu => IoFusionGate {
397 dispatch_ns: 3_000.0,
398 roundtrip_ns: 30_000.0,
399 memory_bw: 100.0,
400 host_readback_bw: 40.0,
401 unified_memory: false,
402 host_thunk_penalty_ns: 25_000_000.0,
403 min_gain_ns: 10_000.0,
404 },
405 FusionTarget::Cpu => IoFusionGate {
406 dispatch_ns: 50.0,
407 roundtrip_ns: 0.0,
408 memory_bw: 50.0,
409 host_readback_bw: 50.0,
410 unified_memory: true,
411 host_thunk_penalty_ns: 0.0,
412 min_gain_ns: 0.0,
413 },
414 }
415}
416
417pub fn fusion_passes(target: FusionTarget, opts: FusionOptions) -> Vec<&'static dyn Pass> {
419 let mut opts = opts;
420 if !opts.keep_elementwise_regions
424 && matches!(target, FusionTarget::Cpu | FusionTarget::Metal)
425 && !opts.unfuse_elementwise_regions
426 {
427 opts.unfuse_elementwise_regions = true;
428 }
429 if opts.fusion_limits == FusionLimits::default() {
430 opts.fusion_limits = fusion_limits_for_target(target);
431 }
432 opts = opts.apply_native_fk_defaults(target);
433 fusion_passes_for_supported(supported_for_target(target), opts, target)
434}
435
436pub fn supported_for_target(target: FusionTarget) -> &'static [OpKind] {
440 use OpKind::*;
441 match target {
442 FusionTarget::Cpu => &[
443 MatMul,
444 DotGeneral,
445 ElementwiseRegion,
446 FusedSwiGLU,
447 FusedMatMulBiasAct,
448 FusedResidualLN,
449 FusedResidualRmsNorm,
450 FusedAttentionBlock,
451 ],
452 FusionTarget::Metal => &[
453 MatMul,
454 DotGeneral,
455 ElementwiseRegion,
456 TransformRegion,
457 BatchElementwiseRegion,
458 FusedSwiGLU,
459 FusedMatMulBiasAct,
460 FusedResidualLN,
461 FusedResidualRmsNorm,
462 ],
463 FusionTarget::Mlx => &[
464 MatMul,
465 DotGeneral,
466 ElementwiseRegion,
467 TransformRegion,
468 BatchElementwiseRegion,
469 FusedSwiGLU,
470 FusedMatMulBiasAct,
471 FusedResidualLN,
472 FusedResidualRmsNorm,
473 ],
474 FusionTarget::Wgpu => &[
475 MatMul,
476 ElementwiseRegion,
477 TransformRegion,
478 BatchElementwiseRegion,
479 FusedSwiGLU,
480 FusedMatMulBiasAct,
481 FusedResidualLN,
482 FusedResidualRmsNorm,
483 FusedAttentionBlock,
484 FusedTransformerLayer,
485 ],
486 FusionTarget::Cuda | FusionTarget::Rocm => &[
487 MatMul,
488 DotGeneral,
489 ElementwiseRegion,
490 TransformRegion,
491 BatchElementwiseRegion,
492 FusedMatMulBiasAct,
493 FusedResidualLN,
494 FusedResidualRmsNorm,
495 ],
496 FusionTarget::Tpu => &[
497 MatMul,
498 ElementwiseRegion,
499 TransformRegion,
500 BatchElementwiseRegion,
501 FusedMatMulBiasAct,
502 FusedResidualLN,
503 ],
504 }
505}
506
507fn finish_pipeline(mut passes: Vec<&'static dyn Pass>) -> Vec<&'static dyn Pass> {
508 passes.push(&DeadCodeElimination);
509 passes
510}
511
512pub fn run_fusion_pipeline(
514 graph: Graph,
515 target: FusionTarget,
516 supported: &[OpKind],
517 opts: FusionOptions,
518) -> Graph {
519 let mut opts = opts.apply_native_fk_defaults(target);
520 if opts.fusion_limits == FusionLimits::default() {
521 opts.fusion_limits = fusion_limits_for_target(target);
522 }
523 let limits = opts.fusion_limits;
524 let passes = fusion_passes_for_supported(supported, opts, target);
525 with_fusion_target(target, || {
526 with_fusion_limits(limits, || run_passes(graph, &passes, false))
527 })
528}
529
530#[cfg(test)]
531mod tests {
532 use super::*;
533 use std::sync::Mutex;
534
535 static ENV_FK_TEST_LOCK: Mutex<()> = Mutex::new(());
536
537 #[test]
538 fn cpu_pipeline_includes_attention_block() {
539 let passes = fusion_passes(FusionTarget::Cpu, FusionOptions::default());
540 assert_eq!(
541 passes.len(),
542 18,
543 "CPU default supported_ops omit Fft/WelchPeaks"
544 );
545 assert_eq!(passes[2].name(), "fuse_matmul_bias_act");
546 assert_eq!(passes[3].name(), "fuse_attention_block");
547 assert!(
548 passes.iter().any(|p| p.name() == "fuse_region_prologue"),
549 "default CPU pipeline should run FKL prologue fusion"
550 );
551 assert_eq!(passes.last().unwrap().name(), "dead_code_elimination");
552 }
553
554 #[test]
555 fn metal_skip_fusion_only_lowers_dot() {
556 let passes = fusion_passes(
557 FusionTarget::Metal,
558 FusionOptions {
559 skip_fusion: true,
560 ..FusionOptions::default()
561 },
562 );
563 assert_eq!(passes.len(), 2);
564 assert_eq!(passes[0].name(), "LowerControlFlow");
565 assert_eq!(passes[1].name(), "lower_dot_general");
566 }
567
568 #[test]
569 fn metal_supported_ops_omit_attention_block_fusion() {
570 let passes = fusion_passes_for_supported(
571 supported_for_target(FusionTarget::Metal),
572 FusionOptions::default(),
573 FusionTarget::Metal,
574 );
575 assert!(
576 !passes.iter().any(|p| p.name() == "fuse_attention_block"),
577 "Metal should not run FuseAttentionBlock"
578 );
579 assert!(
580 passes.iter().any(|p| p.name() == "fuse_matmul_bias_act"),
581 "Metal should fuse matmul+bias+act"
582 );
583 }
584
585 #[test]
586 fn cuda_supported_ops_fuse_matmul_bias_act() {
587 let passes = fusion_passes_for_supported(
588 supported_for_target(FusionTarget::Cuda),
589 FusionOptions::default(),
590 FusionTarget::Cuda,
591 );
592 assert!(
593 passes.iter().any(|p| p.name() == "fuse_matmul_bias_act"),
594 "CUDA should fuse matmul+bias+act when claimed"
595 );
596 assert!(
597 !passes.iter().any(|p| p.name() == "fuse_swiglu"),
598 "CUDA should not fuse SwiGLU"
599 );
600 }
601
602 #[test]
603 fn cpu_unfuses_elementwise_regions() {
604 let passes = fusion_passes_for_supported(
605 supported_for_target(FusionTarget::Cpu),
606 FusionOptions::for_cpu(),
607 FusionTarget::Cpu,
608 );
609 assert!(
610 passes
611 .iter()
612 .any(|p| p.name() == "unfuse_elementwise_regions")
613 );
614 }
615
616 #[test]
617 fn metal_unfuses_elementwise_regions_by_default() {
618 let passes = fusion_passes(FusionTarget::Metal, FusionOptions::default());
619 assert!(
620 passes
621 .iter()
622 .any(|p| p.name() == "unfuse_elementwise_regions")
623 );
624 }
625
626 #[test]
627 fn metal_default_unfuse_preserves_prologue_regions() {
628 let mut g = rlx_ir::Graph::new("t");
629 let shape_in = rlx_ir::Shape::new(&[1, 3, 8, 8], rlx_ir::DType::F32);
630 let shape_out = rlx_ir::Shape::new(&[1, 3, 16, 16], rlx_ir::DType::F32);
631 let x = g.input("x", shape_in);
632 let up = g.add_node(rlx_ir::Op::ResizeNearest2x, vec![x], shape_out.clone());
633 let r = g.add_node(
634 rlx_ir::Op::Activation(rlx_ir::op::Activation::Relu),
635 vec![up],
636 shape_out,
637 );
638 g.set_outputs(vec![r]);
639
640 let passes = fusion_passes(FusionTarget::Metal, FusionOptions::default());
641 let out = rlx_fusion::pass::run_passes(g, &passes, false);
642 assert!(out.nodes().iter().any(|n| {
643 matches!(
644 n.op,
645 rlx_ir::Op::ElementwiseRegion {
646 prologue: rlx_ir::RegionPrologue::ResizeNearest2x,
647 ..
648 }
649 )
650 }));
651 }
652
653 #[test]
654 fn fk_passes_after_elementwise_includes_batch_fusion() {
655 let opts = FusionOptions::default().apply_native_fk_defaults(FusionTarget::Tpu);
656 let passes =
657 fk_passes_after_elementwise_regions(supported_for_target(FusionTarget::Tpu), opts);
658 let names: Vec<_> = passes.iter().map(|p| p.name()).collect();
659 assert!(names.contains(&"mark_batch_slice_regions"));
660 assert!(names.contains(&"fuse_batch_preprocess"));
661 assert!(
662 !names.contains(&"decompose_fusion_regions"),
663 "TPU native FK defaults should keep batch/transform regions"
664 );
665 }
666
667 #[test]
668 fn tpu_native_fk_region_pass_policy() {
669 let _lock = ENV_FK_TEST_LOCK.lock().unwrap();
670 let default_passes = fusion_passes(FusionTarget::Tpu, FusionOptions::default());
671 assert!(
672 !default_passes
673 .iter()
674 .any(|p| p.name() == "decompose_fusion_regions"),
675 "default TPU pipeline keeps batch/transform regions via native_fk_defaults"
676 );
677
678 rlx_ir::env::set("RLX_NO_NATIVE_FK_REGIONS", "1");
679 let opt_out = fusion_passes(FusionTarget::Tpu, FusionOptions::default());
680 rlx_ir::env::unset("RLX_NO_NATIVE_FK_REGIONS");
681 assert!(
682 opt_out
683 .iter()
684 .any(|p| p.name() == "decompose_fusion_regions"),
685 "RLX_NO_NATIVE_FK_REGIONS should force decompose on TPU"
686 );
687 }
688
689 #[test]
690 fn native_fk_regions_skips_decompose_on_tpu() {
691 let passes = fusion_passes(
692 FusionTarget::Tpu,
693 FusionOptions {
694 native_fk_regions: true,
695 decompose_fusion_regions: false,
696 unfuse_elementwise_regions: false,
697 ..FusionOptions::default()
698 },
699 );
700 assert!(
701 !passes
702 .iter()
703 .any(|p| p.name() == "decompose_fusion_regions"),
704 "native_fk_regions should skip decompose on TPU when batch/transform are supported"
705 );
706 }
707
708 #[test]
709 fn native_fk_regions_skips_decompose_on_metal() {
710 let passes = fusion_passes(
711 FusionTarget::Metal,
712 FusionOptions {
713 native_fk_regions: true,
714 decompose_fusion_regions: false,
715 unfuse_elementwise_regions: false,
716 ..FusionOptions::default()
717 },
718 );
719 assert!(
720 !passes
721 .iter()
722 .any(|p| p.name() == "decompose_fusion_regions"),
723 "native_fk_regions should skip decompose when backend claims batch/transform ops"
724 );
725 }
726
727 #[test]
728 fn metal_keeps_elementwise_regions_when_requested() {
729 let passes = fusion_passes(
730 FusionTarget::Metal,
731 FusionOptions {
732 keep_elementwise_regions: true,
733 unfuse_elementwise_regions: false,
734 ..FusionOptions::default()
735 },
736 );
737 assert!(
738 !passes
739 .iter()
740 .any(|p| p.name() == "unfuse_elementwise_regions"),
741 "keep_elementwise_regions should skip unfuse pass"
742 );
743 assert!(
744 passes.iter().any(|p| p.name() == "fuse_region_prologue"),
745 "FKL prologue fusion should still run"
746 );
747 }
748
749 #[test]
750 fn metal_audio_ops_pipeline_includes_peaks_output_gate() {
751 let mut supported = supported_for_target(FusionTarget::Metal).to_vec();
752 supported.push(OpKind::Fft);
753 supported.push(OpKind::WelchPeaks);
754 let passes =
755 fusion_passes_for_supported(&supported, FusionOptions::default(), FusionTarget::Metal);
756 assert!(
757 passes
758 .iter()
759 .any(|p| p.name() == "select_peaks_only_outputs"),
760 "Metal + Fft/WelchPeaks should run IO peaks-only output gate"
761 );
762 }
763
764 #[test]
765 fn should_fuse_with_target_matches_gate() {
766 use crate::fusion_benefit::GraphIoProfile;
767 let dense = GraphIoProfile {
768 kernel_launches: 3,
769 sync_points: 0,
770 host_output_bytes: 33_554_432,
771 device_traffic_bytes: 184_549_376,
772 };
773 let fused = GraphIoProfile {
774 kernel_launches: 4,
775 sync_points: 1,
776 host_output_bytes: 1_048_576,
777 device_traffic_bytes: 219_152_384,
778 };
779 assert!(should_fuse_with_target(FusionTarget::Metal, &dense, &fused));
780 assert!(!should_fuse_with_target(FusionTarget::Wgpu, &dense, &fused));
781 }
782}