Skip to main content

tmp_buffer_bytes

Function tmp_buffer_bytes 

Source
pub fn tmp_buffer_bytes(num_heads: u32, head_dim: u32) -> usize
Expand description

Compute the size in bytes of the temporary buffer needed for flash_attn_vec.

The temp buffer stores partial results from NWG workgroups:

  • nrows * head_dim * NWG floats for the partial output vectors
  • nrows * 2 * NWG floats for the S and M values