Skip to main content

baracuda_kernels/scan/
axis.rs

1//! Single-axis scan forward plan.
2//!
3//! Length-preserving prefix scan along a single axis. Output shape ==
4//! input shape; the scan axis is *not* collapsed (unlike a reduction).
5//!
6//! **Formulas** (forward direction; reverse flips the index range):
7//! - `Cumsum[i]`        = `x[0] + ... + x[i]`
8//! - `Cumprod[i]`       = `x[0] · x[1] · ... · x[i]`
9//! - `Cummax[i]`        = `max(x[0], ..., x[i])`
10//! - `Cummin[i]`        = `min(x[0], ..., x[i])`
11//! - `LogCumsumExp[i]`  = `log(Σ_{j≤i} exp(x[j]))` (numerically stable
12//!   via per-cell running max).
13//!
14//! **When to use**: forward prefix scan. Pair with
15//! [`ScanBackwardPlan`](super::ScanBackwardPlan) for autograd; the BW
16//! pass needs saved `x` and/or `y` per the op (see that module's table).
17//!
18//! **Dtypes / shape**: `{Cumsum, Cumprod, Cummax, Cummin, LogCumsumExp}
19//! × {f32, f16, bf16, f64}`, tensor rank `1..=8`.
20//!
21//! **Workspace**: none.
22//!
23//! **Precision**: deterministic, bit-stable on the same hardware. The
24//! single-thread-per-row sequential accumulator has no warp-reduction
25//! ordering dependence. f16 / bf16 accumulate in f32 (FP detour); f64
26//! keeps everything in double.
27
28use core::ffi::c_void;
29use core::marker::PhantomData;
30
31use baracuda_cutlass::{Error, Result};
32use baracuda_driver::Stream;
33use baracuda_kernels_types::{
34    ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
35    PlanPreference, PrecisionGuarantee, ScanKind, TensorMut, TensorRef, Workspace,
36};
37
38/// Descriptor for a single-axis scan op.
39#[derive(Copy, Clone, Debug)]
40pub struct ScanDescriptor<const N: usize> {
41    /// Which scan kind to apply.
42    pub kind: ScanKind,
43    /// Tensor shape — input and output share it.
44    pub input_shape: [i32; N],
45    /// Axis along which the scan accumulates. Must be in `[0, N)`.
46    pub scan_axis: u8,
47    /// `true` → scan from the end of the axis toward the start; `false`
48    /// → standard forward scan (PyTorch default).
49    pub reverse: bool,
50    /// Element type.
51    pub element: ElementKind,
52}
53
54/// Args bundle for a scan launch.
55pub struct ScanArgs<'a, T: Element, const N: usize> {
56    /// Input tensor.
57    pub x: TensorRef<'a, T, N>,
58    /// Output tensor — same shape as input.
59    pub y: TensorMut<'a, T, N>,
60}
61
62/// Single-axis scan forward plan — see the module-level docs for
63/// formulas, dtypes, workspace, and precision guarantees.
64///
65/// `T: Element` is the element type (`f32` / `f64` / `f16` / `bf16`).
66/// `const N: usize` is the tensor rank (1..=8).
67pub struct ScanPlan<T: Element, const N: usize> {
68    desc: ScanDescriptor<N>,
69    sku: KernelSku,
70    _marker: PhantomData<T>,
71}
72
73impl<T: Element, const N: usize> ScanPlan<T, N> {
74    /// Pick a kernel for `desc`. Validates `scan_axis < N`, the dtype
75    /// is in the wired FP family, and tensor rank ≤ 8. Returns
76    /// [`Error::Unsupported`] for cells outside the wired matrix.
77    pub fn select(
78        _stream: &Stream,
79        desc: &ScanDescriptor<N>,
80        _pref: PlanPreference,
81    ) -> Result<Self> {
82        if desc.element != T::KIND {
83            return Err(Error::Unsupported(
84                "baracuda-kernels::ScanPlan: descriptor element != T",
85            ));
86        }
87        if (desc.scan_axis as usize) >= N {
88            return Err(Error::InvalidProblem(
89                "baracuda-kernels::ScanPlan: scan_axis out of range for rank N",
90            ));
91        }
92        for &d in desc.input_shape.iter() {
93            if d < 0 {
94                return Err(Error::InvalidProblem(
95                    "baracuda-kernels::ScanPlan: shape dims must be non-negative",
96                ));
97            }
98        }
99        if N > 8 {
100            return Err(Error::Unsupported(
101                "baracuda-kernels::ScanPlan: tensor rank > 8 not supported \
102                 (kernel param block fixes MAX_RANK = 8)",
103            ));
104        }
105
106        // Wired today: `{Cumsum, Cumprod, Cummax, Cummin, LogCumsumExp}
107        // × {f32, f16, bf16, f64}`.
108        let dtype_in_fp_family = matches!(
109            T::KIND,
110            ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
111        );
112        let kind_supported = matches!(
113            desc.kind,
114            ScanKind::Cumsum
115                | ScanKind::Cumprod
116                | ScanKind::Cummax
117                | ScanKind::Cummin
118                | ScanKind::LogCumsumExp
119        );
120        let supported = kind_supported && dtype_in_fp_family;
121        if !supported {
122            return Err(Error::Unsupported(
123                "baracuda-kernels::ScanPlan: wired today: \
124                 `{Cumsum, Cumprod, Cummax, Cummin, LogCumsumExp} × {f32, f16, bf16, f64}`",
125            ));
126        }
127
128        let precision_guarantee = PrecisionGuarantee {
129            math_precision: MathPrecision::F32,
130            accumulator: ElementKind::F32,
131            // Bit-stable across runs (deterministic single-thread-per-cell
132            // accumulator; same input → same output).
133            bit_stable_on_same_hardware: true,
134            deterministic: true,
135        };
136        let sku = KernelSku {
137            category: OpCategory::Scan,
138            op: desc.kind as u16,
139            element: T::KIND,
140            aux_element: None,
141            layout: None,
142            epilogue: None,
143            arch: ArchSku::Sm80,
144            backend: BackendKind::Bespoke,
145            precision_guarantee,
146        };
147        Ok(Self {
148            desc: *desc,
149            sku,
150            _marker: PhantomData,
151        })
152    }
153
154    /// Validate args.
155    pub fn can_implement(&self, args: &ScanArgs<'_, T, N>) -> Result<()> {
156        if args.x.shape != self.desc.input_shape {
157            return Err(Error::InvalidProblem(
158                "baracuda-kernels::ScanPlan: x shape mismatch",
159            ));
160        }
161        if args.y.shape != self.desc.input_shape {
162            return Err(Error::InvalidProblem(
163                "baracuda-kernels::ScanPlan: y shape mismatch (scans are \
164                 length-preserving — y.shape must equal x.shape)",
165            ));
166        }
167        let numel = args.x.numel();
168        let x_len = args.x.data.len() as i64;
169        let y_len = args.y.data.len() as i64;
170        if x_len < numel || y_len < numel {
171            return Err(Error::BufferTooSmall {
172                needed: numel as usize,
173                got: x_len.min(y_len) as usize,
174            });
175        }
176        Ok(())
177    }
178
179    /// Workspace size in bytes. Always zero.
180    #[inline]
181    pub fn workspace_size(&self) -> usize {
182        0
183    }
184    /// Identity of the kernel this plan picked.
185    #[inline]
186    pub fn sku(&self) -> KernelSku {
187        self.sku
188    }
189    /// Numerical guarantees for this plan's kernel — deterministic,
190    /// bit-stable on the same hardware, f32 accumulator for f16 / bf16
191    /// inputs (FP detour).
192    #[inline]
193    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
194        self.sku.precision_guarantee
195    }
196
197    /// Launch the kernel against `args`. Calls `can_implement` first;
198    /// returns `Ok(())` for empty tensors.
199    pub fn run(
200        &self,
201        stream: &Stream,
202        _workspace: Workspace<'_>,
203        args: ScanArgs<'_, T, N>,
204    ) -> Result<()> {
205        self.can_implement(&args)?;
206        let numel = args.x.numel();
207        if numel == 0 {
208            return Ok(());
209        }
210        let x_ptr = args.x.data.as_raw().0 as *const c_void;
211        let y_ptr = args.y.data.as_raw().0 as *mut c_void;
212        let stream_ptr = stream.as_raw() as *mut c_void;
213
214        let axis = self.desc.scan_axis as usize;
215        let shape = self.desc.input_shape;
216        let stride_x = args.x.stride;
217        let stride_y = args.y.stride;
218        let rank = N as i32;
219        let scan_extent = shape[axis];
220        let scan_stride_x = stride_x[axis];
221        let reverse = if self.desc.reverse { 1i32 } else { 0 };
222
223        macro_rules! dispatch {
224            ($sym:ident) => {
225                unsafe {
226                    baracuda_kernels_sys::$sym(
227                        numel,
228                        rank,
229                        shape.as_ptr(),
230                        stride_x.as_ptr(),
231                        stride_y.as_ptr(),
232                        axis as i32,
233                        scan_extent,
234                        scan_stride_x,
235                        reverse,
236                        x_ptr,
237                        y_ptr,
238                        core::ptr::null_mut(),
239                        0,
240                        stream_ptr,
241                    )
242                }
243            };
244        }
245
246        let status = match (self.desc.kind, T::KIND) {
247            (ScanKind::Cumsum, ElementKind::F32) => dispatch!(baracuda_kernels_scan_cumsum_f32_run),
248            (ScanKind::Cumsum, ElementKind::F16) => dispatch!(baracuda_kernels_scan_cumsum_f16_run),
249            (ScanKind::Cumsum, ElementKind::Bf16) => {
250                dispatch!(baracuda_kernels_scan_cumsum_bf16_run)
251            }
252            (ScanKind::Cumsum, ElementKind::F64) => dispatch!(baracuda_kernels_scan_cumsum_f64_run),
253            (ScanKind::Cumprod, ElementKind::F32) => {
254                dispatch!(baracuda_kernels_scan_cumprod_f32_run)
255            }
256            (ScanKind::Cumprod, ElementKind::F16) => {
257                dispatch!(baracuda_kernels_scan_cumprod_f16_run)
258            }
259            (ScanKind::Cumprod, ElementKind::Bf16) => {
260                dispatch!(baracuda_kernels_scan_cumprod_bf16_run)
261            }
262            (ScanKind::Cumprod, ElementKind::F64) => {
263                dispatch!(baracuda_kernels_scan_cumprod_f64_run)
264            }
265            (ScanKind::Cummax, ElementKind::F32) => dispatch!(baracuda_kernels_scan_cummax_f32_run),
266            (ScanKind::Cummax, ElementKind::F16) => dispatch!(baracuda_kernels_scan_cummax_f16_run),
267            (ScanKind::Cummax, ElementKind::Bf16) => {
268                dispatch!(baracuda_kernels_scan_cummax_bf16_run)
269            }
270            (ScanKind::Cummax, ElementKind::F64) => dispatch!(baracuda_kernels_scan_cummax_f64_run),
271            (ScanKind::Cummin, ElementKind::F32) => dispatch!(baracuda_kernels_scan_cummin_f32_run),
272            (ScanKind::Cummin, ElementKind::F16) => dispatch!(baracuda_kernels_scan_cummin_f16_run),
273            (ScanKind::Cummin, ElementKind::Bf16) => {
274                dispatch!(baracuda_kernels_scan_cummin_bf16_run)
275            }
276            (ScanKind::Cummin, ElementKind::F64) => dispatch!(baracuda_kernels_scan_cummin_f64_run),
277            (ScanKind::LogCumsumExp, ElementKind::F32) => {
278                dispatch!(baracuda_kernels_scan_log_cumsum_exp_f32_run)
279            }
280            (ScanKind::LogCumsumExp, ElementKind::F16) => {
281                dispatch!(baracuda_kernels_scan_log_cumsum_exp_f16_run)
282            }
283            (ScanKind::LogCumsumExp, ElementKind::Bf16) => {
284                dispatch!(baracuda_kernels_scan_log_cumsum_exp_bf16_run)
285            }
286            (ScanKind::LogCumsumExp, ElementKind::F64) => {
287                dispatch!(baracuda_kernels_scan_log_cumsum_exp_f64_run)
288            }
289            _ => {
290                return Err(Error::Unsupported(
291                    "baracuda-kernels::ScanPlan::run reached an unimplemented \
292                     (kind, dtype) pair — select() should have caught this",
293                ));
294            }
295        };
296        map_status(status)
297    }
298}
299
300fn map_status(code: i32) -> Result<()> {
301    match code {
302        0 => Ok(()),
303        1 => Err(Error::MisalignedOperand),
304        2 => Err(Error::InvalidProblem(
305            "baracuda-kernels-sys reported invalid problem",
306        )),
307        3 => Err(Error::Unsupported(
308            "baracuda-kernels-sys reported unsupported configuration",
309        )),
310        4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
311        n => Err(Error::CutlassInternal(n)),
312    }
313}