Skip to main content

FlashDecodingPlan

Struct FlashDecodingPlan 

Source
pub struct FlashDecodingPlan<T: Element> { /* private fields */ }
Expand description

FlashDecoding forward plan (Dao 2023).

Split-K parallel attention decode for seq_q = 1. Replaces both FlashSdpaPlan and FA2 at the decode regime — both of those tile the Q dimension and waste work when seq_q < 64.

When to use: autoregressive decoder inference token loop. After the prefill step (which uses [FlashSdpaPlan] with fa2 for the long initial context), each generated token calls this plan with seq_q = 1 and the full grown KV cache.

Dtypes: f16, bf16 (the only dtypes inference uses).

Shape limits: head_dim ≤ 128. Arbitrary B, H, K_len.

Workspace: non-zero. See Self::workspace_size.

Precision guarantee: f32 accumulators throughout the split AND combine kernels. Deterministic — each output cell is written by exactly one block; no atomicAdd.

Implementations§

Source§

impl<T: Element> FlashDecodingPlan<T>

Source

pub fn select( _stream: &Stream, desc: &FlashDecodingDescriptor, _pref: PlanPreference, ) -> Result<Self>

Pick a kernel for the supplied descriptor.

Source

pub fn can_implement(&self, args: &FlashDecodingArgs<'_, T>) -> Result<()>

Validate args against the descriptor.

Source

pub fn backend(&self) -> BackendKind

Backend selected by select.

Source

pub fn sku(&self) -> &KernelSku

Kernel SKU descriptor.

Source

pub fn workspace_size(&self) -> usize

Workspace requirement in bytes for the (split + combine) pipeline.

Source

pub fn run( &self, stream: &Stream, workspace: Workspace<'_>, args: FlashDecodingArgs<'_, T>, ) -> Result<()>

Run the FlashDecoding pipeline.

Auto Trait Implementations§

§

impl<T> Freeze for FlashDecodingPlan<T>

§

impl<T> RefUnwindSafe for FlashDecodingPlan<T>
where T: RefUnwindSafe,

§

impl<T> Send for FlashDecodingPlan<T>
where T: Send,

§

impl<T> Sync for FlashDecodingPlan<T>
where T: Sync,

§

impl<T> Unpin for FlashDecodingPlan<T>
where T: Unpin,

§

impl<T> UnsafeUnpin for FlashDecodingPlan<T>

§

impl<T> UnwindSafe for FlashDecodingPlan<T>
where T: UnwindSafe,

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.