Skip to main content

baracuda_kernels/reduce/
count_axis.rs

1//! Single-axis `count_nonzero` reduction. Heterogeneous output dtype:
2//! input is `T: Element`, output is always `i64` (PyTorch convention —
3//! `torch.count_nonzero` returns int64).
4//!
5//! Sibling of [`crate::ReducePlan`] (same-dtype-in-out),
6//! [`crate::ArgReducePlan`] (i64 index out), and
7//! [`crate::BoolReducePlan`] (Bool out for Any / All). Lives in its own
8//! plan shape because of the heterogeneous output. PyTorch parity:
9//! `torch.count_nonzero(x, dim=k)`.
10//!
11//! Wired matrix: `CountNonzero × {f32, f16, bf16, f64, i32, i64, Bool}`
12//! — 7 SKUs. NaN is counted as non-zero (`NaN != 0` is true — matches
13//! PyTorch / IEEE 754). Non-differentiable; no backward plan ships.
14
15use core::ffi::c_void;
16use core::marker::PhantomData;
17
18use baracuda_cutlass::{Error, Result};
19use baracuda_driver::Stream;
20use baracuda_kernels_types::{
21    ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
22    PlanPreference, PrecisionGuarantee, ReduceKind, TensorMut, TensorRef, Workspace,
23};
24
25/// Descriptor for a count-nonzero axis reduction.
26#[derive(Copy, Clone, Debug)]
27pub struct CountReduceDescriptor<const N: usize> {
28    /// Must be [`ReduceKind::CountNonzero`].
29    pub kind: ReduceKind,
30    /// Input tensor shape.
31    pub input_shape: [i32; N],
32    /// Axis to reduce along.
33    pub reduce_axis: u8,
34    /// Input element type.
35    pub element: ElementKind,
36}
37
38impl<const N: usize> CountReduceDescriptor<N> {
39    /// Output shape: input shape with reduce axis = 1 (keepdim).
40    pub fn output_shape(&self) -> [i32; N] {
41        let mut out = self.input_shape;
42        out[self.reduce_axis as usize] = 1;
43        out
44    }
45}
46
47/// Args bundle. Output is always `i64` regardless of input dtype.
48pub struct CountReduceArgs<'a, T: Element, const N: usize> {
49    /// Input tensor.
50    pub x: TensorRef<'a, T, N>,
51    /// Output tensor — always i64.
52    pub y: TensorMut<'a, i64, N>,
53}
54
55/// Plan for a count-nonzero axis reduction.
56pub struct CountReducePlan<T: Element, const N: usize> {
57    desc: CountReduceDescriptor<N>,
58    sku: KernelSku,
59    _marker: PhantomData<T>,
60}
61
62impl<T: Element, const N: usize> CountReducePlan<T, N> {
63    /// Pick a kernel for `desc`.
64    pub fn select(
65        _stream: &Stream,
66        desc: &CountReduceDescriptor<N>,
67        _pref: PlanPreference,
68    ) -> Result<Self> {
69        if desc.element != T::KIND {
70            return Err(Error::Unsupported(
71                "baracuda-kernels::CountReducePlan: descriptor element != type parameter T",
72            ));
73        }
74        if (desc.reduce_axis as usize) >= N {
75            return Err(Error::InvalidProblem(
76                "baracuda-kernels::CountReducePlan: reduce_axis must be < rank",
77            ));
78        }
79        for &d in desc.input_shape.iter() {
80            if d < 0 {
81                return Err(Error::InvalidProblem(
82                    "baracuda-kernels::CountReducePlan: input_shape dims must be non-negative",
83                ));
84            }
85        }
86        if !matches!(desc.kind, ReduceKind::CountNonzero) {
87            return Err(Error::Unsupported(
88                "baracuda-kernels::CountReducePlan: kind must be CountNonzero",
89            ));
90        }
91        let dtype_in_scope = matches!(
92            T::KIND,
93            ElementKind::F32
94                | ElementKind::F16
95                | ElementKind::Bf16
96                | ElementKind::F64
97                | ElementKind::I32
98                | ElementKind::I64
99                | ElementKind::Bool
100        );
101        if !dtype_in_scope {
102            return Err(Error::Unsupported(
103                "baracuda-kernels::CountReducePlan: supported input dtypes are \
104                 {f32, f16, bf16, f64, i32, i64, Bool}",
105            ));
106        }
107        // Pure integer accumulation in i64 — bit-stable on the same
108        // hardware (no FP math).
109        let precision_guarantee = PrecisionGuarantee {
110            math_precision: MathPrecision::F32,
111            accumulator: ElementKind::I64,
112            bit_stable_on_same_hardware: true,
113            deterministic: true,
114        };
115        let sku = KernelSku {
116            category: OpCategory::Reduction,
117            op: desc.kind as u16,
118            element: T::KIND,
119            // Output dtype is i64; ElementKind has an I64 variant.
120            aux_element: Some(ElementKind::I64),
121            layout: None,
122            epilogue: None,
123            arch: ArchSku::Sm80,
124            backend: BackendKind::Bespoke,
125            precision_guarantee,
126        };
127        Ok(Self {
128            desc: *desc,
129            sku,
130            _marker: PhantomData,
131        })
132    }
133
134    /// Validate args.
135    pub fn can_implement(&self, args: &CountReduceArgs<'_, T, N>) -> Result<()> {
136        if args.x.shape != self.desc.input_shape {
137            return Err(Error::InvalidProblem(
138                "baracuda-kernels::CountReducePlan: X shape mismatch with descriptor",
139            ));
140        }
141        let expected_out = self.desc.output_shape();
142        if args.y.shape != expected_out {
143            return Err(Error::InvalidProblem(
144                "baracuda-kernels::CountReducePlan: Y shape mismatch with derived output \
145                 shape (input shape with reduce_axis collapsed to 1)",
146            ));
147        }
148        if N > 8 {
149            return Err(Error::Unsupported(
150                "baracuda-kernels::CountReducePlan: tensor rank > 8 not supported",
151            ));
152        }
153        let y_numel = args.y.numel();
154        let x_numel = args.x.numel();
155        let x_len = args.x.data.len() as i64;
156        let y_len = args.y.data.len() as i64;
157        if y_len < y_numel {
158            return Err(Error::BufferTooSmall {
159                needed: y_numel as usize,
160                got: y_len as usize,
161            });
162        }
163        if x_len < x_numel {
164            return Err(Error::BufferTooSmall {
165                needed: x_numel as usize,
166                got: x_len as usize,
167            });
168        }
169        Ok(())
170    }
171
172    /// Workspace size in bytes. Always 0 for the naive trailblazer.
173    #[inline]
174    pub fn workspace_size(&self) -> usize {
175        0
176    }
177    /// Identity of the kernel this plan picked.
178    #[inline]
179    pub fn sku(&self) -> KernelSku {
180        self.sku
181    }
182    /// Numerical guarantees.
183    #[inline]
184    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
185        self.sku.precision_guarantee
186    }
187
188    /// Launch.
189    pub fn run(
190        &self,
191        stream: &Stream,
192        _workspace: Workspace<'_>,
193        args: CountReduceArgs<'_, T, N>,
194    ) -> Result<()> {
195        self.can_implement(&args)?;
196        let output_numel = args.y.numel();
197        if output_numel == 0 {
198            return Ok(());
199        }
200        let x_ptr = args.x.data.as_raw().0 as *const c_void;
201        let y_ptr = args.y.data.as_raw().0 as *mut c_void;
202        let stream_ptr = stream.as_raw() as *mut c_void;
203
204        let output_shape = self.desc.output_shape();
205        let stride_x = args.x.stride;
206        let stride_y = args.y.stride;
207        let rank = N as i32;
208        let reduce_axis = self.desc.reduce_axis as i32;
209        let reduce_extent = self.desc.input_shape[self.desc.reduce_axis as usize];
210        let reduce_stride_x = args.x.stride[self.desc.reduce_axis as usize];
211
212        macro_rules! dispatch {
213            ($sym:ident) => {{
214                unsafe {
215                    baracuda_kernels_sys::$sym(
216                        output_numel,
217                        rank,
218                        output_shape.as_ptr(),
219                        stride_x.as_ptr(),
220                        stride_y.as_ptr(),
221                        reduce_axis,
222                        reduce_extent,
223                        reduce_stride_x,
224                        x_ptr,
225                        y_ptr,
226                        core::ptr::null_mut(),
227                        0,
228                        stream_ptr,
229                    )
230                }
231            }};
232        }
233
234        let status = match (self.desc.kind, T::KIND) {
235            (ReduceKind::CountNonzero, ElementKind::F32) => {
236                dispatch!(baracuda_kernels_reduce_count_nonzero_f32_run)
237            }
238            (ReduceKind::CountNonzero, ElementKind::F16) => {
239                dispatch!(baracuda_kernels_reduce_count_nonzero_f16_run)
240            }
241            (ReduceKind::CountNonzero, ElementKind::Bf16) => {
242                dispatch!(baracuda_kernels_reduce_count_nonzero_bf16_run)
243            }
244            (ReduceKind::CountNonzero, ElementKind::F64) => {
245                dispatch!(baracuda_kernels_reduce_count_nonzero_f64_run)
246            }
247            (ReduceKind::CountNonzero, ElementKind::I32) => {
248                dispatch!(baracuda_kernels_reduce_count_nonzero_i32_run)
249            }
250            (ReduceKind::CountNonzero, ElementKind::I64) => {
251                dispatch!(baracuda_kernels_reduce_count_nonzero_i64_run)
252            }
253            (ReduceKind::CountNonzero, ElementKind::Bool) => {
254                dispatch!(baracuda_kernels_reduce_count_nonzero_bool_run)
255            }
256            _ => {
257                return Err(Error::Unsupported(
258                    "baracuda-kernels::CountReducePlan::run: only `CountNonzero × \
259                     {f32, f16, bf16, f64, i32, i64, Bool}` wired",
260                ));
261            }
262        };
263        map_status(status)
264    }
265}
266
267fn map_status(code: i32) -> Result<()> {
268    match code {
269        0 => Ok(()),
270        1 => Err(Error::MisalignedOperand),
271        2 => Err(Error::InvalidProblem(
272            "baracuda-kernels-sys reported invalid problem",
273        )),
274        3 => Err(Error::Unsupported(
275            "baracuda-kernels-sys reported unsupported configuration",
276        )),
277        4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
278        n => Err(Error::CutlassInternal(n)),
279    }
280}