Skip to main content

Module flash_attn_vec

Module flash_attn_vec 

Source
Expand description

Flash attention vector kernel dispatch — SIMD-vectorized decode-path SDPA.

Ported from llama.cpp’s flash_attn_ext_vec kernel. This replaces the naive SDPA kernel with a workgroup-parallel implementation that splits the KV cache across nwg workgroups, each computing partial softmax results, then a reduce kernel combines them.

This kernel is optimized for the decode path (seq_len=1) with F32 Q/K/V.

Structs§

FlashAttnVecParams
Parameters for the flash attention vector kernel.

Statics§

FLASH_ATTN_VEC_SHADER_SOURCE
MSL source for the flash attention vector kernel (embedded at compile time).

Functions§

flash_attn_vec
Dispatch flash attention vector kernel on the GPU.
register
Register flash attention vector shader source with the given kernel registry.
tmp_buffer_bytes
Compute the size in bytes of the temporary buffer needed for flash_attn_vec.