Skip to main content

baracuda_kernels/reduce/
axis.rs

1//! Single-axis reduction plan.
2//!
3//! Output shape == input shape with the reduced axis collapsed to size
4//! 1 (keepdim convention).
5//!
6//! **Wired matrix**: `{Sum, Mean, Max, Min, Prod, Norm2, LogSumExp, Var,
7//! Std} × {f32, f16, bf16, f64}` — 36 (kind, dtype) cells. The simple-
8//! reduce kernel template is shared (one thread per output cell,
9//! sequential walk over the reduced axis); each (op, dtype) has its
10//! own functor + FFI symbol. LogSumExp ships a dedicated two-pass
11//! kernel (max, then sum-exp) for numerical stability. Var / Std ship
12//! a Welford one-pass kernel templated on T.
13//!
14//! **Bessel correction** ([`ReduceDescriptor::correction`]): only Var /
15//! Std consume this; `1` = sample (PyTorch default), `0` = population.
16//!
17//! **Workspace**: none — the per-output-cell kernel keeps the running
18//! accumulator in registers.
19//!
20//! **Precision**: deterministic, bit-stable on the same hardware (no
21//! atomic-add; sequential per-cell accumulation has a fixed order).
22//! f16 / bf16 accumulate in f32 (FP detour); f64 keeps everything in
23//! double.
24//!
25//! **Sibling plans**:
26//! - Argmax / Argmin → [`crate::ArgReducePlan`] (i64 output).
27//! - Any / All → [`crate::BoolReducePlan`].
28//! - CountNonzero → [`crate::CountReducePlan`].
29//! - Trace → [`crate::TracePlan`] (rank-2 only, scalar output).
30
31use core::ffi::c_void;
32use core::marker::PhantomData;
33
34use baracuda_cutlass::{Error, Result};
35use baracuda_driver::Stream;
36use baracuda_kernels_types::{
37    ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
38    PlanPreference, PrecisionGuarantee, ReduceKind, TensorMut, TensorRef, Workspace,
39};
40
41/// Descriptor for a single-axis reduction.
42///
43/// `input_shape` is the shape of the input tensor. `reduce_axis` is
44/// the axis to reduce (0 ≤ `reduce_axis` < rank). Output shape is the
45/// input shape with `[reduce_axis]` collapsed to size 1 (keepdim
46/// convention — caller squeezes if they want).
47#[derive(Copy, Clone, Debug)]
48pub struct ReduceDescriptor<const N: usize> {
49    /// Which reduction to apply (Sum / Mean / Max / ...).
50    pub kind: ReduceKind,
51    /// Input tensor shape.
52    pub input_shape: [i32; N],
53    /// Axis to reduce along. Must satisfy `0 <= reduce_axis < N`.
54    pub reduce_axis: u8,
55    /// Element type.
56    pub element: ElementKind,
57    /// Bessel correction for `Var` / `Std` only. `1` = sample
58    /// variance (PyTorch default), `0` = population variance. Ignored
59    /// by other reductions.
60    pub correction: i32,
61}
62
63impl<const N: usize> ReduceDescriptor<N> {
64    /// Compute the output shape (input shape with reduce axis = 1).
65    pub fn output_shape(&self) -> [i32; N] {
66        let mut out = self.input_shape;
67        out[self.reduce_axis as usize] = 1;
68        out
69    }
70}
71
72/// Args bundle for a reduction launch.
73///
74/// `x.shape` must match `desc.input_shape`. `y.shape` must match the
75/// derived output shape. Output is conventionally contiguous; the
76/// kernel accepts arbitrary strides.
77pub struct ReduceArgs<'a, T: Element, const N: usize> {
78    /// Input tensor.
79    pub x: TensorRef<'a, T, N>,
80    /// Output tensor — shape == input with reduced axis collapsed to 1.
81    pub y: TensorMut<'a, T, N>,
82}
83
84/// Single-axis reduction plan — see module docs for the wired matrix,
85/// workspace, and precision guarantees.
86///
87/// `T: Element` is the element type (`f32` / `f64` / `f16` / `bf16`).
88/// `const N: usize` is the tensor rank.
89pub struct ReducePlan<T: Element, const N: usize> {
90    desc: ReduceDescriptor<N>,
91    sku: KernelSku,
92    _marker: PhantomData<T>,
93}
94
95impl<T: Element, const N: usize> ReducePlan<T, N> {
96    /// Pick a kernel for `desc`.
97    pub fn select(
98        _stream: &Stream,
99        desc: &ReduceDescriptor<N>,
100        _pref: PlanPreference,
101    ) -> Result<Self> {
102        if desc.element != T::KIND {
103            return Err(Error::Unsupported(
104                "baracuda-kernels::ReducePlan: descriptor element != type parameter T",
105            ));
106        }
107        if (desc.reduce_axis as usize) >= N {
108            return Err(Error::InvalidProblem(
109                "baracuda-kernels::ReducePlan: reduce_axis must be < rank",
110            ));
111        }
112        for &d in desc.input_shape.iter() {
113            if d < 0 {
114                return Err(Error::InvalidProblem(
115                    "baracuda-kernels::ReducePlan: input_shape dims must be non-negative",
116                ));
117            }
118        }
119
120        // Supported matrix:
121        //   {Sum, Mean, Max, Min, Prod, Norm2, LogSumExp, Var, Std}
122        //                            × {f32, f16, bf16, f64}   (36 cells)
123        // Argmax/Argmin live in `ArgReducePlan` (i64 output); trace
124        // lives in `TracePlan` (scalar output, both axes reduced). The
125        // remaining reserved discriminants (Any / All) land in later
126        // fanout.
127        let kind_in_scope = matches!(
128            desc.kind,
129            ReduceKind::Sum
130                | ReduceKind::Mean
131                | ReduceKind::Max
132                | ReduceKind::Min
133                | ReduceKind::Prod
134                | ReduceKind::Norm2
135                | ReduceKind::LogSumExp
136                | ReduceKind::Var
137                | ReduceKind::Std
138        );
139        let dtype_in_scope = matches!(
140            T::KIND,
141            ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
142        );
143        let supported = kind_in_scope && dtype_in_scope;
144        if !supported {
145            return Err(Error::Unsupported(
146                "baracuda-kernels::ReducePlan: supported matrix is \
147                 {Sum, Mean, Max, Min, Prod, Norm2, LogSumExp, Var, Std} × \
148                 {f32, f16, bf16, f64}; other (kind, dtype) pairs land \
149                 in later fanout (Argmax/Argmin via ArgReducePlan; trace \
150                 via TracePlan)",
151            ));
152        }
153
154        // The naive trailblazer kernel sums in input-order (one thread
155        // per output cell, sequential over the reduced axis). Result
156        // is deterministic and bit-stable for f32 on the same hardware.
157        let precision_guarantee = PrecisionGuarantee {
158            math_precision: MathPrecision::F32,
159            accumulator: ElementKind::F32,
160            bit_stable_on_same_hardware: true,
161            deterministic: true,
162        };
163        let sku = KernelSku {
164            category: OpCategory::Reduction,
165            op: desc.kind as u16,
166            element: T::KIND,
167            aux_element: None,
168            layout: None,
169            epilogue: None,
170            arch: ArchSku::Sm80,
171            backend: BackendKind::Bespoke,
172            precision_guarantee,
173        };
174        Ok(Self {
175            desc: *desc,
176            sku,
177            _marker: PhantomData,
178        })
179    }
180
181    /// Validate args.
182    pub fn can_implement(&self, args: &ReduceArgs<'_, T, N>) -> Result<()> {
183        if args.x.shape != self.desc.input_shape {
184            return Err(Error::InvalidProblem(
185                "baracuda-kernels::ReducePlan: X shape mismatch with descriptor input_shape",
186            ));
187        }
188        let expected_out = self.desc.output_shape();
189        if args.y.shape != expected_out {
190            return Err(Error::InvalidProblem(
191                "baracuda-kernels::ReducePlan: Y shape mismatch with derived output shape \
192                 (input shape with reduce_axis collapsed to 1)",
193            ));
194        }
195        if N > 8 {
196            return Err(Error::Unsupported(
197                "baracuda-kernels::ReducePlan: tensor rank > 8 not supported",
198            ));
199        }
200        let y_numel = args.y.numel();
201        let x_numel = args.x.numel();
202        let x_len = args.x.data.len() as i64;
203        let y_len = args.y.data.len() as i64;
204        if y_len < y_numel {
205            return Err(Error::BufferTooSmall {
206                needed: y_numel as usize,
207                got: y_len as usize,
208            });
209        }
210        if x_len < x_numel {
211            return Err(Error::BufferTooSmall {
212                needed: x_numel as usize,
213                got: x_len as usize,
214            });
215        }
216        Ok(())
217    }
218
219    /// Workspace size in bytes. Always `0` for the naive trailblazer.
220    #[inline]
221    pub fn workspace_size(&self) -> usize {
222        0
223    }
224
225    /// Identity of the kernel this plan picked.
226    #[inline]
227    pub fn sku(&self) -> KernelSku {
228        self.sku
229    }
230
231    /// Numerical guarantees for this plan's kernel.
232    #[inline]
233    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
234        self.sku.precision_guarantee
235    }
236
237    /// Launch.
238    pub fn run(
239        &self,
240        stream: &Stream,
241        _workspace: Workspace<'_>,
242        args: ReduceArgs<'_, T, N>,
243    ) -> Result<()> {
244        self.can_implement(&args)?;
245        let output_numel = args.y.numel();
246        if output_numel == 0 {
247            return Ok(());
248        }
249        let x_ptr = args.x.data.as_raw().0 as *const c_void;
250        let y_ptr = args.y.data.as_raw().0 as *mut c_void;
251        let stream_ptr = stream.as_raw() as *mut c_void;
252
253        let output_shape = self.desc.output_shape();
254        let stride_x = args.x.stride;
255        let stride_y = args.y.stride;
256        let rank = N as i32;
257        let reduce_axis = self.desc.reduce_axis as i32;
258        let reduce_extent = self.desc.input_shape[self.desc.reduce_axis as usize];
259        let reduce_stride_x = args.x.stride[self.desc.reduce_axis as usize];
260
261        // Helper: every reduce FFI symbol shares the same parameter
262        // shape (the kernel template is shared). The macro picks the
263        // right symbol from (kind, dtype).
264        macro_rules! dispatch {
265            ($sym:ident) => {{
266                unsafe {
267                    baracuda_kernels_sys::$sym(
268                        output_numel,
269                        rank,
270                        output_shape.as_ptr(),
271                        stride_x.as_ptr(),
272                        stride_y.as_ptr(),
273                        reduce_axis,
274                        reduce_extent,
275                        reduce_stride_x,
276                        x_ptr,
277                        y_ptr,
278                        core::ptr::null_mut(),
279                        0,
280                        stream_ptr,
281                    )
282                }
283            }};
284        }
285
286        let status = match (self.desc.kind, T::KIND) {
287            // Sum
288            (ReduceKind::Sum, ElementKind::F32) => dispatch!(baracuda_kernels_reduce_sum_f32_run),
289            (ReduceKind::Sum, ElementKind::F16) => dispatch!(baracuda_kernels_reduce_sum_f16_run),
290            (ReduceKind::Sum, ElementKind::Bf16) => dispatch!(baracuda_kernels_reduce_sum_bf16_run),
291            (ReduceKind::Sum, ElementKind::F64) => dispatch!(baracuda_kernels_reduce_sum_f64_run),
292            // Mean
293            (ReduceKind::Mean, ElementKind::F32) => dispatch!(baracuda_kernels_reduce_mean_f32_run),
294            (ReduceKind::Mean, ElementKind::F16) => dispatch!(baracuda_kernels_reduce_mean_f16_run),
295            (ReduceKind::Mean, ElementKind::Bf16) => {
296                dispatch!(baracuda_kernels_reduce_mean_bf16_run)
297            }
298            (ReduceKind::Mean, ElementKind::F64) => dispatch!(baracuda_kernels_reduce_mean_f64_run),
299            // Max
300            (ReduceKind::Max, ElementKind::F32) => dispatch!(baracuda_kernels_reduce_max_f32_run),
301            (ReduceKind::Max, ElementKind::F16) => dispatch!(baracuda_kernels_reduce_max_f16_run),
302            (ReduceKind::Max, ElementKind::Bf16) => dispatch!(baracuda_kernels_reduce_max_bf16_run),
303            (ReduceKind::Max, ElementKind::F64) => dispatch!(baracuda_kernels_reduce_max_f64_run),
304            // Min
305            (ReduceKind::Min, ElementKind::F32) => dispatch!(baracuda_kernels_reduce_min_f32_run),
306            (ReduceKind::Min, ElementKind::F16) => dispatch!(baracuda_kernels_reduce_min_f16_run),
307            (ReduceKind::Min, ElementKind::Bf16) => dispatch!(baracuda_kernels_reduce_min_bf16_run),
308            (ReduceKind::Min, ElementKind::F64) => dispatch!(baracuda_kernels_reduce_min_f64_run),
309            // Prod
310            (ReduceKind::Prod, ElementKind::F32) => dispatch!(baracuda_kernels_reduce_prod_f32_run),
311            (ReduceKind::Prod, ElementKind::F16) => dispatch!(baracuda_kernels_reduce_prod_f16_run),
312            (ReduceKind::Prod, ElementKind::Bf16) => {
313                dispatch!(baracuda_kernels_reduce_prod_bf16_run)
314            }
315            (ReduceKind::Prod, ElementKind::F64) => dispatch!(baracuda_kernels_reduce_prod_f64_run),
316            // Norm2 = sqrt(sum(x^2)) — shares the simple-reduce
317            // parameter shape; finalize() does the sqrt.
318            (ReduceKind::Norm2, ElementKind::F32) => {
319                dispatch!(baracuda_kernels_reduce_norm2_f32_run)
320            }
321            (ReduceKind::Norm2, ElementKind::F16) => {
322                dispatch!(baracuda_kernels_reduce_norm2_f16_run)
323            }
324            (ReduceKind::Norm2, ElementKind::Bf16) => {
325                dispatch!(baracuda_kernels_reduce_norm2_bf16_run)
326            }
327            (ReduceKind::Norm2, ElementKind::F64) => {
328                dispatch!(baracuda_kernels_reduce_norm2_f64_run)
329            }
330            // LogSumExp — `y = log(sum(exp(x - max))) + max` via a
331            // two-pass kernel (max, then sum-exp). Same FFI parameter
332            // shape as the simple-reduce family.
333            (ReduceKind::LogSumExp, ElementKind::F32) => {
334                dispatch!(baracuda_kernels_reduce_logsumexp_f32_run)
335            }
336            (ReduceKind::LogSumExp, ElementKind::F16) => {
337                dispatch!(baracuda_kernels_reduce_logsumexp_f16_run)
338            }
339            (ReduceKind::LogSumExp, ElementKind::Bf16) => {
340                dispatch!(baracuda_kernels_reduce_logsumexp_bf16_run)
341            }
342            (ReduceKind::LogSumExp, ElementKind::F64) => {
343                dispatch!(baracuda_kernels_reduce_logsumexp_f64_run)
344            }
345            // Var / Std take an extra `correction` parameter and route
346            // through the Welford-family FFI symbols. Welford state runs
347            // at f32 for f32/f16/bf16 and f64 for f64 (handled by the
348            // `WelfordAcc<T>` trait inside the kernel template).
349            (ReduceKind::Var, ElementKind::F32) => unsafe {
350                baracuda_kernels_sys::baracuda_kernels_reduce_var_f32_run(
351                    output_numel, rank, output_shape.as_ptr(),
352                    stride_x.as_ptr(), stride_y.as_ptr(),
353                    reduce_axis, reduce_extent, reduce_stride_x,
354                    self.desc.correction,
355                    x_ptr, y_ptr,
356                    core::ptr::null_mut(), 0, stream_ptr,
357                )
358            },
359            (ReduceKind::Var, ElementKind::F16) => unsafe {
360                baracuda_kernels_sys::baracuda_kernels_reduce_var_f16_run(
361                    output_numel, rank, output_shape.as_ptr(),
362                    stride_x.as_ptr(), stride_y.as_ptr(),
363                    reduce_axis, reduce_extent, reduce_stride_x,
364                    self.desc.correction,
365                    x_ptr, y_ptr,
366                    core::ptr::null_mut(), 0, stream_ptr,
367                )
368            },
369            (ReduceKind::Var, ElementKind::Bf16) => unsafe {
370                baracuda_kernels_sys::baracuda_kernels_reduce_var_bf16_run(
371                    output_numel, rank, output_shape.as_ptr(),
372                    stride_x.as_ptr(), stride_y.as_ptr(),
373                    reduce_axis, reduce_extent, reduce_stride_x,
374                    self.desc.correction,
375                    x_ptr, y_ptr,
376                    core::ptr::null_mut(), 0, stream_ptr,
377                )
378            },
379            (ReduceKind::Var, ElementKind::F64) => unsafe {
380                baracuda_kernels_sys::baracuda_kernels_reduce_var_f64_run(
381                    output_numel, rank, output_shape.as_ptr(),
382                    stride_x.as_ptr(), stride_y.as_ptr(),
383                    reduce_axis, reduce_extent, reduce_stride_x,
384                    self.desc.correction,
385                    x_ptr, y_ptr,
386                    core::ptr::null_mut(), 0, stream_ptr,
387                )
388            },
389            (ReduceKind::Std, ElementKind::F32) => unsafe {
390                baracuda_kernels_sys::baracuda_kernels_reduce_std_f32_run(
391                    output_numel, rank, output_shape.as_ptr(),
392                    stride_x.as_ptr(), stride_y.as_ptr(),
393                    reduce_axis, reduce_extent, reduce_stride_x,
394                    self.desc.correction,
395                    x_ptr, y_ptr,
396                    core::ptr::null_mut(), 0, stream_ptr,
397                )
398            },
399            (ReduceKind::Std, ElementKind::F16) => unsafe {
400                baracuda_kernels_sys::baracuda_kernels_reduce_std_f16_run(
401                    output_numel, rank, output_shape.as_ptr(),
402                    stride_x.as_ptr(), stride_y.as_ptr(),
403                    reduce_axis, reduce_extent, reduce_stride_x,
404                    self.desc.correction,
405                    x_ptr, y_ptr,
406                    core::ptr::null_mut(), 0, stream_ptr,
407                )
408            },
409            (ReduceKind::Std, ElementKind::Bf16) => unsafe {
410                baracuda_kernels_sys::baracuda_kernels_reduce_std_bf16_run(
411                    output_numel, rank, output_shape.as_ptr(),
412                    stride_x.as_ptr(), stride_y.as_ptr(),
413                    reduce_axis, reduce_extent, reduce_stride_x,
414                    self.desc.correction,
415                    x_ptr, y_ptr,
416                    core::ptr::null_mut(), 0, stream_ptr,
417                )
418            },
419            (ReduceKind::Std, ElementKind::F64) => unsafe {
420                baracuda_kernels_sys::baracuda_kernels_reduce_std_f64_run(
421                    output_numel, rank, output_shape.as_ptr(),
422                    stride_x.as_ptr(), stride_y.as_ptr(),
423                    reduce_axis, reduce_extent, reduce_stride_x,
424                    self.desc.correction,
425                    x_ptr, y_ptr,
426                    core::ptr::null_mut(), 0, stream_ptr,
427                )
428            },
429            _ => {
430                return Err(Error::Unsupported(
431                    "baracuda-kernels::ReducePlan::run: this (kind, dtype) cell is not yet wired",
432                ));
433            }
434        };
435        map_status(status)
436    }
437}
438
439fn map_status(code: i32) -> Result<()> {
440    match code {
441        0 => Ok(()),
442        1 => Err(Error::MisalignedOperand),
443        2 => Err(Error::InvalidProblem(
444            "baracuda-kernels-sys reported invalid problem",
445        )),
446        3 => Err(Error::Unsupported(
447            "baracuda-kernels-sys reported unsupported configuration",
448        )),
449        4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
450        n => Err(Error::CutlassInternal(n)),
451    }
452}