1use core::ffi::c_void;
55use core::marker::PhantomData;
56
57use baracuda_cutlass::{Error, Result};
58use baracuda_driver::Stream;
59use baracuda_kernels_types::{
60 ArchSku, AttentionKind, BackendKind, Element, ElementKind, KernelSku, MathPrecision,
61 OpCategory, PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
62};
63
64
65#[derive(Copy, Clone, Debug)]
68pub struct PagedKvCacheDescriptor {
69 pub page_size: i32,
71 pub num_total_pages: i32,
73 pub num_kv_heads: i32,
75 pub head_dim: i32,
77 pub element: ElementKind,
79}
80
81#[derive(Copy, Clone, Debug)]
83pub struct BatchPagedDecodeDescriptor {
84 pub batch_size: i32,
86 pub num_qo_heads: i32,
88 pub sm_scale: f32,
90 pub paged_kv: PagedKvCacheDescriptor,
92}
93
94pub struct BatchPagedDecodeArgs<'a, T: Element> {
96 pub q: TensorRef<'a, T, 3>,
98 pub k_data: TensorRef<'a, T, 4>,
102 pub v_data: TensorRef<'a, T, 4>,
104 pub indices: TensorRef<'a, i32, 1>,
106 pub indptr: TensorRef<'a, i32, 1>,
108 pub last_page_len: TensorRef<'a, i32, 1>,
110 pub o: TensorMut<'a, T, 3>,
112 pub lse: TensorMut<'a, f32, 2>,
114}
115
116pub struct BatchPagedDecodePlan<T: Element> {
128 desc: BatchPagedDecodeDescriptor,
129 sku: KernelSku,
130 _marker: PhantomData<T>,
131}
132
133impl<T: Element> BatchPagedDecodePlan<T> {
134 pub fn select(
136 _stream: &Stream,
137 desc: &BatchPagedDecodeDescriptor,
138 _pref: PlanPreference,
139 ) -> Result<Self> {
140 if desc.paged_kv.element != T::KIND {
141 return Err(Error::Unsupported(
142 "BatchPagedDecodePlan: descriptor element != T",
143 ));
144 }
145 if desc.batch_size <= 0
146 || desc.num_qo_heads <= 0
147 || desc.paged_kv.num_kv_heads <= 0
148 || desc.paged_kv.page_size <= 0
149 || desc.paged_kv.num_total_pages <= 0
150 {
151 return Err(Error::InvalidProblem(
152 "BatchPagedDecodePlan: extents must be positive",
153 ));
154 }
155 if desc.num_qo_heads % desc.paged_kv.num_kv_heads != 0 {
156 return Err(Error::InvalidProblem(
157 "BatchPagedDecodePlan: num_qo_heads must be a multiple of num_kv_heads (GQA group size must be integer)",
158 ));
159 }
160 let head_dim = desc.paged_kv.head_dim;
161 if !matches!(head_dim, 64 | 128 | 256) {
162 return Err(Error::Unsupported(
163 "BatchPagedDecodePlan: head_dim must be 64, 128, or 256",
164 ));
165 }
166 if !matches!(T::KIND, ElementKind::F16 | ElementKind::Bf16 | ElementKind::F32) {
167 return Err(Error::Unsupported(
168 "BatchPagedDecodePlan: element type must be f16, bf16, or f32",
169 ));
170 }
171 let precision_guarantee = PrecisionGuarantee {
172 math_precision: MathPrecision::F32,
173 accumulator: ElementKind::F32,
174 bit_stable_on_same_hardware: true,
175 deterministic: true,
176 };
177 let sku = KernelSku {
178 category: OpCategory::Attention,
179 op: AttentionKind::PagedAttention as u16,
180 element: T::KIND,
181 aux_element: None,
182 layout: None,
183 epilogue: None,
184 arch: ArchSku::Sm80,
185 backend: BackendKind::FlashInfer,
186 precision_guarantee,
187 };
188 Ok(Self {
189 desc: *desc,
190 sku,
191 _marker: PhantomData,
192 })
193 }
194
195 pub fn can_implement(&self, args: &BatchPagedDecodeArgs<'_, T>) -> Result<()> {
197 let q_shape = [
198 self.desc.batch_size,
199 self.desc.num_qo_heads,
200 self.desc.paged_kv.head_dim,
201 ];
202 if args.q.shape != q_shape {
203 return Err(Error::InvalidProblem("BatchPagedDecodePlan: q shape mismatch"));
204 }
205 let cache_shape = [
206 self.desc.paged_kv.num_total_pages,
207 self.desc.paged_kv.num_kv_heads,
208 self.desc.paged_kv.page_size,
209 self.desc.paged_kv.head_dim,
210 ];
211 if args.k_data.shape != cache_shape || args.v_data.shape != cache_shape {
212 return Err(Error::InvalidProblem(
213 "BatchPagedDecodePlan: k_data/v_data shape mismatch",
214 ));
215 }
216 if args.indptr.shape != [self.desc.batch_size + 1] {
217 return Err(Error::InvalidProblem(
218 "BatchPagedDecodePlan: indptr shape must be [batch + 1]",
219 ));
220 }
221 if args.last_page_len.shape != [self.desc.batch_size] {
222 return Err(Error::InvalidProblem(
223 "BatchPagedDecodePlan: last_page_len shape must be [batch]",
224 ));
225 }
226 if args.o.shape != q_shape {
227 return Err(Error::InvalidProblem("BatchPagedDecodePlan: o shape mismatch"));
228 }
229 if args.lse.shape != [self.desc.batch_size, self.desc.num_qo_heads] {
230 return Err(Error::InvalidProblem(
231 "BatchPagedDecodePlan: lse shape must be [batch, num_qo_heads]",
232 ));
233 }
234 if !args.q.is_contiguous()
235 || !args.k_data.is_contiguous()
236 || !args.v_data.is_contiguous()
237 || !args.o.is_contiguous()
238 || !args.lse.is_contiguous()
239 {
240 return Err(Error::Unsupported(
241 "BatchPagedDecodePlan: tensors must be contiguous (Tier 1)",
242 ));
243 }
244 Ok(())
245 }
246
247 #[inline]
250 pub fn workspace_size(&self) -> usize {
251 ((3 * self.desc.batch_size as usize) + 2) * core::mem::size_of::<i32>()
254 }
255
256 #[inline]
258 pub fn sku(&self) -> KernelSku {
259 self.sku
260 }
261
262 #[inline]
264 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
265 self.sku.precision_guarantee
266 }
267
268 pub fn run(
271 &self,
272 stream: &Stream,
273 workspace: Workspace<'_>,
274 args: BatchPagedDecodeArgs<'_, T>,
275 ) -> Result<()> {
276 self.can_implement(&args)?;
277 let need = self.workspace_size();
278 let (ws_ptr, ws_bytes) = match workspace {
279 Workspace::None => {
280 return Err(Error::WorkspaceTooSmall { needed: need, got: 0 });
281 }
282 Workspace::Borrowed(slice) => {
283 if slice.len() < need {
284 return Err(Error::WorkspaceTooSmall {
285 needed: need,
286 got: slice.len(),
287 });
288 }
289 (slice.as_raw().0 as *mut c_void, slice.len())
290 }
291 };
292 #[cfg(not(feature = "flashinfer"))]
293 {
294 let _ = (stream, ws_ptr, ws_bytes, &args);
295 Err(Error::Unsupported(
296 "BatchPagedDecodePlan: `flashinfer` cargo feature is not enabled",
297 ))
298 }
299 #[cfg(feature = "flashinfer")]
300 {
301 let stream_ptr = stream.as_raw() as *mut c_void;
302 let q_ptr = args.q.data.as_raw().0 as *const c_void;
303 let k_ptr = args.k_data.data.as_raw().0 as *mut c_void;
304 let v_ptr = args.v_data.data.as_raw().0 as *mut c_void;
305 let indices_ptr = args.indices.data.as_raw().0 as *mut c_void;
306 let indptr_ptr = args.indptr.data.as_raw().0 as *mut c_void;
307 let last_page_len_ptr = args.last_page_len.data.as_raw().0 as *mut c_void;
308 let o_ptr = args.o.data.as_raw().0 as *mut c_void;
309 let lse_ptr = args.lse.data.as_raw().0 as *mut c_void;
310
311 let status = match T::KIND {
312 ElementKind::F16 => unsafe {
313 baracuda_kernels_sys::baracuda_kernels_flashinfer_paged_decode_f16_run(
314 self.desc.batch_size,
315 self.desc.paged_kv.page_size,
316 self.desc.paged_kv.head_dim,
317 self.desc.num_qo_heads,
318 self.desc.paged_kv.num_kv_heads,
319 self.desc.sm_scale,
320 k_ptr, v_ptr, indices_ptr, indptr_ptr, last_page_len_ptr,
321 q_ptr, o_ptr, lse_ptr,
322 ws_ptr, ws_bytes, stream_ptr,
323 )
324 },
325 ElementKind::Bf16 => unsafe {
326 baracuda_kernels_sys::baracuda_kernels_flashinfer_paged_decode_bf16_run(
327 self.desc.batch_size,
328 self.desc.paged_kv.page_size,
329 self.desc.paged_kv.head_dim,
330 self.desc.num_qo_heads,
331 self.desc.paged_kv.num_kv_heads,
332 self.desc.sm_scale,
333 k_ptr, v_ptr, indices_ptr, indptr_ptr, last_page_len_ptr,
334 q_ptr, o_ptr, lse_ptr,
335 ws_ptr, ws_bytes, stream_ptr,
336 )
337 },
338 ElementKind::F32 => unsafe {
339 baracuda_kernels_sys::baracuda_kernels_flashinfer_paged_decode_f32_run(
340 self.desc.batch_size,
341 self.desc.paged_kv.page_size,
342 self.desc.paged_kv.head_dim,
343 self.desc.num_qo_heads,
344 self.desc.paged_kv.num_kv_heads,
345 self.desc.sm_scale,
346 k_ptr, v_ptr, indices_ptr, indptr_ptr, last_page_len_ptr,
347 q_ptr, o_ptr, lse_ptr,
348 ws_ptr, ws_bytes, stream_ptr,
349 )
350 },
351 _ => {
352 return Err(Error::Unsupported(
353 "BatchPagedDecodePlan::run reached an unimplemented dtype",
354 ));
355 }
356 };
357 map_status(status)
358 }
359 }
360}