Skip to main content

baracuda_kernels/attention/
kv_cache.rs

1//! KV-cache append — decoder-inference helper (Milestone 6.5).
2//!
3//! At every autoregressive decoding step the model produces fresh
4//! `K_new` / `V_new` slices that need to be appended to running caches
5//! `K_cache` / `V_cache` shared across the step's attention call. Each
6//! batch sample keeps its own offset (`cache_offsets[b]`) — the next
7//! slot to fill in its slice of the cache — so ragged-batch inference
8//! (different cumulative cache lengths per sample) is natively
9//! supported.
10//!
11//! Op semantics, for each `b, h, l_new, d`:
12//!
13//! ```text
14//! K_cache[b, h, cache_offsets[b] + l_new, d_k] = K_new[b, h, l_new, d_k]
15//! V_cache[b, h, cache_offsets[b] + l_new, d_v] = V_new[b, h, l_new, d_v]
16//! ```
17//!
18//! Shape conventions (all rank-4, contiguous, row-major):
19//!
20//! | tensor          | shape                                    |
21//! |-----------------|------------------------------------------|
22//! | `K_new`         | `[B, H, L_new, D_k]`                     |
23//! | `V_new`         | `[B, H, L_new, D_v]`                     |
24//! | `cache_offsets` | `[B]` (i64)                              |
25//! | `K_cache`       | `[B, H, L_max, D_k]` (modified in place) |
26//! | `V_cache`       | `[B, H, L_max, D_v]` (modified in place) |
27//!
28//! Pure copy — bit-exact across every wired dtype (`{f32, f16, bf16,
29//! f64}`). Cells where `cache_offsets[b] + l_new >= L_max` are silently
30//! skipped; the caller is responsible for sizing the cache so writes
31//! land in bounds. No BW — KV-cache is an inference-time op.
32//!
33//! After the call the caller updates `cache_offsets[b] += L_new` (host-
34//! or device-side, the plan doesn't touch the offset vector).
35
36use core::ffi::c_void;
37use core::marker::PhantomData;
38
39use baracuda_cutlass::{Error, Result};
40use baracuda_driver::Stream;
41use baracuda_kernels_types::{
42    ArchSku, AttentionKind, BackendKind, Element, ElementKind, KernelSku, MathPrecision,
43    OpCategory, PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
44};
45
46use super::map_status;
47
48/// Descriptor for a KV-cache append op.
49#[derive(Copy, Clone, Debug)]
50pub struct KvCacheAppendDescriptor {
51    /// Batch size (`B`).
52    pub batch_size: i32,
53    /// Number of attention heads (`H`).
54    pub num_heads: i32,
55    /// Number of new K/V rows to append per sample (`L_new`).
56    pub new_len: i32,
57    /// Capacity of the cache along the sequence axis (`L_max`).
58    pub max_cache_len: i32,
59    /// Head dimension of K (`D_k`).
60    pub d_k: i32,
61    /// Head dimension of V (`D_v`). May differ from `d_k`.
62    pub d_v: i32,
63    /// Element type — must match the plan's type parameter.
64    pub element: ElementKind,
65}
66
67/// Args bundle for a KV-cache append launch.
68pub struct KvCacheAppendArgs<'a, T: Element> {
69    /// New K rows — shape `[B, H, L_new, D_k]`, contiguous.
70    pub k_new: TensorRef<'a, T, 4>,
71    /// New V rows — shape `[B, H, L_new, D_v]`, contiguous.
72    pub v_new: TensorRef<'a, T, 4>,
73    /// Per-sample insert offsets — shape `[B]`, `i64`. Values must
74    /// satisfy `0 <= cache_offsets[b]` and `cache_offsets[b] + L_new
75    /// <= L_max` for every cell that should land (out-of-range cells
76    /// are silently skipped by the kernel).
77    pub cache_offsets: TensorRef<'a, i64, 1>,
78    /// Destination K cache — shape `[B, H, L_max, D_k]`, contiguous;
79    /// modified in place.
80    pub k_cache: TensorMut<'a, T, 4>,
81    /// Destination V cache — shape `[B, H, L_max, D_v]`, contiguous;
82    /// modified in place.
83    pub v_cache: TensorMut<'a, T, 4>,
84}
85
86/// KV-cache append plan.
87///
88/// Writes new `K_new` / `V_new` slices into running `K_cache` /
89/// `V_cache` buffers at per-sample offsets supplied via
90/// `cache_offsets[b]`. Pure copy — bit-exact across all wired dtypes.
91///
92/// **When to use**: autoregressive decoder inference. Call once per
93/// generation step to extend the cache before the attention op for the
94/// next step. Ragged-batch insertion is supported natively because each
95/// sample carries its own offset. No backward — KV-cache is an
96/// inference-time op.
97///
98/// **Dtypes**: `f32`, `f64`, `f16`, `bf16`.
99///
100/// **Shape limits**: rank-4 contiguous K/V tensors. `d_k != d_v` is
101/// allowed. Cells where `cache_offsets[b] + l_new >= max_cache_len`
102/// are silently skipped — the caller sizes the cache so writes land in
103/// bounds.
104///
105/// **Workspace**: zero (pure in-place copy).
106///
107/// **Precision guarantee**: bit-exact (no math at all).
108pub struct KvCacheAppendPlan<T: Element> {
109    desc: KvCacheAppendDescriptor,
110    sku: KernelSku,
111    _marker: PhantomData<T>,
112}
113
114impl<T: Element> KvCacheAppendPlan<T> {
115    /// Pick a kernel.
116    pub fn select(
117        _stream: &Stream,
118        desc: &KvCacheAppendDescriptor,
119        _pref: PlanPreference,
120    ) -> Result<Self> {
121        if desc.element != T::KIND {
122            return Err(Error::Unsupported(
123                "baracuda-kernels::KvCacheAppendPlan: descriptor element != T",
124            ));
125        }
126        if desc.batch_size < 0
127            || desc.num_heads < 0
128            || desc.new_len < 0
129            || desc.max_cache_len < 0
130            || desc.d_k < 0
131            || desc.d_v < 0
132        {
133            return Err(Error::InvalidProblem(
134                "baracuda-kernels::KvCacheAppendPlan: extents must be non-negative",
135            ));
136        }
137        let dtype_in_scope = matches!(
138            T::KIND,
139            ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
140        );
141        if !dtype_in_scope {
142            return Err(Error::Unsupported(
143                "baracuda-kernels::KvCacheAppendPlan: wired today: `{f32, f16, bf16, f64}`",
144            ));
145        }
146
147        let precision_guarantee = PrecisionGuarantee {
148            math_precision: MathPrecision::F32,
149            // Pure copy — no accumulator math happens at all. The
150            // value carried here is cosmetic; bit-stability is
151            // guaranteed regardless.
152            accumulator: T::KIND,
153            bit_stable_on_same_hardware: true,
154            deterministic: true,
155        };
156        let sku = KernelSku {
157            category: OpCategory::Attention,
158            op: AttentionKind::KvCache as u16,
159            element: T::KIND,
160            aux_element: None,
161            layout: None,
162            epilogue: None,
163            arch: ArchSku::Sm80,
164            backend: BackendKind::Bespoke,
165            precision_guarantee,
166        };
167        Ok(Self {
168            desc: *desc,
169            sku,
170            _marker: PhantomData,
171        })
172    }
173
174    /// Validate args against the descriptor.
175    pub fn can_implement(&self, args: &KvCacheAppendArgs<'_, T>) -> Result<()> {
176        let shape_k_new = [
177            self.desc.batch_size,
178            self.desc.num_heads,
179            self.desc.new_len,
180            self.desc.d_k,
181        ];
182        let shape_v_new = [
183            self.desc.batch_size,
184            self.desc.num_heads,
185            self.desc.new_len,
186            self.desc.d_v,
187        ];
188        let shape_k_cache = [
189            self.desc.batch_size,
190            self.desc.num_heads,
191            self.desc.max_cache_len,
192            self.desc.d_k,
193        ];
194        let shape_v_cache = [
195            self.desc.batch_size,
196            self.desc.num_heads,
197            self.desc.max_cache_len,
198            self.desc.d_v,
199        ];
200        if args.k_new.shape != shape_k_new {
201            return Err(Error::InvalidProblem(
202                "baracuda-kernels::KvCacheAppendPlan: k_new shape mismatch",
203            ));
204        }
205        if args.v_new.shape != shape_v_new {
206            return Err(Error::InvalidProblem(
207                "baracuda-kernels::KvCacheAppendPlan: v_new shape mismatch",
208            ));
209        }
210        if args.k_cache.shape != shape_k_cache {
211            return Err(Error::InvalidProblem(
212                "baracuda-kernels::KvCacheAppendPlan: k_cache shape mismatch",
213            ));
214        }
215        if args.v_cache.shape != shape_v_cache {
216            return Err(Error::InvalidProblem(
217                "baracuda-kernels::KvCacheAppendPlan: v_cache shape mismatch",
218            ));
219        }
220        if args.cache_offsets.shape != [self.desc.batch_size] {
221            return Err(Error::InvalidProblem(
222                "baracuda-kernels::KvCacheAppendPlan: cache_offsets shape must be [batch_size]",
223            ));
224        }
225        if !args.k_new.is_contiguous()
226            || !args.v_new.is_contiguous()
227            || !args.k_cache.is_contiguous()
228            || !args.v_cache.is_contiguous()
229        {
230            return Err(Error::Unsupported(
231                "baracuda-kernels::KvCacheAppendPlan: trailblazer requires contiguous K/V tensors",
232            ));
233        }
234        if args.cache_offsets.stride != [1] {
235            return Err(Error::Unsupported(
236                "baracuda-kernels::KvCacheAppendPlan: cache_offsets must be unit-stride",
237            ));
238        }
239        let k_new_n = args.k_new.numel();
240        let v_new_n = args.v_new.numel();
241        let k_cache_n = args.k_cache.numel();
242        let v_cache_n = args.v_cache.numel();
243        if (args.k_new.data.len() as i64) < k_new_n
244            || (args.v_new.data.len() as i64) < v_new_n
245            || (args.k_cache.data.len() as i64) < k_cache_n
246            || (args.v_cache.data.len() as i64) < v_cache_n
247        {
248            return Err(Error::BufferTooSmall {
249                needed: k_new_n.max(v_new_n).max(k_cache_n).max(v_cache_n) as usize,
250                got: 0,
251            });
252        }
253        if (args.cache_offsets.data.len() as i64) < self.desc.batch_size as i64 {
254            return Err(Error::BufferTooSmall {
255                needed: self.desc.batch_size as usize,
256                got: args.cache_offsets.data.len(),
257            });
258        }
259        Ok(())
260    }
261
262    /// Workspace size in bytes — zero (pure in-place copy).
263    #[inline]
264    pub fn workspace_size(&self) -> usize {
265        0
266    }
267
268    /// SKU identity.
269    #[inline]
270    pub fn sku(&self) -> KernelSku {
271        self.sku
272    }
273
274    /// Numerical guarantees.
275    #[inline]
276    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
277        self.sku.precision_guarantee
278    }
279
280    /// Launch the K + V copy kernels on the supplied stream.
281    pub fn run(
282        &self,
283        stream: &Stream,
284        _workspace: Workspace<'_>,
285        args: KvCacheAppendArgs<'_, T>,
286    ) -> Result<()> {
287        self.can_implement(&args)?;
288        // Empty problem — nothing to do.
289        if self.desc.batch_size == 0
290            || self.desc.num_heads == 0
291            || self.desc.new_len == 0
292        {
293            return Ok(());
294        }
295        let stream_ptr = stream.as_raw() as *mut c_void;
296        let k_new_ptr = args.k_new.data.as_raw().0 as *const c_void;
297        let v_new_ptr = args.v_new.data.as_raw().0 as *const c_void;
298        let offsets_ptr = args.cache_offsets.data.as_raw().0 as *const c_void;
299        let k_cache_ptr = args.k_cache.data.as_raw().0 as *mut c_void;
300        let v_cache_ptr = args.v_cache.data.as_raw().0 as *mut c_void;
301
302        let status = match T::KIND {
303            ElementKind::F32 => unsafe {
304                baracuda_kernels_sys::baracuda_kernels_kv_cache_append_f32_run(
305                    self.desc.batch_size,
306                    self.desc.num_heads,
307                    self.desc.new_len,
308                    self.desc.max_cache_len,
309                    self.desc.d_k,
310                    self.desc.d_v,
311                    k_new_ptr,
312                    v_new_ptr,
313                    offsets_ptr,
314                    k_cache_ptr,
315                    v_cache_ptr,
316                    core::ptr::null_mut(),
317                    0,
318                    stream_ptr,
319                )
320            },
321            ElementKind::F16 => unsafe {
322                baracuda_kernels_sys::baracuda_kernels_kv_cache_append_f16_run(
323                    self.desc.batch_size,
324                    self.desc.num_heads,
325                    self.desc.new_len,
326                    self.desc.max_cache_len,
327                    self.desc.d_k,
328                    self.desc.d_v,
329                    k_new_ptr,
330                    v_new_ptr,
331                    offsets_ptr,
332                    k_cache_ptr,
333                    v_cache_ptr,
334                    core::ptr::null_mut(),
335                    0,
336                    stream_ptr,
337                )
338            },
339            ElementKind::Bf16 => unsafe {
340                baracuda_kernels_sys::baracuda_kernels_kv_cache_append_bf16_run(
341                    self.desc.batch_size,
342                    self.desc.num_heads,
343                    self.desc.new_len,
344                    self.desc.max_cache_len,
345                    self.desc.d_k,
346                    self.desc.d_v,
347                    k_new_ptr,
348                    v_new_ptr,
349                    offsets_ptr,
350                    k_cache_ptr,
351                    v_cache_ptr,
352                    core::ptr::null_mut(),
353                    0,
354                    stream_ptr,
355                )
356            },
357            ElementKind::F64 => unsafe {
358                baracuda_kernels_sys::baracuda_kernels_kv_cache_append_f64_run(
359                    self.desc.batch_size,
360                    self.desc.num_heads,
361                    self.desc.new_len,
362                    self.desc.max_cache_len,
363                    self.desc.d_k,
364                    self.desc.d_v,
365                    k_new_ptr,
366                    v_new_ptr,
367                    offsets_ptr,
368                    k_cache_ptr,
369                    v_cache_ptr,
370                    core::ptr::null_mut(),
371                    0,
372                    stream_ptr,
373                )
374            },
375            _ => {
376                return Err(Error::Unsupported(
377                    "baracuda-kernels::KvCacheAppendPlan::run reached an unimplemented dtype",
378                ));
379            }
380        };
381        map_status(status)
382    }
383}