Skip to main content

Module sdpa_decode

Module sdpa_decode 

Source
Expand description

GPU SDPA decode kernel — F32 Q/K/V, multi-simdgroup tiled, single-token decode.

The kernel divides the KV sequence across N_SG simdgroups that each scan an independent KV chunk and produce a local (max, sum, unnorm_acc) triple. Simdgroup 0 then merges all N_SG partial results using the log-sum-exp combination rule and writes the final F32 output.

§Constraints

  • seq_len must be 1 (decode path only)
  • head_dim must be a multiple of 32 (128, 256, 512 are supported)
  • Q/K/V must be F32
  • n_sg must be 1, 2, or 4

Statics§

SDPA_DECODE_SHADER_SOURCE
Metal shader source.

Functions§

dispatch_sdpa_decode
Dispatch the tiled decode SDPA kernel.
register
Register sdpa_decode pipeline.