pub fn tmp_buffer_bytes(num_heads: u32, head_dim: u32) -> usizeExpand description
Compute the size in bytes of the temporary buffer needed for flash_attn_vec.
The temp buffer stores partial results from NWG workgroups:
nrows * head_dim * NWGfloats for the partial output vectorsnrows * 2 * NWGfloats for the S and M values