Expand description
Flash-attention training forward kernel — host dispatch.
FA-2 forward pass that emits BOTH the attention output O AND the
per-row natural-log logsumexp L required by the Phase 2 backward.
§Algorithm
Identical to super::flash_attn_prefill (online softmax, simdgroup MMA,
same tile geometry, same causal / additive-mask handling, same GQA).
The only addition is the L_out [B, H_q, qL] f32 buffer at buffer(8).
After the K-tile sweep each thread with sn == 0 writes one f32:
L[b, h, i] = max_score_b2 * ln(2) + ln(sum_score_b2)where max_score_b2 and sum_score_b2 are the per-row base-2
running max / unnormalized exp2 sum from the K-sweep (Q is pre-scaled
by scale * log2(e) so all accumulators live in base-2 space).
This equals the FA-2 paper Algorithm 1 logsumexp:
L_i = m_i + log( sum_j exp(s_ij - m_i) ) in natural-log units.
§Buffer layout
| Index | Name | Shape | DType |
|---|---|---|---|
| 0 | Q | [B, H_q, qL, D] | BF16 |
| 1 | K | [B, H_kv, kL, D] | BF16 |
| 2 | V | [B, H_kv, kL, D] | BF16 |
| 3 | O (out) | [B, H_q, qL, D] | BF16 |
| 4 | params | 160-byte ABI struct | — |
| 5 | mask_params | 24-byte struct | — (when has_mask) |
| 6 | mask | [B, H_q, qL, kL] | BF16 or bool (when has_mask) |
| 8 | L_out | [B, H_q, qL] | F32 |
§Function constants
Same 4 constants as flash_attn_prefill.metal:
| Index | Name | Semantics |
|---|---|---|
| 200 | align_Q | qL % BQ == 0 |
| 201 | align_K | kL % BK == 0 |
| 300 | has_mask | additive/bool mask buffer bound |
| 301 | do_causal | in-kernel causal masking |
§Kernel variants
| Name | D | I/O dtype | Mask kind |
|---|---|---|---|
flash_attn_train_fwd_bf16_d64 | 64 | bf16 | bf16 additive |
flash_attn_train_fwd_bf16_d64_boolmask | 64 | bf16 | bool |
flash_attn_train_fwd_bf16_d256 | 256 | bf16 | bf16 additive |
flash_attn_train_fwd_bf16_d256_boolmask | 256 | bf16 | bool |
§Scale convention
Pass scale = 1.0 / sqrt(head_dim). The kernel multiplies internally by
log2(e). Do NOT pre-multiply by log2(e) on the host.
Structs§
- Flash
Attn Train Params - Host-side parameters for the flash-attention training forward dispatcher.
Statics§
- FLASH_
ATTN_ TRAIN_ BWD_ COMPUTE_ D_ SHADER_ SOURCE - MSL source for the compute-D pre-pass kernel.
- FLASH_
ATTN_ TRAIN_ BWD_ SHADER_ SOURCE - MSL source for the backward kernels (embedded at compile time).
- FLASH_
ATTN_ TRAIN_ FWD_ SHADER_ SOURCE - MSL source (embedded at compile time).
Functions§
- dispatch_
flash_ attn_ train_ bwd_ bf16_ d64 - Dispatch the FA-2 backward pass for bf16 I/O, head_dim=64.
- dispatch_
flash_ attn_ train_ bwd_ bf16_ d256 - Dispatch the FA-2 backward pass for bf16 I/O, head_dim=256.
- dispatch_
flash_ attn_ train_ fwd_ bf16_ d64 - Dispatch the FA-2 forward pass for bf16 Q/K/V/O, head_dim=64.
- dispatch_
flash_ attn_ train_ fwd_ bf16_ d256 - Dispatch the FA-2 forward pass for bf16 Q/K/V/O, head_dim=256.
- register
- Register all 4 training-forward kernel entry points with the registry.
- register_
bwd - Register all backward kernel entry points with the registry.