Expand description
Scaled dot-product attention (SDPA) host dispatch.
Computes softmax(Q * K^T / sqrt(head_dim)) * V on the GPU using a fused
Metal compute kernel with causal masking.
Supports grouped-query attention (GQA) where n_heads > n_kv_heads.
Structs§
- Sdpa
Params - Parameters for the SDPA kernel.
Statics§
- SDPA_
SHADER_ SOURCE - MSL source for the SDPA kernel (embedded at compile time).