Skip to main content

rlx_compile/
fusion_pipeline.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Centralized fusion pass pipelines per backend target.
17//!
18//! [`fusion_passes_for_supported`] selects passes from a backend's
19//! [`rlx_ir::OpKind`] claim set so fusion never emits fused ops the
20//! target cannot lower. [`fusion_passes`] keeps the legacy
21//! [`FusionTarget`] entry point and delegates to the same selector.
22
23use rlx_ir::OpKind;
24
25use crate::DeadCodeElimination;
26use rlx_fusion::control_flow::LowerControlFlow;
27use rlx_fusion::fk_fusion::{
28    DecomposeFusionRegions, FuseBatchPreprocess, FuseRegionPrologue, MarkBatchSliceRegions,
29    MarkTransformRegions,
30};
31use rlx_fusion::fusion::{
32    FuseAttentionBlock, FuseMatMulBiasAct, FuseResidualLN, FuseResidualRmsNorm, FuseRmsNormReshape,
33    FuseSharedInputMatMul, FuseSwiGLU, FuseSwiGLUDualMatmul, FuseTransformerLayer,
34    MarkElementwiseRegions, UnfuseElementwiseRegions,
35};
36use rlx_fusion::limits::FusionLimits;
37use rlx_fusion::lower_dot_general::LowerDotGeneral;
38use rlx_fusion::pass::Pass;
39
40/// Compile target that selects a fusion pipeline.
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum FusionTarget {
43    Cpu,
44    Metal,
45    Mlx,
46    Wgpu,
47    Cuda,
48    Rocm,
49    Tpu,
50}
51
52/// Per-target fusion toggles (env-driven on Metal today).
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub struct FusionOptions {
55    /// Skip all pattern fusions (Metal: `RLX_METAL_NO_FUSION`).
56    pub skip_fusion: bool,
57    /// Break `ElementwiseRegion` back into primitives after marking.
58    pub unfuse_elementwise_regions: bool,
59    /// Keep fused `ElementwiseRegion` through lowering (env: `RLX_KEEP_ELEMENTWISE_REGIONS`).
60    pub keep_elementwise_regions: bool,
61    /// Decompose FKL-style transform / batch regions before backend lowering.
62    pub decompose_fusion_regions: bool,
63    /// Run FKL passes (`MarkTransformRegions`, prologue, batch). Env: `RLX_NO_FK_FUSION=1` disables.
64    pub fk_fusion: bool,
65    /// Fold `ResizeNearest2x` into `ElementwiseRegion` prologue. Env: `RLX_FUSE_REGION_PROLOGUE=0` disables.
66    pub fuse_region_prologue: bool,
67    /// Merge parallel region slices into `BatchElementwiseRegion`. Env: `RLX_FUSE_BATCH_PREPROCESS=0` disables.
68    pub fuse_batch_preprocess: bool,
69    /// Keep `TransformRegion` / `BatchElementwiseRegion` in MIR for native lowering. Env: `RLX_NATIVE_FK_REGIONS=1`.
70    pub native_fk_regions: bool,
71    /// Caps for fused elementwise chains (encoder / scratch limits).
72    pub fusion_limits: FusionLimits,
73}
74
75impl Default for FusionOptions {
76    fn default() -> Self {
77        Self {
78            skip_fusion: false,
79            unfuse_elementwise_regions: false,
80            keep_elementwise_regions: false,
81            decompose_fusion_regions: false,
82            fk_fusion: true,
83            fuse_region_prologue: true,
84            fuse_batch_preprocess: true,
85            native_fk_regions: false,
86            fusion_limits: FusionLimits::default(),
87        }
88    }
89}
90
91impl FusionOptions {
92    /// Read Metal-specific env overrides.
93    pub fn from_metal_env() -> Self {
94        Self {
95            skip_fusion: rlx_ir::env::flag("RLX_METAL_NO_FUSION"),
96            unfuse_elementwise_regions: rlx_ir::env::flag("RLX_METAL_UNFUSE_REGIONS"),
97            keep_elementwise_regions: rlx_ir::env::flag("RLX_KEEP_ELEMENTWISE_REGIONS"),
98            decompose_fusion_regions: rlx_ir::env::flag("RLX_DECOMPOSE_FUSION_REGIONS"),
99            fk_fusion: !rlx_ir::env::flag("RLX_NO_FK_FUSION"),
100            fuse_region_prologue: if rlx_ir::env::is_unset("RLX_FUSE_REGION_PROLOGUE") {
101                true
102            } else {
103                rlx_ir::env::flag("RLX_FUSE_REGION_PROLOGUE")
104            },
105            fuse_batch_preprocess: if rlx_ir::env::is_unset("RLX_FUSE_BATCH_PREPROCESS") {
106                true
107            } else {
108                rlx_ir::env::flag("RLX_FUSE_BATCH_PREPROCESS")
109            },
110            native_fk_regions: rlx_ir::env::flag("RLX_NATIVE_FK_REGIONS"),
111            ..Self::default()
112        }
113    }
114
115    /// Merge session options with compile-time env overrides.
116    pub fn merge_env(mut self) -> Self {
117        if rlx_ir::env::flag("RLX_METAL_NO_FUSION") {
118            self.skip_fusion = true;
119        }
120        if rlx_ir::env::flag("RLX_METAL_UNFUSE_REGIONS") {
121            self.unfuse_elementwise_regions = true;
122        }
123        if rlx_ir::env::flag("RLX_KEEP_ELEMENTWISE_REGIONS") {
124            self.keep_elementwise_regions = true;
125        }
126        if rlx_ir::env::flag("RLX_DECOMPOSE_FUSION_REGIONS") {
127            self.decompose_fusion_regions = true;
128        }
129        if rlx_ir::env::flag("RLX_NO_FK_FUSION") {
130            self.fk_fusion = false;
131        }
132        if !rlx_ir::env::is_unset("RLX_FUSE_REGION_PROLOGUE") {
133            self.fuse_region_prologue = rlx_ir::env::flag("RLX_FUSE_REGION_PROLOGUE");
134        }
135        if !rlx_ir::env::is_unset("RLX_FUSE_BATCH_PREPROCESS") {
136            self.fuse_batch_preprocess = rlx_ir::env::flag("RLX_FUSE_BATCH_PREPROCESS");
137        }
138        if rlx_ir::env::flag("RLX_NATIVE_FK_REGIONS") {
139            self.native_fk_regions = true;
140        }
141        if rlx_ir::env::flag("RLX_NO_NATIVE_FK_REGIONS") {
142            self.native_fk_regions = false;
143        }
144        self
145    }
146
147    /// GPU-class targets keep native FKL regions unless opted out.
148    pub fn apply_native_fk_defaults(mut self, target: FusionTarget) -> Self {
149        if rlx_ir::env::flag("RLX_NO_NATIVE_FK_REGIONS") {
150            self.native_fk_regions = false;
151            return self;
152        }
153        if self.native_fk_regions || rlx_ir::env::flag("RLX_NATIVE_FK_REGIONS") {
154            self.native_fk_regions = true;
155            return self;
156        }
157        if matches!(
158            target,
159            FusionTarget::Metal
160                | FusionTarget::Cuda
161                | FusionTarget::Rocm
162                | FusionTarget::Wgpu
163                | FusionTarget::Mlx
164                | FusionTarget::Tpu
165        ) {
166            self.native_fk_regions = true;
167        }
168        self
169    }
170
171    /// CPU executes element-wise chains as per-op thunks — mark then unfuse.
172    pub fn for_cpu() -> Self {
173        Self {
174            unfuse_elementwise_regions: true,
175            fusion_limits: FusionLimits::UNBOUNDED,
176            ..Self::default()
177        }
178    }
179
180    /// Metal keeps RMSNorm / matmul fusions but unfuses `ElementwiseRegion`
181    /// (fused MSL mis-lowers long chains on deep transformer graphs).
182    pub fn for_metal() -> Self {
183        let mut opts = Self::from_metal_env();
184        opts.unfuse_elementwise_regions = true;
185        opts
186    }
187
188    /// wgpu region kernel only supports trailing/scalar broadcast via
189    /// modulus — unfuse so LegalizeBroadcast Expand + Binary run separately.
190    pub fn for_wgpu() -> Self {
191        let keep = rlx_ir::env::flag("RLX_KEEP_ELEMENTWISE_REGIONS");
192        Self {
193            unfuse_elementwise_regions: !keep,
194            keep_elementwise_regions: keep,
195            ..Self::default()
196        }
197    }
198}
199
200/// Elementwise-region caps for `target` (matches GPU kernel encoders).
201pub fn fusion_limits_for_target(target: FusionTarget) -> FusionLimits {
202    match target {
203        FusionTarget::Cpu => FusionLimits::UNBOUNDED,
204        FusionTarget::Tpu => FusionLimits {
205            max_elementwise_steps: 32,
206            max_elementwise_inputs: 16,
207        },
208        _ => FusionLimits::GPU_NATIVE,
209    }
210}
211
212/// True when `supported` is empty (no claim) or contains `kind`.
213#[inline]
214pub fn supports_op(supported: &[OpKind], kind: OpKind) -> bool {
215    supported.is_empty() || supported.contains(&kind)
216}
217
218/// Return the ordered fusion passes allowed for `supported`.
219///
220/// When `supported` is empty every fusion pass runs (legacy "accept
221/// all" backends). When non-empty, each pattern fusion pass is
222/// included only if the backend claims the fused [`OpKind`] it
223/// emits. Lowering passes (`LowerControlFlow`, `LowerDotGeneral`) and
224/// `FuseRmsNormReshape` (topology-only) always run unless
225/// `skip_fusion` is set.
226pub fn fusion_passes_for_supported(
227    supported: &[OpKind],
228    opts: FusionOptions,
229    target: FusionTarget,
230) -> Vec<&'static dyn Pass> {
231    let opts = opts.apply_native_fk_defaults(target);
232    if opts.skip_fusion {
233        return vec![&LowerControlFlow, &LowerDotGeneral];
234    }
235
236    let mut passes: Vec<&'static dyn Pass> = vec![&LowerControlFlow, &LowerDotGeneral];
237
238    // ORDER: FuseMatMulBiasAct first, then FuseAttentionBlock. The block-level
239    // pass matches the post-fusion shape
240    //   FusedMatMulBiasAct(qkv) → narrow×3 → Attention → FusedMatMulBiasAct(out)
241    // which is the pattern BERT-family encoders actually present after the
242    // per-layer matmul+bias fusion has collapsed Q, K, V, and out projections.
243    if supports_op(supported, OpKind::FusedMatMulBiasAct) {
244        passes.push(&FuseMatMulBiasAct);
245    }
246    // Block-level fusion: `Op::FusedAttentionBlock`. All backends that claim
247    // this op now produce parity-correct output (the MLX
248    // `Op::FusedAttentionBlock` lowering at `rlx-mlx/src/lower.rs:1689`
249    // historically diverged on `MaskKind::Custom` BERT masks because it
250    // bypassed the binary→additive conversion and the contiguous
251    // materialization the unfused `Op::Attention` path applies — fixed
252    // alongside this pass landing).
253    if supports_op(supported, OpKind::FusedAttentionBlock) {
254        passes.push(&FuseAttentionBlock);
255    }
256    // FuseResidualLN must run BEFORE FuseTransformerLayer: the layer-level
257    // pass matches `FAB → FusedResidualLN → FMBA(GeLU) → FMBA → FusedResidualLN`
258    // and needs the residual+LN ops already collapsed.
259    if supports_op(supported, OpKind::FusedResidualLN) {
260        passes.push(&FuseResidualLN);
261    }
262    if supports_op(supported, OpKind::FusedResidualRmsNorm) {
263        passes.push(&FuseResidualRmsNorm);
264    }
265    passes.push(&FuseRmsNormReshape);
266
267    // Layer-level fusion runs AFTER FuseResidualLN so it can match the
268    // post-fusion shape `FAB → FusedResidualLN → FMBA(GeLU) → FMBA →
269    // FusedResidualLN`. Opt-in via `RLX_ENABLE_FUSE_TRANSFORMER_LAYER`
270    // because backend perf wins are uneven: WGPU un-fuses with no
271    // dispatch reduction; MLX's lowering is correct (per the FAB fix
272    // above) but the MLX `compile()` already collapses sub-ops, so the
273    // extra IR-level fusion doesn't beat the natural pipeline. The pass
274    // exists for backends planning a monolithic transformer-layer kernel.
275    if rlx_ir::env::flag("RLX_ENABLE_FUSE_TRANSFORMER_LAYER")
276        && supports_op(supported, OpKind::FusedTransformerLayer)
277        && supports_op(supported, OpKind::FusedAttentionBlock)
278    {
279        passes.push(&FuseTransformerLayer);
280    }
281
282    if supports_op(supported, OpKind::FusedSwiGLU) {
283        passes.push(&FuseSwiGLUDualMatmul);
284    }
285    if supports_op(supported, OpKind::MatMul) {
286        passes.push(&FuseSharedInputMatMul);
287    }
288    if supports_op(supported, OpKind::FusedSwiGLU) {
289        passes.push(&FuseSwiGLU);
290    }
291
292    // Mark eligible element-wise chains. Backends that don't lower
293    // ElementwiseRegion natively unfuse immediately afterward.
294    passes.push(&MarkElementwiseRegions);
295    if opts.fk_fusion {
296        passes.push(&MarkBatchSliceRegions);
297        passes.push(&MarkTransformRegions);
298        if opts.fuse_region_prologue {
299            passes.push(&FuseRegionPrologue);
300        }
301        if opts.fuse_batch_preprocess {
302            passes.push(&FuseBatchPreprocess);
303        }
304    }
305    let backend_native_fk = supports_op(supported, OpKind::TransformRegion)
306        && supports_op(supported, OpKind::BatchElementwiseRegion);
307    let keep_native_fk = opts.native_fk_regions && backend_native_fk;
308    if opts.decompose_fusion_regions || !keep_native_fk {
309        passes.push(&DecomposeFusionRegions);
310    }
311    let keep_regions =
312        supports_op(supported, OpKind::ElementwiseRegion) && !opts.unfuse_elementwise_regions;
313    if !keep_regions {
314        let unfuse = if matches!(target, FusionTarget::Cpu) {
315            &UnfuseElementwiseRegions::FOR_CPU
316        } else {
317            &UnfuseElementwiseRegions::FOR_GPU
318        };
319        passes.push(unfuse);
320    }
321
322    finish_pipeline(passes)
323}
324
325/// FKL passes to run after [`MarkElementwiseRegions`] (e.g. `TpuExecutable::compile`).
326pub fn fk_passes_after_elementwise_regions(
327    supported: &[OpKind],
328    opts: FusionOptions,
329) -> Vec<&'static dyn Pass> {
330    let mut passes: Vec<&'static dyn Pass> = Vec::new();
331    if !opts.fk_fusion {
332        let backend_native_fk = supports_op(supported, OpKind::TransformRegion)
333            && supports_op(supported, OpKind::BatchElementwiseRegion);
334        let keep_native_fk = opts.native_fk_regions && backend_native_fk;
335        if opts.decompose_fusion_regions || !keep_native_fk {
336            passes.push(&DecomposeFusionRegions);
337        }
338        return finish_pipeline(passes);
339    }
340    passes.push(&MarkBatchSliceRegions);
341    passes.push(&MarkTransformRegions);
342    if opts.fuse_region_prologue {
343        passes.push(&FuseRegionPrologue);
344    }
345    if opts.fuse_batch_preprocess {
346        passes.push(&FuseBatchPreprocess);
347    }
348    let backend_native_fk = supports_op(supported, OpKind::TransformRegion)
349        && supports_op(supported, OpKind::BatchElementwiseRegion);
350    let keep_native_fk = opts.native_fk_regions && backend_native_fk;
351    if opts.decompose_fusion_regions || !keep_native_fk {
352        passes.push(&DecomposeFusionRegions);
353    }
354    finish_pipeline(passes)
355}
356
357/// Return the ordered fusion passes for `target`.
358pub fn fusion_passes(target: FusionTarget, opts: FusionOptions) -> Vec<&'static dyn Pass> {
359    let mut opts = opts;
360    // CPU thunks execute element-wise chains per-op. Metal's fused
361    // `ElementwiseRegion` MSL kernel mis-lowers long chains on deep
362    // transformer graphs (NaNs past ~14 blocks); keep FAB/RMSNorm fusions.
363    if !opts.keep_elementwise_regions
364        && matches!(target, FusionTarget::Cpu | FusionTarget::Metal)
365        && !opts.unfuse_elementwise_regions
366    {
367        opts.unfuse_elementwise_regions = true;
368    }
369    if opts.fusion_limits == FusionLimits::default() {
370        opts.fusion_limits = fusion_limits_for_target(target);
371    }
372    opts = opts.apply_native_fk_defaults(target);
373    fusion_passes_for_supported(supported_for_target(target), opts, target)
374}
375
376/// Per-target op claims used when a backend doesn't supply an explicit
377/// `supported_ops` slice. Must stay aligned with each backend's
378/// `*_SUPPORTED_OPS` in `rlx-runtime/src/backend.rs`.
379pub fn supported_for_target(target: FusionTarget) -> &'static [OpKind] {
380    use OpKind::*;
381    match target {
382        FusionTarget::Cpu => &[
383            MatMul,
384            DotGeneral,
385            ElementwiseRegion,
386            FusedSwiGLU,
387            FusedMatMulBiasAct,
388            FusedResidualLN,
389            FusedResidualRmsNorm,
390            FusedAttentionBlock,
391        ],
392        FusionTarget::Metal => &[
393            MatMul,
394            DotGeneral,
395            ElementwiseRegion,
396            TransformRegion,
397            BatchElementwiseRegion,
398            FusedSwiGLU,
399            FusedMatMulBiasAct,
400            FusedResidualLN,
401            FusedResidualRmsNorm,
402        ],
403        FusionTarget::Mlx => &[
404            MatMul,
405            DotGeneral,
406            ElementwiseRegion,
407            TransformRegion,
408            BatchElementwiseRegion,
409            FusedSwiGLU,
410            FusedMatMulBiasAct,
411            FusedResidualLN,
412            FusedResidualRmsNorm,
413        ],
414        FusionTarget::Wgpu => &[
415            MatMul,
416            ElementwiseRegion,
417            TransformRegion,
418            BatchElementwiseRegion,
419            FusedSwiGLU,
420            FusedMatMulBiasAct,
421            FusedResidualLN,
422            FusedResidualRmsNorm,
423            FusedAttentionBlock,
424            FusedTransformerLayer,
425        ],
426        FusionTarget::Cuda | FusionTarget::Rocm => &[
427            MatMul,
428            DotGeneral,
429            ElementwiseRegion,
430            TransformRegion,
431            BatchElementwiseRegion,
432            FusedMatMulBiasAct,
433            FusedResidualLN,
434            FusedResidualRmsNorm,
435        ],
436        FusionTarget::Tpu => &[
437            MatMul,
438            ElementwiseRegion,
439            TransformRegion,
440            BatchElementwiseRegion,
441            FusedMatMulBiasAct,
442            FusedResidualLN,
443        ],
444    }
445}
446
447fn finish_pipeline(mut passes: Vec<&'static dyn Pass>) -> Vec<&'static dyn Pass> {
448    passes.push(&DeadCodeElimination);
449    passes
450}
451
452#[cfg(test)]
453mod tests {
454    use super::*;
455    use std::sync::Mutex;
456
457    static ENV_FK_TEST_LOCK: Mutex<()> = Mutex::new(());
458
459    #[test]
460    fn cpu_pipeline_includes_attention_block() {
461        let passes = fusion_passes(FusionTarget::Cpu, FusionOptions::default());
462        assert_eq!(passes.len(), 18);
463        assert_eq!(passes[2].name(), "fuse_matmul_bias_act");
464        assert_eq!(passes[3].name(), "fuse_attention_block");
465        assert!(
466            passes.iter().any(|p| p.name() == "fuse_region_prologue"),
467            "default CPU pipeline should run FKL prologue fusion"
468        );
469        assert_eq!(passes.last().unwrap().name(), "dead_code_elimination");
470    }
471
472    #[test]
473    fn metal_skip_fusion_only_lowers_dot() {
474        let passes = fusion_passes(
475            FusionTarget::Metal,
476            FusionOptions {
477                skip_fusion: true,
478                ..FusionOptions::default()
479            },
480        );
481        assert_eq!(passes.len(), 2);
482        assert_eq!(passes[0].name(), "LowerControlFlow");
483        assert_eq!(passes[1].name(), "lower_dot_general");
484    }
485
486    #[test]
487    fn metal_supported_ops_omit_attention_block_fusion() {
488        let passes = fusion_passes_for_supported(
489            supported_for_target(FusionTarget::Metal),
490            FusionOptions::default(),
491            FusionTarget::Metal,
492        );
493        assert!(
494            !passes.iter().any(|p| p.name() == "fuse_attention_block"),
495            "Metal should not run FuseAttentionBlock"
496        );
497        assert!(
498            passes.iter().any(|p| p.name() == "fuse_matmul_bias_act"),
499            "Metal should fuse matmul+bias+act"
500        );
501    }
502
503    #[test]
504    fn cuda_supported_ops_fuse_matmul_bias_act() {
505        let passes = fusion_passes_for_supported(
506            supported_for_target(FusionTarget::Cuda),
507            FusionOptions::default(),
508            FusionTarget::Cuda,
509        );
510        assert!(
511            passes.iter().any(|p| p.name() == "fuse_matmul_bias_act"),
512            "CUDA should fuse matmul+bias+act when claimed"
513        );
514        assert!(
515            !passes.iter().any(|p| p.name() == "fuse_swiglu"),
516            "CUDA should not fuse SwiGLU"
517        );
518    }
519
520    #[test]
521    fn cpu_unfuses_elementwise_regions() {
522        let passes = fusion_passes_for_supported(
523            supported_for_target(FusionTarget::Cpu),
524            FusionOptions::for_cpu(),
525            FusionTarget::Cpu,
526        );
527        assert!(
528            passes
529                .iter()
530                .any(|p| p.name() == "unfuse_elementwise_regions")
531        );
532    }
533
534    #[test]
535    fn metal_unfuses_elementwise_regions_by_default() {
536        let passes = fusion_passes(FusionTarget::Metal, FusionOptions::default());
537        assert!(
538            passes
539                .iter()
540                .any(|p| p.name() == "unfuse_elementwise_regions")
541        );
542    }
543
544    #[test]
545    fn metal_default_unfuse_preserves_prologue_regions() {
546        let mut g = rlx_ir::Graph::new("t");
547        let shape_in = rlx_ir::Shape::new(&[1, 3, 8, 8], rlx_ir::DType::F32);
548        let shape_out = rlx_ir::Shape::new(&[1, 3, 16, 16], rlx_ir::DType::F32);
549        let x = g.input("x", shape_in);
550        let up = g.add_node(rlx_ir::Op::ResizeNearest2x, vec![x], shape_out.clone());
551        let r = g.add_node(
552            rlx_ir::Op::Activation(rlx_ir::op::Activation::Relu),
553            vec![up],
554            shape_out,
555        );
556        g.set_outputs(vec![r]);
557
558        let passes = fusion_passes(FusionTarget::Metal, FusionOptions::default());
559        let out = rlx_fusion::pass::run_passes(g, &passes, false);
560        assert!(out.nodes().iter().any(|n| {
561            matches!(
562                n.op,
563                rlx_ir::Op::ElementwiseRegion {
564                    prologue: rlx_ir::RegionPrologue::ResizeNearest2x,
565                    ..
566                }
567            )
568        }));
569    }
570
571    #[test]
572    fn fk_passes_after_elementwise_includes_batch_fusion() {
573        let opts = FusionOptions::default().apply_native_fk_defaults(FusionTarget::Tpu);
574        let passes =
575            fk_passes_after_elementwise_regions(supported_for_target(FusionTarget::Tpu), opts);
576        let names: Vec<_> = passes.iter().map(|p| p.name()).collect();
577        assert!(names.contains(&"mark_batch_slice_regions"));
578        assert!(names.contains(&"fuse_batch_preprocess"));
579        assert!(
580            !names.contains(&"decompose_fusion_regions"),
581            "TPU native FK defaults should keep batch/transform regions"
582        );
583    }
584
585    #[test]
586    fn tpu_native_fk_region_pass_policy() {
587        let _lock = ENV_FK_TEST_LOCK.lock().unwrap();
588        let default_passes = fusion_passes(FusionTarget::Tpu, FusionOptions::default());
589        assert!(
590            !default_passes
591                .iter()
592                .any(|p| p.name() == "decompose_fusion_regions"),
593            "default TPU pipeline keeps batch/transform regions via native_fk_defaults"
594        );
595
596        rlx_ir::env::set("RLX_NO_NATIVE_FK_REGIONS", "1");
597        let opt_out = fusion_passes(FusionTarget::Tpu, FusionOptions::default());
598        rlx_ir::env::unset("RLX_NO_NATIVE_FK_REGIONS");
599        assert!(
600            opt_out
601                .iter()
602                .any(|p| p.name() == "decompose_fusion_regions"),
603            "RLX_NO_NATIVE_FK_REGIONS should force decompose on TPU"
604        );
605    }
606
607    #[test]
608    fn native_fk_regions_skips_decompose_on_tpu() {
609        let passes = fusion_passes(
610            FusionTarget::Tpu,
611            FusionOptions {
612                native_fk_regions: true,
613                decompose_fusion_regions: false,
614                unfuse_elementwise_regions: false,
615                ..FusionOptions::default()
616            },
617        );
618        assert!(
619            !passes
620                .iter()
621                .any(|p| p.name() == "decompose_fusion_regions"),
622            "native_fk_regions should skip decompose on TPU when batch/transform are supported"
623        );
624    }
625
626    #[test]
627    fn native_fk_regions_skips_decompose_on_metal() {
628        let passes = fusion_passes(
629            FusionTarget::Metal,
630            FusionOptions {
631                native_fk_regions: true,
632                decompose_fusion_regions: false,
633                unfuse_elementwise_regions: false,
634                ..FusionOptions::default()
635            },
636        );
637        assert!(
638            !passes
639                .iter()
640                .any(|p| p.name() == "decompose_fusion_regions"),
641            "native_fk_regions should skip decompose when backend claims batch/transform ops"
642        );
643    }
644
645    #[test]
646    fn metal_keeps_elementwise_regions_when_requested() {
647        let passes = fusion_passes(
648            FusionTarget::Metal,
649            FusionOptions {
650                keep_elementwise_regions: true,
651                unfuse_elementwise_regions: false,
652                ..FusionOptions::default()
653            },
654        );
655        assert!(
656            !passes
657                .iter()
658                .any(|p| p.name() == "unfuse_elementwise_regions"),
659            "keep_elementwise_regions should skip unfuse pass"
660        );
661        assert!(
662            passes.iter().any(|p| p.name() == "fuse_region_prologue"),
663            "FKL prologue fusion should still run"
664        );
665    }
666}