Expand description
Flash attention vector kernel dispatch — SIMD-vectorized decode-path SDPA.
Ported from llama.cpp’s flash_attn_ext_vec kernel. This replaces the naive
SDPA kernel with a workgroup-parallel implementation that splits the KV cache
across nwg workgroups, each computing partial softmax results, then a
reduce kernel combines them.
This kernel is optimized for the decode path (seq_len=1) with F32 Q/K/V.
Structs§
- Flash
Attn VecParams - Parameters for the flash attention vector kernel.
Statics§
- FLASH_
ATTN_ VEC_ SHADER_ SOURCE - MSL source for the flash attention vector kernel (embedded at compile time).
Functions§
- flash_
attn_ vec - Dispatch flash attention vector kernel on the GPU.
- register
- Register flash attention vector shader source with the given kernel registry.
- tmp_
buffer_ bytes - Compute the size in bytes of the temporary buffer needed for flash_attn_vec.