Skip to main content

baracuda_kernels/softmax/
axis.rs

1//! Softmax forward plan — single-axis numerically-stable softmax /
2//! log-softmax.
3//!
4//! **Formulas**:
5//! - `Softmax`:    `y[k] = exp(x[k] - max(x)) / Σ_j exp(x[j] - max(x))`
6//! - `LogSoftmax`: `y[k] = x[k] - logsumexp(x)`
7//!
8//! Numerically stable via max subtraction.
9//!
10//! **When to use**: forward pass of softmax / log-softmax along a single
11//! axis. Pair with [`SoftmaxBackwardPlan`](super::SoftmaxBackwardPlan)
12//! for autograd (the BW kernel needs the saved forward output `y`).
13//!
14//! **Dtypes / shape**: `{Softmax, LogSoftmax} × {f32, f16, bf16, f64}`,
15//! tensor rank `1..=8`. Half-precision (`f16` / `bf16`) reduces / exps in
16//! `f32` (FP detour) then casts back; `f64` keeps everything in double.
17//!
18//! **Workspace**: none.
19//!
20//! **Precision**: deterministic, bit-stable on the same hardware. The
21//! per-output-cell two-pass scan has no atomic-add / warp-reduction
22//! ordering dependence.
23
24use core::ffi::c_void;
25use core::marker::PhantomData;
26
27use baracuda_cutlass::{Error, Result};
28use baracuda_driver::Stream;
29use baracuda_kernels_types::{
30    ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
31    PlanPreference, PrecisionGuarantee, SoftmaxKind, TensorMut, TensorRef, Workspace,
32};
33
34/// Descriptor for a softmax-family op.
35#[derive(Copy, Clone, Debug)]
36pub struct SoftmaxDescriptor<const N: usize> {
37    /// Which softmax variant.
38    pub kind: SoftmaxKind,
39    /// Tensor shape — input and output share it.
40    pub input_shape: [i32; N],
41    /// Axis along which to compute softmax. Must be in `[0, N)`.
42    pub softmax_axis: u8,
43    /// Element type.
44    pub element: ElementKind,
45}
46
47/// Args bundle for a softmax launch.
48pub struct SoftmaxArgs<'a, T: Element, const N: usize> {
49    /// Input tensor.
50    pub x: TensorRef<'a, T, N>,
51    /// Output tensor — same shape as input.
52    pub y: TensorMut<'a, T, N>,
53}
54
55/// Softmax forward plan — see the module-level docs for formulas,
56/// dtypes, workspace, and precision guarantees.
57///
58/// `T: Element` is the element type (`f32` / `f64` / `f16` / `bf16`).
59/// `const N: usize` is the tensor rank (1..=8).
60pub struct SoftmaxPlan<T: Element, const N: usize> {
61    desc: SoftmaxDescriptor<N>,
62    sku: KernelSku,
63    _marker: PhantomData<T>,
64}
65
66impl<T: Element, const N: usize> SoftmaxPlan<T, N> {
67    /// Pick a kernel for `desc`. Validates `softmax_axis < N`, the dtype
68    /// is in the wired FP family, and tensor rank ≤ 8. Returns
69    /// [`Error::Unsupported`] for cells outside the matrix and
70    /// [`Error::InvalidProblem`] for malformed shapes / axes.
71    pub fn select(
72        _stream: &Stream,
73        desc: &SoftmaxDescriptor<N>,
74        _pref: PlanPreference,
75    ) -> Result<Self> {
76        if desc.element != T::KIND {
77            return Err(Error::Unsupported(
78                "baracuda-kernels::SoftmaxPlan: descriptor element != T",
79            ));
80        }
81        if (desc.softmax_axis as usize) >= N {
82            return Err(Error::InvalidProblem(
83                "baracuda-kernels::SoftmaxPlan: softmax_axis out of range for rank N",
84            ));
85        }
86        for &d in desc.input_shape.iter() {
87            if d < 0 {
88                return Err(Error::InvalidProblem(
89                    "baracuda-kernels::SoftmaxPlan: shape dims must be non-negative",
90                ));
91            }
92        }
93        if N > 8 {
94            return Err(Error::Unsupported(
95                "baracuda-kernels::SoftmaxPlan: tensor rank > 8 not supported",
96            ));
97        }
98        let dtype_in_fp_family = matches!(
99            T::KIND,
100            ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
101        );
102        let kind_supported = matches!(desc.kind, SoftmaxKind::Softmax | SoftmaxKind::LogSoftmax);
103        if !kind_supported || !dtype_in_fp_family {
104            return Err(Error::Unsupported(
105                "baracuda-kernels::SoftmaxPlan: wired today: \
106                 `{Softmax, LogSoftmax} × {f32, f16, bf16, f64}`",
107            ));
108        }
109
110        let precision_guarantee = PrecisionGuarantee {
111            math_precision: MathPrecision::F32,
112            accumulator: ElementKind::F32,
113            // Bit-stable across runs (deterministic per-cell two-pass scan).
114            bit_stable_on_same_hardware: true,
115            deterministic: true,
116        };
117        let sku = KernelSku {
118            category: OpCategory::Softmax,
119            op: desc.kind as u16,
120            element: T::KIND,
121            aux_element: None,
122            layout: None,
123            epilogue: None,
124            arch: ArchSku::Sm80,
125            backend: BackendKind::Bespoke,
126            precision_guarantee,
127        };
128        Ok(Self {
129            desc: *desc,
130            sku,
131            _marker: PhantomData,
132        })
133    }
134
135    /// Validate args.
136    pub fn can_implement(&self, args: &SoftmaxArgs<'_, T, N>) -> Result<()> {
137        if args.x.shape != self.desc.input_shape {
138            return Err(Error::InvalidProblem(
139                "baracuda-kernels::SoftmaxPlan: x shape mismatch",
140            ));
141        }
142        if args.y.shape != self.desc.input_shape {
143            return Err(Error::InvalidProblem(
144                "baracuda-kernels::SoftmaxPlan: y shape mismatch",
145            ));
146        }
147        let numel = args.x.numel();
148        let x_len = args.x.data.len() as i64;
149        let y_len = args.y.data.len() as i64;
150        if x_len < numel || y_len < numel {
151            return Err(Error::BufferTooSmall {
152                needed: numel as usize,
153                got: x_len.min(y_len) as usize,
154            });
155        }
156        Ok(())
157    }
158
159    /// Workspace size in bytes. Always zero — the kernel does its
160    /// two-pass scan in registers.
161    #[inline]
162    pub fn workspace_size(&self) -> usize {
163        0
164    }
165    /// Identity of the kernel this plan picked (for telemetry +
166    /// autotuner cache keying).
167    #[inline]
168    pub fn sku(&self) -> KernelSku {
169        self.sku
170    }
171    /// Numerical guarantees for this plan's kernel — deterministic,
172    /// bit-stable on the same hardware, f32 accumulator for the FP-detour
173    /// half / bf16 inputs and f32 / f64 native for those dtypes.
174    #[inline]
175    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
176        self.sku.precision_guarantee
177    }
178
179    /// Launch the kernel against `args`. Calls `can_implement` first;
180    /// returns `Ok(())` for empty tensors.
181    pub fn run(
182        &self,
183        stream: &Stream,
184        _workspace: Workspace<'_>,
185        args: SoftmaxArgs<'_, T, N>,
186    ) -> Result<()> {
187        self.can_implement(&args)?;
188        let numel = args.x.numel();
189        if numel == 0 {
190            return Ok(());
191        }
192        let x_ptr = args.x.data.as_raw().0 as *const c_void;
193        let y_ptr = args.y.data.as_raw().0 as *mut c_void;
194        let stream_ptr = stream.as_raw() as *mut c_void;
195
196        let axis = self.desc.softmax_axis as usize;
197        let shape = self.desc.input_shape;
198        let stride_x = args.x.stride;
199        let stride_y = args.y.stride;
200        let rank = N as i32;
201        let extent = shape[axis];
202        let stride_x_axis = stride_x[axis];
203        let stride_y_axis = stride_y[axis];
204
205        macro_rules! dispatch {
206            ($sym:ident) => {
207                unsafe {
208                    baracuda_kernels_sys::$sym(
209                        numel,
210                        rank,
211                        shape.as_ptr(),
212                        stride_x.as_ptr(),
213                        stride_y.as_ptr(),
214                        axis as i32,
215                        extent,
216                        stride_x_axis,
217                        stride_y_axis,
218                        x_ptr,
219                        y_ptr,
220                        core::ptr::null_mut(),
221                        0,
222                        stream_ptr,
223                    )
224                }
225            };
226        }
227
228        let status = match (self.desc.kind, T::KIND) {
229            (SoftmaxKind::Softmax, ElementKind::F32) => dispatch!(baracuda_kernels_softmax_f32_run),
230            (SoftmaxKind::Softmax, ElementKind::F16) => dispatch!(baracuda_kernels_softmax_f16_run),
231            (SoftmaxKind::Softmax, ElementKind::Bf16) => {
232                dispatch!(baracuda_kernels_softmax_bf16_run)
233            }
234            (SoftmaxKind::Softmax, ElementKind::F64) => dispatch!(baracuda_kernels_softmax_f64_run),
235            (SoftmaxKind::LogSoftmax, ElementKind::F32) => {
236                dispatch!(baracuda_kernels_log_softmax_f32_run)
237            }
238            (SoftmaxKind::LogSoftmax, ElementKind::F16) => {
239                dispatch!(baracuda_kernels_log_softmax_f16_run)
240            }
241            (SoftmaxKind::LogSoftmax, ElementKind::Bf16) => {
242                dispatch!(baracuda_kernels_log_softmax_bf16_run)
243            }
244            (SoftmaxKind::LogSoftmax, ElementKind::F64) => {
245                dispatch!(baracuda_kernels_log_softmax_f64_run)
246            }
247            _ => {
248                return Err(Error::Unsupported(
249                    "baracuda-kernels::SoftmaxPlan::run reached an unimplemented \
250                     (kind, dtype) pair — select() should have caught this",
251                ));
252            }
253        };
254        map_status(status)
255    }
256}
257
258fn map_status(code: i32) -> Result<()> {
259    match code {
260        0 => Ok(()),
261        1 => Err(Error::MisalignedOperand),
262        2 => Err(Error::InvalidProblem(
263            "baracuda-kernels-sys reported invalid problem",
264        )),
265        3 => Err(Error::Unsupported(
266            "baracuda-kernels-sys reported unsupported configuration",
267        )),
268        4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
269        n => Err(Error::CutlassInternal(n)),
270    }
271}