use rlx_ir::OpKind;
use crate::DeadCodeElimination;
use rlx_fusion::control_flow::LowerControlFlow;
use rlx_fusion::fusion::{
FuseAttentionBlock, FuseMatMulBiasAct, FuseResidualLN, FuseResidualRmsNorm, FuseRmsNormReshape,
FuseSharedInputMatMul, FuseSwiGLU, FuseSwiGLUDualMatmul, MarkElementwiseRegions,
UnfuseElementwiseRegions,
};
use rlx_fusion::limits::FusionLimits;
use rlx_fusion::lower_dot_general::LowerDotGeneral;
use rlx_fusion::pass::Pass;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FusionTarget {
Cpu,
Metal,
Mlx,
Wgpu,
Cuda,
Rocm,
Tpu,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct FusionOptions {
pub skip_fusion: bool,
pub unfuse_elementwise_regions: bool,
pub fusion_limits: FusionLimits,
}
impl FusionOptions {
pub fn from_metal_env() -> Self {
Self {
skip_fusion: rlx_ir::env::flag("RLX_METAL_NO_FUSION"),
unfuse_elementwise_regions: rlx_ir::env::flag("RLX_METAL_UNFUSE_REGIONS"),
..Self::default()
}
}
pub fn for_cpu() -> Self {
Self {
unfuse_elementwise_regions: true,
fusion_limits: FusionLimits::UNBOUNDED,
..Self::default()
}
}
}
pub fn fusion_limits_for_target(target: FusionTarget) -> FusionLimits {
match target {
FusionTarget::Cpu => FusionLimits::UNBOUNDED,
FusionTarget::Tpu => FusionLimits {
max_elementwise_steps: 32,
max_elementwise_inputs: 16,
},
_ => FusionLimits::GPU_NATIVE,
}
}
#[inline]
pub fn supports_op(supported: &[OpKind], kind: OpKind) -> bool {
supported.is_empty() || supported.contains(&kind)
}
pub fn fusion_passes_for_supported(
supported: &[OpKind],
opts: FusionOptions,
) -> Vec<&'static dyn Pass> {
if opts.skip_fusion {
return vec![&LowerControlFlow, &LowerDotGeneral];
}
let mut passes: Vec<&'static dyn Pass> = vec![&LowerControlFlow, &LowerDotGeneral];
if supports_op(supported, OpKind::FusedAttentionBlock) {
passes.push(&FuseAttentionBlock);
}
if supports_op(supported, OpKind::FusedMatMulBiasAct) {
passes.push(&FuseMatMulBiasAct);
}
if supports_op(supported, OpKind::FusedResidualLN) {
passes.push(&FuseResidualLN);
}
if supports_op(supported, OpKind::FusedResidualRmsNorm) {
passes.push(&FuseResidualRmsNorm);
}
passes.push(&FuseRmsNormReshape);
if supports_op(supported, OpKind::FusedSwiGLU) {
passes.push(&FuseSwiGLUDualMatmul);
}
if supports_op(supported, OpKind::MatMul) {
passes.push(&FuseSharedInputMatMul);
}
if supports_op(supported, OpKind::FusedSwiGLU) {
passes.push(&FuseSwiGLU);
}
passes.push(&MarkElementwiseRegions);
let keep_regions =
supports_op(supported, OpKind::ElementwiseRegion) && !opts.unfuse_elementwise_regions;
if !keep_regions {
passes.push(&UnfuseElementwiseRegions);
}
finish_pipeline(passes)
}
pub fn fusion_passes(target: FusionTarget, opts: FusionOptions) -> Vec<&'static dyn Pass> {
let mut opts = opts;
if matches!(target, FusionTarget::Cpu) && !opts.unfuse_elementwise_regions {
opts.unfuse_elementwise_regions = true;
}
if opts.fusion_limits == FusionLimits::default() {
opts.fusion_limits = fusion_limits_for_target(target);
}
fusion_passes_for_supported(supported_for_target(target), opts)
}
pub fn supported_for_target(target: FusionTarget) -> &'static [OpKind] {
use OpKind::*;
match target {
FusionTarget::Cpu => &[
MatMul,
DotGeneral,
ElementwiseRegion,
FusedSwiGLU,
FusedMatMulBiasAct,
FusedResidualLN,
FusedResidualRmsNorm,
FusedAttentionBlock,
],
FusionTarget::Metal => &[
MatMul,
DotGeneral,
ElementwiseRegion,
FusedSwiGLU,
FusedMatMulBiasAct,
FusedResidualLN,
FusedResidualRmsNorm,
],
FusionTarget::Mlx => &[
MatMul,
DotGeneral,
ElementwiseRegion,
FusedSwiGLU,
FusedMatMulBiasAct,
FusedResidualLN,
FusedResidualRmsNorm,
],
FusionTarget::Wgpu => &[
MatMul,
ElementwiseRegion,
FusedSwiGLU,
FusedMatMulBiasAct,
FusedResidualLN,
FusedResidualRmsNorm,
FusedAttentionBlock,
FusedTransformerLayer,
],
FusionTarget::Cuda | FusionTarget::Rocm => &[
MatMul,
DotGeneral,
ElementwiseRegion,
FusedMatMulBiasAct,
FusedResidualLN,
FusedResidualRmsNorm,
],
FusionTarget::Tpu => &[
MatMul,
ElementwiseRegion,
FusedMatMulBiasAct,
FusedResidualLN,
],
}
}
fn finish_pipeline(mut passes: Vec<&'static dyn Pass>) -> Vec<&'static dyn Pass> {
passes.push(&DeadCodeElimination);
passes
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cpu_pipeline_includes_attention_block() {
let passes = fusion_passes(FusionTarget::Cpu, FusionOptions::default());
assert_eq!(passes.len(), 13);
assert_eq!(passes[2].name(), "fuse_attention_block");
assert_eq!(passes.last().unwrap().name(), "dead_code_elimination");
}
#[test]
fn metal_skip_fusion_only_lowers_dot() {
let passes = fusion_passes(
FusionTarget::Metal,
FusionOptions {
skip_fusion: true,
..FusionOptions::default()
},
);
assert_eq!(passes.len(), 2);
assert_eq!(passes[0].name(), "LowerControlFlow");
assert_eq!(passes[1].name(), "lower_dot_general");
}
#[test]
fn metal_supported_ops_omit_attention_block_fusion() {
let passes = fusion_passes_for_supported(
supported_for_target(FusionTarget::Metal),
FusionOptions::default(),
);
assert!(
!passes.iter().any(|p| p.name() == "fuse_attention_block"),
"Metal should not run FuseAttentionBlock"
);
assert!(
passes.iter().any(|p| p.name() == "fuse_matmul_bias_act"),
"Metal should fuse matmul+bias+act"
);
}
#[test]
fn cuda_supported_ops_fuse_matmul_bias_act() {
let passes = fusion_passes_for_supported(
supported_for_target(FusionTarget::Cuda),
FusionOptions::default(),
);
assert!(
passes.iter().any(|p| p.name() == "fuse_matmul_bias_act"),
"CUDA should fuse matmul+bias+act when claimed"
);
assert!(
!passes.iter().any(|p| p.name() == "fuse_swiglu"),
"CUDA should not fuse SwiGLU"
);
}
#[test]
fn cpu_unfuses_elementwise_regions() {
let passes = fusion_passes_for_supported(
supported_for_target(FusionTarget::Cpu),
FusionOptions::for_cpu(),
);
assert!(
passes
.iter()
.any(|p| p.name() == "unfuse_elementwise_regions")
);
}
#[test]
fn metal_keeps_elementwise_regions_by_default() {
let passes = fusion_passes_for_supported(
supported_for_target(FusionTarget::Metal),
FusionOptions::default(),
);
assert!(
!passes
.iter()
.any(|p| p.name() == "unfuse_elementwise_regions")
);
}
}