baracuda_kernels/attention/flash_decoding.rs
1// SPDX-FileCopyrightText: 2026 Eric Evans and the baracuda contributors
2// SPDX-License-Identifier: MIT OR Apache-2.0
3//
4//! FlashDecoding — split-K parallel attention decode for `seq_q = 1`.
5//!
6//! Phase 73 follow-up. Closes the perf gap that both the bespoke
7//! `FlashSdpaPlan` Phase 10 trailblazer AND FA2 leave at the decode
8//! regime, where the seq_q dimension is too short to fill a 64-row
9//! q-tile and most of the GPU sits idle.
10//!
11//! FlashDecoding flips the parallelism axis: split K into chunks of
12//! 256 rows, launch one block per `(b, h, k_split)`, and combine the
13//! per-split online-softmax partials in a second small reduction kernel.
14//! For (B=1, H=32, K=2048, D=128) the split kernel launches `1 × 32 × 8
15//! = 256` blocks vs the FlashAttention kernel's 32 (Q/64=32 × H=32),
16//! and each block does meaningful work instead of being mostly
17//! q-tile padding.
18//!
19//! See `kernels/include/baracuda_flash_decoding.cuh` for the kernel
20//! body. This file wraps it with the standard descriptor / args / plan
21//! triple.
22//!
23//! ## Tier-1 scope
24//!
25//! - dtypes: `f16`, `bf16`. f32 / f64 are decode-uncommon (typical
26//! inference is half-precision).
27//! - `head_dim ∈ [1, 128]`.
28//! - `seq_q == 1` strictly (decode contract — the whole point).
29//! - GQA via stride-0 broadcast on K/V's head axis. Pass an actual
30//! `H_k` < H by setting `args.k.stride[1] = 0` (and same for V) with
31//! `args.k.shape[1] = H`; the safe-wrapper computes the right
32//! per-head base offset.
33//! - `is_causal` is irrelevant — there's only one query row, and the
34//! caller is responsible for slicing the cache to the prefix it
35//! wants attended to.
36//!
37//! ## Out of scope (deferred)
38//!
39//! - f32 / f64 (no decode workload uses these).
40//! - sliding window / ALiBi / soft-cap — pre-mask the cache.
41//! - BW pass — decode is FW-only.
42//! - Tensor-core MMA inside the Q·K dot product — the first cut uses
43//! warp-shuffle reduce in fp32. A tensor-core retune is the next
44//! follow-up phase once perf bench numbers land.
45//!
46//! ## Workspace
47//!
48//! Non-zero. The split kernel emits per-`(B, H, S)` partials (m, l, o)
49//! into the workspace; the combine kernel reads them and writes the
50//! final output. `workspace_size()` returns
51//! `B * H * num_splits * (2 + head_dim) * sizeof(f32)`.
52
53use core::ffi::c_void;
54use core::marker::PhantomData;
55
56use baracuda_cutlass::{Error, Result};
57use baracuda_driver::Stream;
58use baracuda_kernels_types::{
59 ArchSku, AttentionKind, BackendKind, Element, ElementKind, KernelSku, MathPrecision,
60 OpCategory, PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
61};
62
63use super::map_status;
64
65/// Maximum head dimension wired in the Tier-1 trailblazer.
66pub const FLASH_DECODING_MAX_D: i32 = 128;
67const CHUNK_K: i32 = 256;
68
69/// Descriptor for a FlashDecoding op.
70///
71/// `num_kv_heads` is the GQA grouping signal: when it equals `num_heads`
72/// the workload is full MHA; when it's smaller (e.g. 8 for Llama 3 8B
73/// at H_q=32) every K/V head is shared by `group_size = num_heads /
74/// num_kv_heads` Q heads. The launcher uses `group_size` to pick
75/// between the warp-cooperative SIMT kernel (Tier-1) and the
76/// GQA-batched WMMA kernel (Tier-2, gated on group_size ≥ 4 +
77/// head_dim aligned to 16).
78#[derive(Copy, Clone, Debug)]
79pub struct FlashDecodingDescriptor {
80 /// Batch size (`B`).
81 pub batch_size: i32,
82 /// Number of query / output heads (`H_q`).
83 pub num_heads: i32,
84 /// Number of K/V heads (`H_kv`). Must divide `num_heads` evenly.
85 /// `num_kv_heads == num_heads` → pure MHA. `num_kv_heads == 1` →
86 /// MQA. `num_kv_heads < num_heads && > 1` → GQA.
87 pub num_kv_heads: i32,
88 /// K/V sequence length (the full attended prefix, not just the new
89 /// step). Arbitrary; the split-K factor adapts via [`CHUNK_K`].
90 pub k_len: i32,
91 /// Per-head feature dimension. `d_q == d_k == d_v` is enforced —
92 /// the decode regime doesn't justify the d_k != d_v complication
93 /// the prefill kernel handles.
94 pub head_dim: i32,
95 /// Score scaling factor — typically `1.0 / sqrt(head_dim)`.
96 pub scale: f32,
97 /// Element type — must match the plan's type parameter.
98 pub element: ElementKind,
99}
100
101impl FlashDecodingDescriptor {
102 /// Convenience constructor for pure MHA (`num_kv_heads == num_heads`)
103 /// with the standard `1/sqrt(D)` scale.
104 #[inline]
105 pub fn new(batch_size: i32, num_heads: i32, k_len: i32, head_dim: i32, element: ElementKind) -> Self {
106 let scale = 1.0_f32 / (head_dim as f32).sqrt();
107 Self {
108 batch_size,
109 num_heads,
110 num_kv_heads: num_heads,
111 k_len,
112 head_dim,
113 scale,
114 element,
115 }
116 }
117
118 /// Convenience constructor for GQA / MQA. `num_kv_heads` must
119 /// divide `num_heads`.
120 #[inline]
121 pub fn new_gqa(
122 batch_size: i32,
123 num_heads: i32,
124 num_kv_heads: i32,
125 k_len: i32,
126 head_dim: i32,
127 element: ElementKind,
128 ) -> Self {
129 let scale = 1.0_f32 / (head_dim as f32).sqrt();
130 Self {
131 batch_size,
132 num_heads,
133 num_kv_heads,
134 k_len,
135 head_dim,
136 scale,
137 element,
138 }
139 }
140
141 /// Builder: override the score scale (e.g. for QK-norm models that
142 /// pre-divide by something other than `sqrt(head_dim)`).
143 #[inline]
144 pub fn with_scale(mut self, scale: f32) -> Self {
145 self.scale = scale;
146 self
147 }
148
149 /// GQA group size — number of Q heads sharing each K/V head.
150 #[inline]
151 pub fn group_size(&self) -> i32 {
152 if self.num_kv_heads == 0 {
153 0
154 } else {
155 self.num_heads / self.num_kv_heads
156 }
157 }
158}
159
160/// Args bundle for a FlashDecoding launch.
161///
162/// Q is rank-3 because `seq_q == 1` is encoded in the descriptor — no
163/// need to thread a unit axis through the API.
164///
165/// K/V take shape `[B, H_kv, K_len, D]` (the PHYSICAL layout, not the
166/// broadcast-replicated H_q view). The kernel handles the Q→KV head
167/// mapping via integer division `kv_head = q_head / group_size`. For
168/// pure MHA the caller just passes `H_kv == H_q` and the same data
169/// shape as before.
170pub struct FlashDecodingArgs<'a, T: Element> {
171 /// Query tensor — shape `[B, H_q, D]`. Arbitrary strides via the
172 /// supplied stride array; typical case is contig.
173 pub q: TensorRef<'a, T, 3>,
174 /// Key tensor — shape `[B, H_kv, K_len, D]`, physical layout.
175 pub k: TensorRef<'a, T, 4>,
176 /// Value tensor — shape `[B, H_kv, K_len, D]`, physical layout.
177 pub v: TensorRef<'a, T, 4>,
178 /// Output tensor — shape `[B, H_q, D]`.
179 pub y: TensorMut<'a, T, 3>,
180}
181
182/// FlashDecoding forward plan (Dao 2023).
183///
184/// Split-K parallel attention decode for `seq_q = 1`. Replaces both
185/// [`FlashSdpaPlan`](crate::FlashSdpaPlan) and FA2 at the decode regime
186/// — both of those tile the Q dimension and waste work when seq_q < 64.
187///
188/// **When to use**: autoregressive decoder inference token loop. After
189/// the prefill step (which uses [`FlashSdpaPlan`] with `fa2` for the
190/// long initial context), each generated token calls this plan with
191/// `seq_q = 1` and the full grown KV cache.
192///
193/// **Dtypes**: `f16`, `bf16` (the only dtypes inference uses).
194///
195/// **Shape limits**: `head_dim ≤ 128`. Arbitrary `B`, `H`, `K_len`.
196///
197/// **Workspace**: non-zero. See [`Self::workspace_size`].
198///
199/// **Precision guarantee**: f32 accumulators throughout the split AND
200/// combine kernels. Deterministic — each output cell is written by
201/// exactly one block; no atomicAdd.
202pub struct FlashDecodingPlan<T: Element> {
203 desc: FlashDecodingDescriptor,
204 sku: KernelSku,
205 _marker: PhantomData<T>,
206}
207
208impl<T: Element> FlashDecodingPlan<T> {
209 /// Pick a kernel for the supplied descriptor.
210 pub fn select(
211 _stream: &Stream,
212 desc: &FlashDecodingDescriptor,
213 _pref: PlanPreference,
214 ) -> Result<Self> {
215 if desc.element != T::KIND {
216 return Err(Error::Unsupported(
217 "baracuda-kernels::FlashDecodingPlan: descriptor element != T",
218 ));
219 }
220 if desc.batch_size <= 0
221 || desc.num_heads <= 0
222 || desc.num_kv_heads <= 0
223 || desc.k_len < 0
224 || desc.head_dim <= 0
225 {
226 return Err(Error::InvalidProblem(
227 "baracuda-kernels::FlashDecodingPlan: extents must be positive (k_len may be 0)",
228 ));
229 }
230 if desc.num_heads % desc.num_kv_heads != 0 {
231 return Err(Error::InvalidProblem(
232 "baracuda-kernels::FlashDecodingPlan: num_heads must be a multiple of num_kv_heads",
233 ));
234 }
235 if desc.head_dim > FLASH_DECODING_MAX_D {
236 return Err(Error::Unsupported(
237 "baracuda-kernels::FlashDecodingPlan: head_dim > 128 not supported",
238 ));
239 }
240 if !matches!(T::KIND, ElementKind::F16 | ElementKind::Bf16) {
241 return Err(Error::Unsupported(
242 "baracuda-kernels::FlashDecodingPlan: wired today: {f16, bf16}",
243 ));
244 }
245
246 let precision_guarantee = PrecisionGuarantee {
247 math_precision: MathPrecision::F32,
248 accumulator: ElementKind::F32,
249 bit_stable_on_same_hardware: true,
250 deterministic: true,
251 };
252 let sku = KernelSku {
253 category: OpCategory::Attention,
254 op: AttentionKind::FlashAttention as u16,
255 element: T::KIND,
256 aux_element: None,
257 layout: None,
258 epilogue: None,
259 arch: ArchSku::Sm80,
260 backend: BackendKind::Bespoke,
261 precision_guarantee,
262 };
263 Ok(Self {
264 desc: *desc,
265 sku,
266 _marker: PhantomData,
267 })
268 }
269
270 /// Validate args against the descriptor.
271 pub fn can_implement(&self, args: &FlashDecodingArgs<'_, T>) -> Result<()> {
272 let d = self.desc.head_dim;
273 let b = self.desc.batch_size;
274 let h_q = self.desc.num_heads;
275 let h_kv = self.desc.num_kv_heads;
276 let k = self.desc.k_len;
277
278 if args.q.shape != [b, h_q, d] {
279 return Err(Error::InvalidProblem(
280 "FlashDecodingPlan: q.shape mismatch (expected [B, H_q, D])",
281 ));
282 }
283 if args.y.shape != [b, h_q, d] {
284 return Err(Error::InvalidProblem(
285 "FlashDecodingPlan: y.shape mismatch (expected [B, H_q, D])",
286 ));
287 }
288 if args.k.shape != [b, h_kv, k, d] {
289 return Err(Error::InvalidProblem(
290 "FlashDecodingPlan: k.shape mismatch (expected [B, H_kv, K_len, D])",
291 ));
292 }
293 if args.v.shape != [b, h_kv, k, d] {
294 return Err(Error::InvalidProblem(
295 "FlashDecodingPlan: v.shape mismatch (expected [B, H_kv, K_len, D])",
296 ));
297 }
298 Ok(())
299 }
300
301 /// Backend selected by `select`.
302 #[inline]
303 pub fn backend(&self) -> BackendKind {
304 BackendKind::Bespoke
305 }
306
307 /// Kernel SKU descriptor.
308 #[inline]
309 pub fn sku(&self) -> &KernelSku {
310 &self.sku
311 }
312
313 /// Workspace requirement in bytes for the (split + combine) pipeline.
314 pub fn workspace_size(&self) -> usize {
315 let b = self.desc.batch_size as i64;
316 let h = self.desc.num_heads as i64;
317 let s = num_splits(self.desc.k_len) as i64;
318 let d = self.desc.head_dim as i64;
319 if s == 0 || b == 0 || h == 0 {
320 return 0;
321 }
322 // partial_m + partial_l + partial_o[D] → (2 + D) * f32 per
323 // (b, h, split).
324 (b * h * s * (2 + d) * 4) as usize
325 }
326
327 /// Run the FlashDecoding pipeline.
328 pub fn run(
329 &self,
330 stream: &Stream,
331 workspace: Workspace<'_>,
332 args: FlashDecodingArgs<'_, T>,
333 ) -> Result<()> {
334 self.can_implement(&args)?;
335
336 let needed = self.workspace_size();
337 let (ws_ptr, ws_bytes) = match workspace {
338 Workspace::None => {
339 if needed > 0 {
340 return Err(Error::WorkspaceTooSmall {
341 needed,
342 got: 0,
343 });
344 }
345 (core::ptr::null_mut::<c_void>(), 0_usize)
346 }
347 Workspace::Borrowed(buf) => {
348 if buf.len() < needed {
349 return Err(Error::WorkspaceTooSmall {
350 needed,
351 got: buf.len(),
352 });
353 }
354 (buf.as_raw().0 as *mut c_void, buf.len())
355 }
356 };
357
358 let stream_ptr = stream.as_raw() as *mut c_void;
359 let q_ptr = args.q.data.as_raw().0 as *const c_void;
360 let k_ptr = args.k.data.as_raw().0 as *const c_void;
361 let v_ptr = args.v.data.as_raw().0 as *const c_void;
362 let y_ptr = args.y.data.as_raw().0 as *mut c_void;
363
364 let status = unsafe {
365 match T::KIND {
366 ElementKind::F16 => baracuda_kernels_sys::baracuda_kernels_flash_decoding_f16_run(
367 q_ptr,
368 k_ptr,
369 v_ptr,
370 y_ptr,
371 ws_ptr,
372 ws_bytes,
373 self.desc.batch_size,
374 self.desc.num_heads,
375 self.desc.num_kv_heads,
376 self.desc.k_len,
377 self.desc.head_dim,
378 args.q.stride[0],
379 args.q.stride[1],
380 args.k.stride[0],
381 args.k.stride[1],
382 args.k.stride[2],
383 args.v.stride[0],
384 args.v.stride[1],
385 args.v.stride[2],
386 args.y.stride[0],
387 args.y.stride[1],
388 self.desc.scale,
389 stream_ptr,
390 ),
391 ElementKind::Bf16 => baracuda_kernels_sys::baracuda_kernels_flash_decoding_bf16_run(
392 q_ptr,
393 k_ptr,
394 v_ptr,
395 y_ptr,
396 ws_ptr,
397 ws_bytes,
398 self.desc.batch_size,
399 self.desc.num_heads,
400 self.desc.num_kv_heads,
401 self.desc.k_len,
402 self.desc.head_dim,
403 args.q.stride[0],
404 args.q.stride[1],
405 args.k.stride[0],
406 args.k.stride[1],
407 args.k.stride[2],
408 args.v.stride[0],
409 args.v.stride[1],
410 args.v.stride[2],
411 args.y.stride[0],
412 args.y.stride[1],
413 self.desc.scale,
414 stream_ptr,
415 ),
416 _ => {
417 return Err(Error::Unsupported(
418 "baracuda-kernels::FlashDecodingPlan: only f16 / bf16 wired",
419 ));
420 }
421 }
422 };
423 map_status(status)
424 }
425}
426
427#[inline]
428fn num_splits(k_len: i32) -> i32 {
429 if k_len <= 0 {
430 return 0;
431 }
432 (k_len + CHUNK_K - 1) / CHUNK_K
433}