baracuda_kernels/softmax/
sparsemax_backward.rs1use core::ffi::c_void;
13use core::marker::PhantomData;
14
15use baracuda_cutlass::{Error, Result};
16use baracuda_driver::Stream;
17use baracuda_kernels_types::{
18 ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
19 PlanPreference, PrecisionGuarantee, SoftmaxKind, TensorMut, TensorRef, Workspace,
20};
21
22#[derive(Copy, Clone, Debug)]
24pub struct SparsemaxBackwardDescriptor<const N: usize> {
25 pub input_shape: [i32; N],
27 pub softmax_axis: u8,
29 pub element: ElementKind,
31}
32
33pub struct SparsemaxBackwardArgs<'a, T: Element, const N: usize> {
37 pub dy: TensorRef<'a, T, N>,
39 pub y: TensorRef<'a, T, N>,
41 pub dx: TensorMut<'a, T, N>,
43}
44
45pub struct SparsemaxBackwardPlan<T: Element, const N: usize> {
47 desc: SparsemaxBackwardDescriptor<N>,
48 sku: KernelSku,
49 _marker: PhantomData<T>,
50}
51
52impl<T: Element, const N: usize> SparsemaxBackwardPlan<T, N> {
53 pub fn select(
55 _stream: &Stream,
56 desc: &SparsemaxBackwardDescriptor<N>,
57 _pref: PlanPreference,
58 ) -> Result<Self> {
59 if desc.element != T::KIND {
60 return Err(Error::Unsupported(
61 "baracuda-kernels::SparsemaxBackwardPlan: descriptor element != T",
62 ));
63 }
64 if (desc.softmax_axis as usize) >= N {
65 return Err(Error::InvalidProblem(
66 "baracuda-kernels::SparsemaxBackwardPlan: softmax_axis out of range",
67 ));
68 }
69 for &d in desc.input_shape.iter() {
70 if d < 0 {
71 return Err(Error::InvalidProblem(
72 "baracuda-kernels::SparsemaxBackwardPlan: shape dims must be non-negative",
73 ));
74 }
75 }
76 if N > 8 {
77 return Err(Error::Unsupported(
78 "baracuda-kernels::SparsemaxBackwardPlan: tensor rank > 8 not supported",
79 ));
80 }
81 let dtype_in_fp_family = matches!(
82 T::KIND,
83 ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
84 );
85 if !dtype_in_fp_family {
86 return Err(Error::Unsupported(
87 "baracuda-kernels::SparsemaxBackwardPlan: wired today: {f32, f16, bf16, f64}",
88 ));
89 }
90
91 let math_precision = match T::KIND {
92 ElementKind::F64 => MathPrecision::F64,
93 _ => MathPrecision::F32,
94 };
95 let precision_guarantee = PrecisionGuarantee {
96 math_precision,
97 accumulator: match T::KIND {
98 ElementKind::F64 => ElementKind::F64,
99 _ => ElementKind::F32,
100 },
101 bit_stable_on_same_hardware: true,
102 deterministic: true,
103 };
104 let sku = KernelSku {
105 category: OpCategory::Softmax,
106 op: SoftmaxKind::Sparsemax as u16,
107 element: T::KIND,
108 aux_element: None,
109 layout: None,
110 epilogue: None,
111 arch: ArchSku::Sm80,
112 backend: BackendKind::Bespoke,
113 precision_guarantee,
114 };
115 Ok(Self {
116 desc: *desc,
117 sku,
118 _marker: PhantomData,
119 })
120 }
121
122 pub fn can_implement(&self, args: &SparsemaxBackwardArgs<'_, T, N>) -> Result<()> {
124 if args.dy.shape != self.desc.input_shape {
125 return Err(Error::InvalidProblem(
126 "baracuda-kernels::SparsemaxBackwardPlan: dy shape mismatch",
127 ));
128 }
129 if args.y.shape != self.desc.input_shape {
130 return Err(Error::InvalidProblem(
131 "baracuda-kernels::SparsemaxBackwardPlan: y shape mismatch",
132 ));
133 }
134 if args.dx.shape != self.desc.input_shape {
135 return Err(Error::InvalidProblem(
136 "baracuda-kernels::SparsemaxBackwardPlan: dx shape mismatch",
137 ));
138 }
139 let numel = args.dx.numel();
140 let dy_len = args.dy.data.len() as i64;
141 let y_len = args.y.data.len() as i64;
142 let dx_len = args.dx.data.len() as i64;
143 if dy_len < numel || y_len < numel || dx_len < numel {
144 return Err(Error::BufferTooSmall {
145 needed: numel as usize,
146 got: dy_len.min(y_len).min(dx_len) as usize,
147 });
148 }
149 Ok(())
150 }
151
152 #[inline]
154 pub fn workspace_size(&self) -> usize {
155 0
156 }
157
158 #[inline]
160 pub fn sku(&self) -> KernelSku {
161 self.sku
162 }
163
164 #[inline]
166 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
167 self.sku.precision_guarantee
168 }
169
170 pub fn run(
172 &self,
173 stream: &Stream,
174 _workspace: Workspace<'_>,
175 args: SparsemaxBackwardArgs<'_, T, N>,
176 ) -> Result<()> {
177 self.can_implement(&args)?;
178 let numel = args.dx.numel();
179 if numel == 0 {
180 return Ok(());
181 }
182 let dy_ptr = args.dy.data.as_raw().0 as *const c_void;
183 let y_ptr = args.y.data.as_raw().0 as *const c_void;
184 let dx_ptr = args.dx.data.as_raw().0 as *mut c_void;
185 let stream_ptr = stream.as_raw() as *mut c_void;
186
187 let axis = self.desc.softmax_axis as usize;
188 let shape = self.desc.input_shape;
189 let stride_dy = args.dy.stride;
190 let stride_y = args.y.stride;
191 let stride_dx = args.dx.stride;
192 let rank = N as i32;
193 let extent = shape[axis];
194 let stride_dy_axis = stride_dy[axis];
195 let stride_y_axis = stride_y[axis];
196
197 macro_rules! dispatch {
198 ($sym:ident) => {
199 unsafe {
200 baracuda_kernels_sys::$sym(
201 numel,
202 rank,
203 shape.as_ptr(),
204 stride_dy.as_ptr(),
205 stride_y.as_ptr(),
206 stride_dx.as_ptr(),
207 axis as i32,
208 extent,
209 stride_dy_axis,
210 stride_y_axis,
211 dy_ptr,
212 y_ptr,
213 dx_ptr,
214 core::ptr::null_mut(),
215 0,
216 stream_ptr,
217 )
218 }
219 };
220 }
221 let status = match T::KIND {
222 ElementKind::F32 => dispatch!(baracuda_kernels_sparsemax_backward_f32_run),
223 ElementKind::F16 => dispatch!(baracuda_kernels_sparsemax_backward_f16_run),
224 ElementKind::Bf16 => dispatch!(baracuda_kernels_sparsemax_backward_bf16_run),
225 ElementKind::F64 => dispatch!(baracuda_kernels_sparsemax_backward_f64_run),
226 _ => {
227 return Err(Error::Unsupported(
228 "baracuda-kernels::SparsemaxBackwardPlan::run unimplemented dtype",
229 ));
230 }
231 };
232 map_status(status)
233 }
234}
235
236fn map_status(code: i32) -> Result<()> {
237 match code {
238 0 => Ok(()),
239 1 => Err(Error::MisalignedOperand),
240 2 => Err(Error::InvalidProblem(
241 "baracuda-kernels-sys reported invalid problem",
242 )),
243 3 => Err(Error::Unsupported(
244 "baracuda-kernels-sys reported unsupported configuration",
245 )),
246 4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
247 n => Err(Error::CutlassInternal(n)),
248 }
249}