Skip to main content

rlx_compile/
fusion_pipeline.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Centralized fusion pass pipelines per backend target.
17//!
18//! [`fusion_passes_for_supported`] selects passes from a backend's
19//! [`rlx_ir::OpKind`] claim set so fusion never emits fused ops the
20//! target cannot lower. [`fusion_passes`] keeps the legacy
21//! [`FusionTarget`] entry point and delegates to the same selector.
22
23use rlx_ir::OpKind;
24
25use crate::DeadCodeElimination;
26use rlx_fusion::control_flow::LowerControlFlow;
27use rlx_fusion::fusion::{
28    FuseAttentionBlock, FuseMatMulBiasAct, FuseResidualLN, FuseResidualRmsNorm, FuseRmsNormReshape,
29    FuseSharedInputMatMul, FuseSwiGLU, FuseSwiGLUDualMatmul, MarkElementwiseRegions,
30    UnfuseElementwiseRegions,
31};
32use rlx_fusion::limits::FusionLimits;
33use rlx_fusion::lower_dot_general::LowerDotGeneral;
34use rlx_fusion::pass::Pass;
35
36/// Compile target that selects a fusion pipeline.
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub enum FusionTarget {
39    Cpu,
40    Metal,
41    Mlx,
42    Wgpu,
43    Cuda,
44    Rocm,
45    Tpu,
46}
47
48/// Per-target fusion toggles (env-driven on Metal today).
49#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
50pub struct FusionOptions {
51    /// Skip all pattern fusions (Metal: `RLX_METAL_NO_FUSION`).
52    pub skip_fusion: bool,
53    /// Break `ElementwiseRegion` back into primitives after marking.
54    pub unfuse_elementwise_regions: bool,
55    /// Caps for fused elementwise chains (encoder / scratch limits).
56    pub fusion_limits: FusionLimits,
57}
58
59impl FusionOptions {
60    /// Read Metal-specific env overrides.
61    pub fn from_metal_env() -> Self {
62        Self {
63            skip_fusion: rlx_ir::env::flag("RLX_METAL_NO_FUSION"),
64            unfuse_elementwise_regions: rlx_ir::env::flag("RLX_METAL_UNFUSE_REGIONS"),
65            ..Self::default()
66        }
67    }
68
69    /// CPU executes element-wise chains as per-op thunks — mark then unfuse.
70    pub fn for_cpu() -> Self {
71        Self {
72            unfuse_elementwise_regions: true,
73            fusion_limits: FusionLimits::UNBOUNDED,
74            ..Self::default()
75        }
76    }
77}
78
79/// Elementwise-region caps for `target` (matches GPU kernel encoders).
80pub fn fusion_limits_for_target(target: FusionTarget) -> FusionLimits {
81    match target {
82        FusionTarget::Cpu => FusionLimits::UNBOUNDED,
83        FusionTarget::Tpu => FusionLimits {
84            max_elementwise_steps: 32,
85            max_elementwise_inputs: 16,
86        },
87        _ => FusionLimits::GPU_NATIVE,
88    }
89}
90
91/// True when `supported` is empty (no claim) or contains `kind`.
92#[inline]
93pub fn supports_op(supported: &[OpKind], kind: OpKind) -> bool {
94    supported.is_empty() || supported.contains(&kind)
95}
96
97/// Return the ordered fusion passes allowed for `supported`.
98///
99/// When `supported` is empty every fusion pass runs (legacy "accept
100/// all" backends). When non-empty, each pattern fusion pass is
101/// included only if the backend claims the fused [`OpKind`] it
102/// emits. Lowering passes (`LowerControlFlow`, `LowerDotGeneral`) and
103/// `FuseRmsNormReshape` (topology-only) always run unless
104/// `skip_fusion` is set.
105pub fn fusion_passes_for_supported(
106    supported: &[OpKind],
107    opts: FusionOptions,
108) -> Vec<&'static dyn Pass> {
109    if opts.skip_fusion {
110        return vec![&LowerControlFlow, &LowerDotGeneral];
111    }
112
113    let mut passes: Vec<&'static dyn Pass> = vec![&LowerControlFlow, &LowerDotGeneral];
114
115    if supports_op(supported, OpKind::FusedAttentionBlock) {
116        passes.push(&FuseAttentionBlock);
117    }
118    if supports_op(supported, OpKind::FusedMatMulBiasAct) {
119        passes.push(&FuseMatMulBiasAct);
120    }
121    if supports_op(supported, OpKind::FusedResidualLN) {
122        passes.push(&FuseResidualLN);
123    }
124    if supports_op(supported, OpKind::FusedResidualRmsNorm) {
125        passes.push(&FuseResidualRmsNorm);
126    }
127    passes.push(&FuseRmsNormReshape);
128
129    if supports_op(supported, OpKind::FusedSwiGLU) {
130        passes.push(&FuseSwiGLUDualMatmul);
131    }
132    if supports_op(supported, OpKind::MatMul) {
133        passes.push(&FuseSharedInputMatMul);
134    }
135    if supports_op(supported, OpKind::FusedSwiGLU) {
136        passes.push(&FuseSwiGLU);
137    }
138
139    // Mark eligible element-wise chains. Backends that don't lower
140    // ElementwiseRegion natively unfuse immediately afterward.
141    passes.push(&MarkElementwiseRegions);
142    let keep_regions =
143        supports_op(supported, OpKind::ElementwiseRegion) && !opts.unfuse_elementwise_regions;
144    if !keep_regions {
145        passes.push(&UnfuseElementwiseRegions);
146    }
147
148    finish_pipeline(passes)
149}
150
151/// Return the ordered fusion passes for `target`.
152pub fn fusion_passes(target: FusionTarget, opts: FusionOptions) -> Vec<&'static dyn Pass> {
153    let mut opts = opts;
154    if matches!(target, FusionTarget::Cpu) && !opts.unfuse_elementwise_regions {
155        opts.unfuse_elementwise_regions = true;
156    }
157    if opts.fusion_limits == FusionLimits::default() {
158        opts.fusion_limits = fusion_limits_for_target(target);
159    }
160    fusion_passes_for_supported(supported_for_target(target), opts)
161}
162
163/// Per-target op claims used when a backend doesn't supply an explicit
164/// `supported_ops` slice. Must stay aligned with each backend's
165/// `*_SUPPORTED_OPS` in `rlx-runtime/src/backend.rs`.
166pub fn supported_for_target(target: FusionTarget) -> &'static [OpKind] {
167    use OpKind::*;
168    match target {
169        FusionTarget::Cpu => &[
170            MatMul,
171            DotGeneral,
172            ElementwiseRegion,
173            FusedSwiGLU,
174            FusedMatMulBiasAct,
175            FusedResidualLN,
176            FusedResidualRmsNorm,
177            FusedAttentionBlock,
178        ],
179        FusionTarget::Metal => &[
180            MatMul,
181            DotGeneral,
182            ElementwiseRegion,
183            FusedSwiGLU,
184            FusedMatMulBiasAct,
185            FusedResidualLN,
186            FusedResidualRmsNorm,
187        ],
188        FusionTarget::Mlx => &[
189            MatMul,
190            DotGeneral,
191            ElementwiseRegion,
192            FusedSwiGLU,
193            FusedMatMulBiasAct,
194            FusedResidualLN,
195            FusedResidualRmsNorm,
196        ],
197        FusionTarget::Wgpu => &[
198            MatMul,
199            ElementwiseRegion,
200            FusedSwiGLU,
201            FusedMatMulBiasAct,
202            FusedResidualLN,
203            FusedResidualRmsNorm,
204            FusedAttentionBlock,
205            FusedTransformerLayer,
206        ],
207        FusionTarget::Cuda | FusionTarget::Rocm => &[
208            MatMul,
209            DotGeneral,
210            ElementwiseRegion,
211            FusedMatMulBiasAct,
212            FusedResidualLN,
213            FusedResidualRmsNorm,
214        ],
215        FusionTarget::Tpu => &[
216            MatMul,
217            ElementwiseRegion,
218            FusedMatMulBiasAct,
219            FusedResidualLN,
220        ],
221    }
222}
223
224fn finish_pipeline(mut passes: Vec<&'static dyn Pass>) -> Vec<&'static dyn Pass> {
225    passes.push(&DeadCodeElimination);
226    passes
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232
233    #[test]
234    fn cpu_pipeline_includes_attention_block() {
235        let passes = fusion_passes(FusionTarget::Cpu, FusionOptions::default());
236        assert_eq!(passes.len(), 13);
237        assert_eq!(passes[2].name(), "fuse_attention_block");
238        assert_eq!(passes.last().unwrap().name(), "dead_code_elimination");
239    }
240
241    #[test]
242    fn metal_skip_fusion_only_lowers_dot() {
243        let passes = fusion_passes(
244            FusionTarget::Metal,
245            FusionOptions {
246                skip_fusion: true,
247                ..FusionOptions::default()
248            },
249        );
250        assert_eq!(passes.len(), 2);
251        assert_eq!(passes[0].name(), "LowerControlFlow");
252        assert_eq!(passes[1].name(), "lower_dot_general");
253    }
254
255    #[test]
256    fn metal_supported_ops_omit_attention_block_fusion() {
257        let passes = fusion_passes_for_supported(
258            supported_for_target(FusionTarget::Metal),
259            FusionOptions::default(),
260        );
261        assert!(
262            !passes.iter().any(|p| p.name() == "fuse_attention_block"),
263            "Metal should not run FuseAttentionBlock"
264        );
265        assert!(
266            passes.iter().any(|p| p.name() == "fuse_matmul_bias_act"),
267            "Metal should fuse matmul+bias+act"
268        );
269    }
270
271    #[test]
272    fn cuda_supported_ops_fuse_matmul_bias_act() {
273        let passes = fusion_passes_for_supported(
274            supported_for_target(FusionTarget::Cuda),
275            FusionOptions::default(),
276        );
277        assert!(
278            passes.iter().any(|p| p.name() == "fuse_matmul_bias_act"),
279            "CUDA should fuse matmul+bias+act when claimed"
280        );
281        assert!(
282            !passes.iter().any(|p| p.name() == "fuse_swiglu"),
283            "CUDA should not fuse SwiGLU"
284        );
285    }
286
287    #[test]
288    fn cpu_unfuses_elementwise_regions() {
289        let passes = fusion_passes_for_supported(
290            supported_for_target(FusionTarget::Cpu),
291            FusionOptions::for_cpu(),
292        );
293        assert!(
294            passes
295                .iter()
296                .any(|p| p.name() == "unfuse_elementwise_regions")
297        );
298    }
299
300    #[test]
301    fn metal_keeps_elementwise_regions_by_default() {
302        let passes = fusion_passes_for_supported(
303            supported_for_target(FusionTarget::Metal),
304            FusionOptions::default(),
305        );
306        assert!(
307            !passes
308                .iter()
309                .any(|p| p.name() == "unfuse_elementwise_regions")
310        );
311    }
312}