Skip to main content

baracuda_kernels/scan/
axis_backward.rs

1//! Single-axis scan backward plan.
2//!
3//! Backward pass for the prefix-scan family.
4//!
5//! - `Cumsum BW`: `dx = cumsum(dy, reverse = !fw.reverse)` — that is,
6//!   the FW kernel applied to `dy` with the scan direction flipped. No
7//!   new CUDA kernel required.
8//! - `Cumprod BW`: `dx[j] = Σ_{i in suffix} dy[i] * y[i] / x[j]`. Needs
9//!   the saved FW input `x` and the saved FW output `y`.
10//! - `Cummax / Cummin BW`: gradient flows to first-occurrence
11//!   argmax/argmin position. Needs the saved FW input `x` only —
12//!   running winner is recomputed by the kernel from `x`.
13//! - `LogCumsumExp BW`: `dx[k] = Σ_{i ∈ range(k)} dy[i] * exp(x[k] - y[i])`.
14//!   Needs both saved FW input `x` and saved FW output `y`.
15//!
16//! Today wired: `{Cumsum, Cumprod, Cummax, Cummin, LogCumsumExp} ×
17//! {f32, f16, bf16, f64}`.
18
19use core::ffi::c_void;
20use core::marker::PhantomData;
21
22use baracuda_cutlass::{Error, Result};
23use baracuda_driver::Stream;
24use baracuda_kernels_types::{
25    ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
26    PlanPreference, PrecisionGuarantee, ScanKind, TensorMut, TensorRef, Workspace,
27};
28
29/// Descriptor for a single-axis scan BW op.
30///
31/// Mirrors [`crate::ScanDescriptor`] — same `input_shape`, `scan_axis`,
32/// and `reverse` as the FW. The BW kernel handles direction internally.
33#[derive(Copy, Clone, Debug)]
34pub struct ScanBackwardDescriptor<const N: usize> {
35    /// Which forward scan kind this is the backward of.
36    pub kind: ScanKind,
37    /// Tensor shape (shared by dy / dx and the optional x / y saves).
38    pub input_shape: [i32; N],
39    /// Forward scan axis.
40    pub scan_axis: u8,
41    /// Forward direction flag.
42    pub reverse: bool,
43    /// Element type.
44    pub element: ElementKind,
45}
46
47/// Args bundle for a scan BW launch.
48///
49/// `x` and `y` are saved forward tensors; whether they must be supplied
50/// is op-dependent:
51///
52/// | op            | needs x | needs y |
53/// |---------------|---------|---------|
54/// | Cumsum        |    no   |    no   |
55/// | Cumprod       |   yes   |   yes   |
56/// | Cummax        |   yes   |    no   |
57/// | Cummin        |   yes   |    no   |
58/// | LogCumsumExp  |   yes   |   yes   |
59///
60/// Pass `None` for unused slots. The plan's `can_implement` validates
61/// the op-specific requirement.
62pub struct ScanBackwardArgs<'a, T: Element, const N: usize> {
63    /// Upstream gradient — same shape as the forward output.
64    pub dy: TensorRef<'a, T, N>,
65    /// Gradient w.r.t. the forward input — same shape.
66    pub dx: TensorMut<'a, T, N>,
67    /// Saved forward input. Required for Cumprod / Cummax / Cummin BW.
68    pub x: Option<TensorRef<'a, T, N>>,
69    /// Saved forward output. Required for Cumprod BW.
70    pub y: Option<TensorRef<'a, T, N>>,
71}
72
73/// True iff the scan kind's BW requires the saved forward input `x`.
74#[inline]
75fn op_needs_saved_x(kind: ScanKind) -> bool {
76    matches!(
77        kind,
78        ScanKind::Cumprod | ScanKind::Cummax | ScanKind::Cummin | ScanKind::LogCumsumExp
79    )
80}
81
82/// True iff the scan kind's BW requires the saved forward output `y`.
83#[inline]
84fn op_needs_saved_y(kind: ScanKind) -> bool {
85    matches!(kind, ScanKind::Cumprod | ScanKind::LogCumsumExp)
86}
87
88/// Single-axis scan backward plan.
89pub struct ScanBackwardPlan<T: Element, const N: usize> {
90    desc: ScanBackwardDescriptor<N>,
91    sku: KernelSku,
92    _marker: PhantomData<T>,
93}
94
95impl<T: Element, const N: usize> ScanBackwardPlan<T, N> {
96    /// Pick a kernel.
97    pub fn select(
98        _stream: &Stream,
99        desc: &ScanBackwardDescriptor<N>,
100        _pref: PlanPreference,
101    ) -> Result<Self> {
102        if desc.element != T::KIND {
103            return Err(Error::Unsupported(
104                "baracuda-kernels::ScanBackwardPlan: descriptor element != T",
105            ));
106        }
107        if (desc.scan_axis as usize) >= N {
108            return Err(Error::InvalidProblem(
109                "baracuda-kernels::ScanBackwardPlan: scan_axis out of range for rank N",
110            ));
111        }
112        for &d in desc.input_shape.iter() {
113            if d < 0 {
114                return Err(Error::InvalidProblem(
115                    "baracuda-kernels::ScanBackwardPlan: shape dims must be non-negative",
116                ));
117            }
118        }
119        if N > 8 {
120            return Err(Error::Unsupported(
121                "baracuda-kernels::ScanBackwardPlan: tensor rank > 8 not supported",
122            ));
123        }
124        let dtype_in_fp_family = matches!(
125            T::KIND,
126            ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
127        );
128        let kind_supported = matches!(
129            desc.kind,
130            ScanKind::Cumsum
131                | ScanKind::Cumprod
132                | ScanKind::Cummax
133                | ScanKind::Cummin
134                | ScanKind::LogCumsumExp
135        );
136        if !kind_supported || !dtype_in_fp_family {
137            return Err(Error::Unsupported(
138                "baracuda-kernels::ScanBackwardPlan: wired today: \
139                 `{Cumsum, Cumprod, Cummax, Cummin, LogCumsumExp} × {f32, f16, bf16, f64}`",
140            ));
141        }
142
143        let precision_guarantee = PrecisionGuarantee {
144            math_precision: MathPrecision::F32,
145            accumulator: ElementKind::F32,
146            bit_stable_on_same_hardware: true,
147            deterministic: true,
148        };
149        let sku = KernelSku {
150            category: OpCategory::Scan,
151            op: desc.kind as u16,
152            element: T::KIND,
153            aux_element: None,
154            layout: None,
155            epilogue: None,
156            arch: ArchSku::Sm80,
157            backend: BackendKind::Bespoke,
158            precision_guarantee,
159        };
160        Ok(Self {
161            desc: *desc,
162            sku,
163            _marker: PhantomData,
164        })
165    }
166
167    /// Validate args.
168    pub fn can_implement(&self, args: &ScanBackwardArgs<'_, T, N>) -> Result<()> {
169        if args.dy.shape != self.desc.input_shape {
170            return Err(Error::InvalidProblem(
171                "baracuda-kernels::ScanBackwardPlan: dy shape mismatch",
172            ));
173        }
174        if args.dx.shape != self.desc.input_shape {
175            return Err(Error::InvalidProblem(
176                "baracuda-kernels::ScanBackwardPlan: dx shape mismatch",
177            ));
178        }
179        let numel = args.dx.numel();
180        let dy_len = args.dy.data.len() as i64;
181        let dx_len = args.dx.data.len() as i64;
182        if dy_len < numel || dx_len < numel {
183            return Err(Error::BufferTooSmall {
184                needed: numel as usize,
185                got: dy_len.min(dx_len) as usize,
186            });
187        }
188        // Op-specific saved-tensor checks.
189        if op_needs_saved_x(self.desc.kind) {
190            let x = args.x.as_ref().ok_or(Error::InvalidProblem(
191                "baracuda-kernels::ScanBackwardPlan: Cumprod / Cummax / Cummin BW \
192                 require args.x (saved forward input)",
193            ))?;
194            if x.shape != self.desc.input_shape {
195                return Err(Error::InvalidProblem(
196                    "baracuda-kernels::ScanBackwardPlan: args.x shape mismatch",
197                ));
198            }
199            if (x.data.len() as i64) < numel {
200                return Err(Error::BufferTooSmall {
201                    needed: numel as usize,
202                    got: x.data.len(),
203                });
204            }
205        }
206        if op_needs_saved_y(self.desc.kind) {
207            let y = args.y.as_ref().ok_or(Error::InvalidProblem(
208                "baracuda-kernels::ScanBackwardPlan: Cumprod BW requires args.y \
209                 (saved forward output)",
210            ))?;
211            if y.shape != self.desc.input_shape {
212                return Err(Error::InvalidProblem(
213                    "baracuda-kernels::ScanBackwardPlan: args.y shape mismatch",
214                ));
215            }
216            if (y.data.len() as i64) < numel {
217                return Err(Error::BufferTooSmall {
218                    needed: numel as usize,
219                    got: y.data.len(),
220                });
221            }
222        }
223        Ok(())
224    }
225
226    /// Workspace size in bytes.
227    #[inline]
228    pub fn workspace_size(&self) -> usize {
229        0
230    }
231    /// Kernel SKU identity.
232    #[inline]
233    pub fn sku(&self) -> KernelSku {
234        self.sku
235    }
236    /// Numerical guarantees.
237    #[inline]
238    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
239        self.sku.precision_guarantee
240    }
241
242    /// Launch.
243    pub fn run(
244        &self,
245        stream: &Stream,
246        _workspace: Workspace<'_>,
247        args: ScanBackwardArgs<'_, T, N>,
248    ) -> Result<()> {
249        self.can_implement(&args)?;
250        let numel = args.dx.numel();
251        if numel == 0 {
252            return Ok(());
253        }
254        let dy_ptr = args.dy.data.as_raw().0 as *const c_void;
255        let dx_ptr = args.dx.data.as_raw().0 as *mut c_void;
256        let stream_ptr = stream.as_raw() as *mut c_void;
257
258        let axis = self.desc.scan_axis as usize;
259        let shape = self.desc.input_shape;
260        let stride_dy = args.dy.stride;
261        let stride_dx = args.dx.stride;
262        let rank = N as i32;
263        let scan_extent = shape[axis];
264        let reverse_flag = if self.desc.reverse { 1i32 } else { 0 };
265
266        match self.desc.kind {
267            ScanKind::Cumsum => {
268                // BW = FW with reverse flipped. Dispatch to the FW
269                // kernel; dy → x, dx → y.
270                let scan_stride_dy = stride_dy[axis];
271                let cumsum_reverse = if self.desc.reverse { 0i32 } else { 1 };
272                macro_rules! dispatch_cumsum {
273                    ($sym:ident) => {
274                        unsafe {
275                            baracuda_kernels_sys::$sym(
276                                numel,
277                                rank,
278                                shape.as_ptr(),
279                                stride_dy.as_ptr(),
280                                stride_dx.as_ptr(),
281                                axis as i32,
282                                scan_extent,
283                                scan_stride_dy,
284                                cumsum_reverse,
285                                dy_ptr,
286                                dx_ptr,
287                                core::ptr::null_mut(),
288                                0,
289                                stream_ptr,
290                            )
291                        }
292                    };
293                }
294                let status = match T::KIND {
295                    ElementKind::F32 => dispatch_cumsum!(baracuda_kernels_scan_cumsum_f32_run),
296                    ElementKind::F16 => dispatch_cumsum!(baracuda_kernels_scan_cumsum_f16_run),
297                    ElementKind::Bf16 => dispatch_cumsum!(baracuda_kernels_scan_cumsum_bf16_run),
298                    ElementKind::F64 => dispatch_cumsum!(baracuda_kernels_scan_cumsum_f64_run),
299                    _ => {
300                        return Err(Error::Unsupported(
301                            "baracuda-kernels::ScanBackwardPlan::run unsupported dtype for Cumsum",
302                        ));
303                    }
304                };
305                map_status(status)
306            }
307            ScanKind::Cumprod => {
308                let x_ref = args.x.expect("Cumprod BW requires saved x — validated above");
309                let y_ref = args.y.expect("Cumprod BW requires saved y — validated above");
310                let stride_x = x_ref.stride;
311                let stride_y = y_ref.stride;
312                let x_ptr = x_ref.data.as_raw().0 as *const c_void;
313                let y_ptr = y_ref.data.as_raw().0 as *const c_void;
314                macro_rules! dispatch_cumprod_bw {
315                    ($sym:ident) => {
316                        unsafe {
317                            baracuda_kernels_sys::$sym(
318                                numel,
319                                rank,
320                                shape.as_ptr(),
321                                stride_dy.as_ptr(),
322                                stride_x.as_ptr(),
323                                stride_y.as_ptr(),
324                                stride_dx.as_ptr(),
325                                axis as i32,
326                                scan_extent,
327                                reverse_flag,
328                                dy_ptr,
329                                x_ptr,
330                                y_ptr,
331                                dx_ptr,
332                                core::ptr::null_mut(),
333                                0,
334                                stream_ptr,
335                            )
336                        }
337                    };
338                }
339                let status = match T::KIND {
340                    ElementKind::F32 => {
341                        dispatch_cumprod_bw!(baracuda_kernels_scan_cumprod_backward_f32_run)
342                    }
343                    ElementKind::F16 => {
344                        dispatch_cumprod_bw!(baracuda_kernels_scan_cumprod_backward_f16_run)
345                    }
346                    ElementKind::Bf16 => {
347                        dispatch_cumprod_bw!(baracuda_kernels_scan_cumprod_backward_bf16_run)
348                    }
349                    ElementKind::F64 => {
350                        dispatch_cumprod_bw!(baracuda_kernels_scan_cumprod_backward_f64_run)
351                    }
352                    _ => {
353                        return Err(Error::Unsupported(
354                            "baracuda-kernels::ScanBackwardPlan::run unsupported dtype for Cumprod",
355                        ));
356                    }
357                };
358                map_status(status)
359            }
360            ScanKind::LogCumsumExp => {
361                let x_ref = args
362                    .x
363                    .expect("LogCumsumExp BW requires saved x — validated above");
364                let y_ref = args
365                    .y
366                    .expect("LogCumsumExp BW requires saved y — validated above");
367                let stride_x = x_ref.stride;
368                let stride_y = y_ref.stride;
369                let x_ptr = x_ref.data.as_raw().0 as *const c_void;
370                let y_ptr = y_ref.data.as_raw().0 as *const c_void;
371                macro_rules! dispatch_lcse_bw {
372                    ($sym:ident) => {
373                        unsafe {
374                            baracuda_kernels_sys::$sym(
375                                numel,
376                                rank,
377                                shape.as_ptr(),
378                                stride_dy.as_ptr(),
379                                stride_x.as_ptr(),
380                                stride_y.as_ptr(),
381                                stride_dx.as_ptr(),
382                                axis as i32,
383                                scan_extent,
384                                reverse_flag,
385                                dy_ptr,
386                                x_ptr,
387                                y_ptr,
388                                dx_ptr,
389                                core::ptr::null_mut(),
390                                0,
391                                stream_ptr,
392                            )
393                        }
394                    };
395                }
396                let status = match T::KIND {
397                    ElementKind::F32 => {
398                        dispatch_lcse_bw!(baracuda_kernels_scan_log_cumsum_exp_backward_f32_run)
399                    }
400                    ElementKind::F16 => {
401                        dispatch_lcse_bw!(baracuda_kernels_scan_log_cumsum_exp_backward_f16_run)
402                    }
403                    ElementKind::Bf16 => {
404                        dispatch_lcse_bw!(baracuda_kernels_scan_log_cumsum_exp_backward_bf16_run)
405                    }
406                    ElementKind::F64 => {
407                        dispatch_lcse_bw!(baracuda_kernels_scan_log_cumsum_exp_backward_f64_run)
408                    }
409                    _ => {
410                        return Err(Error::Unsupported(
411                            "baracuda-kernels::ScanBackwardPlan::run unsupported dtype for LogCumsumExp",
412                        ));
413                    }
414                };
415                map_status(status)
416            }
417            ScanKind::Cummax | ScanKind::Cummin => {
418                let x_ref = args
419                    .x
420                    .expect("Cummax/Cummin BW requires saved x — validated above");
421                let stride_x = x_ref.stride;
422                let x_ptr = x_ref.data.as_raw().0 as *const c_void;
423                macro_rules! dispatch_extrema_bw {
424                    ($sym:ident) => {
425                        unsafe {
426                            baracuda_kernels_sys::$sym(
427                                numel,
428                                rank,
429                                shape.as_ptr(),
430                                stride_dy.as_ptr(),
431                                stride_x.as_ptr(),
432                                stride_dx.as_ptr(),
433                                axis as i32,
434                                scan_extent,
435                                reverse_flag,
436                                dy_ptr,
437                                x_ptr,
438                                dx_ptr,
439                                core::ptr::null_mut(),
440                                0,
441                                stream_ptr,
442                            )
443                        }
444                    };
445                }
446                let status = match (self.desc.kind, T::KIND) {
447                    (ScanKind::Cummax, ElementKind::F32) => {
448                        dispatch_extrema_bw!(baracuda_kernels_scan_cummax_backward_f32_run)
449                    }
450                    (ScanKind::Cummax, ElementKind::F16) => {
451                        dispatch_extrema_bw!(baracuda_kernels_scan_cummax_backward_f16_run)
452                    }
453                    (ScanKind::Cummax, ElementKind::Bf16) => {
454                        dispatch_extrema_bw!(baracuda_kernels_scan_cummax_backward_bf16_run)
455                    }
456                    (ScanKind::Cummax, ElementKind::F64) => {
457                        dispatch_extrema_bw!(baracuda_kernels_scan_cummax_backward_f64_run)
458                    }
459                    (ScanKind::Cummin, ElementKind::F32) => {
460                        dispatch_extrema_bw!(baracuda_kernels_scan_cummin_backward_f32_run)
461                    }
462                    (ScanKind::Cummin, ElementKind::F16) => {
463                        dispatch_extrema_bw!(baracuda_kernels_scan_cummin_backward_f16_run)
464                    }
465                    (ScanKind::Cummin, ElementKind::Bf16) => {
466                        dispatch_extrema_bw!(baracuda_kernels_scan_cummin_backward_bf16_run)
467                    }
468                    (ScanKind::Cummin, ElementKind::F64) => {
469                        dispatch_extrema_bw!(baracuda_kernels_scan_cummin_backward_f64_run)
470                    }
471                    _ => {
472                        return Err(Error::Unsupported(
473                            "baracuda-kernels::ScanBackwardPlan::run reached an unimplemented \
474                             (kind, dtype) pair for Cummax/Cummin",
475                        ));
476                    }
477                };
478                map_status(status)
479            }
480            // Defensive arm — `ScanKind` is `#[non_exhaustive]`, so a
481            // newly-added variant surfaces here as an explicit
482            // `Unsupported` until the kernel dispatch is wired.
483            _ => Err(Error::Unsupported(
484                "baracuda-kernels::ScanBackwardPlan::run reached an unimplemented ScanKind variant",
485            )),
486        }
487    }
488}
489
490fn map_status(code: i32) -> Result<()> {
491    match code {
492        0 => Ok(()),
493        1 => Err(Error::MisalignedOperand),
494        2 => Err(Error::InvalidProblem(
495            "baracuda-kernels-sys reported invalid problem",
496        )),
497        3 => Err(Error::Unsupported(
498            "baracuda-kernels-sys reported unsupported configuration",
499        )),
500        4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
501        n => Err(Error::CutlassInternal(n)),
502    }
503}