Skip to main content

baracuda_kernels/attention/
flash_decoding.rs

1// SPDX-FileCopyrightText: 2026 Eric Evans and the baracuda contributors
2// SPDX-License-Identifier: MIT OR Apache-2.0
3//
4//! FlashDecoding — split-K parallel attention decode for `seq_q = 1`.
5//!
6//! Phase 73 follow-up. Closes the perf gap that both the bespoke
7//! `FlashSdpaPlan` Phase 10 trailblazer AND FA2 leave at the decode
8//! regime, where the seq_q dimension is too short to fill a 64-row
9//! q-tile and most of the GPU sits idle.
10//!
11//! FlashDecoding flips the parallelism axis: split K into chunks of
12//! 256 rows, launch one block per `(b, h, k_split)`, and combine the
13//! per-split online-softmax partials in a second small reduction kernel.
14//! For (B=1, H=32, K=2048, D=128) the split kernel launches `1 × 32 × 8
15//! = 256` blocks vs the FlashAttention kernel's 32 (Q/64=32 × H=32),
16//! and each block does meaningful work instead of being mostly
17//! q-tile padding.
18//!
19//! See `kernels/include/baracuda_flash_decoding.cuh` for the kernel
20//! body. This file wraps it with the standard descriptor / args / plan
21//! triple.
22//!
23//! ## Tier-1 scope
24//!
25//! - dtypes: `f16`, `bf16`. f32 / f64 are decode-uncommon (typical
26//!   inference is half-precision).
27//! - `head_dim ∈ [1, 128]`.
28//! - `seq_q == 1` strictly (decode contract — the whole point).
29//! - GQA via stride-0 broadcast on K/V's head axis. Pass an actual
30//!   `H_k` < H by setting `args.k.stride[1] = 0` (and same for V) with
31//!   `args.k.shape[1] = H`; the safe-wrapper computes the right
32//!   per-head base offset.
33//! - `is_causal` is irrelevant — there's only one query row, and the
34//!   caller is responsible for slicing the cache to the prefix it
35//!   wants attended to.
36//!
37//! ## Out of scope (deferred)
38//!
39//! - f32 / f64 (no decode workload uses these).
40//! - sliding window / ALiBi / soft-cap — pre-mask the cache.
41//! - BW pass — decode is FW-only.
42//! - Tensor-core MMA inside the Q·K dot product — the first cut uses
43//!   warp-shuffle reduce in fp32. A tensor-core retune is the next
44//!   follow-up phase once perf bench numbers land.
45//!
46//! ## Workspace
47//!
48//! Non-zero. The split kernel emits per-`(B, H, S)` partials (m, l, o)
49//! into the workspace; the combine kernel reads them and writes the
50//! final output. `workspace_size()` returns
51//! `B * H * num_splits * (2 + head_dim) * sizeof(f32)`.
52
53use core::ffi::c_void;
54use core::marker::PhantomData;
55
56use baracuda_cutlass::{Error, Result};
57use baracuda_driver::Stream;
58use baracuda_kernels_types::{
59    ArchSku, AttentionKind, BackendKind, Element, ElementKind, KernelSku, MathPrecision,
60    OpCategory, PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
61};
62
63use super::map_status;
64
65/// Maximum head dimension wired in the Tier-1 trailblazer.
66pub const FLASH_DECODING_MAX_D: i32 = 128;
67const CHUNK_K: i32 = 256;
68
69/// Descriptor for a FlashDecoding op.
70///
71/// `num_kv_heads` is the GQA grouping signal: when it equals `num_heads`
72/// the workload is full MHA; when it's smaller (e.g. 8 for Llama 3 8B
73/// at H_q=32) every K/V head is shared by `group_size = num_heads /
74/// num_kv_heads` Q heads. The launcher uses `group_size` to pick
75/// between the warp-cooperative SIMT kernel (Tier-1) and the
76/// GQA-batched WMMA kernel (Tier-2, gated on group_size ≥ 4 +
77/// head_dim aligned to 16).
78#[derive(Copy, Clone, Debug)]
79pub struct FlashDecodingDescriptor {
80    /// Batch size (`B`).
81    pub batch_size: i32,
82    /// Number of query / output heads (`H_q`).
83    pub num_heads: i32,
84    /// Number of K/V heads (`H_kv`). Must divide `num_heads` evenly.
85    /// `num_kv_heads == num_heads` → pure MHA. `num_kv_heads == 1` →
86    /// MQA. `num_kv_heads < num_heads && > 1` → GQA.
87    pub num_kv_heads: i32,
88    /// K/V sequence length (the full attended prefix, not just the new
89    /// step). Arbitrary; the split-K factor adapts via [`CHUNK_K`].
90    pub k_len: i32,
91    /// Per-head feature dimension. `d_q == d_k == d_v` is enforced —
92    /// the decode regime doesn't justify the d_k != d_v complication
93    /// the prefill kernel handles.
94    pub head_dim: i32,
95    /// Score scaling factor — typically `1.0 / sqrt(head_dim)`.
96    pub scale: f32,
97    /// Element type — must match the plan's type parameter.
98    pub element: ElementKind,
99}
100
101impl FlashDecodingDescriptor {
102    /// Convenience constructor for pure MHA (`num_kv_heads == num_heads`)
103    /// with the standard `1/sqrt(D)` scale.
104    #[inline]
105    pub fn new(batch_size: i32, num_heads: i32, k_len: i32, head_dim: i32, element: ElementKind) -> Self {
106        let scale = 1.0_f32 / (head_dim as f32).sqrt();
107        Self {
108            batch_size,
109            num_heads,
110            num_kv_heads: num_heads,
111            k_len,
112            head_dim,
113            scale,
114            element,
115        }
116    }
117
118    /// Convenience constructor for GQA / MQA. `num_kv_heads` must
119    /// divide `num_heads`.
120    #[inline]
121    pub fn new_gqa(
122        batch_size: i32,
123        num_heads: i32,
124        num_kv_heads: i32,
125        k_len: i32,
126        head_dim: i32,
127        element: ElementKind,
128    ) -> Self {
129        let scale = 1.0_f32 / (head_dim as f32).sqrt();
130        Self {
131            batch_size,
132            num_heads,
133            num_kv_heads,
134            k_len,
135            head_dim,
136            scale,
137            element,
138        }
139    }
140
141    /// Builder: override the score scale (e.g. for QK-norm models that
142    /// pre-divide by something other than `sqrt(head_dim)`).
143    #[inline]
144    pub fn with_scale(mut self, scale: f32) -> Self {
145        self.scale = scale;
146        self
147    }
148
149    /// GQA group size — number of Q heads sharing each K/V head.
150    #[inline]
151    pub fn group_size(&self) -> i32 {
152        if self.num_kv_heads == 0 {
153            0
154        } else {
155            self.num_heads / self.num_kv_heads
156        }
157    }
158}
159
160/// Args bundle for a FlashDecoding launch.
161///
162/// Q is rank-3 because `seq_q == 1` is encoded in the descriptor — no
163/// need to thread a unit axis through the API.
164///
165/// K/V take shape `[B, H_kv, K_len, D]` (the PHYSICAL layout, not the
166/// broadcast-replicated H_q view). The kernel handles the Q→KV head
167/// mapping via integer division `kv_head = q_head / group_size`. For
168/// pure MHA the caller just passes `H_kv == H_q` and the same data
169/// shape as before.
170pub struct FlashDecodingArgs<'a, T: Element> {
171    /// Query tensor — shape `[B, H_q, D]`. Arbitrary strides via the
172    /// supplied stride array; typical case is contig.
173    pub q: TensorRef<'a, T, 3>,
174    /// Key tensor — shape `[B, H_kv, K_len, D]`, physical layout.
175    pub k: TensorRef<'a, T, 4>,
176    /// Value tensor — shape `[B, H_kv, K_len, D]`, physical layout.
177    pub v: TensorRef<'a, T, 4>,
178    /// Output tensor — shape `[B, H_q, D]`.
179    pub y: TensorMut<'a, T, 3>,
180}
181
182/// FlashDecoding forward plan (Dao 2023).
183///
184/// Split-K parallel attention decode for `seq_q = 1`. Replaces both
185/// [`FlashSdpaPlan`](crate::FlashSdpaPlan) and FA2 at the decode regime
186/// — both of those tile the Q dimension and waste work when seq_q < 64.
187///
188/// **When to use**: autoregressive decoder inference token loop. After
189/// the prefill step (which uses [`FlashSdpaPlan`] with `fa2` for the
190/// long initial context), each generated token calls this plan with
191/// `seq_q = 1` and the full grown KV cache.
192///
193/// **Dtypes**: `f16`, `bf16` (the only dtypes inference uses).
194///
195/// **Shape limits**: `head_dim ≤ 128`. Arbitrary `B`, `H`, `K_len`.
196///
197/// **Workspace**: non-zero. See [`Self::workspace_size`].
198///
199/// **Precision guarantee**: f32 accumulators throughout the split AND
200/// combine kernels. Deterministic — each output cell is written by
201/// exactly one block; no atomicAdd.
202pub struct FlashDecodingPlan<T: Element> {
203    desc: FlashDecodingDescriptor,
204    sku: KernelSku,
205    _marker: PhantomData<T>,
206}
207
208impl<T: Element> FlashDecodingPlan<T> {
209    /// Pick a kernel for the supplied descriptor.
210    pub fn select(
211        _stream: &Stream,
212        desc: &FlashDecodingDescriptor,
213        _pref: PlanPreference,
214    ) -> Result<Self> {
215        if desc.element != T::KIND {
216            return Err(Error::Unsupported(
217                "baracuda-kernels::FlashDecodingPlan: descriptor element != T",
218            ));
219        }
220        if desc.batch_size <= 0
221            || desc.num_heads <= 0
222            || desc.num_kv_heads <= 0
223            || desc.k_len < 0
224            || desc.head_dim <= 0
225        {
226            return Err(Error::InvalidProblem(
227                "baracuda-kernels::FlashDecodingPlan: extents must be positive (k_len may be 0)",
228            ));
229        }
230        if desc.num_heads % desc.num_kv_heads != 0 {
231            return Err(Error::InvalidProblem(
232                "baracuda-kernels::FlashDecodingPlan: num_heads must be a multiple of num_kv_heads",
233            ));
234        }
235        if desc.head_dim > FLASH_DECODING_MAX_D {
236            return Err(Error::Unsupported(
237                "baracuda-kernels::FlashDecodingPlan: head_dim > 128 not supported",
238            ));
239        }
240        if !matches!(T::KIND, ElementKind::F16 | ElementKind::Bf16) {
241            return Err(Error::Unsupported(
242                "baracuda-kernels::FlashDecodingPlan: wired today: {f16, bf16}",
243            ));
244        }
245
246        let precision_guarantee = PrecisionGuarantee {
247            math_precision: MathPrecision::F32,
248            accumulator: ElementKind::F32,
249            bit_stable_on_same_hardware: true,
250            deterministic: true,
251        };
252        let sku = KernelSku {
253            category: OpCategory::Attention,
254            op: AttentionKind::FlashAttention as u16,
255            element: T::KIND,
256            aux_element: None,
257            layout: None,
258            epilogue: None,
259            arch: ArchSku::Sm80,
260            backend: BackendKind::Bespoke,
261            precision_guarantee,
262        };
263        Ok(Self {
264            desc: *desc,
265            sku,
266            _marker: PhantomData,
267        })
268    }
269
270    /// Validate args against the descriptor.
271    pub fn can_implement(&self, args: &FlashDecodingArgs<'_, T>) -> Result<()> {
272        let d = self.desc.head_dim;
273        let b = self.desc.batch_size;
274        let h_q = self.desc.num_heads;
275        let h_kv = self.desc.num_kv_heads;
276        let k = self.desc.k_len;
277
278        if args.q.shape != [b, h_q, d] {
279            return Err(Error::InvalidProblem(
280                "FlashDecodingPlan: q.shape mismatch (expected [B, H_q, D])",
281            ));
282        }
283        if args.y.shape != [b, h_q, d] {
284            return Err(Error::InvalidProblem(
285                "FlashDecodingPlan: y.shape mismatch (expected [B, H_q, D])",
286            ));
287        }
288        if args.k.shape != [b, h_kv, k, d] {
289            return Err(Error::InvalidProblem(
290                "FlashDecodingPlan: k.shape mismatch (expected [B, H_kv, K_len, D])",
291            ));
292        }
293        if args.v.shape != [b, h_kv, k, d] {
294            return Err(Error::InvalidProblem(
295                "FlashDecodingPlan: v.shape mismatch (expected [B, H_kv, K_len, D])",
296            ));
297        }
298        Ok(())
299    }
300
301    /// Backend selected by `select`.
302    #[inline]
303    pub fn backend(&self) -> BackendKind {
304        BackendKind::Bespoke
305    }
306
307    /// Kernel SKU descriptor.
308    #[inline]
309    pub fn sku(&self) -> &KernelSku {
310        &self.sku
311    }
312
313    /// Workspace requirement in bytes for the (split + combine) pipeline.
314    pub fn workspace_size(&self) -> usize {
315        let b = self.desc.batch_size as i64;
316        let h = self.desc.num_heads as i64;
317        let s = num_splits(self.desc.k_len) as i64;
318        let d = self.desc.head_dim as i64;
319        if s == 0 || b == 0 || h == 0 {
320            return 0;
321        }
322        // partial_m + partial_l + partial_o[D] → (2 + D) * f32 per
323        // (b, h, split).
324        (b * h * s * (2 + d) * 4) as usize
325    }
326
327    /// Run the FlashDecoding pipeline.
328    pub fn run(
329        &self,
330        stream: &Stream,
331        workspace: Workspace<'_>,
332        args: FlashDecodingArgs<'_, T>,
333    ) -> Result<()> {
334        self.can_implement(&args)?;
335
336        let needed = self.workspace_size();
337        let (ws_ptr, ws_bytes) = match workspace {
338            Workspace::None => {
339                if needed > 0 {
340                    return Err(Error::WorkspaceTooSmall {
341                        needed,
342                        got: 0,
343                    });
344                }
345                (core::ptr::null_mut::<c_void>(), 0_usize)
346            }
347            Workspace::Borrowed(buf) => {
348                if buf.len() < needed {
349                    return Err(Error::WorkspaceTooSmall {
350                        needed,
351                        got: buf.len(),
352                    });
353                }
354                (buf.as_raw().0 as *mut c_void, buf.len())
355            }
356        };
357
358        let stream_ptr = stream.as_raw() as *mut c_void;
359        let q_ptr = args.q.data.as_raw().0 as *const c_void;
360        let k_ptr = args.k.data.as_raw().0 as *const c_void;
361        let v_ptr = args.v.data.as_raw().0 as *const c_void;
362        let y_ptr = args.y.data.as_raw().0 as *mut c_void;
363
364        let status = unsafe {
365            match T::KIND {
366                ElementKind::F16 => baracuda_kernels_sys::baracuda_kernels_flash_decoding_f16_run(
367                    q_ptr,
368                    k_ptr,
369                    v_ptr,
370                    y_ptr,
371                    ws_ptr,
372                    ws_bytes,
373                    self.desc.batch_size,
374                    self.desc.num_heads,
375                    self.desc.num_kv_heads,
376                    self.desc.k_len,
377                    self.desc.head_dim,
378                    args.q.stride[0],
379                    args.q.stride[1],
380                    args.k.stride[0],
381                    args.k.stride[1],
382                    args.k.stride[2],
383                    args.v.stride[0],
384                    args.v.stride[1],
385                    args.v.stride[2],
386                    args.y.stride[0],
387                    args.y.stride[1],
388                    self.desc.scale,
389                    stream_ptr,
390                ),
391                ElementKind::Bf16 => baracuda_kernels_sys::baracuda_kernels_flash_decoding_bf16_run(
392                    q_ptr,
393                    k_ptr,
394                    v_ptr,
395                    y_ptr,
396                    ws_ptr,
397                    ws_bytes,
398                    self.desc.batch_size,
399                    self.desc.num_heads,
400                    self.desc.num_kv_heads,
401                    self.desc.k_len,
402                    self.desc.head_dim,
403                    args.q.stride[0],
404                    args.q.stride[1],
405                    args.k.stride[0],
406                    args.k.stride[1],
407                    args.k.stride[2],
408                    args.v.stride[0],
409                    args.v.stride[1],
410                    args.v.stride[2],
411                    args.y.stride[0],
412                    args.y.stride[1],
413                    self.desc.scale,
414                    stream_ptr,
415                ),
416                _ => {
417                    return Err(Error::Unsupported(
418                        "baracuda-kernels::FlashDecodingPlan: only f16 / bf16 wired",
419                    ));
420                }
421            }
422        };
423        map_status(status)
424    }
425}
426
427#[inline]
428fn num_splits(k_len: i32) -> i32 {
429    if k_len <= 0 {
430        return 0;
431    }
432    (k_len + CHUNK_K - 1) / CHUNK_K
433}