Skip to main content

rlx_compile/
fusion_pipeline.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Centralized fusion pass pipelines per backend target.
17//!
18//! [`fusion_passes_for_supported`] selects passes from a backend's
19//! [`rlx_ir::OpKind`] claim set so fusion never emits fused ops the
20//! target cannot lower. [`fusion_passes`] keeps the legacy
21//! [`FusionTarget`] entry point and delegates to the same selector.
22
23use rlx_ir::OpKind;
24
25use crate::DeadCodeElimination;
26use crate::io_output_gate::SelectPeaksOnlyOutputs;
27use rlx_fusion::control_flow::LowerControlFlow;
28use rlx_fusion::fk_fusion::{
29    DecomposeFusionRegions, FuseBatchPreprocess, FuseRegionPrologue, MarkBatchSliceRegions,
30    MarkTransformRegions,
31};
32use rlx_fusion::fusion::{
33    FuseAttentionBlock, FuseMatMulBiasAct, FuseResidualLN, FuseResidualRmsNorm, FuseRmsNormReshape,
34    FuseSharedInputMatMul, FuseSwiGLU, FuseSwiGLUDualMatmul, FuseTransformerLayer,
35    MarkElementwiseRegions, UnfuseElementwiseRegions,
36};
37use rlx_fusion::limits::{FusionLimits, with_fusion_limits};
38use rlx_fusion::lower_dot_general::LowerDotGeneral;
39use rlx_fusion::pass::{Pass, run_passes};
40use rlx_ir::Graph;
41
42use crate::fusion_target::with_fusion_target;
43
44/// Compile target that selects a fusion pipeline.
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46pub enum FusionTarget {
47    Cpu,
48    Metal,
49    Mlx,
50    Wgpu,
51    Cuda,
52    Rocm,
53    Tpu,
54}
55
56/// Per-target fusion toggles (env-driven on Metal today).
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub struct FusionOptions {
59    /// Skip all pattern fusions (Metal: `RLX_METAL_NO_FUSION`).
60    pub skip_fusion: bool,
61    /// Break `ElementwiseRegion` back into primitives after marking.
62    pub unfuse_elementwise_regions: bool,
63    /// Keep fused `ElementwiseRegion` through lowering (env: `RLX_KEEP_ELEMENTWISE_REGIONS`).
64    pub keep_elementwise_regions: bool,
65    /// Decompose FKL-style transform / batch regions before backend lowering.
66    pub decompose_fusion_regions: bool,
67    /// Run FKL passes (`MarkTransformRegions`, prologue, batch). Env: `RLX_NO_FK_FUSION=1` disables.
68    pub fk_fusion: bool,
69    /// Fold `ResizeNearest2x` into `ElementwiseRegion` prologue. Env: `RLX_FUSE_REGION_PROLOGUE=0` disables.
70    pub fuse_region_prologue: bool,
71    /// Merge parallel region slices into `BatchElementwiseRegion`. Env: `RLX_FUSE_BATCH_PREPROCESS=0` disables.
72    pub fuse_batch_preprocess: bool,
73    /// Keep `TransformRegion` / `BatchElementwiseRegion` in MIR for native lowering. Env: `RLX_NATIVE_FK_REGIONS=1`.
74    pub native_fk_regions: bool,
75    /// Caps for fused elementwise chains (encoder / scratch limits).
76    pub fusion_limits: FusionLimits,
77}
78
79impl Default for FusionOptions {
80    fn default() -> Self {
81        Self {
82            skip_fusion: false,
83            unfuse_elementwise_regions: false,
84            keep_elementwise_regions: false,
85            decompose_fusion_regions: false,
86            fk_fusion: true,
87            fuse_region_prologue: true,
88            fuse_batch_preprocess: true,
89            native_fk_regions: false,
90            fusion_limits: FusionLimits::default(),
91        }
92    }
93}
94
95impl FusionOptions {
96    /// Read Metal-specific env overrides.
97    pub fn from_metal_env() -> Self {
98        Self {
99            skip_fusion: rlx_ir::env::flag("RLX_METAL_NO_FUSION"),
100            unfuse_elementwise_regions: rlx_ir::env::flag("RLX_METAL_UNFUSE_REGIONS"),
101            keep_elementwise_regions: rlx_ir::env::flag("RLX_KEEP_ELEMENTWISE_REGIONS"),
102            decompose_fusion_regions: rlx_ir::env::flag("RLX_DECOMPOSE_FUSION_REGIONS"),
103            fk_fusion: !rlx_ir::env::flag("RLX_NO_FK_FUSION"),
104            fuse_region_prologue: if rlx_ir::env::is_unset("RLX_FUSE_REGION_PROLOGUE") {
105                true
106            } else {
107                rlx_ir::env::flag("RLX_FUSE_REGION_PROLOGUE")
108            },
109            fuse_batch_preprocess: if rlx_ir::env::is_unset("RLX_FUSE_BATCH_PREPROCESS") {
110                true
111            } else {
112                rlx_ir::env::flag("RLX_FUSE_BATCH_PREPROCESS")
113            },
114            native_fk_regions: rlx_ir::env::flag("RLX_NATIVE_FK_REGIONS"),
115            ..Self::default()
116        }
117    }
118
119    /// Merge session options with compile-time env overrides.
120    pub fn merge_env(mut self) -> Self {
121        if rlx_ir::env::flag("RLX_METAL_NO_FUSION") {
122            self.skip_fusion = true;
123        }
124        if rlx_ir::env::flag("RLX_METAL_UNFUSE_REGIONS") {
125            self.unfuse_elementwise_regions = true;
126        }
127        if rlx_ir::env::flag("RLX_KEEP_ELEMENTWISE_REGIONS") {
128            self.keep_elementwise_regions = true;
129        }
130        if rlx_ir::env::flag("RLX_DECOMPOSE_FUSION_REGIONS") {
131            self.decompose_fusion_regions = true;
132        }
133        if rlx_ir::env::flag("RLX_NO_FK_FUSION") {
134            self.fk_fusion = false;
135        }
136        if !rlx_ir::env::is_unset("RLX_FUSE_REGION_PROLOGUE") {
137            self.fuse_region_prologue = rlx_ir::env::flag("RLX_FUSE_REGION_PROLOGUE");
138        }
139        if !rlx_ir::env::is_unset("RLX_FUSE_BATCH_PREPROCESS") {
140            self.fuse_batch_preprocess = rlx_ir::env::flag("RLX_FUSE_BATCH_PREPROCESS");
141        }
142        if rlx_ir::env::flag("RLX_NATIVE_FK_REGIONS") {
143            self.native_fk_regions = true;
144        }
145        if rlx_ir::env::flag("RLX_NO_NATIVE_FK_REGIONS") {
146            self.native_fk_regions = false;
147        }
148        self
149    }
150
151    /// GPU-class targets keep native FKL regions unless opted out.
152    pub fn apply_native_fk_defaults(mut self, target: FusionTarget) -> Self {
153        if rlx_ir::env::flag("RLX_NO_NATIVE_FK_REGIONS") {
154            self.native_fk_regions = false;
155            return self;
156        }
157        if self.native_fk_regions || rlx_ir::env::flag("RLX_NATIVE_FK_REGIONS") {
158            self.native_fk_regions = true;
159            return self;
160        }
161        if matches!(
162            target,
163            FusionTarget::Metal
164                | FusionTarget::Cuda
165                | FusionTarget::Rocm
166                | FusionTarget::Wgpu
167                | FusionTarget::Mlx
168                | FusionTarget::Tpu
169        ) {
170            self.native_fk_regions = true;
171        }
172        self
173    }
174
175    /// CPU executes element-wise chains as per-op thunks — mark then unfuse.
176    pub fn for_cpu() -> Self {
177        Self {
178            unfuse_elementwise_regions: true,
179            fusion_limits: FusionLimits::UNBOUNDED,
180            ..Self::default()
181        }
182    }
183
184    /// Metal keeps RMSNorm / matmul fusions but unfuses `ElementwiseRegion`
185    /// (fused MSL mis-lowers long chains on deep transformer graphs).
186    pub fn for_metal() -> Self {
187        let mut opts = Self::from_metal_env();
188        opts.unfuse_elementwise_regions = true;
189        opts
190    }
191
192    /// wgpu region kernel only supports trailing/scalar broadcast via
193    /// modulus — unfuse so LegalizeBroadcast Expand + Binary run separately.
194    pub fn for_wgpu() -> Self {
195        let keep = rlx_ir::env::flag("RLX_KEEP_ELEMENTWISE_REGIONS");
196        Self {
197            unfuse_elementwise_regions: !keep,
198            keep_elementwise_regions: keep,
199            ..Self::default()
200        }
201    }
202}
203
204/// Elementwise-region caps for `target` (matches GPU kernel encoders).
205pub fn fusion_limits_for_target(target: FusionTarget) -> FusionLimits {
206    match target {
207        FusionTarget::Cpu => FusionLimits::UNBOUNDED,
208        FusionTarget::Tpu => FusionLimits {
209            max_elementwise_steps: 32,
210            max_elementwise_inputs: 16,
211        },
212        _ => FusionLimits::GPU_NATIVE,
213    }
214}
215
216/// True when `supported` is empty (no claim) or contains `kind`.
217#[inline]
218pub fn supports_op(supported: &[OpKind], kind: OpKind) -> bool {
219    supported.is_empty() || supported.contains(&kind)
220}
221
222/// Return the ordered fusion passes allowed for `supported`.
223///
224/// When `supported` is empty every fusion pass runs (legacy "accept
225/// all" backends). When non-empty, each pattern fusion pass is
226/// included only if the backend claims the fused [`OpKind`] it
227/// emits. Lowering passes (`LowerControlFlow`, `LowerDotGeneral`) and
228/// `FuseRmsNormReshape` (topology-only) always run unless
229/// `skip_fusion` is set.
230pub fn fusion_passes_for_supported(
231    supported: &[OpKind],
232    opts: FusionOptions,
233    target: FusionTarget,
234) -> Vec<&'static dyn Pass> {
235    let opts = opts.apply_native_fk_defaults(target);
236    if opts.skip_fusion {
237        return vec![&LowerControlFlow, &LowerDotGeneral];
238    }
239
240    let mut passes: Vec<&'static dyn Pass> = vec![&LowerControlFlow, &LowerDotGeneral];
241
242    // ORDER: FuseMatMulBiasAct first, then FuseAttentionBlock. The block-level
243    // pass matches the post-fusion shape
244    //   FusedMatMulBiasAct(qkv) → narrow×3 → Attention → FusedMatMulBiasAct(out)
245    // which is the pattern BERT-family encoders actually present after the
246    // per-layer matmul+bias fusion has collapsed Q, K, V, and out projections.
247    if supports_op(supported, OpKind::FusedMatMulBiasAct) {
248        passes.push(&FuseMatMulBiasAct);
249    }
250    // Block-level fusion: `Op::FusedAttentionBlock`. All backends that claim
251    // this op now produce parity-correct output (the MLX
252    // `Op::FusedAttentionBlock` lowering at `rlx-mlx/src/lower.rs:1689`
253    // historically diverged on `MaskKind::Custom` BERT masks because it
254    // bypassed the binary→additive conversion and the contiguous
255    // materialization the unfused `Op::Attention` path applies — fixed
256    // alongside this pass landing).
257    if supports_op(supported, OpKind::FusedAttentionBlock) {
258        passes.push(&FuseAttentionBlock);
259    }
260    // FuseResidualLN must run BEFORE FuseTransformerLayer: the layer-level
261    // pass matches `FAB → FusedResidualLN → FMBA(GeLU) → FMBA → FusedResidualLN`
262    // and needs the residual+LN ops already collapsed.
263    if supports_op(supported, OpKind::FusedResidualLN) {
264        passes.push(&FuseResidualLN);
265    }
266    if supports_op(supported, OpKind::FusedResidualRmsNorm) {
267        passes.push(&FuseResidualRmsNorm);
268    }
269    passes.push(&FuseRmsNormReshape);
270
271    // Layer-level fusion runs AFTER FuseResidualLN so it can match the
272    // post-fusion shape `FAB → FusedResidualLN → FMBA(GeLU) → FMBA →
273    // FusedResidualLN`. Opt-in via `RLX_ENABLE_FUSE_TRANSFORMER_LAYER`
274    // because backend perf wins are uneven: WGPU un-fuses with no
275    // dispatch reduction; MLX's lowering is correct (per the FAB fix
276    // above) but the MLX `compile()` already collapses sub-ops, so the
277    // extra IR-level fusion doesn't beat the natural pipeline. The pass
278    // exists for backends planning a monolithic transformer-layer kernel.
279    if rlx_ir::env::flag("RLX_ENABLE_FUSE_TRANSFORMER_LAYER")
280        && supports_op(supported, OpKind::FusedTransformerLayer)
281        && supports_op(supported, OpKind::FusedAttentionBlock)
282    {
283        passes.push(&FuseTransformerLayer);
284    }
285
286    if supports_op(supported, OpKind::FusedSwiGLU) {
287        passes.push(&FuseSwiGLUDualMatmul);
288    }
289    if supports_op(supported, OpKind::MatMul) {
290        passes.push(&FuseSharedInputMatMul);
291    }
292    if supports_op(supported, OpKind::FusedSwiGLU) {
293        passes.push(&FuseSwiGLU);
294    }
295
296    // Mark eligible element-wise chains. Backends that don't lower
297    // ElementwiseRegion natively unfuse immediately afterward.
298    passes.push(&MarkElementwiseRegions);
299    if opts.fk_fusion {
300        passes.push(&MarkBatchSliceRegions);
301        passes.push(&MarkTransformRegions);
302        if opts.fuse_region_prologue {
303            passes.push(&FuseRegionPrologue);
304        }
305        if opts.fuse_batch_preprocess {
306            passes.push(&FuseBatchPreprocess);
307        }
308    }
309    let backend_native_fk = supports_op(supported, OpKind::TransformRegion)
310        && supports_op(supported, OpKind::BatchElementwiseRegion);
311    let keep_native_fk = opts.native_fk_regions && backend_native_fk;
312    if opts.decompose_fusion_regions || !keep_native_fk {
313        passes.push(&DecomposeFusionRegions);
314    }
315    let keep_regions =
316        supports_op(supported, OpKind::ElementwiseRegion) && !opts.unfuse_elementwise_regions;
317    if !keep_regions {
318        let unfuse = if matches!(target, FusionTarget::Cpu) {
319            &UnfuseElementwiseRegions::FOR_CPU
320        } else {
321            &UnfuseElementwiseRegions::FOR_GPU
322        };
323        passes.push(unfuse);
324    }
325
326    if supports_op(supported, OpKind::Fft) && supports_op(supported, OpKind::WelchPeaks) {
327        passes.push(&SelectPeaksOnlyOutputs);
328    }
329
330    finish_pipeline(passes)
331}
332
333/// FKL passes to run after [`MarkElementwiseRegions`] (e.g. `TpuExecutable::compile`).
334pub fn fk_passes_after_elementwise_regions(
335    supported: &[OpKind],
336    opts: FusionOptions,
337) -> Vec<&'static dyn Pass> {
338    let mut passes: Vec<&'static dyn Pass> = Vec::new();
339    if !opts.fk_fusion {
340        let backend_native_fk = supports_op(supported, OpKind::TransformRegion)
341            && supports_op(supported, OpKind::BatchElementwiseRegion);
342        let keep_native_fk = opts.native_fk_regions && backend_native_fk;
343        if opts.decompose_fusion_regions || !keep_native_fk {
344            passes.push(&DecomposeFusionRegions);
345        }
346        return finish_pipeline(passes);
347    }
348    passes.push(&MarkBatchSliceRegions);
349    passes.push(&MarkTransformRegions);
350    if opts.fuse_region_prologue {
351        passes.push(&FuseRegionPrologue);
352    }
353    if opts.fuse_batch_preprocess {
354        passes.push(&FuseBatchPreprocess);
355    }
356    let backend_native_fk = supports_op(supported, OpKind::TransformRegion)
357        && supports_op(supported, OpKind::BatchElementwiseRegion);
358    let keep_native_fk = opts.native_fk_regions && backend_native_fk;
359    if opts.decompose_fusion_regions || !keep_native_fk {
360        passes.push(&DecomposeFusionRegions);
361    }
362    finish_pipeline(passes)
363}
364
365/// IO gate decision for a rewrite on `target` (convenience for compile passes / model crates).
366pub fn should_fuse_with_target(
367    target: FusionTarget,
368    before: &crate::fusion_benefit::GraphIoProfile,
369    after: &crate::fusion_benefit::GraphIoProfile,
370) -> bool {
371    io_fusion_gate_for_target(target).should_fuse(before, after)
372}
373
374/// Phase 3 — IO-aware gate defaults for fusion rewrites on `target`.
375pub fn io_fusion_gate_for_target(target: FusionTarget) -> crate::fusion_benefit::IoFusionGate {
376    use crate::fusion_benefit::IoFusionGate;
377    match target {
378        FusionTarget::Metal | FusionTarget::Mlx => IoFusionGate {
379            dispatch_ns: 500.0,
380            roundtrip_ns: 5_000.0,
381            memory_bw: 200.0,
382            host_readback_bw: 200.0,
383            unified_memory: true,
384            host_thunk_penalty_ns: 2_000_000.0,
385            min_gain_ns: 1_000.0,
386        },
387        FusionTarget::Cuda | FusionTarget::Rocm => IoFusionGate {
388            dispatch_ns: 2_000.0,
389            roundtrip_ns: 20_000.0,
390            memory_bw: 800.0,
391            host_readback_bw: 50.0,
392            unified_memory: false,
393            host_thunk_penalty_ns: 15_000_000.0,
394            min_gain_ns: 5_000.0,
395        },
396        FusionTarget::Wgpu | FusionTarget::Tpu => IoFusionGate {
397            dispatch_ns: 3_000.0,
398            roundtrip_ns: 30_000.0,
399            memory_bw: 100.0,
400            host_readback_bw: 40.0,
401            unified_memory: false,
402            host_thunk_penalty_ns: 25_000_000.0,
403            min_gain_ns: 10_000.0,
404        },
405        FusionTarget::Cpu => IoFusionGate {
406            dispatch_ns: 50.0,
407            roundtrip_ns: 0.0,
408            memory_bw: 50.0,
409            host_readback_bw: 50.0,
410            unified_memory: true,
411            host_thunk_penalty_ns: 0.0,
412            min_gain_ns: 0.0,
413        },
414    }
415}
416
417/// Return the ordered fusion passes for `target`.
418pub fn fusion_passes(target: FusionTarget, opts: FusionOptions) -> Vec<&'static dyn Pass> {
419    let mut opts = opts;
420    // CPU thunks execute element-wise chains per-op. Metal's fused
421    // `ElementwiseRegion` MSL kernel mis-lowers long chains on deep
422    // transformer graphs (NaNs past ~14 blocks); keep FAB/RMSNorm fusions.
423    if !opts.keep_elementwise_regions
424        && matches!(target, FusionTarget::Cpu | FusionTarget::Metal)
425        && !opts.unfuse_elementwise_regions
426    {
427        opts.unfuse_elementwise_regions = true;
428    }
429    if opts.fusion_limits == FusionLimits::default() {
430        opts.fusion_limits = fusion_limits_for_target(target);
431    }
432    opts = opts.apply_native_fk_defaults(target);
433    fusion_passes_for_supported(supported_for_target(target), opts, target)
434}
435
436/// Per-target op claims used when a backend doesn't supply an explicit
437/// `supported_ops` slice. Must stay aligned with each backend's
438/// `*_SUPPORTED_OPS` in `rlx-runtime/src/backend.rs`.
439pub fn supported_for_target(target: FusionTarget) -> &'static [OpKind] {
440    use OpKind::*;
441    match target {
442        FusionTarget::Cpu => &[
443            MatMul,
444            DotGeneral,
445            ElementwiseRegion,
446            FusedSwiGLU,
447            FusedMatMulBiasAct,
448            FusedResidualLN,
449            FusedResidualRmsNorm,
450            FusedAttentionBlock,
451        ],
452        FusionTarget::Metal => &[
453            MatMul,
454            DotGeneral,
455            ElementwiseRegion,
456            TransformRegion,
457            BatchElementwiseRegion,
458            FusedSwiGLU,
459            FusedMatMulBiasAct,
460            FusedResidualLN,
461            FusedResidualRmsNorm,
462        ],
463        FusionTarget::Mlx => &[
464            MatMul,
465            DotGeneral,
466            ElementwiseRegion,
467            TransformRegion,
468            BatchElementwiseRegion,
469            FusedSwiGLU,
470            FusedMatMulBiasAct,
471            FusedResidualLN,
472            FusedResidualRmsNorm,
473        ],
474        FusionTarget::Wgpu => &[
475            MatMul,
476            ElementwiseRegion,
477            TransformRegion,
478            BatchElementwiseRegion,
479            FusedSwiGLU,
480            FusedMatMulBiasAct,
481            FusedResidualLN,
482            FusedResidualRmsNorm,
483            FusedAttentionBlock,
484            FusedTransformerLayer,
485        ],
486        FusionTarget::Cuda | FusionTarget::Rocm => &[
487            MatMul,
488            DotGeneral,
489            ElementwiseRegion,
490            TransformRegion,
491            BatchElementwiseRegion,
492            FusedMatMulBiasAct,
493            FusedResidualLN,
494            FusedResidualRmsNorm,
495        ],
496        FusionTarget::Tpu => &[
497            MatMul,
498            ElementwiseRegion,
499            TransformRegion,
500            BatchElementwiseRegion,
501            FusedMatMulBiasAct,
502            FusedResidualLN,
503        ],
504    }
505}
506
507fn finish_pipeline(mut passes: Vec<&'static dyn Pass>) -> Vec<&'static dyn Pass> {
508    passes.push(&DeadCodeElimination);
509    passes
510}
511
512/// Run the fusion pipeline for `target` on a MIR graph (IO-gated passes included).
513pub fn run_fusion_pipeline(
514    graph: Graph,
515    target: FusionTarget,
516    supported: &[OpKind],
517    opts: FusionOptions,
518) -> Graph {
519    let mut opts = opts.apply_native_fk_defaults(target);
520    if opts.fusion_limits == FusionLimits::default() {
521        opts.fusion_limits = fusion_limits_for_target(target);
522    }
523    let limits = opts.fusion_limits;
524    let passes = fusion_passes_for_supported(supported, opts, target);
525    with_fusion_target(target, || {
526        with_fusion_limits(limits, || run_passes(graph, &passes, false))
527    })
528}
529
530#[cfg(test)]
531mod tests {
532    use super::*;
533    use std::sync::Mutex;
534
535    static ENV_FK_TEST_LOCK: Mutex<()> = Mutex::new(());
536
537    #[test]
538    fn cpu_pipeline_includes_attention_block() {
539        let passes = fusion_passes(FusionTarget::Cpu, FusionOptions::default());
540        assert_eq!(
541            passes.len(),
542            18,
543            "CPU default supported_ops omit Fft/WelchPeaks"
544        );
545        assert_eq!(passes[2].name(), "fuse_matmul_bias_act");
546        assert_eq!(passes[3].name(), "fuse_attention_block");
547        assert!(
548            passes.iter().any(|p| p.name() == "fuse_region_prologue"),
549            "default CPU pipeline should run FKL prologue fusion"
550        );
551        assert_eq!(passes.last().unwrap().name(), "dead_code_elimination");
552    }
553
554    #[test]
555    fn metal_skip_fusion_only_lowers_dot() {
556        let passes = fusion_passes(
557            FusionTarget::Metal,
558            FusionOptions {
559                skip_fusion: true,
560                ..FusionOptions::default()
561            },
562        );
563        assert_eq!(passes.len(), 2);
564        assert_eq!(passes[0].name(), "LowerControlFlow");
565        assert_eq!(passes[1].name(), "lower_dot_general");
566    }
567
568    #[test]
569    fn metal_supported_ops_omit_attention_block_fusion() {
570        let passes = fusion_passes_for_supported(
571            supported_for_target(FusionTarget::Metal),
572            FusionOptions::default(),
573            FusionTarget::Metal,
574        );
575        assert!(
576            !passes.iter().any(|p| p.name() == "fuse_attention_block"),
577            "Metal should not run FuseAttentionBlock"
578        );
579        assert!(
580            passes.iter().any(|p| p.name() == "fuse_matmul_bias_act"),
581            "Metal should fuse matmul+bias+act"
582        );
583    }
584
585    #[test]
586    fn cuda_supported_ops_fuse_matmul_bias_act() {
587        let passes = fusion_passes_for_supported(
588            supported_for_target(FusionTarget::Cuda),
589            FusionOptions::default(),
590            FusionTarget::Cuda,
591        );
592        assert!(
593            passes.iter().any(|p| p.name() == "fuse_matmul_bias_act"),
594            "CUDA should fuse matmul+bias+act when claimed"
595        );
596        assert!(
597            !passes.iter().any(|p| p.name() == "fuse_swiglu"),
598            "CUDA should not fuse SwiGLU"
599        );
600    }
601
602    #[test]
603    fn cpu_unfuses_elementwise_regions() {
604        let passes = fusion_passes_for_supported(
605            supported_for_target(FusionTarget::Cpu),
606            FusionOptions::for_cpu(),
607            FusionTarget::Cpu,
608        );
609        assert!(
610            passes
611                .iter()
612                .any(|p| p.name() == "unfuse_elementwise_regions")
613        );
614    }
615
616    #[test]
617    fn metal_unfuses_elementwise_regions_by_default() {
618        let passes = fusion_passes(FusionTarget::Metal, FusionOptions::default());
619        assert!(
620            passes
621                .iter()
622                .any(|p| p.name() == "unfuse_elementwise_regions")
623        );
624    }
625
626    #[test]
627    fn metal_default_unfuse_preserves_prologue_regions() {
628        let mut g = rlx_ir::Graph::new("t");
629        let shape_in = rlx_ir::Shape::new(&[1, 3, 8, 8], rlx_ir::DType::F32);
630        let shape_out = rlx_ir::Shape::new(&[1, 3, 16, 16], rlx_ir::DType::F32);
631        let x = g.input("x", shape_in);
632        let up = g.add_node(rlx_ir::Op::ResizeNearest2x, vec![x], shape_out.clone());
633        let r = g.add_node(
634            rlx_ir::Op::Activation(rlx_ir::op::Activation::Relu),
635            vec![up],
636            shape_out,
637        );
638        g.set_outputs(vec![r]);
639
640        let passes = fusion_passes(FusionTarget::Metal, FusionOptions::default());
641        let out = rlx_fusion::pass::run_passes(g, &passes, false);
642        assert!(out.nodes().iter().any(|n| {
643            matches!(
644                n.op,
645                rlx_ir::Op::ElementwiseRegion {
646                    prologue: rlx_ir::RegionPrologue::ResizeNearest2x,
647                    ..
648                }
649            )
650        }));
651    }
652
653    #[test]
654    fn fk_passes_after_elementwise_includes_batch_fusion() {
655        let opts = FusionOptions::default().apply_native_fk_defaults(FusionTarget::Tpu);
656        let passes =
657            fk_passes_after_elementwise_regions(supported_for_target(FusionTarget::Tpu), opts);
658        let names: Vec<_> = passes.iter().map(|p| p.name()).collect();
659        assert!(names.contains(&"mark_batch_slice_regions"));
660        assert!(names.contains(&"fuse_batch_preprocess"));
661        assert!(
662            !names.contains(&"decompose_fusion_regions"),
663            "TPU native FK defaults should keep batch/transform regions"
664        );
665    }
666
667    #[test]
668    fn tpu_native_fk_region_pass_policy() {
669        let _lock = ENV_FK_TEST_LOCK.lock().unwrap();
670        let default_passes = fusion_passes(FusionTarget::Tpu, FusionOptions::default());
671        assert!(
672            !default_passes
673                .iter()
674                .any(|p| p.name() == "decompose_fusion_regions"),
675            "default TPU pipeline keeps batch/transform regions via native_fk_defaults"
676        );
677
678        rlx_ir::env::set("RLX_NO_NATIVE_FK_REGIONS", "1");
679        let opt_out = fusion_passes(FusionTarget::Tpu, FusionOptions::default());
680        rlx_ir::env::unset("RLX_NO_NATIVE_FK_REGIONS");
681        assert!(
682            opt_out
683                .iter()
684                .any(|p| p.name() == "decompose_fusion_regions"),
685            "RLX_NO_NATIVE_FK_REGIONS should force decompose on TPU"
686        );
687    }
688
689    #[test]
690    fn native_fk_regions_skips_decompose_on_tpu() {
691        let passes = fusion_passes(
692            FusionTarget::Tpu,
693            FusionOptions {
694                native_fk_regions: true,
695                decompose_fusion_regions: false,
696                unfuse_elementwise_regions: false,
697                ..FusionOptions::default()
698            },
699        );
700        assert!(
701            !passes
702                .iter()
703                .any(|p| p.name() == "decompose_fusion_regions"),
704            "native_fk_regions should skip decompose on TPU when batch/transform are supported"
705        );
706    }
707
708    #[test]
709    fn native_fk_regions_skips_decompose_on_metal() {
710        let passes = fusion_passes(
711            FusionTarget::Metal,
712            FusionOptions {
713                native_fk_regions: true,
714                decompose_fusion_regions: false,
715                unfuse_elementwise_regions: false,
716                ..FusionOptions::default()
717            },
718        );
719        assert!(
720            !passes
721                .iter()
722                .any(|p| p.name() == "decompose_fusion_regions"),
723            "native_fk_regions should skip decompose when backend claims batch/transform ops"
724        );
725    }
726
727    #[test]
728    fn metal_keeps_elementwise_regions_when_requested() {
729        let passes = fusion_passes(
730            FusionTarget::Metal,
731            FusionOptions {
732                keep_elementwise_regions: true,
733                unfuse_elementwise_regions: false,
734                ..FusionOptions::default()
735            },
736        );
737        assert!(
738            !passes
739                .iter()
740                .any(|p| p.name() == "unfuse_elementwise_regions"),
741            "keep_elementwise_regions should skip unfuse pass"
742        );
743        assert!(
744            passes.iter().any(|p| p.name() == "fuse_region_prologue"),
745            "FKL prologue fusion should still run"
746        );
747    }
748
749    #[test]
750    fn metal_audio_ops_pipeline_includes_peaks_output_gate() {
751        let mut supported = supported_for_target(FusionTarget::Metal).to_vec();
752        supported.push(OpKind::Fft);
753        supported.push(OpKind::WelchPeaks);
754        let passes =
755            fusion_passes_for_supported(&supported, FusionOptions::default(), FusionTarget::Metal);
756        assert!(
757            passes
758                .iter()
759                .any(|p| p.name() == "select_peaks_only_outputs"),
760            "Metal + Fft/WelchPeaks should run IO peaks-only output gate"
761        );
762    }
763
764    #[test]
765    fn should_fuse_with_target_matches_gate() {
766        use crate::fusion_benefit::GraphIoProfile;
767        let dense = GraphIoProfile {
768            kernel_launches: 3,
769            sync_points: 0,
770            host_output_bytes: 33_554_432,
771            device_traffic_bytes: 184_549_376,
772        };
773        let fused = GraphIoProfile {
774            kernel_launches: 4,
775            sync_points: 1,
776            host_output_bytes: 1_048_576,
777            device_traffic_bytes: 219_152_384,
778        };
779        assert!(should_fuse_with_target(FusionTarget::Metal, &dense, &fused));
780        assert!(!should_fuse_with_target(FusionTarget::Wgpu, &dense, &fused));
781    }
782}