pub struct FlashDecodingDescriptor {
pub batch_size: i32,
pub num_heads: i32,
pub num_kv_heads: i32,
pub k_len: i32,
pub head_dim: i32,
pub scale: f32,
pub element: ElementKind,
}Expand description
Descriptor for a FlashDecoding op.
num_kv_heads is the GQA grouping signal: when it equals num_heads
the workload is full MHA; when it’s smaller (e.g. 8 for Llama 3 8B
at H_q=32) every K/V head is shared by group_size = num_heads / num_kv_heads Q heads. The launcher uses group_size to pick
between the warp-cooperative SIMT kernel (Tier-1) and the
GQA-batched WMMA kernel (Tier-2, gated on group_size ≥ 4 +
head_dim aligned to 16).
Fields§
§batch_size: i32Batch size (B).
num_heads: i32Number of query / output heads (H_q).
num_kv_heads: i32Number of K/V heads (H_kv). Must divide num_heads evenly.
num_kv_heads == num_heads → pure MHA. num_kv_heads == 1 →
MQA. num_kv_heads < num_heads && > 1 → GQA.
k_len: i32K/V sequence length (the full attended prefix, not just the new
step). Arbitrary; the split-K factor adapts via [CHUNK_K].
head_dim: i32Per-head feature dimension. d_q == d_k == d_v is enforced —
the decode regime doesn’t justify the d_k != d_v complication
the prefill kernel handles.
scale: f32Score scaling factor — typically 1.0 / sqrt(head_dim).
element: ElementKindElement type — must match the plan’s type parameter.
Implementations§
Source§impl FlashDecodingDescriptor
impl FlashDecodingDescriptor
Sourcepub fn new(
batch_size: i32,
num_heads: i32,
k_len: i32,
head_dim: i32,
element: ElementKind,
) -> Self
pub fn new( batch_size: i32, num_heads: i32, k_len: i32, head_dim: i32, element: ElementKind, ) -> Self
Convenience constructor for pure MHA (num_kv_heads == num_heads)
with the standard 1/sqrt(D) scale.
Sourcepub fn new_gqa(
batch_size: i32,
num_heads: i32,
num_kv_heads: i32,
k_len: i32,
head_dim: i32,
element: ElementKind,
) -> Self
pub fn new_gqa( batch_size: i32, num_heads: i32, num_kv_heads: i32, k_len: i32, head_dim: i32, element: ElementKind, ) -> Self
Convenience constructor for GQA / MQA. num_kv_heads must
divide num_heads.
Sourcepub fn with_scale(self, scale: f32) -> Self
pub fn with_scale(self, scale: f32) -> Self
Builder: override the score scale (e.g. for QK-norm models that
pre-divide by something other than sqrt(head_dim)).
Sourcepub fn group_size(&self) -> i32
pub fn group_size(&self) -> i32
GQA group size — number of Q heads sharing each K/V head.
Trait Implementations§
Source§impl Clone for FlashDecodingDescriptor
impl Clone for FlashDecodingDescriptor
Source§fn clone(&self) -> FlashDecodingDescriptor
fn clone(&self) -> FlashDecodingDescriptor
1.0.0 (const: unstable) · Source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
source. Read more