Skip to main content

Module sdpa

Module sdpa 

Source
Expand description

Scaled dot-product attention (SDPA) host dispatch.

Computes softmax(Q * K^T / sqrt(head_dim)) * V on the GPU using a fused Metal compute kernel with causal masking.

Supports grouped-query attention (GQA) where n_heads > n_kv_heads.

Structs§

SdpaParams
Parameters for the SDPA kernel.

Statics§

SDPA_SHADER_SOURCE
MSL source for the SDPA kernel (embedded at compile time).

Functions§

register
Register SDPA shader source with the given kernel registry.
sdpa
Dispatch scaled dot-product attention on the GPU.