1use rlx_ir::OpKind;
24
25use crate::DeadCodeElimination;
26use rlx_fusion::control_flow::LowerControlFlow;
27use rlx_fusion::fusion::{
28 FuseAttentionBlock, FuseMatMulBiasAct, FuseResidualLN, FuseResidualRmsNorm, FuseRmsNormReshape,
29 FuseSharedInputMatMul, FuseSwiGLU, FuseSwiGLUDualMatmul, MarkElementwiseRegions,
30 UnfuseElementwiseRegions,
31};
32use rlx_fusion::limits::FusionLimits;
33use rlx_fusion::lower_dot_general::LowerDotGeneral;
34use rlx_fusion::pass::Pass;
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub enum FusionTarget {
39 Cpu,
40 Metal,
41 Mlx,
42 Wgpu,
43 Cuda,
44 Rocm,
45 Tpu,
46}
47
48#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
50pub struct FusionOptions {
51 pub skip_fusion: bool,
53 pub unfuse_elementwise_regions: bool,
55 pub fusion_limits: FusionLimits,
57}
58
59impl FusionOptions {
60 pub fn from_metal_env() -> Self {
62 Self {
63 skip_fusion: rlx_ir::env::flag("RLX_METAL_NO_FUSION"),
64 unfuse_elementwise_regions: rlx_ir::env::flag("RLX_METAL_UNFUSE_REGIONS"),
65 ..Self::default()
66 }
67 }
68
69 pub fn for_cpu() -> Self {
71 Self {
72 unfuse_elementwise_regions: true,
73 fusion_limits: FusionLimits::UNBOUNDED,
74 ..Self::default()
75 }
76 }
77}
78
79pub fn fusion_limits_for_target(target: FusionTarget) -> FusionLimits {
81 match target {
82 FusionTarget::Cpu => FusionLimits::UNBOUNDED,
83 FusionTarget::Tpu => FusionLimits {
84 max_elementwise_steps: 32,
85 max_elementwise_inputs: 16,
86 },
87 _ => FusionLimits::GPU_NATIVE,
88 }
89}
90
91#[inline]
93pub fn supports_op(supported: &[OpKind], kind: OpKind) -> bool {
94 supported.is_empty() || supported.contains(&kind)
95}
96
97pub fn fusion_passes_for_supported(
106 supported: &[OpKind],
107 opts: FusionOptions,
108) -> Vec<&'static dyn Pass> {
109 if opts.skip_fusion {
110 return vec![&LowerControlFlow, &LowerDotGeneral];
111 }
112
113 let mut passes: Vec<&'static dyn Pass> = vec![&LowerControlFlow, &LowerDotGeneral];
114
115 if supports_op(supported, OpKind::FusedAttentionBlock) {
116 passes.push(&FuseAttentionBlock);
117 }
118 if supports_op(supported, OpKind::FusedMatMulBiasAct) {
119 passes.push(&FuseMatMulBiasAct);
120 }
121 if supports_op(supported, OpKind::FusedResidualLN) {
122 passes.push(&FuseResidualLN);
123 }
124 if supports_op(supported, OpKind::FusedResidualRmsNorm) {
125 passes.push(&FuseResidualRmsNorm);
126 }
127 passes.push(&FuseRmsNormReshape);
128
129 if supports_op(supported, OpKind::FusedSwiGLU) {
130 passes.push(&FuseSwiGLUDualMatmul);
131 }
132 if supports_op(supported, OpKind::MatMul) {
133 passes.push(&FuseSharedInputMatMul);
134 }
135 if supports_op(supported, OpKind::FusedSwiGLU) {
136 passes.push(&FuseSwiGLU);
137 }
138
139 passes.push(&MarkElementwiseRegions);
142 let keep_regions =
143 supports_op(supported, OpKind::ElementwiseRegion) && !opts.unfuse_elementwise_regions;
144 if !keep_regions {
145 passes.push(&UnfuseElementwiseRegions);
146 }
147
148 finish_pipeline(passes)
149}
150
151pub fn fusion_passes(target: FusionTarget, opts: FusionOptions) -> Vec<&'static dyn Pass> {
153 let mut opts = opts;
154 if matches!(target, FusionTarget::Cpu) && !opts.unfuse_elementwise_regions {
155 opts.unfuse_elementwise_regions = true;
156 }
157 if opts.fusion_limits == FusionLimits::default() {
158 opts.fusion_limits = fusion_limits_for_target(target);
159 }
160 fusion_passes_for_supported(supported_for_target(target), opts)
161}
162
163pub fn supported_for_target(target: FusionTarget) -> &'static [OpKind] {
167 use OpKind::*;
168 match target {
169 FusionTarget::Cpu => &[
170 MatMul,
171 DotGeneral,
172 ElementwiseRegion,
173 FusedSwiGLU,
174 FusedMatMulBiasAct,
175 FusedResidualLN,
176 FusedResidualRmsNorm,
177 FusedAttentionBlock,
178 ],
179 FusionTarget::Metal => &[
180 MatMul,
181 DotGeneral,
182 ElementwiseRegion,
183 FusedSwiGLU,
184 FusedMatMulBiasAct,
185 FusedResidualLN,
186 FusedResidualRmsNorm,
187 ],
188 FusionTarget::Mlx => &[
189 MatMul,
190 DotGeneral,
191 ElementwiseRegion,
192 FusedSwiGLU,
193 FusedMatMulBiasAct,
194 FusedResidualLN,
195 FusedResidualRmsNorm,
196 ],
197 FusionTarget::Wgpu => &[
198 MatMul,
199 ElementwiseRegion,
200 FusedSwiGLU,
201 FusedMatMulBiasAct,
202 FusedResidualLN,
203 FusedResidualRmsNorm,
204 FusedAttentionBlock,
205 FusedTransformerLayer,
206 ],
207 FusionTarget::Cuda | FusionTarget::Rocm => &[
208 MatMul,
209 DotGeneral,
210 ElementwiseRegion,
211 FusedMatMulBiasAct,
212 FusedResidualLN,
213 FusedResidualRmsNorm,
214 ],
215 FusionTarget::Tpu => &[
216 MatMul,
217 ElementwiseRegion,
218 FusedMatMulBiasAct,
219 FusedResidualLN,
220 ],
221 }
222}
223
224fn finish_pipeline(mut passes: Vec<&'static dyn Pass>) -> Vec<&'static dyn Pass> {
225 passes.push(&DeadCodeElimination);
226 passes
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232
233 #[test]
234 fn cpu_pipeline_includes_attention_block() {
235 let passes = fusion_passes(FusionTarget::Cpu, FusionOptions::default());
236 assert_eq!(passes.len(), 13);
237 assert_eq!(passes[2].name(), "fuse_attention_block");
238 assert_eq!(passes.last().unwrap().name(), "dead_code_elimination");
239 }
240
241 #[test]
242 fn metal_skip_fusion_only_lowers_dot() {
243 let passes = fusion_passes(
244 FusionTarget::Metal,
245 FusionOptions {
246 skip_fusion: true,
247 ..FusionOptions::default()
248 },
249 );
250 assert_eq!(passes.len(), 2);
251 assert_eq!(passes[0].name(), "LowerControlFlow");
252 assert_eq!(passes[1].name(), "lower_dot_general");
253 }
254
255 #[test]
256 fn metal_supported_ops_omit_attention_block_fusion() {
257 let passes = fusion_passes_for_supported(
258 supported_for_target(FusionTarget::Metal),
259 FusionOptions::default(),
260 );
261 assert!(
262 !passes.iter().any(|p| p.name() == "fuse_attention_block"),
263 "Metal should not run FuseAttentionBlock"
264 );
265 assert!(
266 passes.iter().any(|p| p.name() == "fuse_matmul_bias_act"),
267 "Metal should fuse matmul+bias+act"
268 );
269 }
270
271 #[test]
272 fn cuda_supported_ops_fuse_matmul_bias_act() {
273 let passes = fusion_passes_for_supported(
274 supported_for_target(FusionTarget::Cuda),
275 FusionOptions::default(),
276 );
277 assert!(
278 passes.iter().any(|p| p.name() == "fuse_matmul_bias_act"),
279 "CUDA should fuse matmul+bias+act when claimed"
280 );
281 assert!(
282 !passes.iter().any(|p| p.name() == "fuse_swiglu"),
283 "CUDA should not fuse SwiGLU"
284 );
285 }
286
287 #[test]
288 fn cpu_unfuses_elementwise_regions() {
289 let passes = fusion_passes_for_supported(
290 supported_for_target(FusionTarget::Cpu),
291 FusionOptions::for_cpu(),
292 );
293 assert!(
294 passes
295 .iter()
296 .any(|p| p.name() == "unfuse_elementwise_regions")
297 );
298 }
299
300 #[test]
301 fn metal_keeps_elementwise_regions_by_default() {
302 let passes = fusion_passes_for_supported(
303 supported_for_target(FusionTarget::Metal),
304 FusionOptions::default(),
305 );
306 assert!(
307 !passes
308 .iter()
309 .any(|p| p.name() == "unfuse_elementwise_regions")
310 );
311 }
312}