baracuda_kernels/reduce/reduce_to.rs
1//! Broadcast-reverse reduction plan.
2//!
3//! The autograd primitive that undoes a forward `BroadcastTo`: for
4//! each output cell, reduce every input cell that broadcasts TO it.
5//! The reduced dims are every dim where `output_shape[d] == 1` while
6//! `input_shape[d] != 1` — an arbitrary *set* of axes collapses in a
7//! single launch (contrast [`crate::ReducePlan`], which reduces one
8//! `reduce_axis` per launch). Output keeps the input's rank with the
9//! reduced dims at size 1 (keepdim convention).
10//!
11//! **Wired matrix**: `{Sum, Max, Min, Prod} × {f32, f16, bf16, f64}`
12//! — 16 (op, dtype) cells over the Phase 31 / Phase 37
13//! `baracuda_kernels_reduce_{sum,max,min,prod}_to_*` FFI symbols
14//! (Phase 74 closes the facade gap — the symbols shipped without a
15//! plan-level entry, which hid the capability from plan-surface
16//! audits). The kernel template is shared (one thread per output
17//! cell, sequential walk over the broadcast set); the per-op policy
18//! supplies the identity + combine step.
19//!
20//! **Empty reduce sets** (any reduced `input_shape[d] == 0`): the
21//! kernel writes the op's identity — `0` for Sum, `1` for Prod,
22//! `-FLT_MAX` / `-DBL_MAX` for Max, `+FLT_MAX` / `+DBL_MAX` for Min.
23//! For f32 / f64 outputs that is the most-extreme *finite* value;
24//! for f16 / bf16 the f32 identity overflows the storage dtype on
25//! the final narrowing store and lands as `∓inf`. See [`ReduceToOp`].
26//!
27//! **Layout**: the input may be arbitrarily strided (transposed /
28//! sliced views — common in autograd traces); its strides pass
29//! through to the kernel. The output MUST be contiguous over
30//! `output_shape` (the kernel writes `dst[out_id]` by linear index;
31//! validated in `can_implement`).
32//!
33//! **Workspace**: none — the per-output-cell kernel keeps the running
34//! accumulator in registers.
35//!
36//! **Precision**: deterministic, bit-stable on the same hardware (no
37//! atomic-add; sequential per-cell accumulation has a fixed order).
38//! f16 / bf16 accumulate in f32 (Sum / Prod widen per the PyTorch
39//! convention; Max / Min compare in f32, which is value-preserving);
40//! f64 keeps everything in double.
41
42use core::ffi::c_void;
43use core::marker::PhantomData;
44
45use baracuda_cutlass::{Error, Result};
46use baracuda_driver::Stream;
47use baracuda_kernels_types::{
48 ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
49 PlanPreference, PrecisionGuarantee, ReduceToOp, TensorMut, TensorRef, Workspace,
50};
51
52/// Descriptor for a broadcast-reverse reduction.
53///
54/// `input_shape` is the source extents; `output_shape` is the target
55/// extents — same rank, with every reduced dim collapsed to size 1
56/// (the caller left-pads `output_shape` with 1s if the forward
57/// broadcast added leading dims). Per-dim constraint:
58/// `output_shape[d] == 1 || output_shape[d] == input_shape[d]`.
59#[derive(Copy, Clone, Debug)]
60pub struct ReduceToDescriptor<const N: usize> {
61 /// Which reduction to apply over each output cell's broadcast set.
62 pub op: ReduceToOp,
63 /// Input tensor shape.
64 pub input_shape: [i32; N],
65 /// Output tensor shape — `input_shape` with the reduced dims
66 /// collapsed to 1.
67 pub output_shape: [i32; N],
68 /// Element type.
69 pub element: ElementKind,
70}
71
72/// Args bundle for a broadcast-reverse reduction launch.
73///
74/// `x.shape` must match `desc.input_shape`; arbitrary (non-contiguous)
75/// strides are fine — they pass through to the kernel. `y.shape` must
76/// match `desc.output_shape` and `y` MUST be contiguous.
77pub struct ReduceToArgs<'a, T: Element, const N: usize> {
78 /// Input tensor — may be a strided (transposed / sliced) view.
79 pub x: TensorRef<'a, T, N>,
80 /// Output tensor — contiguous over `desc.output_shape`.
81 pub y: TensorMut<'a, T, N>,
82}
83
84/// Broadcast-reverse reduction plan — see module docs for the wired
85/// matrix, workspace, and precision guarantees.
86///
87/// `T: Element` is the element type (`f32` / `f64` / `f16` / `bf16`).
88/// `const N: usize` is the tensor rank (input and output share it).
89pub struct ReduceToPlan<T: Element, const N: usize> {
90 desc: ReduceToDescriptor<N>,
91 sku: KernelSku,
92 _marker: PhantomData<T>,
93}
94
95impl<T: Element, const N: usize> ReduceToPlan<T, N> {
96 /// Pick a kernel for `desc`.
97 pub fn select(
98 _stream: &Stream,
99 desc: &ReduceToDescriptor<N>,
100 _pref: PlanPreference,
101 ) -> Result<Self> {
102 if desc.element != T::KIND {
103 return Err(Error::Unsupported(
104 "baracuda-kernels::ReduceToPlan: descriptor element != type parameter T",
105 ));
106 }
107 if N > 8 {
108 return Err(Error::Unsupported(
109 "baracuda-kernels::ReduceToPlan: tensor rank > 8 not supported \
110 (kernel param block fixes MAX_RANK = 8)",
111 ));
112 }
113 for d in 0..N {
114 if desc.input_shape[d] < 0 || desc.output_shape[d] < 0 {
115 return Err(Error::InvalidProblem(
116 "baracuda-kernels::ReduceToPlan: shape dims must be non-negative",
117 ));
118 }
119 // Broadcast-reverse contract: every output dim is either
120 // kept (== input dim) or reduced (== 1).
121 if desc.output_shape[d] != 1 && desc.output_shape[d] != desc.input_shape[d] {
122 return Err(Error::InvalidProblem(
123 "baracuda-kernels::ReduceToPlan: per-dim contract violated — \
124 output_shape[d] must be 1 (reduced) or equal input_shape[d] (kept)",
125 ));
126 }
127 }
128
129 // Supported matrix:
130 // {Sum, Max, Min, Prod} × {f32, f16, bf16, f64} (16 cells)
131 // The match arms in `run` remain the authoritative dispatch
132 // table; the unreachable `_ =>` arm catches any future drift.
133 let op_in_scope = matches!(
134 desc.op,
135 ReduceToOp::Sum | ReduceToOp::Max | ReduceToOp::Min | ReduceToOp::Prod
136 );
137 let dtype_in_scope = matches!(
138 T::KIND,
139 ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
140 );
141 let supported = op_in_scope && dtype_in_scope;
142 if !supported {
143 return Err(Error::Unsupported(
144 "baracuda-kernels::ReduceToPlan: supported matrix is \
145 {Sum, Max, Min, Prod} × {f32, f16, bf16, f64}",
146 ));
147 }
148
149 // One thread per output cell, sequential walk over the
150 // broadcast set in a fixed order — deterministic and
151 // bit-stable on the same hardware. f32 / f16 / bf16 accumulate
152 // in f32; f64 keeps everything in double (see module docs).
153 let (math_precision, accumulator) = if T::KIND == ElementKind::F64 {
154 (MathPrecision::F64, ElementKind::F64)
155 } else {
156 (MathPrecision::F32, ElementKind::F32)
157 };
158 let precision_guarantee = PrecisionGuarantee {
159 math_precision,
160 accumulator,
161 bit_stable_on_same_hardware: true,
162 deterministic: true,
163 };
164 let sku = KernelSku {
165 category: OpCategory::Reduction,
166 op: desc.op as u16,
167 element: T::KIND,
168 aux_element: None,
169 layout: None,
170 epilogue: None,
171 arch: ArchSku::Sm80,
172 backend: BackendKind::Bespoke,
173 precision_guarantee,
174 };
175 Ok(Self {
176 desc: *desc,
177 sku,
178 _marker: PhantomData,
179 })
180 }
181
182 /// Validate args.
183 pub fn can_implement(&self, args: &ReduceToArgs<'_, T, N>) -> Result<()> {
184 if args.x.shape != self.desc.input_shape {
185 return Err(Error::InvalidProblem(
186 "baracuda-kernels::ReduceToPlan: X shape mismatch with descriptor input_shape",
187 ));
188 }
189 if args.y.shape != self.desc.output_shape {
190 return Err(Error::InvalidProblem(
191 "baracuda-kernels::ReduceToPlan: Y shape mismatch with descriptor output_shape",
192 ));
193 }
194 // The kernel writes `dst[out_id]` by linear index — the output
195 // must be a plain contiguous allocation over output_shape. The
196 // input may be arbitrarily strided.
197 if !args.y.is_contiguous() {
198 return Err(Error::InvalidProblem(
199 "baracuda-kernels::ReduceToPlan: Y must be contiguous over output_shape \
200 (the kernel writes by linear output index)",
201 ));
202 }
203 let y_numel = args.y.numel();
204 let y_len = args.y.data.len() as i64;
205 if y_len < y_numel {
206 return Err(Error::BufferTooSmall {
207 needed: y_numel as usize,
208 got: y_len as usize,
209 });
210 }
211 // Input bound: `x` may be an arbitrary strided view — including
212 // stride-0 broadcast dims, where `numel` legitimately exceeds
213 // the distinct storage extent — so the right bound is the
214 // reachable SPAN `1 + Σ_d (shape[d]-1)·stride[d]`, not `numel`.
215 // Negative strides can never index in-bounds (TensorRef has no
216 // base offset; the data pointer IS the slice start).
217 if args.x.numel() > 0 {
218 let mut span: i64 = 1;
219 for d in 0..N {
220 let extent = self.desc.input_shape[d] as i64;
221 if extent > 1 {
222 let stride = args.x.stride[d];
223 if stride < 0 {
224 return Err(Error::InvalidProblem(
225 "baracuda-kernels::ReduceToPlan: negative input strides walk \
226 before the buffer base (TensorRef has no base offset)",
227 ));
228 }
229 span += (extent - 1) * stride;
230 }
231 }
232 let x_len = args.x.data.len() as i64;
233 if x_len < span {
234 return Err(Error::BufferTooSmall {
235 needed: span as usize,
236 got: x_len as usize,
237 });
238 }
239 }
240 Ok(())
241 }
242
243 /// Workspace size in bytes. Always `0` — the per-output-cell
244 /// kernel keeps its accumulator in registers.
245 #[inline]
246 pub fn workspace_size(&self) -> usize {
247 0
248 }
249
250 /// Identity of the kernel this plan picked.
251 #[inline]
252 pub fn sku(&self) -> KernelSku {
253 self.sku
254 }
255
256 /// Numerical guarantees for this plan's kernel.
257 #[inline]
258 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
259 self.sku.precision_guarantee
260 }
261
262 /// Launch.
263 pub fn run(
264 &self,
265 stream: &Stream,
266 _workspace: Workspace<'_>,
267 args: ReduceToArgs<'_, T, N>,
268 ) -> Result<()> {
269 self.can_implement(&args)?;
270 if args.y.numel() == 0 {
271 return Ok(());
272 }
273 let x_ptr = args.x.data.as_raw().0 as *const c_void;
274 let y_ptr = args.y.data.as_raw().0 as *mut c_void;
275 let stream_ptr = stream.as_raw() as *mut c_void;
276
277 let input_shape = self.desc.input_shape;
278 let input_stride = args.x.stride;
279 let output_shape = self.desc.output_shape;
280 let rank = N as i32;
281
282 // Helper: every reduce-to FFI symbol shares the same parameter
283 // shape (the kernel template is shared). The macro picks the
284 // right symbol from (op, dtype).
285 macro_rules! dispatch {
286 ($sym:ident) => {{
287 unsafe {
288 baracuda_kernels_sys::$sym(
289 x_ptr,
290 y_ptr,
291 input_shape.as_ptr(),
292 input_stride.as_ptr(),
293 rank,
294 output_shape.as_ptr(),
295 core::ptr::null_mut(),
296 0,
297 stream_ptr,
298 )
299 }
300 }};
301 }
302
303 let status = match (self.desc.op, T::KIND) {
304 // Sum
305 (ReduceToOp::Sum, ElementKind::F32) => {
306 dispatch!(baracuda_kernels_reduce_sum_to_f32_run)
307 }
308 (ReduceToOp::Sum, ElementKind::F16) => {
309 dispatch!(baracuda_kernels_reduce_sum_to_f16_run)
310 }
311 (ReduceToOp::Sum, ElementKind::Bf16) => {
312 dispatch!(baracuda_kernels_reduce_sum_to_bf16_run)
313 }
314 (ReduceToOp::Sum, ElementKind::F64) => {
315 dispatch!(baracuda_kernels_reduce_sum_to_f64_run)
316 }
317 // Max
318 (ReduceToOp::Max, ElementKind::F32) => {
319 dispatch!(baracuda_kernels_reduce_max_to_f32_run)
320 }
321 (ReduceToOp::Max, ElementKind::F16) => {
322 dispatch!(baracuda_kernels_reduce_max_to_f16_run)
323 }
324 (ReduceToOp::Max, ElementKind::Bf16) => {
325 dispatch!(baracuda_kernels_reduce_max_to_bf16_run)
326 }
327 (ReduceToOp::Max, ElementKind::F64) => {
328 dispatch!(baracuda_kernels_reduce_max_to_f64_run)
329 }
330 // Min
331 (ReduceToOp::Min, ElementKind::F32) => {
332 dispatch!(baracuda_kernels_reduce_min_to_f32_run)
333 }
334 (ReduceToOp::Min, ElementKind::F16) => {
335 dispatch!(baracuda_kernels_reduce_min_to_f16_run)
336 }
337 (ReduceToOp::Min, ElementKind::Bf16) => {
338 dispatch!(baracuda_kernels_reduce_min_to_bf16_run)
339 }
340 (ReduceToOp::Min, ElementKind::F64) => {
341 dispatch!(baracuda_kernels_reduce_min_to_f64_run)
342 }
343 // Prod
344 (ReduceToOp::Prod, ElementKind::F32) => {
345 dispatch!(baracuda_kernels_reduce_prod_to_f32_run)
346 }
347 (ReduceToOp::Prod, ElementKind::F16) => {
348 dispatch!(baracuda_kernels_reduce_prod_to_f16_run)
349 }
350 (ReduceToOp::Prod, ElementKind::Bf16) => {
351 dispatch!(baracuda_kernels_reduce_prod_to_bf16_run)
352 }
353 (ReduceToOp::Prod, ElementKind::F64) => {
354 dispatch!(baracuda_kernels_reduce_prod_to_f64_run)
355 }
356 _ => {
357 return Err(Error::Unsupported(
358 "baracuda-kernels::ReduceToPlan::run reached an unimplemented \
359 (op, dtype) pair — select() should have caught this",
360 ));
361 }
362 };
363 map_status(status)
364 }
365}
366
367fn map_status(code: i32) -> Result<()> {
368 match code {
369 0 => Ok(()),
370 1 => Err(Error::MisalignedOperand),
371 2 => Err(Error::InvalidProblem(
372 "baracuda-kernels-sys reported invalid problem",
373 )),
374 3 => Err(Error::Unsupported(
375 "baracuda-kernels-sys reported unsupported configuration",
376 )),
377 4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
378 n => Err(Error::CutlassInternal(n)),
379 }
380}