Skip to main content

rlx_compile/
fusion_pipeline.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Centralized fusion pass pipelines per backend target.
17//!
18//! [`fusion_passes_for_supported`] selects passes from a backend's
19//! [`rlx_ir::OpKind`] claim set so fusion never emits fused ops the
20//! target cannot lower. [`fusion_passes`] keeps the legacy
21//! [`FusionTarget`] entry point and delegates to the same selector.
22
23use rlx_ir::OpKind;
24
25use crate::DeadCodeElimination;
26use rlx_fusion::control_flow::LowerControlFlow;
27use rlx_fusion::fk_fusion::{
28    DecomposeFusionRegions, FuseBatchPreprocess, FuseRegionPrologue, MarkBatchSliceRegions,
29    MarkTransformRegions,
30};
31use rlx_fusion::fusion::{
32    FuseAttentionBlock, FuseMatMulBiasAct, FuseResidualLN, FuseResidualRmsNorm, FuseRmsNormReshape,
33    FuseSharedInputMatMul, FuseSwiGLU, FuseSwiGLUDualMatmul, FuseTransformerLayer,
34    MarkElementwiseRegions, UnfuseElementwiseRegions,
35};
36use rlx_fusion::limits::FusionLimits;
37use rlx_fusion::lower_dot_general::LowerDotGeneral;
38use rlx_fusion::pass::Pass;
39
40/// Compile target that selects a fusion pipeline.
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum FusionTarget {
43    Cpu,
44    Metal,
45    Mlx,
46    Wgpu,
47    Cuda,
48    Rocm,
49    Tpu,
50}
51
52/// Per-target fusion toggles (env-driven on Metal today).
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub struct FusionOptions {
55    /// Skip all pattern fusions (Metal: `RLX_METAL_NO_FUSION`).
56    pub skip_fusion: bool,
57    /// Break `ElementwiseRegion` back into primitives after marking.
58    pub unfuse_elementwise_regions: bool,
59    /// Keep fused `ElementwiseRegion` through lowering (env: `RLX_KEEP_ELEMENTWISE_REGIONS`).
60    pub keep_elementwise_regions: bool,
61    /// Decompose FKL-style transform / batch regions before backend lowering.
62    pub decompose_fusion_regions: bool,
63    /// Run FKL passes (`MarkTransformRegions`, prologue, batch). Env: `RLX_NO_FK_FUSION=1` disables.
64    pub fk_fusion: bool,
65    /// Fold `ResizeNearest2x` into `ElementwiseRegion` prologue. Env: `RLX_FUSE_REGION_PROLOGUE=0` disables.
66    pub fuse_region_prologue: bool,
67    /// Merge parallel region slices into `BatchElementwiseRegion`. Env: `RLX_FUSE_BATCH_PREPROCESS=0` disables.
68    pub fuse_batch_preprocess: bool,
69    /// Keep `TransformRegion` / `BatchElementwiseRegion` in MIR for native lowering. Env: `RLX_NATIVE_FK_REGIONS=1`.
70    pub native_fk_regions: bool,
71    /// Caps for fused elementwise chains (encoder / scratch limits).
72    pub fusion_limits: FusionLimits,
73}
74
75impl Default for FusionOptions {
76    fn default() -> Self {
77        Self {
78            skip_fusion: false,
79            unfuse_elementwise_regions: false,
80            keep_elementwise_regions: false,
81            decompose_fusion_regions: false,
82            fk_fusion: true,
83            fuse_region_prologue: true,
84            fuse_batch_preprocess: true,
85            native_fk_regions: false,
86            fusion_limits: FusionLimits::default(),
87        }
88    }
89}
90
91impl FusionOptions {
92    /// Read Metal-specific env overrides.
93    pub fn from_metal_env() -> Self {
94        Self {
95            skip_fusion: rlx_ir::env::flag("RLX_METAL_NO_FUSION"),
96            unfuse_elementwise_regions: rlx_ir::env::flag("RLX_METAL_UNFUSE_REGIONS"),
97            keep_elementwise_regions: rlx_ir::env::flag("RLX_KEEP_ELEMENTWISE_REGIONS"),
98            decompose_fusion_regions: rlx_ir::env::flag("RLX_DECOMPOSE_FUSION_REGIONS"),
99            fk_fusion: !rlx_ir::env::flag("RLX_NO_FK_FUSION"),
100            fuse_region_prologue: if rlx_ir::env::is_unset("RLX_FUSE_REGION_PROLOGUE") {
101                true
102            } else {
103                rlx_ir::env::flag("RLX_FUSE_REGION_PROLOGUE")
104            },
105            fuse_batch_preprocess: if rlx_ir::env::is_unset("RLX_FUSE_BATCH_PREPROCESS") {
106                true
107            } else {
108                rlx_ir::env::flag("RLX_FUSE_BATCH_PREPROCESS")
109            },
110            native_fk_regions: rlx_ir::env::flag("RLX_NATIVE_FK_REGIONS"),
111            ..Self::default()
112        }
113    }
114
115    /// Merge session options with compile-time env overrides.
116    pub fn merge_env(mut self) -> Self {
117        if rlx_ir::env::flag("RLX_METAL_NO_FUSION") {
118            self.skip_fusion = true;
119        }
120        if rlx_ir::env::flag("RLX_METAL_UNFUSE_REGIONS") {
121            self.unfuse_elementwise_regions = true;
122        }
123        if rlx_ir::env::flag("RLX_KEEP_ELEMENTWISE_REGIONS") {
124            self.keep_elementwise_regions = true;
125        }
126        if rlx_ir::env::flag("RLX_DECOMPOSE_FUSION_REGIONS") {
127            self.decompose_fusion_regions = true;
128        }
129        if rlx_ir::env::flag("RLX_NO_FK_FUSION") {
130            self.fk_fusion = false;
131        }
132        if !rlx_ir::env::is_unset("RLX_FUSE_REGION_PROLOGUE") {
133            self.fuse_region_prologue = rlx_ir::env::flag("RLX_FUSE_REGION_PROLOGUE");
134        }
135        if !rlx_ir::env::is_unset("RLX_FUSE_BATCH_PREPROCESS") {
136            self.fuse_batch_preprocess = rlx_ir::env::flag("RLX_FUSE_BATCH_PREPROCESS");
137        }
138        if rlx_ir::env::flag("RLX_NATIVE_FK_REGIONS") {
139            self.native_fk_regions = true;
140        }
141        if rlx_ir::env::flag("RLX_NO_NATIVE_FK_REGIONS") {
142            self.native_fk_regions = false;
143        }
144        self
145    }
146
147    /// GPU-class targets keep native FKL regions unless opted out.
148    pub fn apply_native_fk_defaults(mut self, target: FusionTarget) -> Self {
149        if rlx_ir::env::flag("RLX_NO_NATIVE_FK_REGIONS") {
150            self.native_fk_regions = false;
151            return self;
152        }
153        if self.native_fk_regions || rlx_ir::env::flag("RLX_NATIVE_FK_REGIONS") {
154            self.native_fk_regions = true;
155            return self;
156        }
157        if matches!(
158            target,
159            FusionTarget::Metal
160                | FusionTarget::Cuda
161                | FusionTarget::Rocm
162                | FusionTarget::Wgpu
163                | FusionTarget::Mlx
164                | FusionTarget::Tpu
165        ) {
166            self.native_fk_regions = true;
167        }
168        self
169    }
170
171    /// CPU executes element-wise chains as per-op thunks — mark then unfuse.
172    pub fn for_cpu() -> Self {
173        Self {
174            unfuse_elementwise_regions: true,
175            fusion_limits: FusionLimits::UNBOUNDED,
176            ..Self::default()
177        }
178    }
179
180    /// Metal keeps RMSNorm / matmul fusions but unfuses `ElementwiseRegion`
181    /// (fused MSL mis-lowers long chains on deep transformer graphs).
182    pub fn for_metal() -> Self {
183        let mut opts = Self::from_metal_env();
184        opts.unfuse_elementwise_regions = true;
185        opts
186    }
187
188    /// wgpu region kernel only supports trailing/scalar broadcast via
189    /// modulus — unfuse so LegalizeBroadcast Expand + Binary run separately.
190    pub fn for_wgpu() -> Self {
191        let keep = rlx_ir::env::flag("RLX_KEEP_ELEMENTWISE_REGIONS");
192        Self {
193            unfuse_elementwise_regions: !keep,
194            keep_elementwise_regions: keep,
195            ..Self::default()
196        }
197    }
198}
199
200/// Elementwise-region caps for `target` (matches GPU kernel encoders).
201pub fn fusion_limits_for_target(target: FusionTarget) -> FusionLimits {
202    match target {
203        FusionTarget::Cpu => FusionLimits::UNBOUNDED,
204        FusionTarget::Tpu => FusionLimits {
205            max_elementwise_steps: 32,
206            max_elementwise_inputs: 16,
207        },
208        _ => FusionLimits::GPU_NATIVE,
209    }
210}
211
212/// True when `supported` is empty (no claim) or contains `kind`.
213#[inline]
214pub fn supports_op(supported: &[OpKind], kind: OpKind) -> bool {
215    supported.is_empty() || supported.contains(&kind)
216}
217
218/// Return the ordered fusion passes allowed for `supported`.
219///
220/// When `supported` is empty every fusion pass runs (legacy "accept
221/// all" backends). When non-empty, each pattern fusion pass is
222/// included only if the backend claims the fused [`OpKind`] it
223/// emits. Lowering passes (`LowerControlFlow`, `LowerDotGeneral`) and
224/// `FuseRmsNormReshape` (topology-only) always run unless
225/// `skip_fusion` is set.
226pub fn fusion_passes_for_supported(
227    supported: &[OpKind],
228    opts: FusionOptions,
229    target: FusionTarget,
230) -> Vec<&'static dyn Pass> {
231    let opts = opts.apply_native_fk_defaults(target);
232    if opts.skip_fusion {
233        return vec![&LowerControlFlow, &LowerDotGeneral];
234    }
235
236    let mut passes: Vec<&'static dyn Pass> = vec![&LowerControlFlow, &LowerDotGeneral];
237
238    // ORDER: FuseMatMulBiasAct first, then FuseAttentionBlock. The block-level
239    // pass matches the post-fusion shape
240    //   FusedMatMulBiasAct(qkv) → narrow×3 → Attention → FusedMatMulBiasAct(out)
241    // which is the pattern BERT-family encoders actually present after the
242    // per-layer matmul+bias fusion has collapsed Q, K, V, and out projections.
243    if supports_op(supported, OpKind::FusedMatMulBiasAct) {
244        passes.push(&FuseMatMulBiasAct);
245    }
246    // Block-level fusion: `Op::FusedAttentionBlock`. All backends that claim
247    // this op now produce parity-correct output (the MLX
248    // `Op::FusedAttentionBlock` lowering at `rlx-mlx/src/lower.rs:1689`
249    // historically diverged on `MaskKind::Custom` BERT masks because it
250    // bypassed the binary→additive conversion and the contiguous
251    // materialization the unfused `Op::Attention` path applies — fixed
252    // alongside this pass landing).
253    if supports_op(supported, OpKind::FusedAttentionBlock) {
254        passes.push(&FuseAttentionBlock);
255    }
256    // FuseResidualLN must run BEFORE FuseTransformerLayer: the layer-level
257    // pass matches `FAB → FusedResidualLN → FMBA(GeLU) → FMBA → FusedResidualLN`
258    // and needs the residual+LN ops already collapsed.
259    if supports_op(supported, OpKind::FusedResidualLN) {
260        passes.push(&FuseResidualLN);
261    }
262    if supports_op(supported, OpKind::FusedResidualRmsNorm) {
263        passes.push(&FuseResidualRmsNorm);
264    }
265    passes.push(&FuseRmsNormReshape);
266
267    // Layer-level fusion runs AFTER FuseResidualLN so it can match the
268    // post-fusion shape `FAB → FusedResidualLN → FMBA(GeLU) → FMBA →
269    // FusedResidualLN`. Opt-in via `RLX_ENABLE_FUSE_TRANSFORMER_LAYER`
270    // because backend perf wins are uneven: WGPU un-fuses with no
271    // dispatch reduction; MLX's lowering is correct (per the FAB fix
272    // above) but the MLX `compile()` already collapses sub-ops, so the
273    // extra IR-level fusion doesn't beat the natural pipeline. The pass
274    // exists for backends planning a monolithic transformer-layer kernel.
275    if rlx_ir::env::flag("RLX_ENABLE_FUSE_TRANSFORMER_LAYER")
276        && supports_op(supported, OpKind::FusedTransformerLayer)
277        && supports_op(supported, OpKind::FusedAttentionBlock)
278    {
279        passes.push(&FuseTransformerLayer);
280    }
281
282    if supports_op(supported, OpKind::FusedSwiGLU) {
283        passes.push(&FuseSwiGLUDualMatmul);
284    }
285    if supports_op(supported, OpKind::MatMul) {
286        passes.push(&FuseSharedInputMatMul);
287    }
288    if supports_op(supported, OpKind::FusedSwiGLU) {
289        passes.push(&FuseSwiGLU);
290    }
291
292    // Mark eligible element-wise chains. Backends that don't lower
293    // ElementwiseRegion natively unfuse immediately afterward.
294    passes.push(&MarkElementwiseRegions);
295    if opts.fk_fusion {
296        passes.push(&MarkBatchSliceRegions);
297        passes.push(&MarkTransformRegions);
298        if opts.fuse_region_prologue {
299            passes.push(&FuseRegionPrologue);
300        }
301        if opts.fuse_batch_preprocess {
302            passes.push(&FuseBatchPreprocess);
303        }
304    }
305    let backend_native_fk = supports_op(supported, OpKind::TransformRegion)
306        && supports_op(supported, OpKind::BatchElementwiseRegion);
307    let keep_native_fk = opts.native_fk_regions && backend_native_fk;
308    if opts.decompose_fusion_regions || !keep_native_fk {
309        passes.push(&DecomposeFusionRegions);
310    }
311    let keep_regions =
312        supports_op(supported, OpKind::ElementwiseRegion) && !opts.unfuse_elementwise_regions;
313    if !keep_regions {
314        let unfuse = if matches!(target, FusionTarget::Cpu) {
315            &UnfuseElementwiseRegions::FOR_CPU
316        } else {
317            &UnfuseElementwiseRegions::FOR_GPU
318        };
319        passes.push(unfuse);
320    }
321
322    finish_pipeline(passes)
323}
324
325/// FKL passes to run after [`MarkElementwiseRegions`] (e.g. `TpuExecutable::compile`).
326pub fn fk_passes_after_elementwise_regions(
327    supported: &[OpKind],
328    opts: FusionOptions,
329) -> Vec<&'static dyn Pass> {
330    let mut passes: Vec<&'static dyn Pass> = Vec::new();
331    if !opts.fk_fusion {
332        let backend_native_fk = supports_op(supported, OpKind::TransformRegion)
333            && supports_op(supported, OpKind::BatchElementwiseRegion);
334        let keep_native_fk = opts.native_fk_regions && backend_native_fk;
335        if opts.decompose_fusion_regions || !keep_native_fk {
336            passes.push(&DecomposeFusionRegions);
337        }
338        return finish_pipeline(passes);
339    }
340    passes.push(&MarkBatchSliceRegions);
341    passes.push(&MarkTransformRegions);
342    if opts.fuse_region_prologue {
343        passes.push(&FuseRegionPrologue);
344    }
345    if opts.fuse_batch_preprocess {
346        passes.push(&FuseBatchPreprocess);
347    }
348    let backend_native_fk = supports_op(supported, OpKind::TransformRegion)
349        && supports_op(supported, OpKind::BatchElementwiseRegion);
350    let keep_native_fk = opts.native_fk_regions && backend_native_fk;
351    if opts.decompose_fusion_regions || !keep_native_fk {
352        passes.push(&DecomposeFusionRegions);
353    }
354    finish_pipeline(passes)
355}
356
357/// Phase 3 — IO-aware gate defaults for fusion rewrites on `target`.
358pub fn io_fusion_gate_for_target(target: FusionTarget) -> crate::fusion_benefit::IoFusionGate {
359    use crate::fusion_benefit::IoFusionGate;
360    match target {
361        FusionTarget::Metal | FusionTarget::Mlx => IoFusionGate {
362            dispatch_ns: 500.0,
363            roundtrip_ns: 5_000.0,
364            memory_bw: 200.0,
365            host_readback_bw: 200.0,
366            unified_memory: true,
367            min_gain_ns: 1_000.0,
368        },
369        FusionTarget::Cuda | FusionTarget::Rocm => IoFusionGate {
370            dispatch_ns: 2_000.0,
371            roundtrip_ns: 20_000.0,
372            memory_bw: 800.0,
373            host_readback_bw: 50.0,
374            unified_memory: false,
375            min_gain_ns: 5_000.0,
376        },
377        FusionTarget::Wgpu | FusionTarget::Tpu => IoFusionGate {
378            dispatch_ns: 3_000.0,
379            roundtrip_ns: 30_000.0,
380            memory_bw: 100.0,
381            host_readback_bw: 40.0,
382            unified_memory: false,
383            min_gain_ns: 10_000.0,
384        },
385        FusionTarget::Cpu => IoFusionGate {
386            dispatch_ns: 50.0,
387            roundtrip_ns: 0.0,
388            memory_bw: 50.0,
389            host_readback_bw: 50.0,
390            unified_memory: true,
391            min_gain_ns: 0.0,
392        },
393    }
394}
395
396/// Return the ordered fusion passes for `target`.
397pub fn fusion_passes(target: FusionTarget, opts: FusionOptions) -> Vec<&'static dyn Pass> {
398    let mut opts = opts;
399    // CPU thunks execute element-wise chains per-op. Metal's fused
400    // `ElementwiseRegion` MSL kernel mis-lowers long chains on deep
401    // transformer graphs (NaNs past ~14 blocks); keep FAB/RMSNorm fusions.
402    if !opts.keep_elementwise_regions
403        && matches!(target, FusionTarget::Cpu | FusionTarget::Metal)
404        && !opts.unfuse_elementwise_regions
405    {
406        opts.unfuse_elementwise_regions = true;
407    }
408    if opts.fusion_limits == FusionLimits::default() {
409        opts.fusion_limits = fusion_limits_for_target(target);
410    }
411    opts = opts.apply_native_fk_defaults(target);
412    fusion_passes_for_supported(supported_for_target(target), opts, target)
413}
414
415/// Per-target op claims used when a backend doesn't supply an explicit
416/// `supported_ops` slice. Must stay aligned with each backend's
417/// `*_SUPPORTED_OPS` in `rlx-runtime/src/backend.rs`.
418pub fn supported_for_target(target: FusionTarget) -> &'static [OpKind] {
419    use OpKind::*;
420    match target {
421        FusionTarget::Cpu => &[
422            MatMul,
423            DotGeneral,
424            ElementwiseRegion,
425            FusedSwiGLU,
426            FusedMatMulBiasAct,
427            FusedResidualLN,
428            FusedResidualRmsNorm,
429            FusedAttentionBlock,
430        ],
431        FusionTarget::Metal => &[
432            MatMul,
433            DotGeneral,
434            ElementwiseRegion,
435            TransformRegion,
436            BatchElementwiseRegion,
437            FusedSwiGLU,
438            FusedMatMulBiasAct,
439            FusedResidualLN,
440            FusedResidualRmsNorm,
441        ],
442        FusionTarget::Mlx => &[
443            MatMul,
444            DotGeneral,
445            ElementwiseRegion,
446            TransformRegion,
447            BatchElementwiseRegion,
448            FusedSwiGLU,
449            FusedMatMulBiasAct,
450            FusedResidualLN,
451            FusedResidualRmsNorm,
452        ],
453        FusionTarget::Wgpu => &[
454            MatMul,
455            ElementwiseRegion,
456            TransformRegion,
457            BatchElementwiseRegion,
458            FusedSwiGLU,
459            FusedMatMulBiasAct,
460            FusedResidualLN,
461            FusedResidualRmsNorm,
462            FusedAttentionBlock,
463            FusedTransformerLayer,
464        ],
465        FusionTarget::Cuda | FusionTarget::Rocm => &[
466            MatMul,
467            DotGeneral,
468            ElementwiseRegion,
469            TransformRegion,
470            BatchElementwiseRegion,
471            FusedMatMulBiasAct,
472            FusedResidualLN,
473            FusedResidualRmsNorm,
474        ],
475        FusionTarget::Tpu => &[
476            MatMul,
477            ElementwiseRegion,
478            TransformRegion,
479            BatchElementwiseRegion,
480            FusedMatMulBiasAct,
481            FusedResidualLN,
482        ],
483    }
484}
485
486fn finish_pipeline(mut passes: Vec<&'static dyn Pass>) -> Vec<&'static dyn Pass> {
487    passes.push(&DeadCodeElimination);
488    passes
489}
490
491#[cfg(test)]
492mod tests {
493    use super::*;
494    use std::sync::Mutex;
495
496    static ENV_FK_TEST_LOCK: Mutex<()> = Mutex::new(());
497
498    #[test]
499    fn cpu_pipeline_includes_attention_block() {
500        let passes = fusion_passes(FusionTarget::Cpu, FusionOptions::default());
501        assert_eq!(passes.len(), 18);
502        assert_eq!(passes[2].name(), "fuse_matmul_bias_act");
503        assert_eq!(passes[3].name(), "fuse_attention_block");
504        assert!(
505            passes.iter().any(|p| p.name() == "fuse_region_prologue"),
506            "default CPU pipeline should run FKL prologue fusion"
507        );
508        assert_eq!(passes.last().unwrap().name(), "dead_code_elimination");
509    }
510
511    #[test]
512    fn metal_skip_fusion_only_lowers_dot() {
513        let passes = fusion_passes(
514            FusionTarget::Metal,
515            FusionOptions {
516                skip_fusion: true,
517                ..FusionOptions::default()
518            },
519        );
520        assert_eq!(passes.len(), 2);
521        assert_eq!(passes[0].name(), "LowerControlFlow");
522        assert_eq!(passes[1].name(), "lower_dot_general");
523    }
524
525    #[test]
526    fn metal_supported_ops_omit_attention_block_fusion() {
527        let passes = fusion_passes_for_supported(
528            supported_for_target(FusionTarget::Metal),
529            FusionOptions::default(),
530            FusionTarget::Metal,
531        );
532        assert!(
533            !passes.iter().any(|p| p.name() == "fuse_attention_block"),
534            "Metal should not run FuseAttentionBlock"
535        );
536        assert!(
537            passes.iter().any(|p| p.name() == "fuse_matmul_bias_act"),
538            "Metal should fuse matmul+bias+act"
539        );
540    }
541
542    #[test]
543    fn cuda_supported_ops_fuse_matmul_bias_act() {
544        let passes = fusion_passes_for_supported(
545            supported_for_target(FusionTarget::Cuda),
546            FusionOptions::default(),
547            FusionTarget::Cuda,
548        );
549        assert!(
550            passes.iter().any(|p| p.name() == "fuse_matmul_bias_act"),
551            "CUDA should fuse matmul+bias+act when claimed"
552        );
553        assert!(
554            !passes.iter().any(|p| p.name() == "fuse_swiglu"),
555            "CUDA should not fuse SwiGLU"
556        );
557    }
558
559    #[test]
560    fn cpu_unfuses_elementwise_regions() {
561        let passes = fusion_passes_for_supported(
562            supported_for_target(FusionTarget::Cpu),
563            FusionOptions::for_cpu(),
564            FusionTarget::Cpu,
565        );
566        assert!(
567            passes
568                .iter()
569                .any(|p| p.name() == "unfuse_elementwise_regions")
570        );
571    }
572
573    #[test]
574    fn metal_unfuses_elementwise_regions_by_default() {
575        let passes = fusion_passes(FusionTarget::Metal, FusionOptions::default());
576        assert!(
577            passes
578                .iter()
579                .any(|p| p.name() == "unfuse_elementwise_regions")
580        );
581    }
582
583    #[test]
584    fn metal_default_unfuse_preserves_prologue_regions() {
585        let mut g = rlx_ir::Graph::new("t");
586        let shape_in = rlx_ir::Shape::new(&[1, 3, 8, 8], rlx_ir::DType::F32);
587        let shape_out = rlx_ir::Shape::new(&[1, 3, 16, 16], rlx_ir::DType::F32);
588        let x = g.input("x", shape_in);
589        let up = g.add_node(rlx_ir::Op::ResizeNearest2x, vec![x], shape_out.clone());
590        let r = g.add_node(
591            rlx_ir::Op::Activation(rlx_ir::op::Activation::Relu),
592            vec![up],
593            shape_out,
594        );
595        g.set_outputs(vec![r]);
596
597        let passes = fusion_passes(FusionTarget::Metal, FusionOptions::default());
598        let out = rlx_fusion::pass::run_passes(g, &passes, false);
599        assert!(out.nodes().iter().any(|n| {
600            matches!(
601                n.op,
602                rlx_ir::Op::ElementwiseRegion {
603                    prologue: rlx_ir::RegionPrologue::ResizeNearest2x,
604                    ..
605                }
606            )
607        }));
608    }
609
610    #[test]
611    fn fk_passes_after_elementwise_includes_batch_fusion() {
612        let opts = FusionOptions::default().apply_native_fk_defaults(FusionTarget::Tpu);
613        let passes =
614            fk_passes_after_elementwise_regions(supported_for_target(FusionTarget::Tpu), opts);
615        let names: Vec<_> = passes.iter().map(|p| p.name()).collect();
616        assert!(names.contains(&"mark_batch_slice_regions"));
617        assert!(names.contains(&"fuse_batch_preprocess"));
618        assert!(
619            !names.contains(&"decompose_fusion_regions"),
620            "TPU native FK defaults should keep batch/transform regions"
621        );
622    }
623
624    #[test]
625    fn tpu_native_fk_region_pass_policy() {
626        let _lock = ENV_FK_TEST_LOCK.lock().unwrap();
627        let default_passes = fusion_passes(FusionTarget::Tpu, FusionOptions::default());
628        assert!(
629            !default_passes
630                .iter()
631                .any(|p| p.name() == "decompose_fusion_regions"),
632            "default TPU pipeline keeps batch/transform regions via native_fk_defaults"
633        );
634
635        rlx_ir::env::set("RLX_NO_NATIVE_FK_REGIONS", "1");
636        let opt_out = fusion_passes(FusionTarget::Tpu, FusionOptions::default());
637        rlx_ir::env::unset("RLX_NO_NATIVE_FK_REGIONS");
638        assert!(
639            opt_out
640                .iter()
641                .any(|p| p.name() == "decompose_fusion_regions"),
642            "RLX_NO_NATIVE_FK_REGIONS should force decompose on TPU"
643        );
644    }
645
646    #[test]
647    fn native_fk_regions_skips_decompose_on_tpu() {
648        let passes = fusion_passes(
649            FusionTarget::Tpu,
650            FusionOptions {
651                native_fk_regions: true,
652                decompose_fusion_regions: false,
653                unfuse_elementwise_regions: false,
654                ..FusionOptions::default()
655            },
656        );
657        assert!(
658            !passes
659                .iter()
660                .any(|p| p.name() == "decompose_fusion_regions"),
661            "native_fk_regions should skip decompose on TPU when batch/transform are supported"
662        );
663    }
664
665    #[test]
666    fn native_fk_regions_skips_decompose_on_metal() {
667        let passes = fusion_passes(
668            FusionTarget::Metal,
669            FusionOptions {
670                native_fk_regions: true,
671                decompose_fusion_regions: false,
672                unfuse_elementwise_regions: false,
673                ..FusionOptions::default()
674            },
675        );
676        assert!(
677            !passes
678                .iter()
679                .any(|p| p.name() == "decompose_fusion_regions"),
680            "native_fk_regions should skip decompose when backend claims batch/transform ops"
681        );
682    }
683
684    #[test]
685    fn metal_keeps_elementwise_regions_when_requested() {
686        let passes = fusion_passes(
687            FusionTarget::Metal,
688            FusionOptions {
689                keep_elementwise_regions: true,
690                unfuse_elementwise_regions: false,
691                ..FusionOptions::default()
692            },
693        );
694        assert!(
695            !passes
696                .iter()
697                .any(|p| p.name() == "unfuse_elementwise_regions"),
698            "keep_elementwise_regions should skip unfuse pass"
699        );
700        assert!(
701            passes.iter().any(|p| p.name() == "fuse_region_prologue"),
702            "FKL prologue fusion should still run"
703        );
704    }
705}