Skip to main content

baracuda_kernels/reduce/
arg_axis.rs

1//! Argmax / Argmin single-axis reduction.
2//!
3//! New plan shape from [`crate::ReducePlan`] because the output dtype
4//! differs from the input dtype: input is `T: Element` (value), output
5//! is `I: IndexOutputElement` (defaults to `i64` — PyTorch convention).
6//!
7//! **When to use**: forward argmax / argmin. No backward — `argmax` /
8//! `argmin` are non-differentiable (gradient is zero almost everywhere).
9//!
10//! **Dtypes / shape**: `{Argmax, Argmin} × {f32, f16, bf16, f64}` value
11//! input × `{u32, i32, i64}` index output; tensor rank `1..=8`; reduce
12//! axis must be non-empty.
13//!
14//! **Tie-breaking**: returns the first-occurrence index along the
15//! reduce axis (PyTorch convention).
16//!
17//! **Workspace**: none.
18//!
19//! **Precision**: deterministic, bit-stable on the same hardware (one-
20//! thread-per-output-cell sequential scan over the reduce axis).
21//!
22//! Phase 12.2 (Fuel team feedback): output index dtype is now generic
23//! over [`IndexOutputElement`] (`u32` / `i32` / `i64`). The legacy
24//! default is `i64` so pre-Phase-12.2 callers compile unchanged; opt
25//! into `u32` / `i32` via the third type parameter, e.g.
26//! `ArgReducePlan::<f32, 3, u32>::select(...)`.
27
28use core::ffi::c_void;
29use core::marker::PhantomData;
30
31use baracuda_cutlass::{Error, Result};
32use baracuda_driver::Stream;
33use baracuda_kernels_types::{
34    ArchSku, ArgReduceKind, BackendKind, Element, ElementKind, IndexOutputElement,
35    IndexOutputKind, KernelSku, MathPrecision, OpCategory, PlanPreference, PrecisionGuarantee,
36    TensorMut, TensorRef, Workspace,
37};
38
39/// Descriptor for an argmax / argmin axis reduction.
40#[derive(Copy, Clone, Debug)]
41pub struct ArgReduceDescriptor<const N: usize> {
42    /// Which arg-reduction to apply.
43    pub kind: ArgReduceKind,
44    /// Input tensor shape.
45    pub input_shape: [i32; N],
46    /// Axis to reduce along.
47    pub reduce_axis: u8,
48    /// Input value element type.
49    pub element: ElementKind,
50}
51
52impl<const N: usize> ArgReduceDescriptor<N> {
53    /// Compute the output shape: input shape with reduce axis = 1.
54    pub fn output_shape(&self) -> [i32; N] {
55        let mut out = self.input_shape;
56        out[self.reduce_axis as usize] = 1;
57        out
58    }
59}
60
61/// Args bundle for an arg-reduction launch.
62///
63/// Note the asymmetric dtypes: `x` is the value dtype `T`, `y` is the
64/// index dtype `I` (defaults to `i64` — PyTorch convention).
65pub struct ArgReduceArgs<'a, T: Element, const N: usize, I: IndexOutputElement = i64> {
66    /// Input.
67    pub x: TensorRef<'a, T, N>,
68    /// Output indices — shape matches input with reduce axis = 1. Type
69    /// parameter `I` selects `u32`, `i32`, or `i64` (default).
70    pub y: TensorMut<'a, I, N>,
71}
72
73/// Arg-reduce plan (argmax / argmin) — see module docs for dtypes,
74/// tie-breaking, and precision.
75///
76/// `T: Element` is the value (input) dtype; `I: IndexOutputElement` is
77/// the output index dtype (defaults to `i64`). `const N: usize` is the
78/// tensor rank (1..=8).
79///
80/// The `I = i64` default preserves source-compat for pre-Phase-12.2
81/// callers; new callers opt into narrower output dtypes via
82/// `ArgReducePlan::<T, N, u32>::select(...)` or `<T, N, i32>`.
83pub struct ArgReducePlan<T: Element, const N: usize, I: IndexOutputElement = i64> {
84    desc: ArgReduceDescriptor<N>,
85    sku: KernelSku,
86    _marker: PhantomData<(T, I)>,
87}
88
89impl<T: Element, const N: usize, I: IndexOutputElement> ArgReducePlan<T, N, I> {
90    /// Pick a kernel for `desc`.
91    pub fn select(
92        _stream: &Stream,
93        desc: &ArgReduceDescriptor<N>,
94        _pref: PlanPreference,
95    ) -> Result<Self> {
96        if desc.element != T::KIND {
97            return Err(Error::Unsupported(
98                "baracuda-kernels::ArgReducePlan: descriptor element != type parameter T",
99            ));
100        }
101        if (desc.reduce_axis as usize) >= N {
102            return Err(Error::InvalidProblem(
103                "baracuda-kernels::ArgReducePlan: reduce_axis must be < rank",
104            ));
105        }
106        for &d in desc.input_shape.iter() {
107            if d < 0 {
108                return Err(Error::InvalidProblem(
109                    "baracuda-kernels::ArgReducePlan: input_shape dims must be non-negative",
110                ));
111            }
112        }
113        if desc.input_shape[desc.reduce_axis as usize] <= 0 {
114            return Err(Error::InvalidProblem(
115                "baracuda-kernels::ArgReducePlan: cannot arg-reduce over an empty axis",
116            ));
117        }
118        let supported = matches!(
119            T::KIND,
120            ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
121        );
122        if !supported {
123            return Err(Error::Unsupported(
124                "baracuda-kernels::ArgReducePlan: today only `f32`, `f16`, `bf16`, `f64` \
125                 value dtypes are wired; other dtypes land in future fanout",
126            ));
127        }
128        let precision_guarantee = PrecisionGuarantee {
129            math_precision: MathPrecision::F32,
130            accumulator: ElementKind::F32,
131            bit_stable_on_same_hardware: true,
132            deterministic: true,
133        };
134        // Distinguish the three output-dtype SKUs via `aux_element`.
135        // `ElementKind` has I32 / I64 variants; for u32 we fall back to
136        // `None` (the only output dtype without a matching ElementKind
137        // — kernel selection is still uniquely keyed by `I::KIND` in
138        // `run`, this tag is informational).
139        let aux_element = match I::KIND {
140            IndexOutputKind::U32 => None,
141            IndexOutputKind::I32 => Some(ElementKind::I32),
142            IndexOutputKind::I64 => Some(ElementKind::I64),
143            // Defensive arm — `IndexOutputKind` is `#[non_exhaustive]`,
144            // so unrecognized variants surface as a `None` aux tag
145            // until a wired case is added.
146            _ => None,
147        };
148        let sku = KernelSku {
149            category: OpCategory::Reduction,
150            op: desc.kind as u16,
151            element: T::KIND,
152            aux_element,
153            layout: None,
154            epilogue: None,
155            arch: ArchSku::Sm80,
156            backend: BackendKind::Bespoke,
157            precision_guarantee,
158        };
159        Ok(Self {
160            desc: *desc,
161            sku,
162            _marker: PhantomData,
163        })
164    }
165
166    /// Validate args.
167    pub fn can_implement(&self, args: &ArgReduceArgs<'_, T, N, I>) -> Result<()> {
168        if args.x.shape != self.desc.input_shape {
169            return Err(Error::InvalidProblem(
170                "baracuda-kernels::ArgReducePlan: X shape mismatch with descriptor",
171            ));
172        }
173        let expected_out = self.desc.output_shape();
174        if args.y.shape != expected_out {
175            return Err(Error::InvalidProblem(
176                "baracuda-kernels::ArgReducePlan: Y shape mismatch with derived output \
177                 shape (input shape with reduce_axis collapsed to 1)",
178            ));
179        }
180        if N > 8 {
181            return Err(Error::Unsupported(
182                "baracuda-kernels::ArgReducePlan: tensor rank > 8 not supported",
183            ));
184        }
185        let y_numel = args.y.numel();
186        let x_numel = args.x.numel();
187        let x_len = args.x.data.len() as i64;
188        let y_len = args.y.data.len() as i64;
189        if y_len < y_numel {
190            return Err(Error::BufferTooSmall {
191                needed: y_numel as usize,
192                got: y_len as usize,
193            });
194        }
195        if x_len < x_numel {
196            return Err(Error::BufferTooSmall {
197                needed: x_numel as usize,
198                got: x_len as usize,
199            });
200        }
201        Ok(())
202    }
203
204    /// Workspace size in bytes.
205    #[inline]
206    pub fn workspace_size(&self) -> usize {
207        0
208    }
209    /// Identity of the kernel this plan picked.
210    #[inline]
211    pub fn sku(&self) -> KernelSku {
212        self.sku
213    }
214    /// Numerical guarantees.
215    #[inline]
216    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
217        self.sku.precision_guarantee
218    }
219
220    /// Launch.
221    pub fn run(
222        &self,
223        stream: &Stream,
224        _workspace: Workspace<'_>,
225        args: ArgReduceArgs<'_, T, N, I>,
226    ) -> Result<()> {
227        self.can_implement(&args)?;
228        let output_numel = args.y.numel();
229        if output_numel == 0 {
230            return Ok(());
231        }
232        let x_ptr = args.x.data.as_raw().0 as *const c_void;
233        let y_ptr = args.y.data.as_raw().0 as *mut c_void;
234        let stream_ptr = stream.as_raw() as *mut c_void;
235
236        let output_shape = self.desc.output_shape();
237        let stride_x = args.x.stride;
238        let stride_y = args.y.stride;
239        let rank = N as i32;
240        let reduce_axis = self.desc.reduce_axis as i32;
241        let reduce_extent = self.desc.input_shape[self.desc.reduce_axis as usize];
242        let reduce_stride_x = args.x.stride[self.desc.reduce_axis as usize];
243
244        let status = match (self.desc.kind, T::KIND, I::KIND) {
245            // -----------------------------------------------------------------
246            // i64 output (legacy / default).
247            // -----------------------------------------------------------------
248            (ArgReduceKind::Argmax, ElementKind::F32, IndexOutputKind::I64) => unsafe {
249                baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmax_f32_run(
250                    output_numel, rank, output_shape.as_ptr(),
251                    stride_x.as_ptr(), stride_y.as_ptr(),
252                    reduce_axis, reduce_extent, reduce_stride_x,
253                    x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
254                )
255            },
256            (ArgReduceKind::Argmin, ElementKind::F32, IndexOutputKind::I64) => unsafe {
257                baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmin_f32_run(
258                    output_numel, rank, output_shape.as_ptr(),
259                    stride_x.as_ptr(), stride_y.as_ptr(),
260                    reduce_axis, reduce_extent, reduce_stride_x,
261                    x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
262                )
263            },
264            (ArgReduceKind::Argmax, ElementKind::F16, IndexOutputKind::I64) => unsafe {
265                baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmax_f16_run(
266                    output_numel, rank, output_shape.as_ptr(),
267                    stride_x.as_ptr(), stride_y.as_ptr(),
268                    reduce_axis, reduce_extent, reduce_stride_x,
269                    x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
270                )
271            },
272            (ArgReduceKind::Argmin, ElementKind::F16, IndexOutputKind::I64) => unsafe {
273                baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmin_f16_run(
274                    output_numel, rank, output_shape.as_ptr(),
275                    stride_x.as_ptr(), stride_y.as_ptr(),
276                    reduce_axis, reduce_extent, reduce_stride_x,
277                    x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
278                )
279            },
280            (ArgReduceKind::Argmax, ElementKind::Bf16, IndexOutputKind::I64) => unsafe {
281                baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmax_bf16_run(
282                    output_numel, rank, output_shape.as_ptr(),
283                    stride_x.as_ptr(), stride_y.as_ptr(),
284                    reduce_axis, reduce_extent, reduce_stride_x,
285                    x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
286                )
287            },
288            (ArgReduceKind::Argmin, ElementKind::Bf16, IndexOutputKind::I64) => unsafe {
289                baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmin_bf16_run(
290                    output_numel, rank, output_shape.as_ptr(),
291                    stride_x.as_ptr(), stride_y.as_ptr(),
292                    reduce_axis, reduce_extent, reduce_stride_x,
293                    x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
294                )
295            },
296            (ArgReduceKind::Argmax, ElementKind::F64, IndexOutputKind::I64) => unsafe {
297                baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmax_f64_run(
298                    output_numel, rank, output_shape.as_ptr(),
299                    stride_x.as_ptr(), stride_y.as_ptr(),
300                    reduce_axis, reduce_extent, reduce_stride_x,
301                    x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
302                )
303            },
304            (ArgReduceKind::Argmin, ElementKind::F64, IndexOutputKind::I64) => unsafe {
305                baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmin_f64_run(
306                    output_numel, rank, output_shape.as_ptr(),
307                    stride_x.as_ptr(), stride_y.as_ptr(),
308                    reduce_axis, reduce_extent, reduce_stride_x,
309                    x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
310                )
311            },
312            // -----------------------------------------------------------------
313            // u32 output (Phase 12.2).
314            // -----------------------------------------------------------------
315            (ArgReduceKind::Argmax, ElementKind::F32, IndexOutputKind::U32) => unsafe {
316                baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmax_f32_u32_run(
317                    output_numel, rank, output_shape.as_ptr(),
318                    stride_x.as_ptr(), stride_y.as_ptr(),
319                    reduce_axis, reduce_extent, reduce_stride_x,
320                    x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
321                )
322            },
323            (ArgReduceKind::Argmin, ElementKind::F32, IndexOutputKind::U32) => unsafe {
324                baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmin_f32_u32_run(
325                    output_numel, rank, output_shape.as_ptr(),
326                    stride_x.as_ptr(), stride_y.as_ptr(),
327                    reduce_axis, reduce_extent, reduce_stride_x,
328                    x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
329                )
330            },
331            (ArgReduceKind::Argmax, ElementKind::F16, IndexOutputKind::U32) => unsafe {
332                baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmax_f16_u32_run(
333                    output_numel, rank, output_shape.as_ptr(),
334                    stride_x.as_ptr(), stride_y.as_ptr(),
335                    reduce_axis, reduce_extent, reduce_stride_x,
336                    x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
337                )
338            },
339            (ArgReduceKind::Argmin, ElementKind::F16, IndexOutputKind::U32) => unsafe {
340                baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmin_f16_u32_run(
341                    output_numel, rank, output_shape.as_ptr(),
342                    stride_x.as_ptr(), stride_y.as_ptr(),
343                    reduce_axis, reduce_extent, reduce_stride_x,
344                    x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
345                )
346            },
347            (ArgReduceKind::Argmax, ElementKind::Bf16, IndexOutputKind::U32) => unsafe {
348                baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmax_bf16_u32_run(
349                    output_numel, rank, output_shape.as_ptr(),
350                    stride_x.as_ptr(), stride_y.as_ptr(),
351                    reduce_axis, reduce_extent, reduce_stride_x,
352                    x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
353                )
354            },
355            (ArgReduceKind::Argmin, ElementKind::Bf16, IndexOutputKind::U32) => unsafe {
356                baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmin_bf16_u32_run(
357                    output_numel, rank, output_shape.as_ptr(),
358                    stride_x.as_ptr(), stride_y.as_ptr(),
359                    reduce_axis, reduce_extent, reduce_stride_x,
360                    x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
361                )
362            },
363            (ArgReduceKind::Argmax, ElementKind::F64, IndexOutputKind::U32) => unsafe {
364                baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmax_f64_u32_run(
365                    output_numel, rank, output_shape.as_ptr(),
366                    stride_x.as_ptr(), stride_y.as_ptr(),
367                    reduce_axis, reduce_extent, reduce_stride_x,
368                    x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
369                )
370            },
371            (ArgReduceKind::Argmin, ElementKind::F64, IndexOutputKind::U32) => unsafe {
372                baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmin_f64_u32_run(
373                    output_numel, rank, output_shape.as_ptr(),
374                    stride_x.as_ptr(), stride_y.as_ptr(),
375                    reduce_axis, reduce_extent, reduce_stride_x,
376                    x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
377                )
378            },
379            // -----------------------------------------------------------------
380            // i32 output (Phase 12.2).
381            // -----------------------------------------------------------------
382            (ArgReduceKind::Argmax, ElementKind::F32, IndexOutputKind::I32) => unsafe {
383                baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmax_f32_i32_run(
384                    output_numel, rank, output_shape.as_ptr(),
385                    stride_x.as_ptr(), stride_y.as_ptr(),
386                    reduce_axis, reduce_extent, reduce_stride_x,
387                    x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
388                )
389            },
390            (ArgReduceKind::Argmin, ElementKind::F32, IndexOutputKind::I32) => unsafe {
391                baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmin_f32_i32_run(
392                    output_numel, rank, output_shape.as_ptr(),
393                    stride_x.as_ptr(), stride_y.as_ptr(),
394                    reduce_axis, reduce_extent, reduce_stride_x,
395                    x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
396                )
397            },
398            (ArgReduceKind::Argmax, ElementKind::F16, IndexOutputKind::I32) => unsafe {
399                baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmax_f16_i32_run(
400                    output_numel, rank, output_shape.as_ptr(),
401                    stride_x.as_ptr(), stride_y.as_ptr(),
402                    reduce_axis, reduce_extent, reduce_stride_x,
403                    x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
404                )
405            },
406            (ArgReduceKind::Argmin, ElementKind::F16, IndexOutputKind::I32) => unsafe {
407                baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmin_f16_i32_run(
408                    output_numel, rank, output_shape.as_ptr(),
409                    stride_x.as_ptr(), stride_y.as_ptr(),
410                    reduce_axis, reduce_extent, reduce_stride_x,
411                    x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
412                )
413            },
414            (ArgReduceKind::Argmax, ElementKind::Bf16, IndexOutputKind::I32) => unsafe {
415                baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmax_bf16_i32_run(
416                    output_numel, rank, output_shape.as_ptr(),
417                    stride_x.as_ptr(), stride_y.as_ptr(),
418                    reduce_axis, reduce_extent, reduce_stride_x,
419                    x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
420                )
421            },
422            (ArgReduceKind::Argmin, ElementKind::Bf16, IndexOutputKind::I32) => unsafe {
423                baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmin_bf16_i32_run(
424                    output_numel, rank, output_shape.as_ptr(),
425                    stride_x.as_ptr(), stride_y.as_ptr(),
426                    reduce_axis, reduce_extent, reduce_stride_x,
427                    x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
428                )
429            },
430            (ArgReduceKind::Argmax, ElementKind::F64, IndexOutputKind::I32) => unsafe {
431                baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmax_f64_i32_run(
432                    output_numel, rank, output_shape.as_ptr(),
433                    stride_x.as_ptr(), stride_y.as_ptr(),
434                    reduce_axis, reduce_extent, reduce_stride_x,
435                    x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
436                )
437            },
438            (ArgReduceKind::Argmin, ElementKind::F64, IndexOutputKind::I32) => unsafe {
439                baracuda_kernels_sys::baracuda_kernels_arg_reduce_argmin_f64_i32_run(
440                    output_numel, rank, output_shape.as_ptr(),
441                    stride_x.as_ptr(), stride_y.as_ptr(),
442                    reduce_axis, reduce_extent, reduce_stride_x,
443                    x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
444                )
445            },
446            _ => {
447                return Err(Error::Unsupported(
448                    "baracuda-kernels::ArgReducePlan::run: only `{Argmax,Argmin} × \
449                     {f32,f16,bf16,f64} × {u32,i32,i64}` wired today",
450                ));
451            }
452        };
453        map_status(status)
454    }
455}
456
457fn map_status(code: i32) -> Result<()> {
458    match code {
459        0 => Ok(()),
460        1 => Err(Error::MisalignedOperand),
461        2 => Err(Error::InvalidProblem(
462            "baracuda-kernels-sys reported invalid problem",
463        )),
464        3 => Err(Error::Unsupported(
465            "baracuda-kernels-sys reported unsupported configuration",
466        )),
467        4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
468        n => Err(Error::CutlassInternal(n)),
469    }
470}