baracuda_kernels/softmax/
axis_backward.rs1use core::ffi::c_void;
23use core::marker::PhantomData;
24
25use baracuda_cutlass::{Error, Result};
26use baracuda_driver::Stream;
27use baracuda_kernels_types::{
28 ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
29 PlanPreference, PrecisionGuarantee, SoftmaxKind, TensorMut, TensorRef, Workspace,
30};
31
32#[derive(Copy, Clone, Debug)]
34pub struct SoftmaxBackwardDescriptor<const N: usize> {
35 pub kind: SoftmaxKind,
37 pub input_shape: [i32; N],
39 pub softmax_axis: u8,
41 pub element: ElementKind,
43}
44
45pub struct SoftmaxBackwardArgs<'a, T: Element, const N: usize> {
50 pub dy: TensorRef<'a, T, N>,
52 pub y: TensorRef<'a, T, N>,
54 pub dx: TensorMut<'a, T, N>,
56}
57
58pub struct SoftmaxBackwardPlan<T: Element, const N: usize> {
64 desc: SoftmaxBackwardDescriptor<N>,
65 sku: KernelSku,
66 _marker: PhantomData<T>,
67}
68
69impl<T: Element, const N: usize> SoftmaxBackwardPlan<T, N> {
70 pub fn select(
75 _stream: &Stream,
76 desc: &SoftmaxBackwardDescriptor<N>,
77 _pref: PlanPreference,
78 ) -> Result<Self> {
79 if desc.element != T::KIND {
80 return Err(Error::Unsupported(
81 "baracuda-kernels::SoftmaxBackwardPlan: descriptor element != T",
82 ));
83 }
84 if (desc.softmax_axis as usize) >= N {
85 return Err(Error::InvalidProblem(
86 "baracuda-kernels::SoftmaxBackwardPlan: softmax_axis out of range for rank N",
87 ));
88 }
89 for &d in desc.input_shape.iter() {
90 if d < 0 {
91 return Err(Error::InvalidProblem(
92 "baracuda-kernels::SoftmaxBackwardPlan: shape dims must be non-negative",
93 ));
94 }
95 }
96 if N > 8 {
97 return Err(Error::Unsupported(
98 "baracuda-kernels::SoftmaxBackwardPlan: tensor rank > 8 not supported",
99 ));
100 }
101 let dtype_in_fp_family = matches!(
102 T::KIND,
103 ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
104 );
105 let kind_supported = matches!(desc.kind, SoftmaxKind::Softmax | SoftmaxKind::LogSoftmax);
106 if !kind_supported || !dtype_in_fp_family {
107 return Err(Error::Unsupported(
108 "baracuda-kernels::SoftmaxBackwardPlan: wired today: \
109 `{Softmax, LogSoftmax} × {f32, f16, bf16, f64}`",
110 ));
111 }
112
113 let precision_guarantee = PrecisionGuarantee {
114 math_precision: MathPrecision::F32,
115 accumulator: ElementKind::F32,
116 bit_stable_on_same_hardware: true,
117 deterministic: true,
118 };
119 let sku = KernelSku {
120 category: OpCategory::Softmax,
121 op: desc.kind as u16,
122 element: T::KIND,
123 aux_element: None,
124 layout: None,
125 epilogue: None,
126 arch: ArchSku::Sm80,
127 backend: BackendKind::Bespoke,
128 precision_guarantee,
129 };
130 Ok(Self {
131 desc: *desc,
132 sku,
133 _marker: PhantomData,
134 })
135 }
136
137 pub fn can_implement(&self, args: &SoftmaxBackwardArgs<'_, T, N>) -> Result<()> {
139 if args.dy.shape != self.desc.input_shape {
140 return Err(Error::InvalidProblem(
141 "baracuda-kernels::SoftmaxBackwardPlan: dy shape mismatch",
142 ));
143 }
144 if args.y.shape != self.desc.input_shape {
145 return Err(Error::InvalidProblem(
146 "baracuda-kernels::SoftmaxBackwardPlan: y shape mismatch",
147 ));
148 }
149 if args.dx.shape != self.desc.input_shape {
150 return Err(Error::InvalidProblem(
151 "baracuda-kernels::SoftmaxBackwardPlan: dx shape mismatch",
152 ));
153 }
154 let numel = args.dx.numel();
155 let dy_len = args.dy.data.len() as i64;
156 let y_len = args.y.data.len() as i64;
157 let dx_len = args.dx.data.len() as i64;
158 if dy_len < numel || y_len < numel || dx_len < numel {
159 return Err(Error::BufferTooSmall {
160 needed: numel as usize,
161 got: dy_len.min(y_len).min(dx_len) as usize,
162 });
163 }
164 Ok(())
165 }
166
167 #[inline]
169 pub fn workspace_size(&self) -> usize {
170 0
171 }
172 #[inline]
174 pub fn sku(&self) -> KernelSku {
175 self.sku
176 }
177 #[inline]
181 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
182 self.sku.precision_guarantee
183 }
184
185 pub fn run(
188 &self,
189 stream: &Stream,
190 _workspace: Workspace<'_>,
191 args: SoftmaxBackwardArgs<'_, T, N>,
192 ) -> Result<()> {
193 self.can_implement(&args)?;
194 let numel = args.dx.numel();
195 if numel == 0 {
196 return Ok(());
197 }
198 let dy_ptr = args.dy.data.as_raw().0 as *const c_void;
199 let y_ptr = args.y.data.as_raw().0 as *const c_void;
200 let dx_ptr = args.dx.data.as_raw().0 as *mut c_void;
201 let stream_ptr = stream.as_raw() as *mut c_void;
202
203 let axis = self.desc.softmax_axis as usize;
204 let shape = self.desc.input_shape;
205 let stride_dy = args.dy.stride;
206 let stride_y = args.y.stride;
207 let stride_dx = args.dx.stride;
208 let rank = N as i32;
209 let extent = shape[axis];
210 let stride_dy_axis = stride_dy[axis];
211 let stride_y_axis = stride_y[axis];
212
213 macro_rules! dispatch {
214 ($sym:ident) => {
215 unsafe {
216 baracuda_kernels_sys::$sym(
217 numel,
218 rank,
219 shape.as_ptr(),
220 stride_dy.as_ptr(),
221 stride_y.as_ptr(),
222 stride_dx.as_ptr(),
223 axis as i32,
224 extent,
225 stride_dy_axis,
226 stride_y_axis,
227 dy_ptr,
228 y_ptr,
229 dx_ptr,
230 core::ptr::null_mut(),
231 0,
232 stream_ptr,
233 )
234 }
235 };
236 }
237
238 let status = match (self.desc.kind, T::KIND) {
239 (SoftmaxKind::Softmax, ElementKind::F32) => {
240 dispatch!(baracuda_kernels_softmax_backward_f32_run)
241 }
242 (SoftmaxKind::Softmax, ElementKind::F16) => {
243 dispatch!(baracuda_kernels_softmax_backward_f16_run)
244 }
245 (SoftmaxKind::Softmax, ElementKind::Bf16) => {
246 dispatch!(baracuda_kernels_softmax_backward_bf16_run)
247 }
248 (SoftmaxKind::Softmax, ElementKind::F64) => {
249 dispatch!(baracuda_kernels_softmax_backward_f64_run)
250 }
251 (SoftmaxKind::LogSoftmax, ElementKind::F32) => {
252 dispatch!(baracuda_kernels_log_softmax_backward_f32_run)
253 }
254 (SoftmaxKind::LogSoftmax, ElementKind::F16) => {
255 dispatch!(baracuda_kernels_log_softmax_backward_f16_run)
256 }
257 (SoftmaxKind::LogSoftmax, ElementKind::Bf16) => {
258 dispatch!(baracuda_kernels_log_softmax_backward_bf16_run)
259 }
260 (SoftmaxKind::LogSoftmax, ElementKind::F64) => {
261 dispatch!(baracuda_kernels_log_softmax_backward_f64_run)
262 }
263 _ => {
264 return Err(Error::Unsupported(
265 "baracuda-kernels::SoftmaxBackwardPlan::run reached an unimplemented \
266 (kind, dtype) pair — select() should have caught this",
267 ));
268 }
269 };
270 map_status(status)
271 }
272}
273
274fn map_status(code: i32) -> Result<()> {
275 match code {
276 0 => Ok(()),
277 1 => Err(Error::MisalignedOperand),
278 2 => Err(Error::InvalidProblem(
279 "baracuda-kernels-sys reported invalid problem",
280 )),
281 3 => Err(Error::Unsupported(
282 "baracuda-kernels-sys reported unsupported configuration",
283 )),
284 4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
285 n => Err(Error::CutlassInternal(n)),
286 }
287}