use rlx_ir::OpKind;
use crate::DeadCodeElimination;
use rlx_fusion::control_flow::LowerControlFlow;
use rlx_fusion::fk_fusion::{
DecomposeFusionRegions, FuseBatchPreprocess, FuseRegionPrologue, MarkBatchSliceRegions,
MarkTransformRegions,
};
use rlx_fusion::fusion::{
FuseAttentionBlock, FuseMatMulBiasAct, FuseResidualLN, FuseResidualRmsNorm, FuseRmsNormReshape,
FuseSharedInputMatMul, FuseSwiGLU, FuseSwiGLUDualMatmul, FuseTransformerLayer,
MarkElementwiseRegions, UnfuseElementwiseRegions,
};
use rlx_fusion::limits::FusionLimits;
use rlx_fusion::lower_dot_general::LowerDotGeneral;
use rlx_fusion::pass::Pass;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FusionTarget {
Cpu,
Metal,
Mlx,
Wgpu,
Cuda,
Rocm,
Tpu,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct FusionOptions {
pub skip_fusion: bool,
pub unfuse_elementwise_regions: bool,
pub keep_elementwise_regions: bool,
pub decompose_fusion_regions: bool,
pub fk_fusion: bool,
pub fuse_region_prologue: bool,
pub fuse_batch_preprocess: bool,
pub native_fk_regions: bool,
pub fusion_limits: FusionLimits,
}
impl Default for FusionOptions {
fn default() -> Self {
Self {
skip_fusion: false,
unfuse_elementwise_regions: false,
keep_elementwise_regions: false,
decompose_fusion_regions: false,
fk_fusion: true,
fuse_region_prologue: true,
fuse_batch_preprocess: true,
native_fk_regions: false,
fusion_limits: FusionLimits::default(),
}
}
}
impl FusionOptions {
pub fn from_metal_env() -> Self {
Self {
skip_fusion: rlx_ir::env::flag("RLX_METAL_NO_FUSION"),
unfuse_elementwise_regions: rlx_ir::env::flag("RLX_METAL_UNFUSE_REGIONS"),
keep_elementwise_regions: rlx_ir::env::flag("RLX_KEEP_ELEMENTWISE_REGIONS"),
decompose_fusion_regions: rlx_ir::env::flag("RLX_DECOMPOSE_FUSION_REGIONS"),
fk_fusion: !rlx_ir::env::flag("RLX_NO_FK_FUSION"),
fuse_region_prologue: if rlx_ir::env::is_unset("RLX_FUSE_REGION_PROLOGUE") {
true
} else {
rlx_ir::env::flag("RLX_FUSE_REGION_PROLOGUE")
},
fuse_batch_preprocess: if rlx_ir::env::is_unset("RLX_FUSE_BATCH_PREPROCESS") {
true
} else {
rlx_ir::env::flag("RLX_FUSE_BATCH_PREPROCESS")
},
native_fk_regions: rlx_ir::env::flag("RLX_NATIVE_FK_REGIONS"),
..Self::default()
}
}
pub fn merge_env(mut self) -> Self {
if rlx_ir::env::flag("RLX_METAL_NO_FUSION") {
self.skip_fusion = true;
}
if rlx_ir::env::flag("RLX_METAL_UNFUSE_REGIONS") {
self.unfuse_elementwise_regions = true;
}
if rlx_ir::env::flag("RLX_KEEP_ELEMENTWISE_REGIONS") {
self.keep_elementwise_regions = true;
}
if rlx_ir::env::flag("RLX_DECOMPOSE_FUSION_REGIONS") {
self.decompose_fusion_regions = true;
}
if rlx_ir::env::flag("RLX_NO_FK_FUSION") {
self.fk_fusion = false;
}
if !rlx_ir::env::is_unset("RLX_FUSE_REGION_PROLOGUE") {
self.fuse_region_prologue = rlx_ir::env::flag("RLX_FUSE_REGION_PROLOGUE");
}
if !rlx_ir::env::is_unset("RLX_FUSE_BATCH_PREPROCESS") {
self.fuse_batch_preprocess = rlx_ir::env::flag("RLX_FUSE_BATCH_PREPROCESS");
}
if rlx_ir::env::flag("RLX_NATIVE_FK_REGIONS") {
self.native_fk_regions = true;
}
if rlx_ir::env::flag("RLX_NO_NATIVE_FK_REGIONS") {
self.native_fk_regions = false;
}
self
}
pub fn apply_native_fk_defaults(mut self, target: FusionTarget) -> Self {
if rlx_ir::env::flag("RLX_NO_NATIVE_FK_REGIONS") {
self.native_fk_regions = false;
return self;
}
if self.native_fk_regions || rlx_ir::env::flag("RLX_NATIVE_FK_REGIONS") {
self.native_fk_regions = true;
return self;
}
if matches!(
target,
FusionTarget::Metal
| FusionTarget::Cuda
| FusionTarget::Rocm
| FusionTarget::Wgpu
| FusionTarget::Mlx
| FusionTarget::Tpu
) {
self.native_fk_regions = true;
}
self
}
pub fn for_cpu() -> Self {
Self {
unfuse_elementwise_regions: true,
fusion_limits: FusionLimits::UNBOUNDED,
..Self::default()
}
}
pub fn for_metal() -> Self {
let mut opts = Self::from_metal_env();
opts.unfuse_elementwise_regions = true;
opts
}
pub fn for_wgpu() -> Self {
let keep = rlx_ir::env::flag("RLX_KEEP_ELEMENTWISE_REGIONS");
Self {
unfuse_elementwise_regions: !keep,
keep_elementwise_regions: keep,
..Self::default()
}
}
}
pub fn fusion_limits_for_target(target: FusionTarget) -> FusionLimits {
match target {
FusionTarget::Cpu => FusionLimits::UNBOUNDED,
FusionTarget::Tpu => FusionLimits {
max_elementwise_steps: 32,
max_elementwise_inputs: 16,
},
_ => FusionLimits::GPU_NATIVE,
}
}
#[inline]
pub fn supports_op(supported: &[OpKind], kind: OpKind) -> bool {
supported.is_empty() || supported.contains(&kind)
}
pub fn fusion_passes_for_supported(
supported: &[OpKind],
opts: FusionOptions,
target: FusionTarget,
) -> Vec<&'static dyn Pass> {
let opts = opts.apply_native_fk_defaults(target);
if opts.skip_fusion {
return vec![&LowerControlFlow, &LowerDotGeneral];
}
let mut passes: Vec<&'static dyn Pass> = vec![&LowerControlFlow, &LowerDotGeneral];
if supports_op(supported, OpKind::FusedMatMulBiasAct) {
passes.push(&FuseMatMulBiasAct);
}
if supports_op(supported, OpKind::FusedAttentionBlock) {
passes.push(&FuseAttentionBlock);
}
if supports_op(supported, OpKind::FusedResidualLN) {
passes.push(&FuseResidualLN);
}
if supports_op(supported, OpKind::FusedResidualRmsNorm) {
passes.push(&FuseResidualRmsNorm);
}
passes.push(&FuseRmsNormReshape);
if rlx_ir::env::flag("RLX_ENABLE_FUSE_TRANSFORMER_LAYER")
&& supports_op(supported, OpKind::FusedTransformerLayer)
&& supports_op(supported, OpKind::FusedAttentionBlock)
{
passes.push(&FuseTransformerLayer);
}
if supports_op(supported, OpKind::FusedSwiGLU) {
passes.push(&FuseSwiGLUDualMatmul);
}
if supports_op(supported, OpKind::MatMul) {
passes.push(&FuseSharedInputMatMul);
}
if supports_op(supported, OpKind::FusedSwiGLU) {
passes.push(&FuseSwiGLU);
}
passes.push(&MarkElementwiseRegions);
if opts.fk_fusion {
passes.push(&MarkBatchSliceRegions);
passes.push(&MarkTransformRegions);
if opts.fuse_region_prologue {
passes.push(&FuseRegionPrologue);
}
if opts.fuse_batch_preprocess {
passes.push(&FuseBatchPreprocess);
}
}
let backend_native_fk = supports_op(supported, OpKind::TransformRegion)
&& supports_op(supported, OpKind::BatchElementwiseRegion);
let keep_native_fk = opts.native_fk_regions && backend_native_fk;
if opts.decompose_fusion_regions || !keep_native_fk {
passes.push(&DecomposeFusionRegions);
}
let keep_regions =
supports_op(supported, OpKind::ElementwiseRegion) && !opts.unfuse_elementwise_regions;
if !keep_regions {
let unfuse = if matches!(target, FusionTarget::Cpu) {
&UnfuseElementwiseRegions::FOR_CPU
} else {
&UnfuseElementwiseRegions::FOR_GPU
};
passes.push(unfuse);
}
finish_pipeline(passes)
}
pub fn fk_passes_after_elementwise_regions(
supported: &[OpKind],
opts: FusionOptions,
) -> Vec<&'static dyn Pass> {
let mut passes: Vec<&'static dyn Pass> = Vec::new();
if !opts.fk_fusion {
let backend_native_fk = supports_op(supported, OpKind::TransformRegion)
&& supports_op(supported, OpKind::BatchElementwiseRegion);
let keep_native_fk = opts.native_fk_regions && backend_native_fk;
if opts.decompose_fusion_regions || !keep_native_fk {
passes.push(&DecomposeFusionRegions);
}
return finish_pipeline(passes);
}
passes.push(&MarkBatchSliceRegions);
passes.push(&MarkTransformRegions);
if opts.fuse_region_prologue {
passes.push(&FuseRegionPrologue);
}
if opts.fuse_batch_preprocess {
passes.push(&FuseBatchPreprocess);
}
let backend_native_fk = supports_op(supported, OpKind::TransformRegion)
&& supports_op(supported, OpKind::BatchElementwiseRegion);
let keep_native_fk = opts.native_fk_regions && backend_native_fk;
if opts.decompose_fusion_regions || !keep_native_fk {
passes.push(&DecomposeFusionRegions);
}
finish_pipeline(passes)
}
pub fn fusion_passes(target: FusionTarget, opts: FusionOptions) -> Vec<&'static dyn Pass> {
let mut opts = opts;
if !opts.keep_elementwise_regions
&& matches!(target, FusionTarget::Cpu | FusionTarget::Metal)
&& !opts.unfuse_elementwise_regions
{
opts.unfuse_elementwise_regions = true;
}
if opts.fusion_limits == FusionLimits::default() {
opts.fusion_limits = fusion_limits_for_target(target);
}
opts = opts.apply_native_fk_defaults(target);
fusion_passes_for_supported(supported_for_target(target), opts, target)
}
pub fn supported_for_target(target: FusionTarget) -> &'static [OpKind] {
use OpKind::*;
match target {
FusionTarget::Cpu => &[
MatMul,
DotGeneral,
ElementwiseRegion,
FusedSwiGLU,
FusedMatMulBiasAct,
FusedResidualLN,
FusedResidualRmsNorm,
FusedAttentionBlock,
],
FusionTarget::Metal => &[
MatMul,
DotGeneral,
ElementwiseRegion,
TransformRegion,
BatchElementwiseRegion,
FusedSwiGLU,
FusedMatMulBiasAct,
FusedResidualLN,
FusedResidualRmsNorm,
],
FusionTarget::Mlx => &[
MatMul,
DotGeneral,
ElementwiseRegion,
TransformRegion,
BatchElementwiseRegion,
FusedSwiGLU,
FusedMatMulBiasAct,
FusedResidualLN,
FusedResidualRmsNorm,
],
FusionTarget::Wgpu => &[
MatMul,
ElementwiseRegion,
TransformRegion,
BatchElementwiseRegion,
FusedSwiGLU,
FusedMatMulBiasAct,
FusedResidualLN,
FusedResidualRmsNorm,
FusedAttentionBlock,
FusedTransformerLayer,
],
FusionTarget::Cuda | FusionTarget::Rocm => &[
MatMul,
DotGeneral,
ElementwiseRegion,
TransformRegion,
BatchElementwiseRegion,
FusedMatMulBiasAct,
FusedResidualLN,
FusedResidualRmsNorm,
],
FusionTarget::Tpu => &[
MatMul,
ElementwiseRegion,
TransformRegion,
BatchElementwiseRegion,
FusedMatMulBiasAct,
FusedResidualLN,
],
}
}
fn finish_pipeline(mut passes: Vec<&'static dyn Pass>) -> Vec<&'static dyn Pass> {
passes.push(&DeadCodeElimination);
passes
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
static ENV_FK_TEST_LOCK: Mutex<()> = Mutex::new(());
#[test]
fn cpu_pipeline_includes_attention_block() {
let passes = fusion_passes(FusionTarget::Cpu, FusionOptions::default());
assert_eq!(passes.len(), 18);
assert_eq!(passes[2].name(), "fuse_matmul_bias_act");
assert_eq!(passes[3].name(), "fuse_attention_block");
assert!(
passes.iter().any(|p| p.name() == "fuse_region_prologue"),
"default CPU pipeline should run FKL prologue fusion"
);
assert_eq!(passes.last().unwrap().name(), "dead_code_elimination");
}
#[test]
fn metal_skip_fusion_only_lowers_dot() {
let passes = fusion_passes(
FusionTarget::Metal,
FusionOptions {
skip_fusion: true,
..FusionOptions::default()
},
);
assert_eq!(passes.len(), 2);
assert_eq!(passes[0].name(), "LowerControlFlow");
assert_eq!(passes[1].name(), "lower_dot_general");
}
#[test]
fn metal_supported_ops_omit_attention_block_fusion() {
let passes = fusion_passes_for_supported(
supported_for_target(FusionTarget::Metal),
FusionOptions::default(),
FusionTarget::Metal,
);
assert!(
!passes.iter().any(|p| p.name() == "fuse_attention_block"),
"Metal should not run FuseAttentionBlock"
);
assert!(
passes.iter().any(|p| p.name() == "fuse_matmul_bias_act"),
"Metal should fuse matmul+bias+act"
);
}
#[test]
fn cuda_supported_ops_fuse_matmul_bias_act() {
let passes = fusion_passes_for_supported(
supported_for_target(FusionTarget::Cuda),
FusionOptions::default(),
FusionTarget::Cuda,
);
assert!(
passes.iter().any(|p| p.name() == "fuse_matmul_bias_act"),
"CUDA should fuse matmul+bias+act when claimed"
);
assert!(
!passes.iter().any(|p| p.name() == "fuse_swiglu"),
"CUDA should not fuse SwiGLU"
);
}
#[test]
fn cpu_unfuses_elementwise_regions() {
let passes = fusion_passes_for_supported(
supported_for_target(FusionTarget::Cpu),
FusionOptions::for_cpu(),
FusionTarget::Cpu,
);
assert!(
passes
.iter()
.any(|p| p.name() == "unfuse_elementwise_regions")
);
}
#[test]
fn metal_unfuses_elementwise_regions_by_default() {
let passes = fusion_passes(FusionTarget::Metal, FusionOptions::default());
assert!(
passes
.iter()
.any(|p| p.name() == "unfuse_elementwise_regions")
);
}
#[test]
fn metal_default_unfuse_preserves_prologue_regions() {
let mut g = rlx_ir::Graph::new("t");
let shape_in = rlx_ir::Shape::new(&[1, 3, 8, 8], rlx_ir::DType::F32);
let shape_out = rlx_ir::Shape::new(&[1, 3, 16, 16], rlx_ir::DType::F32);
let x = g.input("x", shape_in);
let up = g.add_node(rlx_ir::Op::ResizeNearest2x, vec![x], shape_out.clone());
let r = g.add_node(
rlx_ir::Op::Activation(rlx_ir::op::Activation::Relu),
vec![up],
shape_out,
);
g.set_outputs(vec![r]);
let passes = fusion_passes(FusionTarget::Metal, FusionOptions::default());
let out = rlx_fusion::pass::run_passes(g, &passes, false);
assert!(out.nodes().iter().any(|n| {
matches!(
n.op,
rlx_ir::Op::ElementwiseRegion {
prologue: rlx_ir::RegionPrologue::ResizeNearest2x,
..
}
)
}));
}
#[test]
fn fk_passes_after_elementwise_includes_batch_fusion() {
let opts = FusionOptions::default().apply_native_fk_defaults(FusionTarget::Tpu);
let passes =
fk_passes_after_elementwise_regions(supported_for_target(FusionTarget::Tpu), opts);
let names: Vec<_> = passes.iter().map(|p| p.name()).collect();
assert!(names.contains(&"mark_batch_slice_regions"));
assert!(names.contains(&"fuse_batch_preprocess"));
assert!(
!names.contains(&"decompose_fusion_regions"),
"TPU native FK defaults should keep batch/transform regions"
);
}
#[test]
fn tpu_native_fk_region_pass_policy() {
let _lock = ENV_FK_TEST_LOCK.lock().unwrap();
let default_passes = fusion_passes(FusionTarget::Tpu, FusionOptions::default());
assert!(
!default_passes
.iter()
.any(|p| p.name() == "decompose_fusion_regions"),
"default TPU pipeline keeps batch/transform regions via native_fk_defaults"
);
rlx_ir::env::set("RLX_NO_NATIVE_FK_REGIONS", "1");
let opt_out = fusion_passes(FusionTarget::Tpu, FusionOptions::default());
rlx_ir::env::unset("RLX_NO_NATIVE_FK_REGIONS");
assert!(
opt_out
.iter()
.any(|p| p.name() == "decompose_fusion_regions"),
"RLX_NO_NATIVE_FK_REGIONS should force decompose on TPU"
);
}
#[test]
fn native_fk_regions_skips_decompose_on_tpu() {
let passes = fusion_passes(
FusionTarget::Tpu,
FusionOptions {
native_fk_regions: true,
decompose_fusion_regions: false,
unfuse_elementwise_regions: false,
..FusionOptions::default()
},
);
assert!(
!passes
.iter()
.any(|p| p.name() == "decompose_fusion_regions"),
"native_fk_regions should skip decompose on TPU when batch/transform are supported"
);
}
#[test]
fn native_fk_regions_skips_decompose_on_metal() {
let passes = fusion_passes(
FusionTarget::Metal,
FusionOptions {
native_fk_regions: true,
decompose_fusion_regions: false,
unfuse_elementwise_regions: false,
..FusionOptions::default()
},
);
assert!(
!passes
.iter()
.any(|p| p.name() == "decompose_fusion_regions"),
"native_fk_regions should skip decompose when backend claims batch/transform ops"
);
}
#[test]
fn metal_keeps_elementwise_regions_when_requested() {
let passes = fusion_passes(
FusionTarget::Metal,
FusionOptions {
keep_elementwise_regions: true,
unfuse_elementwise_regions: false,
..FusionOptions::default()
},
);
assert!(
!passes
.iter()
.any(|p| p.name() == "unfuse_elementwise_regions"),
"keep_elementwise_regions should skip unfuse pass"
);
assert!(
passes.iter().any(|p| p.name() == "fuse_region_prologue"),
"FKL prologue fusion should still run"
);
}
}