pub struct FlashDecodingArgs<'a, T: Element> {
pub q: TensorRef<'a, T, 3>,
pub k: TensorRef<'a, T, 4>,
pub v: TensorRef<'a, T, 4>,
pub y: TensorMut<'a, T, 3>,
}Expand description
Args bundle for a FlashDecoding launch.
Q is rank-3 because seq_q == 1 is encoded in the descriptor — no
need to thread a unit axis through the API.
K/V take shape [B, H_kv, K_len, D] (the PHYSICAL layout, not the
broadcast-replicated H_q view). The kernel handles the Q→KV head
mapping via integer division kv_head = q_head / group_size. For
pure MHA the caller just passes H_kv == H_q and the same data
shape as before.
Fields§
§q: TensorRef<'a, T, 3>Query tensor — shape [B, H_q, D]. Arbitrary strides via the
supplied stride array; typical case is contig.
k: TensorRef<'a, T, 4>Key tensor — shape [B, H_kv, K_len, D], physical layout.
v: TensorRef<'a, T, 4>Value tensor — shape [B, H_kv, K_len, D], physical layout.
y: TensorMut<'a, T, 3>Output tensor — shape [B, H_q, D].
Auto Trait Implementations§
impl<'a, T> !UnwindSafe for FlashDecodingArgs<'a, T>
impl<'a, T> Freeze for FlashDecodingArgs<'a, T>
impl<'a, T> RefUnwindSafe for FlashDecodingArgs<'a, T>where
T: RefUnwindSafe,
impl<'a, T> Send for FlashDecodingArgs<'a, T>
impl<'a, T> Sync for FlashDecodingArgs<'a, T>where
T: Sync,
impl<'a, T> Unpin for FlashDecodingArgs<'a, T>
impl<'a, T> UnsafeUnpin for FlashDecodingArgs<'a, T>
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more