1use core::ffi::c_void;
20use core::marker::PhantomData;
21
22use baracuda_cutlass::{Error, Result};
23use baracuda_driver::Stream;
24use baracuda_kernels_types::{
25 ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
26 PlanPreference, PrecisionGuarantee, ScanKind, TensorMut, TensorRef, Workspace,
27};
28
29#[derive(Copy, Clone, Debug)]
34pub struct ScanBackwardDescriptor<const N: usize> {
35 pub kind: ScanKind,
37 pub input_shape: [i32; N],
39 pub scan_axis: u8,
41 pub reverse: bool,
43 pub element: ElementKind,
45}
46
47pub struct ScanBackwardArgs<'a, T: Element, const N: usize> {
63 pub dy: TensorRef<'a, T, N>,
65 pub dx: TensorMut<'a, T, N>,
67 pub x: Option<TensorRef<'a, T, N>>,
69 pub y: Option<TensorRef<'a, T, N>>,
71}
72
73#[inline]
75fn op_needs_saved_x(kind: ScanKind) -> bool {
76 matches!(
77 kind,
78 ScanKind::Cumprod | ScanKind::Cummax | ScanKind::Cummin | ScanKind::LogCumsumExp
79 )
80}
81
82#[inline]
84fn op_needs_saved_y(kind: ScanKind) -> bool {
85 matches!(kind, ScanKind::Cumprod | ScanKind::LogCumsumExp)
86}
87
88pub struct ScanBackwardPlan<T: Element, const N: usize> {
90 desc: ScanBackwardDescriptor<N>,
91 sku: KernelSku,
92 _marker: PhantomData<T>,
93}
94
95impl<T: Element, const N: usize> ScanBackwardPlan<T, N> {
96 pub fn select(
98 _stream: &Stream,
99 desc: &ScanBackwardDescriptor<N>,
100 _pref: PlanPreference,
101 ) -> Result<Self> {
102 if desc.element != T::KIND {
103 return Err(Error::Unsupported(
104 "baracuda-kernels::ScanBackwardPlan: descriptor element != T",
105 ));
106 }
107 if (desc.scan_axis as usize) >= N {
108 return Err(Error::InvalidProblem(
109 "baracuda-kernels::ScanBackwardPlan: scan_axis out of range for rank N",
110 ));
111 }
112 for &d in desc.input_shape.iter() {
113 if d < 0 {
114 return Err(Error::InvalidProblem(
115 "baracuda-kernels::ScanBackwardPlan: shape dims must be non-negative",
116 ));
117 }
118 }
119 if N > 8 {
120 return Err(Error::Unsupported(
121 "baracuda-kernels::ScanBackwardPlan: tensor rank > 8 not supported",
122 ));
123 }
124 let dtype_in_fp_family = matches!(
125 T::KIND,
126 ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
127 );
128 let kind_supported = matches!(
129 desc.kind,
130 ScanKind::Cumsum
131 | ScanKind::Cumprod
132 | ScanKind::Cummax
133 | ScanKind::Cummin
134 | ScanKind::LogCumsumExp
135 );
136 if !kind_supported || !dtype_in_fp_family {
137 return Err(Error::Unsupported(
138 "baracuda-kernels::ScanBackwardPlan: wired today: \
139 `{Cumsum, Cumprod, Cummax, Cummin, LogCumsumExp} × {f32, f16, bf16, f64}`",
140 ));
141 }
142
143 let precision_guarantee = PrecisionGuarantee {
144 math_precision: MathPrecision::F32,
145 accumulator: ElementKind::F32,
146 bit_stable_on_same_hardware: true,
147 deterministic: true,
148 };
149 let sku = KernelSku {
150 category: OpCategory::Scan,
151 op: desc.kind as u16,
152 element: T::KIND,
153 aux_element: None,
154 layout: None,
155 epilogue: None,
156 arch: ArchSku::Sm80,
157 backend: BackendKind::Bespoke,
158 precision_guarantee,
159 };
160 Ok(Self {
161 desc: *desc,
162 sku,
163 _marker: PhantomData,
164 })
165 }
166
167 pub fn can_implement(&self, args: &ScanBackwardArgs<'_, T, N>) -> Result<()> {
169 if args.dy.shape != self.desc.input_shape {
170 return Err(Error::InvalidProblem(
171 "baracuda-kernels::ScanBackwardPlan: dy shape mismatch",
172 ));
173 }
174 if args.dx.shape != self.desc.input_shape {
175 return Err(Error::InvalidProblem(
176 "baracuda-kernels::ScanBackwardPlan: dx shape mismatch",
177 ));
178 }
179 let numel = args.dx.numel();
180 let dy_len = args.dy.data.len() as i64;
181 let dx_len = args.dx.data.len() as i64;
182 if dy_len < numel || dx_len < numel {
183 return Err(Error::BufferTooSmall {
184 needed: numel as usize,
185 got: dy_len.min(dx_len) as usize,
186 });
187 }
188 if op_needs_saved_x(self.desc.kind) {
190 let x = args.x.as_ref().ok_or(Error::InvalidProblem(
191 "baracuda-kernels::ScanBackwardPlan: Cumprod / Cummax / Cummin BW \
192 require args.x (saved forward input)",
193 ))?;
194 if x.shape != self.desc.input_shape {
195 return Err(Error::InvalidProblem(
196 "baracuda-kernels::ScanBackwardPlan: args.x shape mismatch",
197 ));
198 }
199 if (x.data.len() as i64) < numel {
200 return Err(Error::BufferTooSmall {
201 needed: numel as usize,
202 got: x.data.len(),
203 });
204 }
205 }
206 if op_needs_saved_y(self.desc.kind) {
207 let y = args.y.as_ref().ok_or(Error::InvalidProblem(
208 "baracuda-kernels::ScanBackwardPlan: Cumprod BW requires args.y \
209 (saved forward output)",
210 ))?;
211 if y.shape != self.desc.input_shape {
212 return Err(Error::InvalidProblem(
213 "baracuda-kernels::ScanBackwardPlan: args.y shape mismatch",
214 ));
215 }
216 if (y.data.len() as i64) < numel {
217 return Err(Error::BufferTooSmall {
218 needed: numel as usize,
219 got: y.data.len(),
220 });
221 }
222 }
223 Ok(())
224 }
225
226 #[inline]
228 pub fn workspace_size(&self) -> usize {
229 0
230 }
231 #[inline]
233 pub fn sku(&self) -> KernelSku {
234 self.sku
235 }
236 #[inline]
238 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
239 self.sku.precision_guarantee
240 }
241
242 pub fn run(
244 &self,
245 stream: &Stream,
246 _workspace: Workspace<'_>,
247 args: ScanBackwardArgs<'_, T, N>,
248 ) -> Result<()> {
249 self.can_implement(&args)?;
250 let numel = args.dx.numel();
251 if numel == 0 {
252 return Ok(());
253 }
254 let dy_ptr = args.dy.data.as_raw().0 as *const c_void;
255 let dx_ptr = args.dx.data.as_raw().0 as *mut c_void;
256 let stream_ptr = stream.as_raw() as *mut c_void;
257
258 let axis = self.desc.scan_axis as usize;
259 let shape = self.desc.input_shape;
260 let stride_dy = args.dy.stride;
261 let stride_dx = args.dx.stride;
262 let rank = N as i32;
263 let scan_extent = shape[axis];
264 let reverse_flag = if self.desc.reverse { 1i32 } else { 0 };
265
266 match self.desc.kind {
267 ScanKind::Cumsum => {
268 let scan_stride_dy = stride_dy[axis];
271 let cumsum_reverse = if self.desc.reverse { 0i32 } else { 1 };
272 macro_rules! dispatch_cumsum {
273 ($sym:ident) => {
274 unsafe {
275 baracuda_kernels_sys::$sym(
276 numel,
277 rank,
278 shape.as_ptr(),
279 stride_dy.as_ptr(),
280 stride_dx.as_ptr(),
281 axis as i32,
282 scan_extent,
283 scan_stride_dy,
284 cumsum_reverse,
285 dy_ptr,
286 dx_ptr,
287 core::ptr::null_mut(),
288 0,
289 stream_ptr,
290 )
291 }
292 };
293 }
294 let status = match T::KIND {
295 ElementKind::F32 => dispatch_cumsum!(baracuda_kernels_scan_cumsum_f32_run),
296 ElementKind::F16 => dispatch_cumsum!(baracuda_kernels_scan_cumsum_f16_run),
297 ElementKind::Bf16 => dispatch_cumsum!(baracuda_kernels_scan_cumsum_bf16_run),
298 ElementKind::F64 => dispatch_cumsum!(baracuda_kernels_scan_cumsum_f64_run),
299 _ => {
300 return Err(Error::Unsupported(
301 "baracuda-kernels::ScanBackwardPlan::run unsupported dtype for Cumsum",
302 ));
303 }
304 };
305 map_status(status)
306 }
307 ScanKind::Cumprod => {
308 let x_ref = args.x.expect("Cumprod BW requires saved x — validated above");
309 let y_ref = args.y.expect("Cumprod BW requires saved y — validated above");
310 let stride_x = x_ref.stride;
311 let stride_y = y_ref.stride;
312 let x_ptr = x_ref.data.as_raw().0 as *const c_void;
313 let y_ptr = y_ref.data.as_raw().0 as *const c_void;
314 macro_rules! dispatch_cumprod_bw {
315 ($sym:ident) => {
316 unsafe {
317 baracuda_kernels_sys::$sym(
318 numel,
319 rank,
320 shape.as_ptr(),
321 stride_dy.as_ptr(),
322 stride_x.as_ptr(),
323 stride_y.as_ptr(),
324 stride_dx.as_ptr(),
325 axis as i32,
326 scan_extent,
327 reverse_flag,
328 dy_ptr,
329 x_ptr,
330 y_ptr,
331 dx_ptr,
332 core::ptr::null_mut(),
333 0,
334 stream_ptr,
335 )
336 }
337 };
338 }
339 let status = match T::KIND {
340 ElementKind::F32 => {
341 dispatch_cumprod_bw!(baracuda_kernels_scan_cumprod_backward_f32_run)
342 }
343 ElementKind::F16 => {
344 dispatch_cumprod_bw!(baracuda_kernels_scan_cumprod_backward_f16_run)
345 }
346 ElementKind::Bf16 => {
347 dispatch_cumprod_bw!(baracuda_kernels_scan_cumprod_backward_bf16_run)
348 }
349 ElementKind::F64 => {
350 dispatch_cumprod_bw!(baracuda_kernels_scan_cumprod_backward_f64_run)
351 }
352 _ => {
353 return Err(Error::Unsupported(
354 "baracuda-kernels::ScanBackwardPlan::run unsupported dtype for Cumprod",
355 ));
356 }
357 };
358 map_status(status)
359 }
360 ScanKind::LogCumsumExp => {
361 let x_ref = args
362 .x
363 .expect("LogCumsumExp BW requires saved x — validated above");
364 let y_ref = args
365 .y
366 .expect("LogCumsumExp BW requires saved y — validated above");
367 let stride_x = x_ref.stride;
368 let stride_y = y_ref.stride;
369 let x_ptr = x_ref.data.as_raw().0 as *const c_void;
370 let y_ptr = y_ref.data.as_raw().0 as *const c_void;
371 macro_rules! dispatch_lcse_bw {
372 ($sym:ident) => {
373 unsafe {
374 baracuda_kernels_sys::$sym(
375 numel,
376 rank,
377 shape.as_ptr(),
378 stride_dy.as_ptr(),
379 stride_x.as_ptr(),
380 stride_y.as_ptr(),
381 stride_dx.as_ptr(),
382 axis as i32,
383 scan_extent,
384 reverse_flag,
385 dy_ptr,
386 x_ptr,
387 y_ptr,
388 dx_ptr,
389 core::ptr::null_mut(),
390 0,
391 stream_ptr,
392 )
393 }
394 };
395 }
396 let status = match T::KIND {
397 ElementKind::F32 => {
398 dispatch_lcse_bw!(baracuda_kernels_scan_log_cumsum_exp_backward_f32_run)
399 }
400 ElementKind::F16 => {
401 dispatch_lcse_bw!(baracuda_kernels_scan_log_cumsum_exp_backward_f16_run)
402 }
403 ElementKind::Bf16 => {
404 dispatch_lcse_bw!(baracuda_kernels_scan_log_cumsum_exp_backward_bf16_run)
405 }
406 ElementKind::F64 => {
407 dispatch_lcse_bw!(baracuda_kernels_scan_log_cumsum_exp_backward_f64_run)
408 }
409 _ => {
410 return Err(Error::Unsupported(
411 "baracuda-kernels::ScanBackwardPlan::run unsupported dtype for LogCumsumExp",
412 ));
413 }
414 };
415 map_status(status)
416 }
417 ScanKind::Cummax | ScanKind::Cummin => {
418 let x_ref = args
419 .x
420 .expect("Cummax/Cummin BW requires saved x — validated above");
421 let stride_x = x_ref.stride;
422 let x_ptr = x_ref.data.as_raw().0 as *const c_void;
423 macro_rules! dispatch_extrema_bw {
424 ($sym:ident) => {
425 unsafe {
426 baracuda_kernels_sys::$sym(
427 numel,
428 rank,
429 shape.as_ptr(),
430 stride_dy.as_ptr(),
431 stride_x.as_ptr(),
432 stride_dx.as_ptr(),
433 axis as i32,
434 scan_extent,
435 reverse_flag,
436 dy_ptr,
437 x_ptr,
438 dx_ptr,
439 core::ptr::null_mut(),
440 0,
441 stream_ptr,
442 )
443 }
444 };
445 }
446 let status = match (self.desc.kind, T::KIND) {
447 (ScanKind::Cummax, ElementKind::F32) => {
448 dispatch_extrema_bw!(baracuda_kernels_scan_cummax_backward_f32_run)
449 }
450 (ScanKind::Cummax, ElementKind::F16) => {
451 dispatch_extrema_bw!(baracuda_kernels_scan_cummax_backward_f16_run)
452 }
453 (ScanKind::Cummax, ElementKind::Bf16) => {
454 dispatch_extrema_bw!(baracuda_kernels_scan_cummax_backward_bf16_run)
455 }
456 (ScanKind::Cummax, ElementKind::F64) => {
457 dispatch_extrema_bw!(baracuda_kernels_scan_cummax_backward_f64_run)
458 }
459 (ScanKind::Cummin, ElementKind::F32) => {
460 dispatch_extrema_bw!(baracuda_kernels_scan_cummin_backward_f32_run)
461 }
462 (ScanKind::Cummin, ElementKind::F16) => {
463 dispatch_extrema_bw!(baracuda_kernels_scan_cummin_backward_f16_run)
464 }
465 (ScanKind::Cummin, ElementKind::Bf16) => {
466 dispatch_extrema_bw!(baracuda_kernels_scan_cummin_backward_bf16_run)
467 }
468 (ScanKind::Cummin, ElementKind::F64) => {
469 dispatch_extrema_bw!(baracuda_kernels_scan_cummin_backward_f64_run)
470 }
471 _ => {
472 return Err(Error::Unsupported(
473 "baracuda-kernels::ScanBackwardPlan::run reached an unimplemented \
474 (kind, dtype) pair for Cummax/Cummin",
475 ));
476 }
477 };
478 map_status(status)
479 }
480 _ => Err(Error::Unsupported(
484 "baracuda-kernels::ScanBackwardPlan::run reached an unimplemented ScanKind variant",
485 )),
486 }
487 }
488}
489
490fn map_status(code: i32) -> Result<()> {
491 match code {
492 0 => Ok(()),
493 1 => Err(Error::MisalignedOperand),
494 2 => Err(Error::InvalidProblem(
495 "baracuda-kernels-sys reported invalid problem",
496 )),
497 3 => Err(Error::Unsupported(
498 "baracuda-kernels-sys reported unsupported configuration",
499 )),
500 4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
501 n => Err(Error::CutlassInternal(n)),
502 }
503}