Skip to main content

baracuda_kernels/softmax/
sparsemax_backward.rs

1//! Sparsemax backward plan — Jacobian-vector product.
2//!
3//! For active positions (`y > 0`):
4//!   `dx[i] = dy[i] - sum_dy_active / n_active`
5//! where `sum_dy_active = Σ_{j: y[j] > 0} dy[j]` and `n_active` counts
6//! the actives in the row. Inactive positions get `dx[i] = 0`.
7//!
8//! Needs saved forward output `y` (drives the active mask).
9//!
10//! Wired today: `T ∈ {f32, f16, bf16, f64}`.
11
12use core::ffi::c_void;
13use core::marker::PhantomData;
14
15use baracuda_cutlass::{Error, Result};
16use baracuda_driver::Stream;
17use baracuda_kernels_types::{
18    ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
19    PlanPreference, PrecisionGuarantee, SoftmaxKind, TensorMut, TensorRef, Workspace,
20};
21
22/// Descriptor for a Sparsemax backward op.
23#[derive(Copy, Clone, Debug)]
24pub struct SparsemaxBackwardDescriptor<const N: usize> {
25    /// Tensor shape (dy / y / dx share it).
26    pub input_shape: [i32; N],
27    /// Forward sparsemax axis.
28    pub softmax_axis: u8,
29    /// Element type.
30    pub element: ElementKind,
31}
32
33/// Args bundle for a Sparsemax backward launch.
34///
35/// `y` is the SAVED forward output (used to derive the active mask).
36pub struct SparsemaxBackwardArgs<'a, T: Element, const N: usize> {
37    /// Upstream gradient.
38    pub dy: TensorRef<'a, T, N>,
39    /// Saved forward output.
40    pub y: TensorRef<'a, T, N>,
41    /// Gradient w.r.t. the forward input.
42    pub dx: TensorMut<'a, T, N>,
43}
44
45/// Sparsemax backward plan.
46pub struct SparsemaxBackwardPlan<T: Element, const N: usize> {
47    desc: SparsemaxBackwardDescriptor<N>,
48    sku: KernelSku,
49    _marker: PhantomData<T>,
50}
51
52impl<T: Element, const N: usize> SparsemaxBackwardPlan<T, N> {
53    /// Pick a kernel.
54    pub fn select(
55        _stream: &Stream,
56        desc: &SparsemaxBackwardDescriptor<N>,
57        _pref: PlanPreference,
58    ) -> Result<Self> {
59        if desc.element != T::KIND {
60            return Err(Error::Unsupported(
61                "baracuda-kernels::SparsemaxBackwardPlan: descriptor element != T",
62            ));
63        }
64        if (desc.softmax_axis as usize) >= N {
65            return Err(Error::InvalidProblem(
66                "baracuda-kernels::SparsemaxBackwardPlan: softmax_axis out of range",
67            ));
68        }
69        for &d in desc.input_shape.iter() {
70            if d < 0 {
71                return Err(Error::InvalidProblem(
72                    "baracuda-kernels::SparsemaxBackwardPlan: shape dims must be non-negative",
73                ));
74            }
75        }
76        if N > 8 {
77            return Err(Error::Unsupported(
78                "baracuda-kernels::SparsemaxBackwardPlan: tensor rank > 8 not supported",
79            ));
80        }
81        let dtype_in_fp_family = matches!(
82            T::KIND,
83            ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
84        );
85        if !dtype_in_fp_family {
86            return Err(Error::Unsupported(
87                "baracuda-kernels::SparsemaxBackwardPlan: wired today: {f32, f16, bf16, f64}",
88            ));
89        }
90
91        let math_precision = match T::KIND {
92            ElementKind::F64 => MathPrecision::F64,
93            _ => MathPrecision::F32,
94        };
95        let precision_guarantee = PrecisionGuarantee {
96            math_precision,
97            accumulator: match T::KIND {
98                ElementKind::F64 => ElementKind::F64,
99                _ => ElementKind::F32,
100            },
101            bit_stable_on_same_hardware: true,
102            deterministic: true,
103        };
104        let sku = KernelSku {
105            category: OpCategory::Softmax,
106            op: SoftmaxKind::Sparsemax as u16,
107            element: T::KIND,
108            aux_element: None,
109            layout: None,
110            epilogue: None,
111            arch: ArchSku::Sm80,
112            backend: BackendKind::Bespoke,
113            precision_guarantee,
114        };
115        Ok(Self {
116            desc: *desc,
117            sku,
118            _marker: PhantomData,
119        })
120    }
121
122    /// Validate args.
123    pub fn can_implement(&self, args: &SparsemaxBackwardArgs<'_, T, N>) -> Result<()> {
124        if args.dy.shape != self.desc.input_shape {
125            return Err(Error::InvalidProblem(
126                "baracuda-kernels::SparsemaxBackwardPlan: dy shape mismatch",
127            ));
128        }
129        if args.y.shape != self.desc.input_shape {
130            return Err(Error::InvalidProblem(
131                "baracuda-kernels::SparsemaxBackwardPlan: y shape mismatch",
132            ));
133        }
134        if args.dx.shape != self.desc.input_shape {
135            return Err(Error::InvalidProblem(
136                "baracuda-kernels::SparsemaxBackwardPlan: dx shape mismatch",
137            ));
138        }
139        let numel = args.dx.numel();
140        let dy_len = args.dy.data.len() as i64;
141        let y_len = args.y.data.len() as i64;
142        let dx_len = args.dx.data.len() as i64;
143        if dy_len < numel || y_len < numel || dx_len < numel {
144            return Err(Error::BufferTooSmall {
145                needed: numel as usize,
146                got: dy_len.min(y_len).min(dx_len) as usize,
147            });
148        }
149        Ok(())
150    }
151
152    /// Workspace size in bytes.
153    #[inline]
154    pub fn workspace_size(&self) -> usize {
155        0
156    }
157
158    /// Kernel SKU identity.
159    #[inline]
160    pub fn sku(&self) -> KernelSku {
161        self.sku
162    }
163
164    /// Numerical guarantees.
165    #[inline]
166    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
167        self.sku.precision_guarantee
168    }
169
170    /// Launch.
171    pub fn run(
172        &self,
173        stream: &Stream,
174        _workspace: Workspace<'_>,
175        args: SparsemaxBackwardArgs<'_, T, N>,
176    ) -> Result<()> {
177        self.can_implement(&args)?;
178        let numel = args.dx.numel();
179        if numel == 0 {
180            return Ok(());
181        }
182        let dy_ptr = args.dy.data.as_raw().0 as *const c_void;
183        let y_ptr = args.y.data.as_raw().0 as *const c_void;
184        let dx_ptr = args.dx.data.as_raw().0 as *mut c_void;
185        let stream_ptr = stream.as_raw() as *mut c_void;
186
187        let axis = self.desc.softmax_axis as usize;
188        let shape = self.desc.input_shape;
189        let stride_dy = args.dy.stride;
190        let stride_y = args.y.stride;
191        let stride_dx = args.dx.stride;
192        let rank = N as i32;
193        let extent = shape[axis];
194        let stride_dy_axis = stride_dy[axis];
195        let stride_y_axis = stride_y[axis];
196
197        macro_rules! dispatch {
198            ($sym:ident) => {
199                unsafe {
200                    baracuda_kernels_sys::$sym(
201                        numel,
202                        rank,
203                        shape.as_ptr(),
204                        stride_dy.as_ptr(),
205                        stride_y.as_ptr(),
206                        stride_dx.as_ptr(),
207                        axis as i32,
208                        extent,
209                        stride_dy_axis,
210                        stride_y_axis,
211                        dy_ptr,
212                        y_ptr,
213                        dx_ptr,
214                        core::ptr::null_mut(),
215                        0,
216                        stream_ptr,
217                    )
218                }
219            };
220        }
221        let status = match T::KIND {
222            ElementKind::F32 => dispatch!(baracuda_kernels_sparsemax_backward_f32_run),
223            ElementKind::F16 => dispatch!(baracuda_kernels_sparsemax_backward_f16_run),
224            ElementKind::Bf16 => dispatch!(baracuda_kernels_sparsemax_backward_bf16_run),
225            ElementKind::F64 => dispatch!(baracuda_kernels_sparsemax_backward_f64_run),
226            _ => {
227                return Err(Error::Unsupported(
228                    "baracuda-kernels::SparsemaxBackwardPlan::run unimplemented dtype",
229                ));
230            }
231        };
232        map_status(status)
233    }
234}
235
236fn map_status(code: i32) -> Result<()> {
237    match code {
238        0 => Ok(()),
239        1 => Err(Error::MisalignedOperand),
240        2 => Err(Error::InvalidProblem(
241            "baracuda-kernels-sys reported invalid problem",
242        )),
243        3 => Err(Error::Unsupported(
244            "baracuda-kernels-sys reported unsupported configuration",
245        )),
246        4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
247        n => Err(Error::CutlassInternal(n)),
248    }
249}