Skip to main content

Module flash_attn_train

Module flash_attn_train 

Source
Expand description

Flash-attention training forward kernel — host dispatch.

FA-2 forward pass that emits BOTH the attention output O AND the per-row natural-log logsumexp L required by the Phase 2 backward.

§Algorithm

Identical to super::flash_attn_prefill (online softmax, simdgroup MMA, same tile geometry, same causal / additive-mask handling, same GQA).

The only addition is the L_out [B, H_q, qL] f32 buffer at buffer(8). After the K-tile sweep each thread with sn == 0 writes one f32:

L[b, h, i] = max_score_b2 * ln(2) + ln(sum_score_b2)

where max_score_b2 and sum_score_b2 are the per-row base-2 running max / unnormalized exp2 sum from the K-sweep (Q is pre-scaled by scale * log2(e) so all accumulators live in base-2 space).

This equals the FA-2 paper Algorithm 1 logsumexp: L_i = m_i + log( sum_j exp(s_ij - m_i) ) in natural-log units.

§Buffer layout

IndexNameShapeDType
0Q[B, H_q, qL, D]BF16
1K[B, H_kv, kL, D]BF16
2V[B, H_kv, kL, D]BF16
3O (out)[B, H_q, qL, D]BF16
4params160-byte ABI struct
5mask_params24-byte struct— (when has_mask)
6mask[B, H_q, qL, kL]BF16 or bool (when has_mask)
8L_out[B, H_q, qL]F32

§Function constants

Same 4 constants as flash_attn_prefill.metal:

IndexNameSemantics
200align_QqL % BQ == 0
201align_KkL % BK == 0
300has_maskadditive/bool mask buffer bound
301do_causalin-kernel causal masking

§Kernel variants

NameDI/O dtypeMask kind
flash_attn_train_fwd_bf16_d6464bf16bf16 additive
flash_attn_train_fwd_bf16_d64_boolmask64bf16bool
flash_attn_train_fwd_bf16_d256256bf16bf16 additive
flash_attn_train_fwd_bf16_d256_boolmask256bf16bool

§Scale convention

Pass scale = 1.0 / sqrt(head_dim). The kernel multiplies internally by log2(e). Do NOT pre-multiply by log2(e) on the host.

Structs§

FlashAttnTrainParams
Host-side parameters for the flash-attention training forward dispatcher.

Statics§

FLASH_ATTN_TRAIN_BWD_COMPUTE_D_SHADER_SOURCE
MSL source for the compute-D pre-pass kernel.
FLASH_ATTN_TRAIN_BWD_SHADER_SOURCE
MSL source for the backward kernels (embedded at compile time).
FLASH_ATTN_TRAIN_FWD_SHADER_SOURCE
MSL source (embedded at compile time).

Functions§

dispatch_flash_attn_train_bwd_bf16_d64
Dispatch the FA-2 backward pass for bf16 I/O, head_dim=64.
dispatch_flash_attn_train_bwd_bf16_d256
Dispatch the FA-2 backward pass for bf16 I/O, head_dim=256.
dispatch_flash_attn_train_fwd_bf16_d64
Dispatch the FA-2 forward pass for bf16 Q/K/V/O, head_dim=64.
dispatch_flash_attn_train_fwd_bf16_d256
Dispatch the FA-2 forward pass for bf16 Q/K/V/O, head_dim=256.
register
Register all 4 training-forward kernel entry points with the registry.
register_bwd
Register all backward kernel entry points with the registry.