pub struct SdpaParams {
pub n_heads: u32,
pub n_kv_heads: u32,
pub head_dim: u32,
pub seq_len: u32,
pub kv_seq_len: u32,
pub scale: f32,
pub kv_capacity: u32,
}Expand description
Parameters for the SDPA kernel.
These describe the tensor shapes and head configuration for the attention computation.
Fields§
§n_heads: u32Number of query attention heads (e.g. 16 for Gemma 4).
n_kv_heads: u32Number of key/value attention heads (may be less than n_heads for GQA).
head_dim: u32Dimension of each attention head.
seq_len: u32Query sequence length.
kv_seq_len: u32Key/value sequence length (may differ from seq_len in decode mode).
scale: f32Attention score scaling factor. Typically 1.0 / sqrt(head_dim), but
models like Gemma 4 (which use QK norms) require scale = 1.0.
kv_capacity: u32KV cache capacity — the stride (in positions) between KV heads in the
cache buffer. When the KV cache is pre-allocated to a fixed capacity
larger than kv_seq_len, set this to the capacity so the kernel reads
the correct memory offsets. When KV buffers are tightly packed (no
extra capacity), set equal to kv_seq_len. Default: 0 means “use
kv_seq_len as capacity” for backwards compatibility.
Trait Implementations§
Source§impl Clone for SdpaParams
impl Clone for SdpaParams
Source§fn clone(&self) -> SdpaParams
fn clone(&self) -> SdpaParams
1.0.0 · Source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
source. Read more