1use core::ffi::c_void;
29use core::marker::PhantomData;
30
31use baracuda_cutlass::{Error, Result};
32use baracuda_driver::Stream;
33use baracuda_kernels_types::{
34 ArchSku, ArgReduceKind, BackendKind, Element, ElementKind, IndexOutputElement,
35 IndexOutputKind, KernelSku, MathPrecision, OpCategory, PlanPreference, PrecisionGuarantee,
36 TensorMut, TensorRef, Workspace,
37};
38
39#[derive(Copy, Clone, Debug)]
41pub struct ArgReduceDescriptor<const N: usize> {
42 pub kind: ArgReduceKind,
44 pub input_shape: [i32; N],
46 pub reduce_axis: u8,
48 pub element: ElementKind,
50}
51
52impl<const N: usize> ArgReduceDescriptor<N> {
53 pub fn output_shape(&self) -> [i32; N] {
55 let mut out = self.input_shape;
56 out[self.reduce_axis as usize] = 1;
57 out
58 }
59}
60
61pub struct ArgReduceArgs<'a, T: Element, const N: usize, I: IndexOutputElement = i64> {
66 pub x: TensorRef<'a, T, N>,
68 pub y: TensorMut<'a, I, N>,
71}
72
73pub struct ArgReducePlan<T: Element, const N: usize, I: IndexOutputElement = i64> {
84 desc: ArgReduceDescriptor<N>,
85 sku: KernelSku,
86 _marker: PhantomData<(T, I)>,
87}
88
89impl<T: Element, const N: usize, I: IndexOutputElement> ArgReducePlan<T, N, I> {
90 pub fn select(
92 _stream: &Stream,
93 desc: &ArgReduceDescriptor<N>,
94 _pref: PlanPreference,
95 ) -> Result<Self> {
96 if desc.element != T::KIND {
97 return Err(Error::Unsupported(
98 "baracuda-kernels::ArgReducePlan: descriptor element != type parameter T",
99 ));
100 }
101 if (desc.reduce_axis as usize) >= N {
102 return Err(Error::InvalidProblem(
103 "baracuda-kernels::ArgReducePlan: reduce_axis must be < rank",
104 ));
105 }
106 for &d in desc.input_shape.iter() {
107 if d < 0 {
108 return Err(Error::InvalidProblem(
109 "baracuda-kernels::ArgReducePlan: input_shape dims must be non-negative",
110 ));
111 }
112 }
113 if desc.input_shape[desc.reduce_axis as usize] <= 0 {
114 return Err(Error::InvalidProblem(
115 "baracuda-kernels::ArgReducePlan: cannot arg-reduce over an empty axis",
116 ));
117 }
118 let supported = matches!(
119 T::KIND,
120 ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
121 );
122 if !supported {
123 return Err(Error::Unsupported(
124 "baracuda-kernels::ArgReducePlan: today only `f32`, `f16`, `bf16`, `f64` \
125 value dtypes are wired; other dtypes land in future fanout",
126 ));
127 }
128 let precision_guarantee = PrecisionGuarantee {
129 math_precision: MathPrecision::F32,
130 accumulator: ElementKind::F32,
131 bit_stable_on_same_hardware: true,
132 deterministic: true,
133 };
134 let aux_element = match I::KIND {
140 IndexOutputKind::U32 => None,
141 IndexOutputKind::I32 => Some(ElementKind::I32),
142 IndexOutputKind::I64 => Some(ElementKind::I64),
143 _ => None,
147 };
148 let sku = KernelSku {
149 category: OpCategory::Reduction,
150 op: desc.kind as u16,
151 element: T::KIND,
152 aux_element,
153 layout: None,
154 epilogue: None,
155 arch: ArchSku::Sm80,
156 backend: BackendKind::Bespoke,
157 precision_guarantee,
158 };
159 Ok(Self {
160 desc: *desc,
161 sku,
162 _marker: PhantomData,
163 })
164 }
165
166 pub fn can_implement(&self, args: &ArgReduceArgs<'_, T, N, I>) -> Result<()> {
168 if args.x.shape != self.desc.input_shape {
169 return Err(Error::InvalidProblem(
170 "baracuda-kernels::ArgReducePlan: X shape mismatch with descriptor",
171 ));
172 }
173 let expected_out = self.desc.output_shape();
174 if args.y.shape != expected_out {
175 return Err(Error::InvalidProblem(
176 "baracuda-kernels::ArgReducePlan: Y shape mismatch with derived output \
177 shape (input shape with reduce_axis collapsed to 1)",
178 ));
179 }
180 if N > 8 {
181 return Err(Error::Unsupported(
182 "baracuda-kernels::ArgReducePlan: tensor rank > 8 not supported",
183 ));
184 }
185 let y_numel = args.y.numel();
186 let x_numel = args.x.numel();
187 let x_len = args.x.data.len() as i64;
188 let y_len = args.y.data.len() as i64;
189 if y_len < y_numel {
190 return Err(Error::BufferTooSmall {
191 needed: y_numel as usize,
192 got: y_len as usize,
193 });
194 }
195 if x_len < x_numel {
196 return Err(Error::BufferTooSmall {
197 needed: x_numel as usize,
198 got: x_len as usize,
199 });
200 }
201 Ok(())
202 }
203
204 #[inline]
206 pub fn workspace_size(&self) -> usize {
207 0
208 }
209 #[inline]
211 pub fn sku(&self) -> KernelSku {
212 self.sku
213 }
214 #[inline]
216 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
217 self.sku.precision_guarantee
218 }
219
220 pub fn run(
222 &self,
223 stream: &Stream,
224 _workspace: Workspace<'_>,
225 args: ArgReduceArgs<'_, T, N, I>,
226 ) -> Result<()> {
227 self.can_implement(&args)?;
228 let output_numel = args.y.numel();
229 if output_numel == 0 {
230 return Ok(());
231 }
232 let x_ptr = args.x.data.as_raw().0 as *const c_void;
233 let y_ptr = args.y.data.as_raw().0 as *mut c_void;
234 let stream_ptr = stream.as_raw() as *mut c_void;
235
236 let output_shape = self.desc.output_shape();
237 let stride_x = args.x.stride;
238 let stride_y = args.y.stride;
239 let rank = N as i32;
240 let reduce_axis = self.desc.reduce_axis as i32;
241 let reduce_extent = self.desc.input_shape[self.desc.reduce_axis as usize];
242 let reduce_stride_x = args.x.stride[self.desc.reduce_axis as usize];
243
244 let status = match (self.desc.kind, T::KIND, I::KIND) {
245 (ArgReduceKind::Argmax, ElementKind::F32, IndexOutputKind::I64) => unsafe {
249 baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmax_f32_run(
250 output_numel, rank, output_shape.as_ptr(),
251 stride_x.as_ptr(), stride_y.as_ptr(),
252 reduce_axis, reduce_extent, reduce_stride_x,
253 x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
254 )
255 },
256 (ArgReduceKind::Argmin, ElementKind::F32, IndexOutputKind::I64) => unsafe {
257 baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmin_f32_run(
258 output_numel, rank, output_shape.as_ptr(),
259 stride_x.as_ptr(), stride_y.as_ptr(),
260 reduce_axis, reduce_extent, reduce_stride_x,
261 x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
262 )
263 },
264 (ArgReduceKind::Argmax, ElementKind::F16, IndexOutputKind::I64) => unsafe {
265 baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmax_f16_run(
266 output_numel, rank, output_shape.as_ptr(),
267 stride_x.as_ptr(), stride_y.as_ptr(),
268 reduce_axis, reduce_extent, reduce_stride_x,
269 x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
270 )
271 },
272 (ArgReduceKind::Argmin, ElementKind::F16, IndexOutputKind::I64) => unsafe {
273 baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmin_f16_run(
274 output_numel, rank, output_shape.as_ptr(),
275 stride_x.as_ptr(), stride_y.as_ptr(),
276 reduce_axis, reduce_extent, reduce_stride_x,
277 x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
278 )
279 },
280 (ArgReduceKind::Argmax, ElementKind::Bf16, IndexOutputKind::I64) => unsafe {
281 baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmax_bf16_run(
282 output_numel, rank, output_shape.as_ptr(),
283 stride_x.as_ptr(), stride_y.as_ptr(),
284 reduce_axis, reduce_extent, reduce_stride_x,
285 x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
286 )
287 },
288 (ArgReduceKind::Argmin, ElementKind::Bf16, IndexOutputKind::I64) => unsafe {
289 baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmin_bf16_run(
290 output_numel, rank, output_shape.as_ptr(),
291 stride_x.as_ptr(), stride_y.as_ptr(),
292 reduce_axis, reduce_extent, reduce_stride_x,
293 x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
294 )
295 },
296 (ArgReduceKind::Argmax, ElementKind::F64, IndexOutputKind::I64) => unsafe {
297 baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmax_f64_run(
298 output_numel, rank, output_shape.as_ptr(),
299 stride_x.as_ptr(), stride_y.as_ptr(),
300 reduce_axis, reduce_extent, reduce_stride_x,
301 x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
302 )
303 },
304 (ArgReduceKind::Argmin, ElementKind::F64, IndexOutputKind::I64) => unsafe {
305 baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmin_f64_run(
306 output_numel, rank, output_shape.as_ptr(),
307 stride_x.as_ptr(), stride_y.as_ptr(),
308 reduce_axis, reduce_extent, reduce_stride_x,
309 x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
310 )
311 },
312 (ArgReduceKind::Argmax, ElementKind::F32, IndexOutputKind::U32) => unsafe {
316 baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmax_f32_u32_run(
317 output_numel, rank, output_shape.as_ptr(),
318 stride_x.as_ptr(), stride_y.as_ptr(),
319 reduce_axis, reduce_extent, reduce_stride_x,
320 x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
321 )
322 },
323 (ArgReduceKind::Argmin, ElementKind::F32, IndexOutputKind::U32) => unsafe {
324 baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmin_f32_u32_run(
325 output_numel, rank, output_shape.as_ptr(),
326 stride_x.as_ptr(), stride_y.as_ptr(),
327 reduce_axis, reduce_extent, reduce_stride_x,
328 x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
329 )
330 },
331 (ArgReduceKind::Argmax, ElementKind::F16, IndexOutputKind::U32) => unsafe {
332 baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmax_f16_u32_run(
333 output_numel, rank, output_shape.as_ptr(),
334 stride_x.as_ptr(), stride_y.as_ptr(),
335 reduce_axis, reduce_extent, reduce_stride_x,
336 x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
337 )
338 },
339 (ArgReduceKind::Argmin, ElementKind::F16, IndexOutputKind::U32) => unsafe {
340 baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmin_f16_u32_run(
341 output_numel, rank, output_shape.as_ptr(),
342 stride_x.as_ptr(), stride_y.as_ptr(),
343 reduce_axis, reduce_extent, reduce_stride_x,
344 x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
345 )
346 },
347 (ArgReduceKind::Argmax, ElementKind::Bf16, IndexOutputKind::U32) => unsafe {
348 baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmax_bf16_u32_run(
349 output_numel, rank, output_shape.as_ptr(),
350 stride_x.as_ptr(), stride_y.as_ptr(),
351 reduce_axis, reduce_extent, reduce_stride_x,
352 x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
353 )
354 },
355 (ArgReduceKind::Argmin, ElementKind::Bf16, IndexOutputKind::U32) => unsafe {
356 baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmin_bf16_u32_run(
357 output_numel, rank, output_shape.as_ptr(),
358 stride_x.as_ptr(), stride_y.as_ptr(),
359 reduce_axis, reduce_extent, reduce_stride_x,
360 x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
361 )
362 },
363 (ArgReduceKind::Argmax, ElementKind::F64, IndexOutputKind::U32) => unsafe {
364 baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmax_f64_u32_run(
365 output_numel, rank, output_shape.as_ptr(),
366 stride_x.as_ptr(), stride_y.as_ptr(),
367 reduce_axis, reduce_extent, reduce_stride_x,
368 x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
369 )
370 },
371 (ArgReduceKind::Argmin, ElementKind::F64, IndexOutputKind::U32) => unsafe {
372 baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmin_f64_u32_run(
373 output_numel, rank, output_shape.as_ptr(),
374 stride_x.as_ptr(), stride_y.as_ptr(),
375 reduce_axis, reduce_extent, reduce_stride_x,
376 x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
377 )
378 },
379 (ArgReduceKind::Argmax, ElementKind::F32, IndexOutputKind::I32) => unsafe {
383 baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmax_f32_i32_run(
384 output_numel, rank, output_shape.as_ptr(),
385 stride_x.as_ptr(), stride_y.as_ptr(),
386 reduce_axis, reduce_extent, reduce_stride_x,
387 x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
388 )
389 },
390 (ArgReduceKind::Argmin, ElementKind::F32, IndexOutputKind::I32) => unsafe {
391 baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmin_f32_i32_run(
392 output_numel, rank, output_shape.as_ptr(),
393 stride_x.as_ptr(), stride_y.as_ptr(),
394 reduce_axis, reduce_extent, reduce_stride_x,
395 x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
396 )
397 },
398 (ArgReduceKind::Argmax, ElementKind::F16, IndexOutputKind::I32) => unsafe {
399 baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmax_f16_i32_run(
400 output_numel, rank, output_shape.as_ptr(),
401 stride_x.as_ptr(), stride_y.as_ptr(),
402 reduce_axis, reduce_extent, reduce_stride_x,
403 x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
404 )
405 },
406 (ArgReduceKind::Argmin, ElementKind::F16, IndexOutputKind::I32) => unsafe {
407 baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmin_f16_i32_run(
408 output_numel, rank, output_shape.as_ptr(),
409 stride_x.as_ptr(), stride_y.as_ptr(),
410 reduce_axis, reduce_extent, reduce_stride_x,
411 x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
412 )
413 },
414 (ArgReduceKind::Argmax, ElementKind::Bf16, IndexOutputKind::I32) => unsafe {
415 baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmax_bf16_i32_run(
416 output_numel, rank, output_shape.as_ptr(),
417 stride_x.as_ptr(), stride_y.as_ptr(),
418 reduce_axis, reduce_extent, reduce_stride_x,
419 x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
420 )
421 },
422 (ArgReduceKind::Argmin, ElementKind::Bf16, IndexOutputKind::I32) => unsafe {
423 baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmin_bf16_i32_run(
424 output_numel, rank, output_shape.as_ptr(),
425 stride_x.as_ptr(), stride_y.as_ptr(),
426 reduce_axis, reduce_extent, reduce_stride_x,
427 x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
428 )
429 },
430 (ArgReduceKind::Argmax, ElementKind::F64, IndexOutputKind::I32) => unsafe {
431 baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmax_f64_i32_run(
432 output_numel, rank, output_shape.as_ptr(),
433 stride_x.as_ptr(), stride_y.as_ptr(),
434 reduce_axis, reduce_extent, reduce_stride_x,
435 x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
436 )
437 },
438 (ArgReduceKind::Argmin, ElementKind::F64, IndexOutputKind::I32) => unsafe {
439 baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmin_f64_i32_run(
440 output_numel, rank, output_shape.as_ptr(),
441 stride_x.as_ptr(), stride_y.as_ptr(),
442 reduce_axis, reduce_extent, reduce_stride_x,
443 x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
444 )
445 },
446 _ => {
447 return Err(Error::Unsupported(
448 "baracuda-kernels::ArgReducePlan::run: only `{Argmax,Argmin} × \
449 {f32,f16,bf16,f64} × {u32,i32,i64}` wired today",
450 ));
451 }
452 };
453 map_status(status)
454 }
455}
456
457fn map_status(code: i32) -> Result<()> {
458 match code {
459 0 => Ok(()),
460 1 => Err(Error::MisalignedOperand),
461 2 => Err(Error::InvalidProblem(
462 "baracuda-kernels-sys reported invalid problem",
463 )),
464 3 => Err(Error::Unsupported(
465 "baracuda-kernels-sys reported unsupported configuration",
466 )),
467 4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
468 n => Err(Error::CutlassInternal(n)),
469 }
470}