pub struct FlashDecodingPlan<T: Element> { /* private fields */ }Expand description
FlashDecoding forward plan (Dao 2023).
Split-K parallel attention decode for seq_q = 1. Replaces both
FlashSdpaPlan and FA2 at the decode regime
— both of those tile the Q dimension and waste work when seq_q < 64.
When to use: autoregressive decoder inference token loop. After
the prefill step (which uses [FlashSdpaPlan] with fa2 for the
long initial context), each generated token calls this plan with
seq_q = 1 and the full grown KV cache.
Dtypes: f16, bf16 (the only dtypes inference uses).
Shape limits: head_dim ≤ 128. Arbitrary B, H, K_len.
Workspace: non-zero. See Self::workspace_size.
Precision guarantee: f32 accumulators throughout the split AND combine kernels. Deterministic — each output cell is written by exactly one block; no atomicAdd.
Implementations§
Source§impl<T: Element> FlashDecodingPlan<T>
impl<T: Element> FlashDecodingPlan<T>
Sourcepub fn select(
_stream: &Stream,
desc: &FlashDecodingDescriptor,
_pref: PlanPreference,
) -> Result<Self>
pub fn select( _stream: &Stream, desc: &FlashDecodingDescriptor, _pref: PlanPreference, ) -> Result<Self>
Pick a kernel for the supplied descriptor.
Sourcepub fn can_implement(&self, args: &FlashDecodingArgs<'_, T>) -> Result<()>
pub fn can_implement(&self, args: &FlashDecodingArgs<'_, T>) -> Result<()>
Validate args against the descriptor.
Sourcepub fn backend(&self) -> BackendKind
pub fn backend(&self) -> BackendKind
Backend selected by select.
Sourcepub fn workspace_size(&self) -> usize
pub fn workspace_size(&self) -> usize
Workspace requirement in bytes for the (split + combine) pipeline.