Skip to main content

Crate rlx_fusion

Crate rlx_fusion 

Source
Expand description

MIR fusion passes and fused-op decomposition.

Pattern-matching fusion (FuseMatMulBiasAct, FuseSwiGLU, …) and the inverse unfuse_fused_for_autodiff rewrite used before autodiff.

Re-exports§

pub use control_flow::LowerControlFlow;
pub use control_flow::inline_if;
pub use control_flow::inline_subgraph_into;
pub use control_flow::inline_subgraph_into_outputs;
pub use control_flow::unroll_while;
pub use fk_fusion::DecomposeFusionRegions;
pub use fk_fusion::FuseBatchPreprocess;
pub use fk_fusion::FuseRegionPrologue;
pub use fk_fusion::MarkBatchSliceRegions;
pub use fk_fusion::MarkTransformRegions;
pub use fk_graphs::batch_narrow_relu_primitive_graph;
pub use fk_graphs::batch_narrow_relu_regions_graph;
pub use fk_graphs::nchw;
pub use fk_graphs::resize_relu_graph;
pub use fk_graphs::resize_relu_region_graph;
pub use fusion::FuseAttentionBlock;
pub use fusion::FuseMatMulBiasAct;
pub use fusion::FuseResidualLN;
pub use fusion::FuseResidualRmsNorm;
pub use fusion::FuseRmsNormReshape;
pub use fusion::FuseSharedInputMatMul;
pub use fusion::FuseSwiGLU;
pub use fusion::FuseSwiGLUDualMatmul;
pub use fusion::FuseTransformerLayer;
pub use fusion::MarkElementwiseRegions;
pub use fusion::UnfuseElementwiseRegions;
pub use fusion::clip_elementwise_regions;
pub use fusion_fragment::FusionFragment;
pub use fusion_fragment::FusionRole;
pub use fusion_fragment::fusion_fragments;
pub use fusion_fragment::is_registered_transform_op;
pub use fusion_fragment::prologue_for_transform_op;
pub use fusion_fragment::register_fusion_fragment;
pub use fusion_fragment::transform_chain_eligible;
pub use fusion_report::FusionReport;
pub use fusion_report::MissReason;
pub use fusion_report::MissedFusion;
pub use limits::FusionLimits;
pub use limits::active_fusion_limits;
pub use limits::with_fusion_limits;
pub use lower_backward_ops::LowerBackwardOps;
pub use lower_dot_general::LowerDotGeneral;
pub use lower_logical_kernels::lower_logical_kernels;
pub use lower_loss_ops::LowerSoftmaxCrossEntropy;
pub use lower_reduce_axes::LowerNonLastAxisReduce;
pub use lower_vae_ops::LowerBatchNormInference;
pub use lower_vae_ops::LowerGroupNorm;
pub use lower_vae_ops::LowerResizeNearest2x;
pub use pass::Pass;
pub use pass::run_passes;
pub use unfuse::unfuse_fused_for_autodiff;

Modules§

control_flow
Control-flow lowering passes: Op::IfWhere + inlined branches; Op::While → bounded unroll of body replicas.
fk_fusion
FKL-inspired transform / prologue / batch fusion passes.
fk_graphs
Shared FKL-style benchmark / test graphs.
fusion
Fusion passes — pattern-match and replace subgraphs with fused ops.
fusion_fragment
FKL-style fusion fragment registry - extensible op roles for region passes.
fusion_report
Fusion diagnostics — what fused, what missed, and why.
graph_rewrite
Shared graph rewriter for fusion passes.
limits
Per-backend caps for fused IR (elementwise region chains, etc.).
lower_backward_ops
Lower dedicated backward ops (ReluBackward, ActivationBackward) to primitives (Compare, Where, Binary, Activation) for backends that do not implement closed-form gradient kernels (e.g. Metal).
lower_dot_general
Lower Op::DotGeneral to primitive ops (MatMul + Transpose + Reshape).
lower_logical_kernels
Lower logical kernels to common IR when native backend ops are unavailable.
lower_loss_ops
Lower SoftmaxCrossEntropyWithLogits / SoftmaxCrossEntropyBackward to primitives for backends (CUDA, Metal) that lack native kernels.
lower_reduce_axes
Lower Op::Reduce on non-last axes (and multi-axis reduce) for backends that only implement reduction along the trailing dimension (e.g. wgpu).
lower_vae_ops
Lower VAE-specific ops (GroupNorm, BatchNormInference, ResizeNearest2x) to primitives.
pass
Pass infrastructure — trait + pipeline runner.
unfuse
Decompose tier-2 fused MIR ops into primitives for autodiff and backends.