pub struct FuseTransformerLayer;Expand description
Fuses an entire BERT-style transformer layer (attention block + residual+LN +
FFN + residual+LN) into one Op::FusedTransformerLayer node.
Pattern (after FuseMatMulBiasAct, FuseResidualLN, and
FuseAttentionBlock have run — order matters):
skip ──┬─→ FusedAttentionBlock(qkv_w, out_w, mask, qkv_b, out_b) ─→ attn_out
└─→ FusedResidualLN(attn_out, skip, ln1_g, ln1_b) ─→ h1
├─→ FusedMatMulBiasAct(fc1_w, fc1_b, GeLU) ─→ ffn_int
│ ↓
│ FusedMatMulBiasAct(fc2_w, fc2_b, None) ─→ ffn_out
└────────────────────→ FusedResidualLN(ffn_out, h1, ln2_g, ln2_b) ─→ outAll five nodes collapse into a single FusedTransformerLayer { num_heads, head_dim, intermediate_size, eps1, eps2, activation, has_bias: true }
with the 14-input layout consumed by rlx-mlx’s lowering at
rlx-mlx/src/lower.rs:1528:
[hidden, qkv_w, qkv_b, out_w, out_b, ln1_g, ln1_b, fc1_w, fc1_b, fc2_w, fc2_b, ln2_g, ln2_b, mask].
Threshold is the same as FuseAttentionBlock (RLX_FUSE_ATTN_THRESHOLD,
default 64). Backends that don’t natively support FusedTransformerLayer
un-fuse it back to primitives at compile time; backends that do (MLX) can
emit one monolithic kernel per layer.