Skip to main content

baracuda_kernels/reduce/
reduce_to.rs

1//! Broadcast-reverse reduction plan.
2//!
3//! The autograd primitive that undoes a forward `BroadcastTo`: for
4//! each output cell, reduce every input cell that broadcasts TO it.
5//! The reduced dims are every dim where `output_shape[d] == 1` while
6//! `input_shape[d] != 1` — an arbitrary *set* of axes collapses in a
7//! single launch (contrast [`crate::ReducePlan`], which reduces one
8//! `reduce_axis` per launch). Output keeps the input's rank with the
9//! reduced dims at size 1 (keepdim convention).
10//!
11//! **Wired matrix**: `{Sum, Max, Min, Prod} × {f32, f16, bf16, f64}`
12//! — 16 (op, dtype) cells over the Phase 31 / Phase 37
13//! `baracuda_kernels_reduce_{sum,max,min,prod}_to_*` FFI symbols
14//! (Phase 74 closes the facade gap — the symbols shipped without a
15//! plan-level entry, which hid the capability from plan-surface
16//! audits). The kernel template is shared (one thread per output
17//! cell, sequential walk over the broadcast set); the per-op policy
18//! supplies the identity + combine step.
19//!
20//! **Empty reduce sets** (any reduced `input_shape[d] == 0`): the
21//! kernel writes the op's identity — `0` for Sum, `1` for Prod,
22//! `-FLT_MAX` / `-DBL_MAX` for Max, `+FLT_MAX` / `+DBL_MAX` for Min.
23//! For f32 / f64 outputs that is the most-extreme *finite* value;
24//! for f16 / bf16 the f32 identity overflows the storage dtype on
25//! the final narrowing store and lands as `∓inf`. See [`ReduceToOp`].
26//!
27//! **Layout**: the input may be arbitrarily strided (transposed /
28//! sliced views — common in autograd traces); its strides pass
29//! through to the kernel. The output MUST be contiguous over
30//! `output_shape` (the kernel writes `dst[out_id]` by linear index;
31//! validated in `can_implement`).
32//!
33//! **Workspace**: none — the per-output-cell kernel keeps the running
34//! accumulator in registers.
35//!
36//! **Precision**: deterministic, bit-stable on the same hardware (no
37//! atomic-add; sequential per-cell accumulation has a fixed order).
38//! f16 / bf16 accumulate in f32 (Sum / Prod widen per the PyTorch
39//! convention; Max / Min compare in f32, which is value-preserving);
40//! f64 keeps everything in double.
41
42use core::ffi::c_void;
43use core::marker::PhantomData;
44
45use baracuda_cutlass::{Error, Result};
46use baracuda_driver::Stream;
47use baracuda_kernels_types::{
48    ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
49    PlanPreference, PrecisionGuarantee, ReduceToOp, TensorMut, TensorRef, Workspace,
50};
51
52/// Descriptor for a broadcast-reverse reduction.
53///
54/// `input_shape` is the source extents; `output_shape` is the target
55/// extents — same rank, with every reduced dim collapsed to size 1
56/// (the caller left-pads `output_shape` with 1s if the forward
57/// broadcast added leading dims). Per-dim constraint:
58/// `output_shape[d] == 1 || output_shape[d] == input_shape[d]`.
59#[derive(Copy, Clone, Debug)]
60pub struct ReduceToDescriptor<const N: usize> {
61    /// Which reduction to apply over each output cell's broadcast set.
62    pub op: ReduceToOp,
63    /// Input tensor shape.
64    pub input_shape: [i32; N],
65    /// Output tensor shape — `input_shape` with the reduced dims
66    /// collapsed to 1.
67    pub output_shape: [i32; N],
68    /// Element type.
69    pub element: ElementKind,
70}
71
72/// Args bundle for a broadcast-reverse reduction launch.
73///
74/// `x.shape` must match `desc.input_shape`; arbitrary (non-contiguous)
75/// strides are fine — they pass through to the kernel. `y.shape` must
76/// match `desc.output_shape` and `y` MUST be contiguous.
77pub struct ReduceToArgs<'a, T: Element, const N: usize> {
78    /// Input tensor — may be a strided (transposed / sliced) view.
79    pub x: TensorRef<'a, T, N>,
80    /// Output tensor — contiguous over `desc.output_shape`.
81    pub y: TensorMut<'a, T, N>,
82}
83
84/// Broadcast-reverse reduction plan — see module docs for the wired
85/// matrix, workspace, and precision guarantees.
86///
87/// `T: Element` is the element type (`f32` / `f64` / `f16` / `bf16`).
88/// `const N: usize` is the tensor rank (input and output share it).
89pub struct ReduceToPlan<T: Element, const N: usize> {
90    desc: ReduceToDescriptor<N>,
91    sku: KernelSku,
92    _marker: PhantomData<T>,
93}
94
95impl<T: Element, const N: usize> ReduceToPlan<T, N> {
96    /// Pick a kernel for `desc`.
97    pub fn select(
98        _stream: &Stream,
99        desc: &ReduceToDescriptor<N>,
100        _pref: PlanPreference,
101    ) -> Result<Self> {
102        if desc.element != T::KIND {
103            return Err(Error::Unsupported(
104                "baracuda-kernels::ReduceToPlan: descriptor element != type parameter T",
105            ));
106        }
107        if N > 8 {
108            return Err(Error::Unsupported(
109                "baracuda-kernels::ReduceToPlan: tensor rank > 8 not supported \
110                 (kernel param block fixes MAX_RANK = 8)",
111            ));
112        }
113        for d in 0..N {
114            if desc.input_shape[d] < 0 || desc.output_shape[d] < 0 {
115                return Err(Error::InvalidProblem(
116                    "baracuda-kernels::ReduceToPlan: shape dims must be non-negative",
117                ));
118            }
119            // Broadcast-reverse contract: every output dim is either
120            // kept (== input dim) or reduced (== 1).
121            if desc.output_shape[d] != 1 && desc.output_shape[d] != desc.input_shape[d] {
122                return Err(Error::InvalidProblem(
123                    "baracuda-kernels::ReduceToPlan: per-dim contract violated — \
124                     output_shape[d] must be 1 (reduced) or equal input_shape[d] (kept)",
125                ));
126            }
127        }
128
129        // Supported matrix:
130        //   {Sum, Max, Min, Prod} × {f32, f16, bf16, f64}   (16 cells)
131        // The match arms in `run` remain the authoritative dispatch
132        // table; the unreachable `_ =>` arm catches any future drift.
133        let op_in_scope = matches!(
134            desc.op,
135            ReduceToOp::Sum | ReduceToOp::Max | ReduceToOp::Min | ReduceToOp::Prod
136        );
137        let dtype_in_scope = matches!(
138            T::KIND,
139            ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
140        );
141        let supported = op_in_scope && dtype_in_scope;
142        if !supported {
143            return Err(Error::Unsupported(
144                "baracuda-kernels::ReduceToPlan: supported matrix is \
145                 {Sum, Max, Min, Prod} × {f32, f16, bf16, f64}",
146            ));
147        }
148
149        // One thread per output cell, sequential walk over the
150        // broadcast set in a fixed order — deterministic and
151        // bit-stable on the same hardware. f32 / f16 / bf16 accumulate
152        // in f32; f64 keeps everything in double (see module docs).
153        let (math_precision, accumulator) = if T::KIND == ElementKind::F64 {
154            (MathPrecision::F64, ElementKind::F64)
155        } else {
156            (MathPrecision::F32, ElementKind::F32)
157        };
158        let precision_guarantee = PrecisionGuarantee {
159            math_precision,
160            accumulator,
161            bit_stable_on_same_hardware: true,
162            deterministic: true,
163        };
164        let sku = KernelSku {
165            category: OpCategory::Reduction,
166            op: desc.op as u16,
167            element: T::KIND,
168            aux_element: None,
169            layout: None,
170            epilogue: None,
171            arch: ArchSku::Sm80,
172            backend: BackendKind::Bespoke,
173            precision_guarantee,
174        };
175        Ok(Self {
176            desc: *desc,
177            sku,
178            _marker: PhantomData,
179        })
180    }
181
182    /// Validate args.
183    pub fn can_implement(&self, args: &ReduceToArgs<'_, T, N>) -> Result<()> {
184        if args.x.shape != self.desc.input_shape {
185            return Err(Error::InvalidProblem(
186                "baracuda-kernels::ReduceToPlan: X shape mismatch with descriptor input_shape",
187            ));
188        }
189        if args.y.shape != self.desc.output_shape {
190            return Err(Error::InvalidProblem(
191                "baracuda-kernels::ReduceToPlan: Y shape mismatch with descriptor output_shape",
192            ));
193        }
194        // The kernel writes `dst[out_id]` by linear index — the output
195        // must be a plain contiguous allocation over output_shape. The
196        // input may be arbitrarily strided.
197        if !args.y.is_contiguous() {
198            return Err(Error::InvalidProblem(
199                "baracuda-kernels::ReduceToPlan: Y must be contiguous over output_shape \
200                 (the kernel writes by linear output index)",
201            ));
202        }
203        let y_numel = args.y.numel();
204        let y_len = args.y.data.len() as i64;
205        if y_len < y_numel {
206            return Err(Error::BufferTooSmall {
207                needed: y_numel as usize,
208                got: y_len as usize,
209            });
210        }
211        // Input bound: `x` may be an arbitrary strided view — including
212        // stride-0 broadcast dims, where `numel` legitimately exceeds
213        // the distinct storage extent — so the right bound is the
214        // reachable SPAN `1 + Σ_d (shape[d]-1)·stride[d]`, not `numel`.
215        // Negative strides can never index in-bounds (TensorRef has no
216        // base offset; the data pointer IS the slice start).
217        if args.x.numel() > 0 {
218            let mut span: i64 = 1;
219            for d in 0..N {
220                let extent = self.desc.input_shape[d] as i64;
221                if extent > 1 {
222                    let stride = args.x.stride[d];
223                    if stride < 0 {
224                        return Err(Error::InvalidProblem(
225                            "baracuda-kernels::ReduceToPlan: negative input strides walk \
226                             before the buffer base (TensorRef has no base offset)",
227                        ));
228                    }
229                    span += (extent - 1) * stride;
230                }
231            }
232            let x_len = args.x.data.len() as i64;
233            if x_len < span {
234                return Err(Error::BufferTooSmall {
235                    needed: span as usize,
236                    got: x_len as usize,
237                });
238            }
239        }
240        Ok(())
241    }
242
243    /// Workspace size in bytes. Always `0` — the per-output-cell
244    /// kernel keeps its accumulator in registers.
245    #[inline]
246    pub fn workspace_size(&self) -> usize {
247        0
248    }
249
250    /// Identity of the kernel this plan picked.
251    #[inline]
252    pub fn sku(&self) -> KernelSku {
253        self.sku
254    }
255
256    /// Numerical guarantees for this plan's kernel.
257    #[inline]
258    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
259        self.sku.precision_guarantee
260    }
261
262    /// Launch.
263    pub fn run(
264        &self,
265        stream: &Stream,
266        _workspace: Workspace<'_>,
267        args: ReduceToArgs<'_, T, N>,
268    ) -> Result<()> {
269        self.can_implement(&args)?;
270        if args.y.numel() == 0 {
271            return Ok(());
272        }
273        let x_ptr = args.x.data.as_raw().0 as *const c_void;
274        let y_ptr = args.y.data.as_raw().0 as *mut c_void;
275        let stream_ptr = stream.as_raw() as *mut c_void;
276
277        let input_shape = self.desc.input_shape;
278        let input_stride = args.x.stride;
279        let output_shape = self.desc.output_shape;
280        let rank = N as i32;
281
282        // Helper: every reduce-to FFI symbol shares the same parameter
283        // shape (the kernel template is shared). The macro picks the
284        // right symbol from (op, dtype).
285        macro_rules! dispatch {
286            ($sym:ident) => {{
287                unsafe {
288                    baracuda_kernels_sys::$sym(
289                        x_ptr,
290                        y_ptr,
291                        input_shape.as_ptr(),
292                        input_stride.as_ptr(),
293                        rank,
294                        output_shape.as_ptr(),
295                        core::ptr::null_mut(),
296                        0,
297                        stream_ptr,
298                    )
299                }
300            }};
301        }
302
303        let status = match (self.desc.op, T::KIND) {
304            // Sum
305            (ReduceToOp::Sum, ElementKind::F32) => {
306                dispatch!(baracuda_kernels_reduce_sum_to_f32_run)
307            }
308            (ReduceToOp::Sum, ElementKind::F16) => {
309                dispatch!(baracuda_kernels_reduce_sum_to_f16_run)
310            }
311            (ReduceToOp::Sum, ElementKind::Bf16) => {
312                dispatch!(baracuda_kernels_reduce_sum_to_bf16_run)
313            }
314            (ReduceToOp::Sum, ElementKind::F64) => {
315                dispatch!(baracuda_kernels_reduce_sum_to_f64_run)
316            }
317            // Max
318            (ReduceToOp::Max, ElementKind::F32) => {
319                dispatch!(baracuda_kernels_reduce_max_to_f32_run)
320            }
321            (ReduceToOp::Max, ElementKind::F16) => {
322                dispatch!(baracuda_kernels_reduce_max_to_f16_run)
323            }
324            (ReduceToOp::Max, ElementKind::Bf16) => {
325                dispatch!(baracuda_kernels_reduce_max_to_bf16_run)
326            }
327            (ReduceToOp::Max, ElementKind::F64) => {
328                dispatch!(baracuda_kernels_reduce_max_to_f64_run)
329            }
330            // Min
331            (ReduceToOp::Min, ElementKind::F32) => {
332                dispatch!(baracuda_kernels_reduce_min_to_f32_run)
333            }
334            (ReduceToOp::Min, ElementKind::F16) => {
335                dispatch!(baracuda_kernels_reduce_min_to_f16_run)
336            }
337            (ReduceToOp::Min, ElementKind::Bf16) => {
338                dispatch!(baracuda_kernels_reduce_min_to_bf16_run)
339            }
340            (ReduceToOp::Min, ElementKind::F64) => {
341                dispatch!(baracuda_kernels_reduce_min_to_f64_run)
342            }
343            // Prod
344            (ReduceToOp::Prod, ElementKind::F32) => {
345                dispatch!(baracuda_kernels_reduce_prod_to_f32_run)
346            }
347            (ReduceToOp::Prod, ElementKind::F16) => {
348                dispatch!(baracuda_kernels_reduce_prod_to_f16_run)
349            }
350            (ReduceToOp::Prod, ElementKind::Bf16) => {
351                dispatch!(baracuda_kernels_reduce_prod_to_bf16_run)
352            }
353            (ReduceToOp::Prod, ElementKind::F64) => {
354                dispatch!(baracuda_kernels_reduce_prod_to_f64_run)
355            }
356            _ => {
357                return Err(Error::Unsupported(
358                    "baracuda-kernels::ReduceToPlan::run reached an unimplemented \
359                     (op, dtype) pair — select() should have caught this",
360                ));
361            }
362        };
363        map_status(status)
364    }
365}
366
367fn map_status(code: i32) -> Result<()> {
368    match code {
369        0 => Ok(()),
370        1 => Err(Error::MisalignedOperand),
371        2 => Err(Error::InvalidProblem(
372            "baracuda-kernels-sys reported invalid problem",
373        )),
374        3 => Err(Error::Unsupported(
375            "baracuda-kernels-sys reported unsupported configuration",
376        )),
377        4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
378        n => Err(Error::CutlassInternal(n)),
379    }
380}