Expand description
GPU SDPA decode kernel — F32 Q/K/V, multi-simdgroup tiled, single-token decode.
The kernel divides the KV sequence across N_SG simdgroups that each scan an independent KV chunk and produce a local (max, sum, unnorm_acc) triple. Simdgroup 0 then merges all N_SG partial results using the log-sum-exp combination rule and writes the final F32 output.
§Constraints
- seq_len must be 1 (decode path only)
- head_dim must be a multiple of 32 (128, 256, 512 are supported)
- Q/K/V must be F32
- n_sg must be 1, 2, or 4
Statics§
- SDPA_
DECODE_ SHADER_ SOURCE - Metal shader source.
Functions§
- dispatch_
sdpa_ decode - Dispatch the tiled decode SDPA kernel.
- register
- Register
sdpa_decodepipeline.