baracuda_kernels/attention/
kv_cache.rs1use core::ffi::c_void;
37use core::marker::PhantomData;
38
39use baracuda_cutlass::{Error, Result};
40use baracuda_driver::Stream;
41use baracuda_kernels_types::{
42 ArchSku, AttentionKind, BackendKind, Element, ElementKind, KernelSku, MathPrecision,
43 OpCategory, PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
44};
45
46use super::map_status;
47
48#[derive(Copy, Clone, Debug)]
50pub struct KvCacheAppendDescriptor {
51 pub batch_size: i32,
53 pub num_heads: i32,
55 pub new_len: i32,
57 pub max_cache_len: i32,
59 pub d_k: i32,
61 pub d_v: i32,
63 pub element: ElementKind,
65}
66
67pub struct KvCacheAppendArgs<'a, T: Element> {
69 pub k_new: TensorRef<'a, T, 4>,
71 pub v_new: TensorRef<'a, T, 4>,
73 pub cache_offsets: TensorRef<'a, i64, 1>,
78 pub k_cache: TensorMut<'a, T, 4>,
81 pub v_cache: TensorMut<'a, T, 4>,
84}
85
86pub struct KvCacheAppendPlan<T: Element> {
109 desc: KvCacheAppendDescriptor,
110 sku: KernelSku,
111 _marker: PhantomData<T>,
112}
113
114impl<T: Element> KvCacheAppendPlan<T> {
115 pub fn select(
117 _stream: &Stream,
118 desc: &KvCacheAppendDescriptor,
119 _pref: PlanPreference,
120 ) -> Result<Self> {
121 if desc.element != T::KIND {
122 return Err(Error::Unsupported(
123 "baracuda-kernels::KvCacheAppendPlan: descriptor element != T",
124 ));
125 }
126 if desc.batch_size < 0
127 || desc.num_heads < 0
128 || desc.new_len < 0
129 || desc.max_cache_len < 0
130 || desc.d_k < 0
131 || desc.d_v < 0
132 {
133 return Err(Error::InvalidProblem(
134 "baracuda-kernels::KvCacheAppendPlan: extents must be non-negative",
135 ));
136 }
137 let dtype_in_scope = matches!(
138 T::KIND,
139 ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
140 );
141 if !dtype_in_scope {
142 return Err(Error::Unsupported(
143 "baracuda-kernels::KvCacheAppendPlan: wired today: `{f32, f16, bf16, f64}`",
144 ));
145 }
146
147 let precision_guarantee = PrecisionGuarantee {
148 math_precision: MathPrecision::F32,
149 accumulator: T::KIND,
153 bit_stable_on_same_hardware: true,
154 deterministic: true,
155 };
156 let sku = KernelSku {
157 category: OpCategory::Attention,
158 op: AttentionKind::KvCache as u16,
159 element: T::KIND,
160 aux_element: None,
161 layout: None,
162 epilogue: None,
163 arch: ArchSku::Sm80,
164 backend: BackendKind::Bespoke,
165 precision_guarantee,
166 };
167 Ok(Self {
168 desc: *desc,
169 sku,
170 _marker: PhantomData,
171 })
172 }
173
174 pub fn can_implement(&self, args: &KvCacheAppendArgs<'_, T>) -> Result<()> {
176 let shape_k_new = [
177 self.desc.batch_size,
178 self.desc.num_heads,
179 self.desc.new_len,
180 self.desc.d_k,
181 ];
182 let shape_v_new = [
183 self.desc.batch_size,
184 self.desc.num_heads,
185 self.desc.new_len,
186 self.desc.d_v,
187 ];
188 let shape_k_cache = [
189 self.desc.batch_size,
190 self.desc.num_heads,
191 self.desc.max_cache_len,
192 self.desc.d_k,
193 ];
194 let shape_v_cache = [
195 self.desc.batch_size,
196 self.desc.num_heads,
197 self.desc.max_cache_len,
198 self.desc.d_v,
199 ];
200 if args.k_new.shape != shape_k_new {
201 return Err(Error::InvalidProblem(
202 "baracuda-kernels::KvCacheAppendPlan: k_new shape mismatch",
203 ));
204 }
205 if args.v_new.shape != shape_v_new {
206 return Err(Error::InvalidProblem(
207 "baracuda-kernels::KvCacheAppendPlan: v_new shape mismatch",
208 ));
209 }
210 if args.k_cache.shape != shape_k_cache {
211 return Err(Error::InvalidProblem(
212 "baracuda-kernels::KvCacheAppendPlan: k_cache shape mismatch",
213 ));
214 }
215 if args.v_cache.shape != shape_v_cache {
216 return Err(Error::InvalidProblem(
217 "baracuda-kernels::KvCacheAppendPlan: v_cache shape mismatch",
218 ));
219 }
220 if args.cache_offsets.shape != [self.desc.batch_size] {
221 return Err(Error::InvalidProblem(
222 "baracuda-kernels::KvCacheAppendPlan: cache_offsets shape must be [batch_size]",
223 ));
224 }
225 if !args.k_new.is_contiguous()
226 || !args.v_new.is_contiguous()
227 || !args.k_cache.is_contiguous()
228 || !args.v_cache.is_contiguous()
229 {
230 return Err(Error::Unsupported(
231 "baracuda-kernels::KvCacheAppendPlan: trailblazer requires contiguous K/V tensors",
232 ));
233 }
234 if args.cache_offsets.stride != [1] {
235 return Err(Error::Unsupported(
236 "baracuda-kernels::KvCacheAppendPlan: cache_offsets must be unit-stride",
237 ));
238 }
239 let k_new_n = args.k_new.numel();
240 let v_new_n = args.v_new.numel();
241 let k_cache_n = args.k_cache.numel();
242 let v_cache_n = args.v_cache.numel();
243 if (args.k_new.data.len() as i64) < k_new_n
244 || (args.v_new.data.len() as i64) < v_new_n
245 || (args.k_cache.data.len() as i64) < k_cache_n
246 || (args.v_cache.data.len() as i64) < v_cache_n
247 {
248 return Err(Error::BufferTooSmall {
249 needed: k_new_n.max(v_new_n).max(k_cache_n).max(v_cache_n) as usize,
250 got: 0,
251 });
252 }
253 if (args.cache_offsets.data.len() as i64) < self.desc.batch_size as i64 {
254 return Err(Error::BufferTooSmall {
255 needed: self.desc.batch_size as usize,
256 got: args.cache_offsets.data.len(),
257 });
258 }
259 Ok(())
260 }
261
262 #[inline]
264 pub fn workspace_size(&self) -> usize {
265 0
266 }
267
268 #[inline]
270 pub fn sku(&self) -> KernelSku {
271 self.sku
272 }
273
274 #[inline]
276 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
277 self.sku.precision_guarantee
278 }
279
280 pub fn run(
282 &self,
283 stream: &Stream,
284 _workspace: Workspace<'_>,
285 args: KvCacheAppendArgs<'_, T>,
286 ) -> Result<()> {
287 self.can_implement(&args)?;
288 if self.desc.batch_size == 0
290 || self.desc.num_heads == 0
291 || self.desc.new_len == 0
292 {
293 return Ok(());
294 }
295 let stream_ptr = stream.as_raw() as *mut c_void;
296 let k_new_ptr = args.k_new.data.as_raw().0 as *const c_void;
297 let v_new_ptr = args.v_new.data.as_raw().0 as *const c_void;
298 let offsets_ptr = args.cache_offsets.data.as_raw().0 as *const c_void;
299 let k_cache_ptr = args.k_cache.data.as_raw().0 as *mut c_void;
300 let v_cache_ptr = args.v_cache.data.as_raw().0 as *mut c_void;
301
302 let status = match T::KIND {
303 ElementKind::F32 => unsafe {
304 baracuda_kernels_sys::baracuda_kernels_kv_cache_append_f32_run(
305 self.desc.batch_size,
306 self.desc.num_heads,
307 self.desc.new_len,
308 self.desc.max_cache_len,
309 self.desc.d_k,
310 self.desc.d_v,
311 k_new_ptr,
312 v_new_ptr,
313 offsets_ptr,
314 k_cache_ptr,
315 v_cache_ptr,
316 core::ptr::null_mut(),
317 0,
318 stream_ptr,
319 )
320 },
321 ElementKind::F16 => unsafe {
322 baracuda_kernels_sys::baracuda_kernels_kv_cache_append_f16_run(
323 self.desc.batch_size,
324 self.desc.num_heads,
325 self.desc.new_len,
326 self.desc.max_cache_len,
327 self.desc.d_k,
328 self.desc.d_v,
329 k_new_ptr,
330 v_new_ptr,
331 offsets_ptr,
332 k_cache_ptr,
333 v_cache_ptr,
334 core::ptr::null_mut(),
335 0,
336 stream_ptr,
337 )
338 },
339 ElementKind::Bf16 => unsafe {
340 baracuda_kernels_sys::baracuda_kernels_kv_cache_append_bf16_run(
341 self.desc.batch_size,
342 self.desc.num_heads,
343 self.desc.new_len,
344 self.desc.max_cache_len,
345 self.desc.d_k,
346 self.desc.d_v,
347 k_new_ptr,
348 v_new_ptr,
349 offsets_ptr,
350 k_cache_ptr,
351 v_cache_ptr,
352 core::ptr::null_mut(),
353 0,
354 stream_ptr,
355 )
356 },
357 ElementKind::F64 => unsafe {
358 baracuda_kernels_sys::baracuda_kernels_kv_cache_append_f64_run(
359 self.desc.batch_size,
360 self.desc.num_heads,
361 self.desc.new_len,
362 self.desc.max_cache_len,
363 self.desc.d_k,
364 self.desc.d_v,
365 k_new_ptr,
366 v_new_ptr,
367 offsets_ptr,
368 k_cache_ptr,
369 v_cache_ptr,
370 core::ptr::null_mut(),
371 0,
372 stream_ptr,
373 )
374 },
375 _ => {
376 return Err(Error::Unsupported(
377 "baracuda-kernels::KvCacheAppendPlan::run reached an unimplemented dtype",
378 ));
379 }
380 };
381 map_status(status)
382 }
383}