Skip to main content

rlx_compile/
fusion_pipeline.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Centralized fusion pass pipelines per backend target.
17//!
18//! [`fusion_passes_for_supported`] selects passes from a backend's
19//! [`rlx_ir::OpKind`] claim set so fusion never emits fused ops the
20//! target cannot lower. [`fusion_passes`] keeps the legacy
21//! [`FusionTarget`] entry point and delegates to the same selector.
22
23use rlx_ir::OpKind;
24
25use crate::DeadCodeElimination;
26use crate::io_output_gate::SelectPeaksOnlyOutputs;
27use rlx_fusion::control_flow::LowerControlFlow;
28use rlx_fusion::fk_fusion::{
29    DecomposeFusionRegions, FuseBatchPreprocess, FuseRegionPrologue, MarkBatchSliceRegions,
30    MarkTransformRegions,
31};
32use rlx_fusion::fusion::{
33    FuseAttentionBlock, FuseMatMulBiasAct, FuseResidualLN, FuseResidualRmsNorm, FuseRmsNormReshape,
34    FuseSharedInputMatMul, FuseSwiGLU, FuseSwiGLUDualMatmul, FuseTransformerLayer,
35    MarkElementwiseRegions, UnfuseElementwiseRegions,
36};
37use rlx_fusion::limits::{FusionLimits, with_fusion_limits};
38use rlx_fusion::lower_dot_general::LowerDotGeneral;
39use rlx_fusion::pass::{Pass, run_passes};
40use rlx_ir::Graph;
41
42use crate::fusion_target::with_fusion_target;
43
44/// Compile target that selects a fusion pipeline.
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46pub enum FusionTarget {
47    Cpu,
48    Metal,
49    Mlx,
50    Wgpu,
51    Cuda,
52    Rocm,
53    Tpu,
54}
55
56/// Per-target fusion toggles (env-driven on Metal today).
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub struct FusionOptions {
59    /// Skip all pattern fusions (Metal: `RLX_METAL_NO_FUSION`).
60    pub skip_fusion: bool,
61    /// Break `ElementwiseRegion` back into primitives after marking.
62    pub unfuse_elementwise_regions: bool,
63    /// Keep fused `ElementwiseRegion` through lowering (env: `RLX_KEEP_ELEMENTWISE_REGIONS`).
64    pub keep_elementwise_regions: bool,
65    /// Decompose FKL-style transform / batch regions before backend lowering.
66    pub decompose_fusion_regions: bool,
67    /// Run FKL passes (`MarkTransformRegions`, prologue, batch). Env: `RLX_NO_FK_FUSION=1` disables.
68    pub fk_fusion: bool,
69    /// Fold `ResizeNearest2x` into `ElementwiseRegion` prologue. Env: `RLX_FUSE_REGION_PROLOGUE=0` disables.
70    pub fuse_region_prologue: bool,
71    /// Merge parallel region slices into `BatchElementwiseRegion`. Env: `RLX_FUSE_BATCH_PREPROCESS=0` disables.
72    pub fuse_batch_preprocess: bool,
73    /// Keep `TransformRegion` / `BatchElementwiseRegion` in MIR for native lowering. Env: `RLX_NATIVE_FK_REGIONS=1`.
74    pub native_fk_regions: bool,
75    /// Caps for fused elementwise chains (encoder / scratch limits).
76    pub fusion_limits: FusionLimits,
77}
78
79impl Default for FusionOptions {
80    fn default() -> Self {
81        Self {
82            skip_fusion: false,
83            unfuse_elementwise_regions: false,
84            keep_elementwise_regions: false,
85            decompose_fusion_regions: false,
86            fk_fusion: true,
87            fuse_region_prologue: true,
88            fuse_batch_preprocess: true,
89            native_fk_regions: false,
90            fusion_limits: FusionLimits::default(),
91        }
92    }
93}
94
95impl FusionOptions {
96    /// Read Metal-specific env overrides.
97    pub fn from_metal_env() -> Self {
98        Self {
99            skip_fusion: rlx_ir::env::flag("RLX_METAL_NO_FUSION"),
100            unfuse_elementwise_regions: rlx_ir::env::flag("RLX_METAL_UNFUSE_REGIONS"),
101            keep_elementwise_regions: rlx_ir::env::flag("RLX_KEEP_ELEMENTWISE_REGIONS"),
102            decompose_fusion_regions: rlx_ir::env::flag("RLX_DECOMPOSE_FUSION_REGIONS"),
103            fk_fusion: !rlx_ir::env::flag("RLX_NO_FK_FUSION"),
104            fuse_region_prologue: if rlx_ir::env::is_unset("RLX_FUSE_REGION_PROLOGUE") {
105                true
106            } else {
107                rlx_ir::env::flag("RLX_FUSE_REGION_PROLOGUE")
108            },
109            fuse_batch_preprocess: if rlx_ir::env::is_unset("RLX_FUSE_BATCH_PREPROCESS") {
110                true
111            } else {
112                rlx_ir::env::flag("RLX_FUSE_BATCH_PREPROCESS")
113            },
114            native_fk_regions: rlx_ir::env::flag("RLX_NATIVE_FK_REGIONS"),
115            ..Self::default()
116        }
117    }
118
119    /// Merge session options with compile-time env overrides.
120    pub fn merge_env(mut self) -> Self {
121        if rlx_ir::env::flag("RLX_METAL_NO_FUSION") {
122            self.skip_fusion = true;
123        }
124        if rlx_ir::env::flag("RLX_METAL_UNFUSE_REGIONS") {
125            self.unfuse_elementwise_regions = true;
126        }
127        if rlx_ir::env::flag("RLX_KEEP_ELEMENTWISE_REGIONS") {
128            self.keep_elementwise_regions = true;
129        }
130        if rlx_ir::env::flag("RLX_DECOMPOSE_FUSION_REGIONS") {
131            self.decompose_fusion_regions = true;
132        }
133        if rlx_ir::env::flag("RLX_NO_FK_FUSION") {
134            self.fk_fusion = false;
135        }
136        if !rlx_ir::env::is_unset("RLX_FUSE_REGION_PROLOGUE") {
137            self.fuse_region_prologue = rlx_ir::env::flag("RLX_FUSE_REGION_PROLOGUE");
138        }
139        if !rlx_ir::env::is_unset("RLX_FUSE_BATCH_PREPROCESS") {
140            self.fuse_batch_preprocess = rlx_ir::env::flag("RLX_FUSE_BATCH_PREPROCESS");
141        }
142        if rlx_ir::env::flag("RLX_NATIVE_FK_REGIONS") {
143            self.native_fk_regions = true;
144        }
145        if rlx_ir::env::flag("RLX_NO_NATIVE_FK_REGIONS") {
146            self.native_fk_regions = false;
147        }
148        self
149    }
150
151    /// GPU-class targets keep native FKL regions unless opted out.
152    pub fn apply_native_fk_defaults(mut self, target: FusionTarget) -> Self {
153        if rlx_ir::env::flag("RLX_NO_NATIVE_FK_REGIONS") {
154            self.native_fk_regions = false;
155            return self;
156        }
157        if self.native_fk_regions || rlx_ir::env::flag("RLX_NATIVE_FK_REGIONS") {
158            self.native_fk_regions = true;
159            return self;
160        }
161        if matches!(
162            target,
163            FusionTarget::Metal
164                | FusionTarget::Cuda
165                | FusionTarget::Rocm
166                | FusionTarget::Wgpu
167                | FusionTarget::Mlx
168                | FusionTarget::Tpu
169        ) {
170            self.native_fk_regions = true;
171        }
172        self
173    }
174
175    /// CPU executes element-wise chains as per-op thunks — mark then unfuse.
176    pub fn for_cpu() -> Self {
177        Self {
178            unfuse_elementwise_regions: true,
179            fusion_limits: FusionLimits::UNBOUNDED,
180            ..Self::default()
181        }
182    }
183
184    /// Metal keeps RMSNorm / matmul fusions but unfuses `ElementwiseRegion`
185    /// (fused MSL mis-lowers long chains on deep transformer graphs).
186    pub fn for_metal() -> Self {
187        let mut opts = Self::from_metal_env();
188        opts.unfuse_elementwise_regions = true;
189        opts
190    }
191
192    /// wgpu region kernel only supports trailing/scalar broadcast via
193    /// modulus — unfuse so LegalizeBroadcast Expand + Binary run separately.
194    pub fn for_wgpu() -> Self {
195        let keep = rlx_ir::env::flag("RLX_KEEP_ELEMENTWISE_REGIONS");
196        Self {
197            unfuse_elementwise_regions: !keep,
198            keep_elementwise_regions: keep,
199            ..Self::default()
200        }
201    }
202}
203
204/// Elementwise-region caps for `target` (matches GPU kernel encoders).
205pub fn fusion_limits_for_target(target: FusionTarget) -> FusionLimits {
206    match target {
207        FusionTarget::Cpu => FusionLimits::UNBOUNDED,
208        FusionTarget::Tpu => FusionLimits {
209            max_elementwise_steps: 32,
210            max_elementwise_inputs: 16,
211        },
212        _ => FusionLimits::GPU_NATIVE,
213    }
214}
215
216/// True when `supported` is empty (no claim) or contains `kind`.
217#[inline]
218pub fn supports_op(supported: &[OpKind], kind: OpKind) -> bool {
219    supported.is_empty() || supported.contains(&kind)
220}
221
222/// Return the ordered fusion passes allowed for `supported`.
223///
224/// When `supported` is empty every fusion pass runs (legacy "accept
225/// all" backends). When non-empty, each pattern fusion pass is
226/// included only if the backend claims the fused [`OpKind`] it
227/// emits. Lowering passes (`LowerControlFlow`, `LowerDotGeneral`) and
228/// `FuseRmsNormReshape` (topology-only) always run unless
229/// `skip_fusion` is set.
230pub fn fusion_passes_for_supported(
231    supported: &[OpKind],
232    opts: FusionOptions,
233    target: FusionTarget,
234) -> Vec<&'static dyn Pass> {
235    let opts = opts.apply_native_fk_defaults(target);
236    if opts.skip_fusion {
237        return vec![&LowerControlFlow, &LowerDotGeneral];
238    }
239
240    let mut passes: Vec<&'static dyn Pass> = vec![&LowerControlFlow, &LowerDotGeneral];
241
242    // ORDER: FuseMatMulBiasAct first, then FuseAttentionBlock. The block-level
243    // pass matches the post-fusion shape
244    //   FusedMatMulBiasAct(qkv) → narrow×3 → Attention → FusedMatMulBiasAct(out)
245    // which is the pattern BERT-family encoders actually present after the
246    // per-layer matmul+bias fusion has collapsed Q, K, V, and out projections.
247    if supports_op(supported, OpKind::FusedMatMulBiasAct) {
248        passes.push(&FuseMatMulBiasAct);
249    }
250    // Block-level fusion: `Op::FusedAttentionBlock`. All backends that claim
251    // this op now produce parity-correct output (the MLX
252    // `Op::FusedAttentionBlock` lowering at `rlx-mlx/src/lower.rs:1689`
253    // historically diverged on `MaskKind::Custom` BERT masks because it
254    // bypassed the binary→additive conversion and the contiguous
255    // materialization the unfused `Op::Attention` path applies — fixed
256    // alongside this pass landing).
257    if supports_op(supported, OpKind::FusedAttentionBlock) {
258        passes.push(&FuseAttentionBlock);
259    }
260    // FuseResidualLN must run BEFORE FuseTransformerLayer: the layer-level
261    // pass matches `FAB → FusedResidualLN → FMBA(GeLU) → FMBA → FusedResidualLN`
262    // and needs the residual+LN ops already collapsed.
263    if supports_op(supported, OpKind::FusedResidualLN) {
264        passes.push(&FuseResidualLN);
265    }
266    if supports_op(supported, OpKind::FusedResidualRmsNorm) {
267        passes.push(&FuseResidualRmsNorm);
268    }
269    passes.push(&FuseRmsNormReshape);
270
271    // Layer-level fusion runs AFTER FuseResidualLN so it can match the
272    // post-fusion shape `FAB → FusedResidualLN → FMBA(GeLU) → FMBA →
273    // FusedResidualLN`. Opt-in via `RLX_ENABLE_FUSE_TRANSFORMER_LAYER`
274    // because backend perf wins are uneven: WGPU un-fuses with no
275    // dispatch reduction; MLX's lowering is correct (per the FAB fix
276    // above) but the MLX `compile()` already collapses sub-ops, so the
277    // extra IR-level fusion doesn't beat the natural pipeline. The pass
278    // exists for backends planning a monolithic transformer-layer kernel.
279    if rlx_ir::env::flag("RLX_ENABLE_FUSE_TRANSFORMER_LAYER")
280        && supports_op(supported, OpKind::FusedTransformerLayer)
281        && supports_op(supported, OpKind::FusedAttentionBlock)
282    {
283        passes.push(&FuseTransformerLayer);
284    }
285
286    if supports_op(supported, OpKind::FusedSwiGLU) {
287        passes.push(&FuseSwiGLUDualMatmul);
288    }
289    if supports_op(supported, OpKind::MatMul) {
290        passes.push(&FuseSharedInputMatMul);
291    }
292    if supports_op(supported, OpKind::FusedSwiGLU) {
293        passes.push(&FuseSwiGLU);
294    }
295
296    // Mark eligible element-wise chains only when the backend keeps regions.
297    // CPU/Metal unfuse immediately afterward — marking first duplicates the
298    // full graph for no benefit.
299    let keep_regions =
300        supports_op(supported, OpKind::ElementwiseRegion) && !opts.unfuse_elementwise_regions;
301    if keep_regions {
302        passes.push(&MarkElementwiseRegions);
303    }
304    if opts.fk_fusion {
305        passes.push(&MarkBatchSliceRegions);
306        passes.push(&MarkTransformRegions);
307        if opts.fuse_region_prologue {
308            passes.push(&FuseRegionPrologue);
309        }
310        if opts.fuse_batch_preprocess {
311            passes.push(&FuseBatchPreprocess);
312        }
313    }
314    let backend_native_fk = supports_op(supported, OpKind::TransformRegion)
315        && supports_op(supported, OpKind::BatchElementwiseRegion);
316    let keep_native_fk = opts.native_fk_regions && backend_native_fk;
317    if opts.decompose_fusion_regions || !keep_native_fk {
318        passes.push(&DecomposeFusionRegions);
319    }
320    if !keep_regions {
321        let unfuse = if matches!(target, FusionTarget::Cpu) {
322            &UnfuseElementwiseRegions::FOR_CPU
323        } else {
324            &UnfuseElementwiseRegions::FOR_GPU
325        };
326        passes.push(unfuse);
327    }
328
329    if supports_op(supported, OpKind::Fft) && supports_op(supported, OpKind::WelchPeaks) {
330        passes.push(&SelectPeaksOnlyOutputs);
331    }
332
333    finish_pipeline(passes)
334}
335
336/// FKL passes to run after [`MarkElementwiseRegions`] (e.g. `TpuExecutable::compile`).
337pub fn fk_passes_after_elementwise_regions(
338    supported: &[OpKind],
339    opts: FusionOptions,
340) -> Vec<&'static dyn Pass> {
341    let mut passes: Vec<&'static dyn Pass> = Vec::new();
342    if !opts.fk_fusion {
343        let backend_native_fk = supports_op(supported, OpKind::TransformRegion)
344            && supports_op(supported, OpKind::BatchElementwiseRegion);
345        let keep_native_fk = opts.native_fk_regions && backend_native_fk;
346        if opts.decompose_fusion_regions || !keep_native_fk {
347            passes.push(&DecomposeFusionRegions);
348        }
349        return finish_pipeline(passes);
350    }
351    passes.push(&MarkBatchSliceRegions);
352    passes.push(&MarkTransformRegions);
353    if opts.fuse_region_prologue {
354        passes.push(&FuseRegionPrologue);
355    }
356    if opts.fuse_batch_preprocess {
357        passes.push(&FuseBatchPreprocess);
358    }
359    let backend_native_fk = supports_op(supported, OpKind::TransformRegion)
360        && supports_op(supported, OpKind::BatchElementwiseRegion);
361    let keep_native_fk = opts.native_fk_regions && backend_native_fk;
362    if opts.decompose_fusion_regions || !keep_native_fk {
363        passes.push(&DecomposeFusionRegions);
364    }
365    finish_pipeline(passes)
366}
367
368/// IO gate decision for a rewrite on `target` (convenience for compile passes / model crates).
369pub fn should_fuse_with_target(
370    target: FusionTarget,
371    before: &crate::fusion_benefit::GraphIoProfile,
372    after: &crate::fusion_benefit::GraphIoProfile,
373) -> bool {
374    io_fusion_gate_for_target(target).should_fuse(before, after)
375}
376
377/// Phase 3 — IO-aware gate defaults for fusion rewrites on `target`.
378pub fn io_fusion_gate_for_target(target: FusionTarget) -> crate::fusion_benefit::IoFusionGate {
379    use crate::fusion_benefit::IoFusionGate;
380    match target {
381        FusionTarget::Metal | FusionTarget::Mlx => IoFusionGate {
382            dispatch_ns: 500.0,
383            roundtrip_ns: 5_000.0,
384            memory_bw: 200.0,
385            host_readback_bw: 200.0,
386            unified_memory: true,
387            host_thunk_penalty_ns: 2_000_000.0,
388            min_gain_ns: 1_000.0,
389        },
390        FusionTarget::Cuda | FusionTarget::Rocm => IoFusionGate {
391            dispatch_ns: 2_000.0,
392            roundtrip_ns: 20_000.0,
393            memory_bw: 800.0,
394            host_readback_bw: 50.0,
395            unified_memory: false,
396            host_thunk_penalty_ns: 15_000_000.0,
397            min_gain_ns: 5_000.0,
398        },
399        FusionTarget::Wgpu | FusionTarget::Tpu => IoFusionGate {
400            dispatch_ns: 3_000.0,
401            roundtrip_ns: 30_000.0,
402            memory_bw: 100.0,
403            host_readback_bw: 40.0,
404            unified_memory: false,
405            host_thunk_penalty_ns: 25_000_000.0,
406            min_gain_ns: 10_000.0,
407        },
408        FusionTarget::Cpu => IoFusionGate {
409            dispatch_ns: 50.0,
410            roundtrip_ns: 0.0,
411            memory_bw: 50.0,
412            host_readback_bw: 50.0,
413            unified_memory: true,
414            host_thunk_penalty_ns: 0.0,
415            min_gain_ns: 0.0,
416        },
417    }
418}
419
420/// Return the ordered fusion passes for `target`.
421pub fn fusion_passes(target: FusionTarget, opts: FusionOptions) -> Vec<&'static dyn Pass> {
422    let mut opts = opts;
423    // CPU thunks execute element-wise chains per-op. Metal's fused
424    // `ElementwiseRegion` MSL kernel mis-lowers long chains on deep
425    // transformer graphs (NaNs past ~14 blocks); keep FAB/RMSNorm fusions.
426    if !opts.keep_elementwise_regions
427        && matches!(target, FusionTarget::Cpu | FusionTarget::Metal)
428        && !opts.unfuse_elementwise_regions
429    {
430        opts.unfuse_elementwise_regions = true;
431    }
432    if opts.fusion_limits == FusionLimits::default() {
433        opts.fusion_limits = fusion_limits_for_target(target);
434    }
435    opts = opts.apply_native_fk_defaults(target);
436    fusion_passes_for_supported(supported_for_target(target), opts, target)
437}
438
439/// Per-target op claims used when a backend doesn't supply an explicit
440/// `supported_ops` slice. Must stay aligned with each backend's
441/// `*_SUPPORTED_OPS` in `rlx-runtime/src/backend.rs`.
442pub fn supported_for_target(target: FusionTarget) -> &'static [OpKind] {
443    use OpKind::*;
444    match target {
445        FusionTarget::Cpu => &[
446            MatMul,
447            DotGeneral,
448            ElementwiseRegion,
449            FusedSwiGLU,
450            FusedMatMulBiasAct,
451            FusedResidualLN,
452            FusedResidualRmsNorm,
453            FusedAttentionBlock,
454        ],
455        FusionTarget::Metal => &[
456            MatMul,
457            DotGeneral,
458            ElementwiseRegion,
459            TransformRegion,
460            BatchElementwiseRegion,
461            FusedSwiGLU,
462            FusedMatMulBiasAct,
463            FusedResidualLN,
464            FusedResidualRmsNorm,
465        ],
466        FusionTarget::Mlx => &[
467            MatMul,
468            DotGeneral,
469            ElementwiseRegion,
470            TransformRegion,
471            BatchElementwiseRegion,
472            FusedSwiGLU,
473            FusedMatMulBiasAct,
474            FusedResidualLN,
475            FusedResidualRmsNorm,
476        ],
477        FusionTarget::Wgpu => &[
478            MatMul,
479            ElementwiseRegion,
480            TransformRegion,
481            BatchElementwiseRegion,
482            FusedSwiGLU,
483            FusedMatMulBiasAct,
484            FusedResidualLN,
485            FusedResidualRmsNorm,
486            FusedAttentionBlock,
487            FusedTransformerLayer,
488        ],
489        FusionTarget::Cuda | FusionTarget::Rocm => &[
490            MatMul,
491            DotGeneral,
492            ElementwiseRegion,
493            TransformRegion,
494            BatchElementwiseRegion,
495            FusedMatMulBiasAct,
496            FusedResidualLN,
497            FusedResidualRmsNorm,
498        ],
499        FusionTarget::Tpu => &[
500            MatMul,
501            ElementwiseRegion,
502            TransformRegion,
503            BatchElementwiseRegion,
504            FusedMatMulBiasAct,
505            FusedResidualLN,
506        ],
507    }
508}
509
510fn finish_pipeline(mut passes: Vec<&'static dyn Pass>) -> Vec<&'static dyn Pass> {
511    passes.push(&DeadCodeElimination);
512    passes
513}
514
515/// Run the fusion pipeline for `target` on a MIR graph (IO-gated passes included).
516pub fn run_fusion_pipeline(
517    graph: Graph,
518    target: FusionTarget,
519    supported: &[OpKind],
520    opts: FusionOptions,
521) -> Graph {
522    let mut opts = opts.apply_native_fk_defaults(target);
523    if opts.fusion_limits == FusionLimits::default() {
524        opts.fusion_limits = fusion_limits_for_target(target);
525    }
526    let limits = opts.fusion_limits;
527    let passes = fusion_passes_for_supported(supported, opts, target);
528    with_fusion_target(target, || {
529        with_fusion_limits(limits, || run_passes(graph, &passes, false))
530    })
531}
532
533#[cfg(test)]
534mod tests {
535    use super::*;
536    use std::sync::Mutex;
537
538    static ENV_FK_TEST_LOCK: Mutex<()> = Mutex::new(());
539
540    #[test]
541    fn cpu_pipeline_includes_attention_block() {
542        let passes = fusion_passes(FusionTarget::Cpu, FusionOptions::default());
543        assert_eq!(
544            passes.len(),
545            17,
546            "CPU default supported_ops omit Fft/WelchPeaks and mark_elementwise (unfuse-only backends skip mark)"
547        );
548        assert_eq!(passes[2].name(), "fuse_matmul_bias_act");
549        assert_eq!(passes[3].name(), "fuse_attention_block");
550        assert!(
551            passes.iter().any(|p| p.name() == "fuse_region_prologue"),
552            "default CPU pipeline should run FKL prologue fusion"
553        );
554        assert!(
555            !passes
556                .iter()
557                .any(|p| p.name() == "mark_elementwise_regions"),
558            "CPU unfuse backends should not mark elementwise regions before unfusing"
559        );
560        assert_eq!(passes.last().unwrap().name(), "dead_code_elimination");
561    }
562
563    #[test]
564    fn metal_skip_fusion_only_lowers_dot() {
565        let passes = fusion_passes(
566            FusionTarget::Metal,
567            FusionOptions {
568                skip_fusion: true,
569                ..FusionOptions::default()
570            },
571        );
572        assert_eq!(passes.len(), 2);
573        assert_eq!(passes[0].name(), "LowerControlFlow");
574        assert_eq!(passes[1].name(), "lower_dot_general");
575    }
576
577    #[test]
578    fn metal_supported_ops_omit_attention_block_fusion() {
579        let passes = fusion_passes_for_supported(
580            supported_for_target(FusionTarget::Metal),
581            FusionOptions::default(),
582            FusionTarget::Metal,
583        );
584        assert!(
585            !passes.iter().any(|p| p.name() == "fuse_attention_block"),
586            "Metal should not run FuseAttentionBlock"
587        );
588        assert!(
589            passes.iter().any(|p| p.name() == "fuse_matmul_bias_act"),
590            "Metal should fuse matmul+bias+act"
591        );
592    }
593
594    #[test]
595    fn cuda_supported_ops_fuse_matmul_bias_act() {
596        let passes = fusion_passes_for_supported(
597            supported_for_target(FusionTarget::Cuda),
598            FusionOptions::default(),
599            FusionTarget::Cuda,
600        );
601        assert!(
602            passes.iter().any(|p| p.name() == "fuse_matmul_bias_act"),
603            "CUDA should fuse matmul+bias+act when claimed"
604        );
605        assert!(
606            !passes.iter().any(|p| p.name() == "fuse_swiglu"),
607            "CUDA should not fuse SwiGLU"
608        );
609    }
610
611    #[test]
612    fn cpu_unfuses_elementwise_regions() {
613        let passes = fusion_passes_for_supported(
614            supported_for_target(FusionTarget::Cpu),
615            FusionOptions::for_cpu(),
616            FusionTarget::Cpu,
617        );
618        assert!(
619            passes
620                .iter()
621                .any(|p| p.name() == "unfuse_elementwise_regions")
622        );
623    }
624
625    #[test]
626    fn metal_unfuses_elementwise_regions_by_default() {
627        let passes = fusion_passes(FusionTarget::Metal, FusionOptions::default());
628        assert!(
629            passes
630                .iter()
631                .any(|p| p.name() == "unfuse_elementwise_regions")
632        );
633    }
634
635    #[test]
636    fn metal_default_unfuse_preserves_prologue_regions() {
637        let mut g = rlx_ir::Graph::new("t");
638        let shape_in = rlx_ir::Shape::new(&[1, 3, 8, 8], rlx_ir::DType::F32);
639        let shape_out = rlx_ir::Shape::new(&[1, 3, 16, 16], rlx_ir::DType::F32);
640        let x = g.input("x", shape_in);
641        let up = g.add_node(rlx_ir::Op::ResizeNearest2x, vec![x], shape_out.clone());
642        let r = g.add_node(
643            rlx_ir::Op::Activation(rlx_ir::op::Activation::Relu),
644            vec![up],
645            shape_out,
646        );
647        g.set_outputs(vec![r]);
648
649        let passes = fusion_passes(FusionTarget::Metal, FusionOptions::default());
650        let out = rlx_fusion::pass::run_passes(g, &passes, false);
651        assert!(out.nodes().iter().any(|n| {
652            matches!(
653                n.op,
654                rlx_ir::Op::ElementwiseRegion {
655                    prologue: rlx_ir::RegionPrologue::ResizeNearest2x,
656                    ..
657                }
658            )
659        }));
660    }
661
662    #[test]
663    fn fk_passes_after_elementwise_includes_batch_fusion() {
664        let opts = FusionOptions::default().apply_native_fk_defaults(FusionTarget::Tpu);
665        let passes =
666            fk_passes_after_elementwise_regions(supported_for_target(FusionTarget::Tpu), opts);
667        let names: Vec<_> = passes.iter().map(|p| p.name()).collect();
668        assert!(names.contains(&"mark_batch_slice_regions"));
669        assert!(names.contains(&"fuse_batch_preprocess"));
670        assert!(
671            !names.contains(&"decompose_fusion_regions"),
672            "TPU native FK defaults should keep batch/transform regions"
673        );
674    }
675
676    #[test]
677    fn tpu_native_fk_region_pass_policy() {
678        let _lock = ENV_FK_TEST_LOCK.lock().unwrap();
679        let default_passes = fusion_passes(FusionTarget::Tpu, FusionOptions::default());
680        assert!(
681            !default_passes
682                .iter()
683                .any(|p| p.name() == "decompose_fusion_regions"),
684            "default TPU pipeline keeps batch/transform regions via native_fk_defaults"
685        );
686
687        rlx_ir::env::set("RLX_NO_NATIVE_FK_REGIONS", "1");
688        let opt_out = fusion_passes(FusionTarget::Tpu, FusionOptions::default());
689        rlx_ir::env::unset("RLX_NO_NATIVE_FK_REGIONS");
690        assert!(
691            opt_out
692                .iter()
693                .any(|p| p.name() == "decompose_fusion_regions"),
694            "RLX_NO_NATIVE_FK_REGIONS should force decompose on TPU"
695        );
696    }
697
698    #[test]
699    fn native_fk_regions_skips_decompose_on_tpu() {
700        let passes = fusion_passes(
701            FusionTarget::Tpu,
702            FusionOptions {
703                native_fk_regions: true,
704                decompose_fusion_regions: false,
705                unfuse_elementwise_regions: false,
706                ..FusionOptions::default()
707            },
708        );
709        assert!(
710            !passes
711                .iter()
712                .any(|p| p.name() == "decompose_fusion_regions"),
713            "native_fk_regions should skip decompose on TPU when batch/transform are supported"
714        );
715    }
716
717    #[test]
718    fn native_fk_regions_skips_decompose_on_metal() {
719        let passes = fusion_passes(
720            FusionTarget::Metal,
721            FusionOptions {
722                native_fk_regions: true,
723                decompose_fusion_regions: false,
724                unfuse_elementwise_regions: false,
725                ..FusionOptions::default()
726            },
727        );
728        assert!(
729            !passes
730                .iter()
731                .any(|p| p.name() == "decompose_fusion_regions"),
732            "native_fk_regions should skip decompose when backend claims batch/transform ops"
733        );
734    }
735
736    #[test]
737    fn metal_keeps_elementwise_regions_when_requested() {
738        let passes = fusion_passes(
739            FusionTarget::Metal,
740            FusionOptions {
741                keep_elementwise_regions: true,
742                unfuse_elementwise_regions: false,
743                ..FusionOptions::default()
744            },
745        );
746        assert!(
747            !passes
748                .iter()
749                .any(|p| p.name() == "unfuse_elementwise_regions"),
750            "keep_elementwise_regions should skip unfuse pass"
751        );
752        assert!(
753            passes.iter().any(|p| p.name() == "fuse_region_prologue"),
754            "FKL prologue fusion should still run"
755        );
756    }
757
758    #[test]
759    fn metal_audio_ops_pipeline_includes_peaks_output_gate() {
760        let mut supported = supported_for_target(FusionTarget::Metal).to_vec();
761        supported.push(OpKind::Fft);
762        supported.push(OpKind::WelchPeaks);
763        let passes =
764            fusion_passes_for_supported(&supported, FusionOptions::default(), FusionTarget::Metal);
765        assert!(
766            passes
767                .iter()
768                .any(|p| p.name() == "select_peaks_only_outputs"),
769            "Metal + Fft/WelchPeaks should run IO peaks-only output gate"
770        );
771    }
772
773    #[test]
774    fn should_fuse_with_target_matches_gate() {
775        use crate::fusion_benefit::GraphIoProfile;
776        let dense = GraphIoProfile {
777            kernel_launches: 3,
778            sync_points: 0,
779            host_output_bytes: 33_554_432,
780            device_traffic_bytes: 184_549_376,
781        };
782        let fused = GraphIoProfile {
783            kernel_launches: 4,
784            sync_points: 1,
785            host_output_bytes: 1_048_576,
786            device_traffic_bytes: 219_152_384,
787        };
788        assert!(should_fuse_with_target(FusionTarget::Metal, &dense, &fused));
789        assert!(!should_fuse_with_target(FusionTarget::Wgpu, &dense, &fused));
790    }
791}