Skip to main content

baracuda_kernels/reduce/
axis_backward.rs

1//! Backward plan for single-axis reductions.
2//!
3//! Sibling of [`crate::ReducePlan`] for gradient computation. Today
4//! only [`ReduceKind::Sum`] is wired — the Phase 4 reduction BW
5//! trailblazer.
6//!
7//! **Sum BW** is the simplest reduction backward: the gradient broadcasts
8//! `dy` across the reduced axis. With keepdim convention,
9//! `dy.shape[reduce_axis] = 1` and we want
10//! `dx[c] = dy[c with c[reduce_axis] = 0]` for every coord `c` in dx.
11//!
12//! Implementation: a strided-copy kernel that uses
13//! `stride_dy[reduce_axis] = 0` so reading varies-coord-on-reduced-axis
14//! collapses to the singleton dy slot. The Rust dispatcher constructs
15//! this stride layout from the args' natural strides — the caller hands
16//! in dy with whatever strides their contig allocator gave it, and the
17//! plan overrides the reduce-axis stride to 0 before launch.
18//!
19//! Other reductions ([`ReduceKind::Mean`], `Max`, `Min`, `Prod`,
20//! `Norm2`, ...) land in fanout. Mean BW is `Sum BW × (1/k)` where k
21//! is the reduced extent (next sub-wave). Max/Min BW need to mask by
22//! `(x == y)`; Prod BW needs `y / x` per cell. Each has its own
23//! kernel template.
24//!
25//! Trailblazer constraints: contig-only on dx (the kernel writes
26//! linearly into dx's coord space); arbitrary strides accepted on dy
27//! but in practice the caller passes contig keepdim dy.
28
29use core::ffi::c_void;
30use core::marker::PhantomData;
31
32use baracuda_cutlass::{Error, Result};
33use baracuda_driver::Stream;
34use baracuda_kernels_types::{
35    ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
36    PlanPreference, PrecisionGuarantee, ReduceKind, TensorMut, TensorRef, Workspace,
37};
38
39/// Descriptor for a single-axis reduction backward.
40#[derive(Copy, Clone, Debug)]
41pub struct ReduceBackwardDescriptor<const N: usize> {
42    /// Which forward reduction this is the backward of.
43    pub kind: ReduceKind,
44    /// Shape of the forward input (= shape of dx).
45    pub input_shape: [i32; N],
46    /// Axis that was reduced. Must satisfy `0 <= reduce_axis < N`.
47    pub reduce_axis: u8,
48    /// Element type.
49    pub element: ElementKind,
50    /// Bessel correction for `Var` / `Std` BW only. `1` = sample
51    /// variance (PyTorch default), `0` = population variance. Ignored
52    /// by other reductions.
53    pub correction: i32,
54}
55
56impl<const N: usize> ReduceBackwardDescriptor<N> {
57    /// Compute the keepdim dy shape (input shape with reduce_axis = 1).
58    pub fn dy_shape(&self) -> [i32; N] {
59        let mut out = self.input_shape;
60        out[self.reduce_axis as usize] = 1;
61        out
62    }
63}
64
65/// Args bundle for a reduction-backward launch.
66///
67/// `dy.shape` must equal the keepdim form (input shape with the reduced
68/// axis collapsed to 1). `dx.shape` must equal `input_shape`. Both
69/// fully contiguous (trailblazer constraint).
70///
71/// Save requirements vary by op:
72/// - Sum, Mean: neither save needed; pass `x = None, y = None`.
73/// - Max, Min: BOTH saves required — `x` is the forward input (full
74///   shape), `y` is the forward output (keepdim shape). Gradient flows
75///   to every position where `x[c] == y[c_reduced]` (split-across-ties
76///   semantic; matches JAX, differs from PyTorch's first-index pick).
77/// - Prod, Norm2 (future): same dual-save requirement.
78pub struct ReduceBackwardArgs<'a, T: Element, const N: usize> {
79    /// Upstream gradient — keepdim shape matching forward output.
80    pub dy: TensorRef<'a, T, N>,
81    /// Saved forward input — full input shape. Required by Max/Min/
82    /// Prod/Norm2; pass `None` for Sum/Mean.
83    pub x: Option<TensorRef<'a, T, N>>,
84    /// Saved forward output — keepdim shape (= dy.shape). Required by
85    /// Max/Min/Prod/Norm2; pass `None` for Sum/Mean.
86    pub y: Option<TensorRef<'a, T, N>>,
87    /// Gradient w.r.t. the forward input — full input shape.
88    pub dx: TensorMut<'a, T, N>,
89}
90
91/// Single-axis reduction backward plan.
92pub struct ReduceBackwardPlan<T: Element, const N: usize> {
93    desc: ReduceBackwardDescriptor<N>,
94    sku: KernelSku,
95    _marker: PhantomData<T>,
96}
97
98#[inline]
99fn op_needs_saves(kind: ReduceKind) -> bool {
100    // Max/Min/Prod/Norm2/Std/LogSumExp reference both forward input
101    // and forward output in their BW formulas. Var references only
102    // saved x but takes a `y` slot for ABI uniformity with Std — we
103    // still require a non-null `y` so callers stage both consistently.
104    // Sum/Mean need neither.
105    matches!(
106        kind,
107        ReduceKind::Max
108            | ReduceKind::Min
109            | ReduceKind::Prod
110            | ReduceKind::Norm2
111            | ReduceKind::Var
112            | ReduceKind::Std
113            | ReduceKind::LogSumExp
114    )
115}
116
117impl<T: Element, const N: usize> ReduceBackwardPlan<T, N> {
118    /// Pick a kernel.
119    pub fn select(
120        _stream: &Stream,
121        desc: &ReduceBackwardDescriptor<N>,
122        _pref: PlanPreference,
123    ) -> Result<Self> {
124        if desc.element != T::KIND {
125            return Err(Error::Unsupported(
126                "baracuda-kernels::ReduceBackwardPlan: descriptor element != T",
127            ));
128        }
129        if (desc.reduce_axis as usize) >= N {
130            return Err(Error::InvalidProblem(
131                "baracuda-kernels::ReduceBackwardPlan: reduce_axis out of range for rank N",
132            ));
133        }
134        for &d in desc.input_shape.iter() {
135            if d < 0 {
136                return Err(Error::InvalidProblem(
137                    "baracuda-kernels::ReduceBackwardPlan: shape dims must be non-negative",
138                ));
139            }
140        }
141        if N > 8 {
142            return Err(Error::Unsupported(
143                "baracuda-kernels::ReduceBackwardPlan: tensor rank > 8 not supported \
144                 (kernel param block fixes MAX_RANK = 8)",
145            ));
146        }
147
148        // Wired today:
149        //   `{Sum, Mean, Max, Min, Prod, Norm2, LogSumExp, Var, Std}
150        //      × {f32, f16, bf16, f64}`
151        // Max/Min use a single unified kernel (the routing logic is
152        // identical: `x[c] == y[c_reduced]`). Prod, Norm2, and
153        // LogSumExp each have their own dual-save kernel with a
154        // different formula (LogSumExp computes `dy * exp(x - y)`).
155        // Var / Std (Welford BW) are templated on T; internal
156        // accumulation is f32 for f32/f16/bf16 and f64 for f64.
157        let dtype_in_fp_family = matches!(
158            T::KIND,
159            ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
160        );
161        let kind_in_scope = matches!(
162            desc.kind,
163            ReduceKind::Sum
164                | ReduceKind::Mean
165                | ReduceKind::Max
166                | ReduceKind::Min
167                | ReduceKind::Prod
168                | ReduceKind::Norm2
169                | ReduceKind::LogSumExp
170                | ReduceKind::Var
171                | ReduceKind::Std
172        );
173        let supported = kind_in_scope && dtype_in_fp_family;
174        if !supported {
175            return Err(Error::Unsupported(
176                "baracuda-kernels::ReduceBackwardPlan: wired today: \
177                 `{Sum, Mean, Max, Min, Prod, Norm2, LogSumExp, Var, Std} \
178                  × {f32, f16, bf16, f64}`; \
179                 other (kind, dtype) pairs land in later fanout",
180            ));
181        }
182
183        let precision_guarantee = PrecisionGuarantee {
184            math_precision: MathPrecision::F32,
185            accumulator: ElementKind::F32,
186            bit_stable_on_same_hardware: true,
187            deterministic: true,
188        };
189        let sku = KernelSku {
190            category: OpCategory::Reduction,
191            op: desc.kind as u16,
192            element: T::KIND,
193            aux_element: None,
194            layout: None,
195            epilogue: None,
196            arch: ArchSku::Sm80,
197            backend: BackendKind::Bespoke,
198            precision_guarantee,
199        };
200        Ok(Self {
201            desc: *desc,
202            sku,
203            _marker: PhantomData,
204        })
205    }
206
207    /// Validate args.
208    pub fn can_implement(&self, args: &ReduceBackwardArgs<'_, T, N>) -> Result<()> {
209        if args.dx.shape != self.desc.input_shape {
210            return Err(Error::InvalidProblem(
211                "baracuda-kernels::ReduceBackwardPlan: dx shape must equal input_shape",
212            ));
213        }
214        let expected_dy_shape = self.desc.dy_shape();
215        if args.dy.shape != expected_dy_shape {
216            return Err(Error::InvalidProblem(
217                "baracuda-kernels::ReduceBackwardPlan: dy shape must equal input_shape \
218                 with reduce_axis collapsed to 1 (keepdim form)",
219            ));
220        }
221        if !args.dy.is_contiguous() || !args.dx.is_contiguous() {
222            return Err(Error::Unsupported(
223                "baracuda-kernels::ReduceBackwardPlan: trailblazer requires contiguous \
224                 dy / dx; strided fanout lands later",
225            ));
226        }
227        let dx_numel = args.dx.numel();
228        let dy_numel = args.dy.numel();
229        if (args.dx.data.len() as i64) < dx_numel {
230            return Err(Error::BufferTooSmall {
231                needed: dx_numel as usize,
232                got: args.dx.data.len(),
233            });
234        }
235        if (args.dy.data.len() as i64) < dy_numel {
236            return Err(Error::BufferTooSmall {
237                needed: dy_numel as usize,
238                got: args.dy.data.len(),
239            });
240        }
241        // Max/Min require BOTH saved-x (forward input, full shape) and
242        // saved-y (forward output, keepdim shape).
243        if op_needs_saves(self.desc.kind) {
244            let x = args.x.as_ref().ok_or(Error::InvalidProblem(
245                "baracuda-kernels::ReduceBackwardPlan: this op requires saved input `x`",
246            ))?;
247            let y = args.y.as_ref().ok_or(Error::InvalidProblem(
248                "baracuda-kernels::ReduceBackwardPlan: this op requires saved output `y`",
249            ))?;
250            if x.shape != self.desc.input_shape {
251                return Err(Error::InvalidProblem(
252                    "baracuda-kernels::ReduceBackwardPlan: saved `x` shape must equal input_shape",
253                ));
254            }
255            if y.shape != expected_dy_shape {
256                return Err(Error::InvalidProblem(
257                    "baracuda-kernels::ReduceBackwardPlan: saved `y` shape must equal \
258                     keepdim form (input_shape with reduce_axis = 1)",
259                ));
260            }
261            if !x.is_contiguous() || !y.is_contiguous() {
262                return Err(Error::Unsupported(
263                    "baracuda-kernels::ReduceBackwardPlan: saved x / y must be contiguous \
264                     (strided fanout lands later)",
265                ));
266            }
267            if (x.data.len() as i64) < dx_numel {
268                return Err(Error::BufferTooSmall {
269                    needed: dx_numel as usize,
270                    got: x.data.len(),
271                });
272            }
273            if (y.data.len() as i64) < dy_numel {
274                return Err(Error::BufferTooSmall {
275                    needed: dy_numel as usize,
276                    got: y.data.len(),
277                });
278            }
279        }
280        Ok(())
281    }
282
283    /// Workspace size in bytes.
284    #[inline]
285    pub fn workspace_size(&self) -> usize {
286        0
287    }
288
289    /// Kernel SKU identity.
290    #[inline]
291    pub fn sku(&self) -> KernelSku {
292        self.sku
293    }
294
295    /// Numerical guarantees.
296    #[inline]
297    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
298        self.sku.precision_guarantee
299    }
300
301    /// Launch.
302    pub fn run(
303        &self,
304        stream: &Stream,
305        _workspace: Workspace<'_>,
306        args: ReduceBackwardArgs<'_, T, N>,
307    ) -> Result<()> {
308        self.can_implement(&args)?;
309        let numel = args.dx.numel();
310        if numel == 0 {
311            return Ok(());
312        }
313        // Construct the broadcast dy stride layout: take dy's natural
314        // strides and zero out the reduced axis. The kernel walks the
315        // full dx coord space; reading dy with stride 0 on the reduce
316        // axis collapses every reduce-axis coord to the singleton dy
317        // slot.
318        let axis = self.desc.reduce_axis as usize;
319        let mut stride_dy = args.dy.stride;
320        stride_dy[axis] = 0;
321        let shape = self.desc.input_shape;
322        let stride_dx = args.dx.stride;
323        let rank = N as i32;
324        let dy_ptr = args.dy.data.as_raw().0 as *const c_void;
325        let dx_ptr = args.dx.data.as_raw().0 as *mut c_void;
326        let stream_ptr = stream.as_raw() as *mut c_void;
327
328        let status = match (self.desc.kind, T::KIND) {
329            (ReduceKind::Sum, ElementKind::F32) => unsafe {
330                baracuda_kernels_sys::baracuda_kernels_reduce_sum_backward_f32_run(
331                    numel, rank, shape.as_ptr(), stride_dy.as_ptr(), stride_dx.as_ptr(),
332                    dy_ptr, dx_ptr, core::ptr::null_mut(), 0, stream_ptr,
333                )
334            },
335            (ReduceKind::Sum, ElementKind::F16) => unsafe {
336                baracuda_kernels_sys::baracuda_kernels_reduce_sum_backward_f16_run(
337                    numel, rank, shape.as_ptr(), stride_dy.as_ptr(), stride_dx.as_ptr(),
338                    dy_ptr, dx_ptr, core::ptr::null_mut(), 0, stream_ptr,
339                )
340            },
341            (ReduceKind::Sum, ElementKind::Bf16) => unsafe {
342                baracuda_kernels_sys::baracuda_kernels_reduce_sum_backward_bf16_run(
343                    numel, rank, shape.as_ptr(), stride_dy.as_ptr(), stride_dx.as_ptr(),
344                    dy_ptr, dx_ptr, core::ptr::null_mut(), 0, stream_ptr,
345                )
346            },
347            (ReduceKind::Sum, ElementKind::F64) => unsafe {
348                baracuda_kernels_sys::baracuda_kernels_reduce_sum_backward_f64_run(
349                    numel, rank, shape.as_ptr(), stride_dy.as_ptr(), stride_dx.as_ptr(),
350                    dy_ptr, dx_ptr, core::ptr::null_mut(), 0, stream_ptr,
351                )
352            },
353            (ReduceKind::Max, _) | (ReduceKind::Min, _) => {
354                // Both ops share one kernel: `x[c] == y[c_reduced]`
355                // identifies recipient positions regardless of whether
356                // y is a max or a min.
357                let x = args.x.as_ref().expect("Max/Min BW require saved x");
358                let y = args.y.as_ref().expect("Max/Min BW require saved y");
359                let x_ptr = x.data.as_raw().0 as *const c_void;
360                let y_ptr = y.data.as_raw().0 as *const c_void;
361                let stride_x = x.stride;
362                let mut stride_y = y.stride;
363                stride_y[axis] = 0;
364                match T::KIND {
365                    ElementKind::F32 => unsafe {
366                        baracuda_kernels_sys::baracuda_kernels_reduce_max_min_backward_f32_run(
367                            numel, rank, shape.as_ptr(),
368                            stride_dy.as_ptr(), stride_x.as_ptr(),
369                            stride_y.as_ptr(), stride_dx.as_ptr(),
370                            dy_ptr, x_ptr, y_ptr, dx_ptr,
371                            core::ptr::null_mut(), 0, stream_ptr,
372                        )
373                    },
374                    ElementKind::F16 => unsafe {
375                        baracuda_kernels_sys::baracuda_kernels_reduce_max_min_backward_f16_run(
376                            numel, rank, shape.as_ptr(),
377                            stride_dy.as_ptr(), stride_x.as_ptr(),
378                            stride_y.as_ptr(), stride_dx.as_ptr(),
379                            dy_ptr, x_ptr, y_ptr, dx_ptr,
380                            core::ptr::null_mut(), 0, stream_ptr,
381                        )
382                    },
383                    ElementKind::Bf16 => unsafe {
384                        baracuda_kernels_sys::baracuda_kernels_reduce_max_min_backward_bf16_run(
385                            numel, rank, shape.as_ptr(),
386                            stride_dy.as_ptr(), stride_x.as_ptr(),
387                            stride_y.as_ptr(), stride_dx.as_ptr(),
388                            dy_ptr, x_ptr, y_ptr, dx_ptr,
389                            core::ptr::null_mut(), 0, stream_ptr,
390                        )
391                    },
392                    ElementKind::F64 => unsafe {
393                        baracuda_kernels_sys::baracuda_kernels_reduce_max_min_backward_f64_run(
394                            numel, rank, shape.as_ptr(),
395                            stride_dy.as_ptr(), stride_x.as_ptr(),
396                            stride_y.as_ptr(), stride_dx.as_ptr(),
397                            dy_ptr, x_ptr, y_ptr, dx_ptr,
398                            core::ptr::null_mut(), 0, stream_ptr,
399                        )
400                    },
401                    _ => return Err(Error::Unsupported(
402                        "baracuda-kernels::ReduceBackwardPlan::run: Max/Min BW reached an \
403                         unimplemented dtype — select() should have caught this",
404                    )),
405                }
406            }
407            (ReduceKind::Mean, _) => {
408                // `1/k` where k = reduced extent. Computed in f64 on the
409                // host and cast to T inside the kernel.
410                let extent = self.desc.input_shape[axis] as f64;
411                if extent == 0.0 {
412                    return Err(Error::InvalidProblem(
413                        "baracuda-kernels::ReduceBackwardPlan: Mean BW requires \
414                         reduced extent > 0",
415                    ));
416                }
417                let inv_extent = 1.0_f64 / extent;
418                match T::KIND {
419                    ElementKind::F32 => unsafe {
420                        baracuda_kernels_sys::baracuda_kernels_reduce_mean_backward_f32_run(
421                            numel, rank, shape.as_ptr(), stride_dy.as_ptr(), stride_dx.as_ptr(),
422                            dy_ptr, dx_ptr, inv_extent,
423                            core::ptr::null_mut(), 0, stream_ptr,
424                        )
425                    },
426                    ElementKind::F16 => unsafe {
427                        baracuda_kernels_sys::baracuda_kernels_reduce_mean_backward_f16_run(
428                            numel, rank, shape.as_ptr(), stride_dy.as_ptr(), stride_dx.as_ptr(),
429                            dy_ptr, dx_ptr, inv_extent,
430                            core::ptr::null_mut(), 0, stream_ptr,
431                        )
432                    },
433                    ElementKind::Bf16 => unsafe {
434                        baracuda_kernels_sys::baracuda_kernels_reduce_mean_backward_bf16_run(
435                            numel, rank, shape.as_ptr(), stride_dy.as_ptr(), stride_dx.as_ptr(),
436                            dy_ptr, dx_ptr, inv_extent,
437                            core::ptr::null_mut(), 0, stream_ptr,
438                        )
439                    },
440                    ElementKind::F64 => unsafe {
441                        baracuda_kernels_sys::baracuda_kernels_reduce_mean_backward_f64_run(
442                            numel, rank, shape.as_ptr(), stride_dy.as_ptr(), stride_dx.as_ptr(),
443                            dy_ptr, dx_ptr, inv_extent,
444                            core::ptr::null_mut(), 0, stream_ptr,
445                        )
446                    },
447                    _ => return Err(Error::Unsupported(
448                        "baracuda-kernels::ReduceBackwardPlan::run: Mean BW reached an \
449                         unimplemented dtype — select() should have caught this",
450                    )),
451                }
452            }
453            (ReduceKind::Prod, _) => {
454                // `dx[c] = dy[c_reduced] * y[c_reduced] / x[c]`. Dual-save.
455                let x = args.x.as_ref().expect("Prod BW require saved x");
456                let y = args.y.as_ref().expect("Prod BW require saved y");
457                let x_ptr = x.data.as_raw().0 as *const c_void;
458                let y_ptr = y.data.as_raw().0 as *const c_void;
459                let stride_x = x.stride;
460                let mut stride_y = y.stride;
461                stride_y[axis] = 0;
462                match T::KIND {
463                    ElementKind::F32 => unsafe {
464                        baracuda_kernels_sys::baracuda_kernels_reduce_prod_backward_f32_run(
465                            numel, rank, shape.as_ptr(),
466                            stride_dy.as_ptr(), stride_x.as_ptr(),
467                            stride_y.as_ptr(), stride_dx.as_ptr(),
468                            dy_ptr, x_ptr, y_ptr, dx_ptr,
469                            core::ptr::null_mut(), 0, stream_ptr,
470                        )
471                    },
472                    ElementKind::F16 => unsafe {
473                        baracuda_kernels_sys::baracuda_kernels_reduce_prod_backward_f16_run(
474                            numel, rank, shape.as_ptr(),
475                            stride_dy.as_ptr(), stride_x.as_ptr(),
476                            stride_y.as_ptr(), stride_dx.as_ptr(),
477                            dy_ptr, x_ptr, y_ptr, dx_ptr,
478                            core::ptr::null_mut(), 0, stream_ptr,
479                        )
480                    },
481                    ElementKind::Bf16 => unsafe {
482                        baracuda_kernels_sys::baracuda_kernels_reduce_prod_backward_bf16_run(
483                            numel, rank, shape.as_ptr(),
484                            stride_dy.as_ptr(), stride_x.as_ptr(),
485                            stride_y.as_ptr(), stride_dx.as_ptr(),
486                            dy_ptr, x_ptr, y_ptr, dx_ptr,
487                            core::ptr::null_mut(), 0, stream_ptr,
488                        )
489                    },
490                    ElementKind::F64 => unsafe {
491                        baracuda_kernels_sys::baracuda_kernels_reduce_prod_backward_f64_run(
492                            numel, rank, shape.as_ptr(),
493                            stride_dy.as_ptr(), stride_x.as_ptr(),
494                            stride_y.as_ptr(), stride_dx.as_ptr(),
495                            dy_ptr, x_ptr, y_ptr, dx_ptr,
496                            core::ptr::null_mut(), 0, stream_ptr,
497                        )
498                    },
499                    _ => return Err(Error::Unsupported(
500                        "baracuda-kernels::ReduceBackwardPlan::run: Prod BW reached an \
501                         unimplemented dtype — select() should have caught this",
502                    )),
503                }
504            }
505            (ReduceKind::Var, _) | (ReduceKind::Std, _) => {
506                // Welford BW. `mean[c_reduced]` is recomputed inside the
507                // kernel (single-pass sum/n on the saved-x reduce axis).
508                // Var BW: `dx[c] = dy[c_reduced] * 2 * (x[c] - mean) / m`
509                // Std BW: `dx[c] = dy[c_reduced] * (x[c] - mean) /
510                //                  (m * y[c_reduced])`
511                // where `m = max(n - correction, 1)`. Internal Welford
512                // accumulator runs at f32 for f32/f16/bf16 and f64 for
513                // f64 (see `WelfordAcc<T>` in the kernel header).
514                let x = args
515                    .x
516                    .as_ref()
517                    .expect("Var/Std BW require saved x");
518                let y = args
519                    .y
520                    .as_ref()
521                    .expect("Var/Std BW require saved y (Var ignores it; passed for ABI uniformity)");
522                let x_ptr = x.data.as_raw().0 as *const c_void;
523                let y_ptr = y.data.as_raw().0 as *const c_void;
524                let stride_x = x.stride;
525                let mut stride_y = y.stride;
526                stride_y[axis] = 0;
527                let reduce_axis_i32 = self.desc.reduce_axis as i32;
528                let reduce_extent = self.desc.input_shape[axis];
529                let reduce_stride_x = stride_x[axis];
530                let correction = self.desc.correction;
531                match (self.desc.kind, T::KIND) {
532                    (ReduceKind::Var, ElementKind::F32) => unsafe {
533                        baracuda_kernels_sys::baracuda_kernels_reduce_var_backward_f32_run(
534                            numel, rank, shape.as_ptr(),
535                            stride_dy.as_ptr(), stride_x.as_ptr(),
536                            stride_y.as_ptr(), stride_dx.as_ptr(),
537                            dy_ptr, x_ptr, y_ptr, dx_ptr,
538                            reduce_axis_i32, reduce_extent, reduce_stride_x, correction,
539                            core::ptr::null_mut(), 0, stream_ptr,
540                        )
541                    },
542                    (ReduceKind::Var, ElementKind::F16) => unsafe {
543                        baracuda_kernels_sys::baracuda_kernels_reduce_var_backward_f16_run(
544                            numel, rank, shape.as_ptr(),
545                            stride_dy.as_ptr(), stride_x.as_ptr(),
546                            stride_y.as_ptr(), stride_dx.as_ptr(),
547                            dy_ptr, x_ptr, y_ptr, dx_ptr,
548                            reduce_axis_i32, reduce_extent, reduce_stride_x, correction,
549                            core::ptr::null_mut(), 0, stream_ptr,
550                        )
551                    },
552                    (ReduceKind::Var, ElementKind::Bf16) => unsafe {
553                        baracuda_kernels_sys::baracuda_kernels_reduce_var_backward_bf16_run(
554                            numel, rank, shape.as_ptr(),
555                            stride_dy.as_ptr(), stride_x.as_ptr(),
556                            stride_y.as_ptr(), stride_dx.as_ptr(),
557                            dy_ptr, x_ptr, y_ptr, dx_ptr,
558                            reduce_axis_i32, reduce_extent, reduce_stride_x, correction,
559                            core::ptr::null_mut(), 0, stream_ptr,
560                        )
561                    },
562                    (ReduceKind::Var, ElementKind::F64) => unsafe {
563                        baracuda_kernels_sys::baracuda_kernels_reduce_var_backward_f64_run(
564                            numel, rank, shape.as_ptr(),
565                            stride_dy.as_ptr(), stride_x.as_ptr(),
566                            stride_y.as_ptr(), stride_dx.as_ptr(),
567                            dy_ptr, x_ptr, y_ptr, dx_ptr,
568                            reduce_axis_i32, reduce_extent, reduce_stride_x, correction,
569                            core::ptr::null_mut(), 0, stream_ptr,
570                        )
571                    },
572                    (ReduceKind::Std, ElementKind::F32) => unsafe {
573                        baracuda_kernels_sys::baracuda_kernels_reduce_std_backward_f32_run(
574                            numel, rank, shape.as_ptr(),
575                            stride_dy.as_ptr(), stride_x.as_ptr(),
576                            stride_y.as_ptr(), stride_dx.as_ptr(),
577                            dy_ptr, x_ptr, y_ptr, dx_ptr,
578                            reduce_axis_i32, reduce_extent, reduce_stride_x, correction,
579                            core::ptr::null_mut(), 0, stream_ptr,
580                        )
581                    },
582                    (ReduceKind::Std, ElementKind::F16) => unsafe {
583                        baracuda_kernels_sys::baracuda_kernels_reduce_std_backward_f16_run(
584                            numel, rank, shape.as_ptr(),
585                            stride_dy.as_ptr(), stride_x.as_ptr(),
586                            stride_y.as_ptr(), stride_dx.as_ptr(),
587                            dy_ptr, x_ptr, y_ptr, dx_ptr,
588                            reduce_axis_i32, reduce_extent, reduce_stride_x, correction,
589                            core::ptr::null_mut(), 0, stream_ptr,
590                        )
591                    },
592                    (ReduceKind::Std, ElementKind::Bf16) => unsafe {
593                        baracuda_kernels_sys::baracuda_kernels_reduce_std_backward_bf16_run(
594                            numel, rank, shape.as_ptr(),
595                            stride_dy.as_ptr(), stride_x.as_ptr(),
596                            stride_y.as_ptr(), stride_dx.as_ptr(),
597                            dy_ptr, x_ptr, y_ptr, dx_ptr,
598                            reduce_axis_i32, reduce_extent, reduce_stride_x, correction,
599                            core::ptr::null_mut(), 0, stream_ptr,
600                        )
601                    },
602                    (ReduceKind::Std, ElementKind::F64) => unsafe {
603                        baracuda_kernels_sys::baracuda_kernels_reduce_std_backward_f64_run(
604                            numel, rank, shape.as_ptr(),
605                            stride_dy.as_ptr(), stride_x.as_ptr(),
606                            stride_y.as_ptr(), stride_dx.as_ptr(),
607                            dy_ptr, x_ptr, y_ptr, dx_ptr,
608                            reduce_axis_i32, reduce_extent, reduce_stride_x, correction,
609                            core::ptr::null_mut(), 0, stream_ptr,
610                        )
611                    },
612                    _ => return Err(Error::Unsupported(
613                        "baracuda-kernels::ReduceBackwardPlan::run: Var/Std BW reached an \
614                         unimplemented dtype — select() should have caught this",
615                    )),
616                }
617            }
618            (ReduceKind::Norm2, _) => {
619                // `dx[c] = dy[c_reduced] * x[c] / y[c_reduced]`. Dual-save.
620                let x = args.x.as_ref().expect("Norm2 BW require saved x");
621                let y = args.y.as_ref().expect("Norm2 BW require saved y");
622                let x_ptr = x.data.as_raw().0 as *const c_void;
623                let y_ptr = y.data.as_raw().0 as *const c_void;
624                let stride_x = x.stride;
625                let mut stride_y = y.stride;
626                stride_y[axis] = 0;
627                match T::KIND {
628                    ElementKind::F32 => unsafe {
629                        baracuda_kernels_sys::baracuda_kernels_reduce_norm2_backward_f32_run(
630                            numel, rank, shape.as_ptr(),
631                            stride_dy.as_ptr(), stride_x.as_ptr(),
632                            stride_y.as_ptr(), stride_dx.as_ptr(),
633                            dy_ptr, x_ptr, y_ptr, dx_ptr,
634                            core::ptr::null_mut(), 0, stream_ptr,
635                        )
636                    },
637                    ElementKind::F16 => unsafe {
638                        baracuda_kernels_sys::baracuda_kernels_reduce_norm2_backward_f16_run(
639                            numel, rank, shape.as_ptr(),
640                            stride_dy.as_ptr(), stride_x.as_ptr(),
641                            stride_y.as_ptr(), stride_dx.as_ptr(),
642                            dy_ptr, x_ptr, y_ptr, dx_ptr,
643                            core::ptr::null_mut(), 0, stream_ptr,
644                        )
645                    },
646                    ElementKind::Bf16 => unsafe {
647                        baracuda_kernels_sys::baracuda_kernels_reduce_norm2_backward_bf16_run(
648                            numel, rank, shape.as_ptr(),
649                            stride_dy.as_ptr(), stride_x.as_ptr(),
650                            stride_y.as_ptr(), stride_dx.as_ptr(),
651                            dy_ptr, x_ptr, y_ptr, dx_ptr,
652                            core::ptr::null_mut(), 0, stream_ptr,
653                        )
654                    },
655                    ElementKind::F64 => unsafe {
656                        baracuda_kernels_sys::baracuda_kernels_reduce_norm2_backward_f64_run(
657                            numel, rank, shape.as_ptr(),
658                            stride_dy.as_ptr(), stride_x.as_ptr(),
659                            stride_y.as_ptr(), stride_dx.as_ptr(),
660                            dy_ptr, x_ptr, y_ptr, dx_ptr,
661                            core::ptr::null_mut(), 0, stream_ptr,
662                        )
663                    },
664                    _ => return Err(Error::Unsupported(
665                        "baracuda-kernels::ReduceBackwardPlan::run: Norm2 BW reached an \
666                         unimplemented dtype — select() should have caught this",
667                    )),
668                }
669            }
670            (ReduceKind::LogSumExp, _) => {
671                // `dx[c] = dy[c_reduced] * exp(x[c] - y[c_reduced])`.
672                // Dual-save. `y = lse(x) ≥ max(x) ≥ x[c]`, so the exp
673                // arg is `≤ 0` and the result is bounded in `(0, 1]` —
674                // no overflow possible at any dtype.
675                let x = args.x.as_ref().expect("LogSumExp BW require saved x");
676                let y = args.y.as_ref().expect("LogSumExp BW require saved y");
677                let x_ptr = x.data.as_raw().0 as *const c_void;
678                let y_ptr = y.data.as_raw().0 as *const c_void;
679                let stride_x = x.stride;
680                let mut stride_y = y.stride;
681                stride_y[axis] = 0;
682                match T::KIND {
683                    ElementKind::F32 => unsafe {
684                        baracuda_kernels_sys::baracuda_kernels_reduce_logsumexp_backward_f32_run(
685                            numel, rank, shape.as_ptr(),
686                            stride_dy.as_ptr(), stride_x.as_ptr(),
687                            stride_y.as_ptr(), stride_dx.as_ptr(),
688                            dy_ptr, x_ptr, y_ptr, dx_ptr,
689                            core::ptr::null_mut(), 0, stream_ptr,
690                        )
691                    },
692                    ElementKind::F16 => unsafe {
693                        baracuda_kernels_sys::baracuda_kernels_reduce_logsumexp_backward_f16_run(
694                            numel, rank, shape.as_ptr(),
695                            stride_dy.as_ptr(), stride_x.as_ptr(),
696                            stride_y.as_ptr(), stride_dx.as_ptr(),
697                            dy_ptr, x_ptr, y_ptr, dx_ptr,
698                            core::ptr::null_mut(), 0, stream_ptr,
699                        )
700                    },
701                    ElementKind::Bf16 => unsafe {
702                        baracuda_kernels_sys::baracuda_kernels_reduce_logsumexp_backward_bf16_run(
703                            numel, rank, shape.as_ptr(),
704                            stride_dy.as_ptr(), stride_x.as_ptr(),
705                            stride_y.as_ptr(), stride_dx.as_ptr(),
706                            dy_ptr, x_ptr, y_ptr, dx_ptr,
707                            core::ptr::null_mut(), 0, stream_ptr,
708                        )
709                    },
710                    ElementKind::F64 => unsafe {
711                        baracuda_kernels_sys::baracuda_kernels_reduce_logsumexp_backward_f64_run(
712                            numel, rank, shape.as_ptr(),
713                            stride_dy.as_ptr(), stride_x.as_ptr(),
714                            stride_y.as_ptr(), stride_dx.as_ptr(),
715                            dy_ptr, x_ptr, y_ptr, dx_ptr,
716                            core::ptr::null_mut(), 0, stream_ptr,
717                        )
718                    },
719                    _ => return Err(Error::Unsupported(
720                        "baracuda-kernels::ReduceBackwardPlan::run: LogSumExp BW reached an \
721                         unimplemented dtype — select() should have caught this",
722                    )),
723                }
724            }
725            _ => {
726                return Err(Error::Unsupported(
727                    "baracuda-kernels::ReduceBackwardPlan::run reached an unimplemented \
728                     (kind, dtype) pair — select() should have caught this",
729                ));
730            }
731        };
732        map_status(status)
733    }
734}
735
736fn map_status(code: i32) -> Result<()> {
737    match code {
738        0 => Ok(()),
739        1 => Err(Error::MisalignedOperand),
740        2 => Err(Error::InvalidProblem(
741            "baracuda-kernels-sys reported invalid problem",
742        )),
743        3 => Err(Error::Unsupported(
744            "baracuda-kernels-sys reported unsupported configuration",
745        )),
746        4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
747        n => Err(Error::CutlassInternal(n)),
748    }
749}