1use core::ffi::c_void;
30use core::marker::PhantomData;
31
32use baracuda_cutlass::{Error, Result};
33use baracuda_driver::Stream;
34use baracuda_kernels_types::{
35 ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
36 PlanPreference, PrecisionGuarantee, ReduceKind, TensorMut, TensorRef, Workspace,
37};
38
39#[derive(Copy, Clone, Debug)]
41pub struct ReduceBackwardDescriptor<const N: usize> {
42 pub kind: ReduceKind,
44 pub input_shape: [i32; N],
46 pub reduce_axis: u8,
48 pub element: ElementKind,
50 pub correction: i32,
54}
55
56impl<const N: usize> ReduceBackwardDescriptor<N> {
57 pub fn dy_shape(&self) -> [i32; N] {
59 let mut out = self.input_shape;
60 out[self.reduce_axis as usize] = 1;
61 out
62 }
63}
64
65pub struct ReduceBackwardArgs<'a, T: Element, const N: usize> {
79 pub dy: TensorRef<'a, T, N>,
81 pub x: Option<TensorRef<'a, T, N>>,
84 pub y: Option<TensorRef<'a, T, N>>,
87 pub dx: TensorMut<'a, T, N>,
89}
90
91pub struct ReduceBackwardPlan<T: Element, const N: usize> {
93 desc: ReduceBackwardDescriptor<N>,
94 sku: KernelSku,
95 _marker: PhantomData<T>,
96}
97
98#[inline]
99fn op_needs_saves(kind: ReduceKind) -> bool {
100 matches!(
106 kind,
107 ReduceKind::Max
108 | ReduceKind::Min
109 | ReduceKind::Prod
110 | ReduceKind::Norm2
111 | ReduceKind::Var
112 | ReduceKind::Std
113 | ReduceKind::LogSumExp
114 )
115}
116
117impl<T: Element, const N: usize> ReduceBackwardPlan<T, N> {
118 pub fn select(
120 _stream: &Stream,
121 desc: &ReduceBackwardDescriptor<N>,
122 _pref: PlanPreference,
123 ) -> Result<Self> {
124 if desc.element != T::KIND {
125 return Err(Error::Unsupported(
126 "baracuda-kernels::ReduceBackwardPlan: descriptor element != T",
127 ));
128 }
129 if (desc.reduce_axis as usize) >= N {
130 return Err(Error::InvalidProblem(
131 "baracuda-kernels::ReduceBackwardPlan: reduce_axis out of range for rank N",
132 ));
133 }
134 for &d in desc.input_shape.iter() {
135 if d < 0 {
136 return Err(Error::InvalidProblem(
137 "baracuda-kernels::ReduceBackwardPlan: shape dims must be non-negative",
138 ));
139 }
140 }
141 if N > 8 {
142 return Err(Error::Unsupported(
143 "baracuda-kernels::ReduceBackwardPlan: tensor rank > 8 not supported \
144 (kernel param block fixes MAX_RANK = 8)",
145 ));
146 }
147
148 let dtype_in_fp_family = matches!(
158 T::KIND,
159 ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
160 );
161 let kind_in_scope = matches!(
162 desc.kind,
163 ReduceKind::Sum
164 | ReduceKind::Mean
165 | ReduceKind::Max
166 | ReduceKind::Min
167 | ReduceKind::Prod
168 | ReduceKind::Norm2
169 | ReduceKind::LogSumExp
170 | ReduceKind::Var
171 | ReduceKind::Std
172 );
173 let supported = kind_in_scope && dtype_in_fp_family;
174 if !supported {
175 return Err(Error::Unsupported(
176 "baracuda-kernels::ReduceBackwardPlan: wired today: \
177 `{Sum, Mean, Max, Min, Prod, Norm2, LogSumExp, Var, Std} \
178 × {f32, f16, bf16, f64}`; \
179 other (kind, dtype) pairs land in later fanout",
180 ));
181 }
182
183 let precision_guarantee = PrecisionGuarantee {
184 math_precision: MathPrecision::F32,
185 accumulator: ElementKind::F32,
186 bit_stable_on_same_hardware: true,
187 deterministic: true,
188 };
189 let sku = KernelSku {
190 category: OpCategory::Reduction,
191 op: desc.kind as u16,
192 element: T::KIND,
193 aux_element: None,
194 layout: None,
195 epilogue: None,
196 arch: ArchSku::Sm80,
197 backend: BackendKind::Bespoke,
198 precision_guarantee,
199 };
200 Ok(Self {
201 desc: *desc,
202 sku,
203 _marker: PhantomData,
204 })
205 }
206
207 pub fn can_implement(&self, args: &ReduceBackwardArgs<'_, T, N>) -> Result<()> {
209 if args.dx.shape != self.desc.input_shape {
210 return Err(Error::InvalidProblem(
211 "baracuda-kernels::ReduceBackwardPlan: dx shape must equal input_shape",
212 ));
213 }
214 let expected_dy_shape = self.desc.dy_shape();
215 if args.dy.shape != expected_dy_shape {
216 return Err(Error::InvalidProblem(
217 "baracuda-kernels::ReduceBackwardPlan: dy shape must equal input_shape \
218 with reduce_axis collapsed to 1 (keepdim form)",
219 ));
220 }
221 if !args.dy.is_contiguous() || !args.dx.is_contiguous() {
222 return Err(Error::Unsupported(
223 "baracuda-kernels::ReduceBackwardPlan: trailblazer requires contiguous \
224 dy / dx; strided fanout lands later",
225 ));
226 }
227 let dx_numel = args.dx.numel();
228 let dy_numel = args.dy.numel();
229 if (args.dx.data.len() as i64) < dx_numel {
230 return Err(Error::BufferTooSmall {
231 needed: dx_numel as usize,
232 got: args.dx.data.len(),
233 });
234 }
235 if (args.dy.data.len() as i64) < dy_numel {
236 return Err(Error::BufferTooSmall {
237 needed: dy_numel as usize,
238 got: args.dy.data.len(),
239 });
240 }
241 if op_needs_saves(self.desc.kind) {
244 let x = args.x.as_ref().ok_or(Error::InvalidProblem(
245 "baracuda-kernels::ReduceBackwardPlan: this op requires saved input `x`",
246 ))?;
247 let y = args.y.as_ref().ok_or(Error::InvalidProblem(
248 "baracuda-kernels::ReduceBackwardPlan: this op requires saved output `y`",
249 ))?;
250 if x.shape != self.desc.input_shape {
251 return Err(Error::InvalidProblem(
252 "baracuda-kernels::ReduceBackwardPlan: saved `x` shape must equal input_shape",
253 ));
254 }
255 if y.shape != expected_dy_shape {
256 return Err(Error::InvalidProblem(
257 "baracuda-kernels::ReduceBackwardPlan: saved `y` shape must equal \
258 keepdim form (input_shape with reduce_axis = 1)",
259 ));
260 }
261 if !x.is_contiguous() || !y.is_contiguous() {
262 return Err(Error::Unsupported(
263 "baracuda-kernels::ReduceBackwardPlan: saved x / y must be contiguous \
264 (strided fanout lands later)",
265 ));
266 }
267 if (x.data.len() as i64) < dx_numel {
268 return Err(Error::BufferTooSmall {
269 needed: dx_numel as usize,
270 got: x.data.len(),
271 });
272 }
273 if (y.data.len() as i64) < dy_numel {
274 return Err(Error::BufferTooSmall {
275 needed: dy_numel as usize,
276 got: y.data.len(),
277 });
278 }
279 }
280 Ok(())
281 }
282
283 #[inline]
285 pub fn workspace_size(&self) -> usize {
286 0
287 }
288
289 #[inline]
291 pub fn sku(&self) -> KernelSku {
292 self.sku
293 }
294
295 #[inline]
297 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
298 self.sku.precision_guarantee
299 }
300
301 pub fn run(
303 &self,
304 stream: &Stream,
305 _workspace: Workspace<'_>,
306 args: ReduceBackwardArgs<'_, T, N>,
307 ) -> Result<()> {
308 self.can_implement(&args)?;
309 let numel = args.dx.numel();
310 if numel == 0 {
311 return Ok(());
312 }
313 let axis = self.desc.reduce_axis as usize;
319 let mut stride_dy = args.dy.stride;
320 stride_dy[axis] = 0;
321 let shape = self.desc.input_shape;
322 let stride_dx = args.dx.stride;
323 let rank = N as i32;
324 let dy_ptr = args.dy.data.as_raw().0 as *const c_void;
325 let dx_ptr = args.dx.data.as_raw().0 as *mut c_void;
326 let stream_ptr = stream.as_raw() as *mut c_void;
327
328 let status = match (self.desc.kind, T::KIND) {
329 (ReduceKind::Sum, ElementKind::F32) => unsafe {
330 baracuda_kernels_sys::baracuda_kernels_reduce_sum_backward_f32_run(
331 numel, rank, shape.as_ptr(), stride_dy.as_ptr(), stride_dx.as_ptr(),
332 dy_ptr, dx_ptr, core::ptr::null_mut(), 0, stream_ptr,
333 )
334 },
335 (ReduceKind::Sum, ElementKind::F16) => unsafe {
336 baracuda_kernels_sys::baracuda_kernels_reduce_sum_backward_f16_run(
337 numel, rank, shape.as_ptr(), stride_dy.as_ptr(), stride_dx.as_ptr(),
338 dy_ptr, dx_ptr, core::ptr::null_mut(), 0, stream_ptr,
339 )
340 },
341 (ReduceKind::Sum, ElementKind::Bf16) => unsafe {
342 baracuda_kernels_sys::baracuda_kernels_reduce_sum_backward_bf16_run(
343 numel, rank, shape.as_ptr(), stride_dy.as_ptr(), stride_dx.as_ptr(),
344 dy_ptr, dx_ptr, core::ptr::null_mut(), 0, stream_ptr,
345 )
346 },
347 (ReduceKind::Sum, ElementKind::F64) => unsafe {
348 baracuda_kernels_sys::baracuda_kernels_reduce_sum_backward_f64_run(
349 numel, rank, shape.as_ptr(), stride_dy.as_ptr(), stride_dx.as_ptr(),
350 dy_ptr, dx_ptr, core::ptr::null_mut(), 0, stream_ptr,
351 )
352 },
353 (ReduceKind::Max, _) | (ReduceKind::Min, _) => {
354 let x = args.x.as_ref().expect("Max/Min BW require saved x");
358 let y = args.y.as_ref().expect("Max/Min BW require saved y");
359 let x_ptr = x.data.as_raw().0 as *const c_void;
360 let y_ptr = y.data.as_raw().0 as *const c_void;
361 let stride_x = x.stride;
362 let mut stride_y = y.stride;
363 stride_y[axis] = 0;
364 match T::KIND {
365 ElementKind::F32 => unsafe {
366 baracuda_kernels_sys::baracuda_kernels_reduce_max_min_backward_f32_run(
367 numel, rank, shape.as_ptr(),
368 stride_dy.as_ptr(), stride_x.as_ptr(),
369 stride_y.as_ptr(), stride_dx.as_ptr(),
370 dy_ptr, x_ptr, y_ptr, dx_ptr,
371 core::ptr::null_mut(), 0, stream_ptr,
372 )
373 },
374 ElementKind::F16 => unsafe {
375 baracuda_kernels_sys::baracuda_kernels_reduce_max_min_backward_f16_run(
376 numel, rank, shape.as_ptr(),
377 stride_dy.as_ptr(), stride_x.as_ptr(),
378 stride_y.as_ptr(), stride_dx.as_ptr(),
379 dy_ptr, x_ptr, y_ptr, dx_ptr,
380 core::ptr::null_mut(), 0, stream_ptr,
381 )
382 },
383 ElementKind::Bf16 => unsafe {
384 baracuda_kernels_sys::baracuda_kernels_reduce_max_min_backward_bf16_run(
385 numel, rank, shape.as_ptr(),
386 stride_dy.as_ptr(), stride_x.as_ptr(),
387 stride_y.as_ptr(), stride_dx.as_ptr(),
388 dy_ptr, x_ptr, y_ptr, dx_ptr,
389 core::ptr::null_mut(), 0, stream_ptr,
390 )
391 },
392 ElementKind::F64 => unsafe {
393 baracuda_kernels_sys::baracuda_kernels_reduce_max_min_backward_f64_run(
394 numel, rank, shape.as_ptr(),
395 stride_dy.as_ptr(), stride_x.as_ptr(),
396 stride_y.as_ptr(), stride_dx.as_ptr(),
397 dy_ptr, x_ptr, y_ptr, dx_ptr,
398 core::ptr::null_mut(), 0, stream_ptr,
399 )
400 },
401 _ => return Err(Error::Unsupported(
402 "baracuda-kernels::ReduceBackwardPlan::run: Max/Min BW reached an \
403 unimplemented dtype — select() should have caught this",
404 )),
405 }
406 }
407 (ReduceKind::Mean, _) => {
408 let extent = self.desc.input_shape[axis] as f64;
411 if extent == 0.0 {
412 return Err(Error::InvalidProblem(
413 "baracuda-kernels::ReduceBackwardPlan: Mean BW requires \
414 reduced extent > 0",
415 ));
416 }
417 let inv_extent = 1.0_f64 / extent;
418 match T::KIND {
419 ElementKind::F32 => unsafe {
420 baracuda_kernels_sys::baracuda_kernels_reduce_mean_backward_f32_run(
421 numel, rank, shape.as_ptr(), stride_dy.as_ptr(), stride_dx.as_ptr(),
422 dy_ptr, dx_ptr, inv_extent,
423 core::ptr::null_mut(), 0, stream_ptr,
424 )
425 },
426 ElementKind::F16 => unsafe {
427 baracuda_kernels_sys::baracuda_kernels_reduce_mean_backward_f16_run(
428 numel, rank, shape.as_ptr(), stride_dy.as_ptr(), stride_dx.as_ptr(),
429 dy_ptr, dx_ptr, inv_extent,
430 core::ptr::null_mut(), 0, stream_ptr,
431 )
432 },
433 ElementKind::Bf16 => unsafe {
434 baracuda_kernels_sys::baracuda_kernels_reduce_mean_backward_bf16_run(
435 numel, rank, shape.as_ptr(), stride_dy.as_ptr(), stride_dx.as_ptr(),
436 dy_ptr, dx_ptr, inv_extent,
437 core::ptr::null_mut(), 0, stream_ptr,
438 )
439 },
440 ElementKind::F64 => unsafe {
441 baracuda_kernels_sys::baracuda_kernels_reduce_mean_backward_f64_run(
442 numel, rank, shape.as_ptr(), stride_dy.as_ptr(), stride_dx.as_ptr(),
443 dy_ptr, dx_ptr, inv_extent,
444 core::ptr::null_mut(), 0, stream_ptr,
445 )
446 },
447 _ => return Err(Error::Unsupported(
448 "baracuda-kernels::ReduceBackwardPlan::run: Mean BW reached an \
449 unimplemented dtype — select() should have caught this",
450 )),
451 }
452 }
453 (ReduceKind::Prod, _) => {
454 let x = args.x.as_ref().expect("Prod BW require saved x");
456 let y = args.y.as_ref().expect("Prod BW require saved y");
457 let x_ptr = x.data.as_raw().0 as *const c_void;
458 let y_ptr = y.data.as_raw().0 as *const c_void;
459 let stride_x = x.stride;
460 let mut stride_y = y.stride;
461 stride_y[axis] = 0;
462 match T::KIND {
463 ElementKind::F32 => unsafe {
464 baracuda_kernels_sys::baracuda_kernels_reduce_prod_backward_f32_run(
465 numel, rank, shape.as_ptr(),
466 stride_dy.as_ptr(), stride_x.as_ptr(),
467 stride_y.as_ptr(), stride_dx.as_ptr(),
468 dy_ptr, x_ptr, y_ptr, dx_ptr,
469 core::ptr::null_mut(), 0, stream_ptr,
470 )
471 },
472 ElementKind::F16 => unsafe {
473 baracuda_kernels_sys::baracuda_kernels_reduce_prod_backward_f16_run(
474 numel, rank, shape.as_ptr(),
475 stride_dy.as_ptr(), stride_x.as_ptr(),
476 stride_y.as_ptr(), stride_dx.as_ptr(),
477 dy_ptr, x_ptr, y_ptr, dx_ptr,
478 core::ptr::null_mut(), 0, stream_ptr,
479 )
480 },
481 ElementKind::Bf16 => unsafe {
482 baracuda_kernels_sys::baracuda_kernels_reduce_prod_backward_bf16_run(
483 numel, rank, shape.as_ptr(),
484 stride_dy.as_ptr(), stride_x.as_ptr(),
485 stride_y.as_ptr(), stride_dx.as_ptr(),
486 dy_ptr, x_ptr, y_ptr, dx_ptr,
487 core::ptr::null_mut(), 0, stream_ptr,
488 )
489 },
490 ElementKind::F64 => unsafe {
491 baracuda_kernels_sys::baracuda_kernels_reduce_prod_backward_f64_run(
492 numel, rank, shape.as_ptr(),
493 stride_dy.as_ptr(), stride_x.as_ptr(),
494 stride_y.as_ptr(), stride_dx.as_ptr(),
495 dy_ptr, x_ptr, y_ptr, dx_ptr,
496 core::ptr::null_mut(), 0, stream_ptr,
497 )
498 },
499 _ => return Err(Error::Unsupported(
500 "baracuda-kernels::ReduceBackwardPlan::run: Prod BW reached an \
501 unimplemented dtype — select() should have caught this",
502 )),
503 }
504 }
505 (ReduceKind::Var, _) | (ReduceKind::Std, _) => {
506 let x = args
515 .x
516 .as_ref()
517 .expect("Var/Std BW require saved x");
518 let y = args
519 .y
520 .as_ref()
521 .expect("Var/Std BW require saved y (Var ignores it; passed for ABI uniformity)");
522 let x_ptr = x.data.as_raw().0 as *const c_void;
523 let y_ptr = y.data.as_raw().0 as *const c_void;
524 let stride_x = x.stride;
525 let mut stride_y = y.stride;
526 stride_y[axis] = 0;
527 let reduce_axis_i32 = self.desc.reduce_axis as i32;
528 let reduce_extent = self.desc.input_shape[axis];
529 let reduce_stride_x = stride_x[axis];
530 let correction = self.desc.correction;
531 match (self.desc.kind, T::KIND) {
532 (ReduceKind::Var, ElementKind::F32) => unsafe {
533 baracuda_kernels_sys::baracuda_kernels_reduce_var_backward_f32_run(
534 numel, rank, shape.as_ptr(),
535 stride_dy.as_ptr(), stride_x.as_ptr(),
536 stride_y.as_ptr(), stride_dx.as_ptr(),
537 dy_ptr, x_ptr, y_ptr, dx_ptr,
538 reduce_axis_i32, reduce_extent, reduce_stride_x, correction,
539 core::ptr::null_mut(), 0, stream_ptr,
540 )
541 },
542 (ReduceKind::Var, ElementKind::F16) => unsafe {
543 baracuda_kernels_sys::baracuda_kernels_reduce_var_backward_f16_run(
544 numel, rank, shape.as_ptr(),
545 stride_dy.as_ptr(), stride_x.as_ptr(),
546 stride_y.as_ptr(), stride_dx.as_ptr(),
547 dy_ptr, x_ptr, y_ptr, dx_ptr,
548 reduce_axis_i32, reduce_extent, reduce_stride_x, correction,
549 core::ptr::null_mut(), 0, stream_ptr,
550 )
551 },
552 (ReduceKind::Var, ElementKind::Bf16) => unsafe {
553 baracuda_kernels_sys::baracuda_kernels_reduce_var_backward_bf16_run(
554 numel, rank, shape.as_ptr(),
555 stride_dy.as_ptr(), stride_x.as_ptr(),
556 stride_y.as_ptr(), stride_dx.as_ptr(),
557 dy_ptr, x_ptr, y_ptr, dx_ptr,
558 reduce_axis_i32, reduce_extent, reduce_stride_x, correction,
559 core::ptr::null_mut(), 0, stream_ptr,
560 )
561 },
562 (ReduceKind::Var, ElementKind::F64) => unsafe {
563 baracuda_kernels_sys::baracuda_kernels_reduce_var_backward_f64_run(
564 numel, rank, shape.as_ptr(),
565 stride_dy.as_ptr(), stride_x.as_ptr(),
566 stride_y.as_ptr(), stride_dx.as_ptr(),
567 dy_ptr, x_ptr, y_ptr, dx_ptr,
568 reduce_axis_i32, reduce_extent, reduce_stride_x, correction,
569 core::ptr::null_mut(), 0, stream_ptr,
570 )
571 },
572 (ReduceKind::Std, ElementKind::F32) => unsafe {
573 baracuda_kernels_sys::baracuda_kernels_reduce_std_backward_f32_run(
574 numel, rank, shape.as_ptr(),
575 stride_dy.as_ptr(), stride_x.as_ptr(),
576 stride_y.as_ptr(), stride_dx.as_ptr(),
577 dy_ptr, x_ptr, y_ptr, dx_ptr,
578 reduce_axis_i32, reduce_extent, reduce_stride_x, correction,
579 core::ptr::null_mut(), 0, stream_ptr,
580 )
581 },
582 (ReduceKind::Std, ElementKind::F16) => unsafe {
583 baracuda_kernels_sys::baracuda_kernels_reduce_std_backward_f16_run(
584 numel, rank, shape.as_ptr(),
585 stride_dy.as_ptr(), stride_x.as_ptr(),
586 stride_y.as_ptr(), stride_dx.as_ptr(),
587 dy_ptr, x_ptr, y_ptr, dx_ptr,
588 reduce_axis_i32, reduce_extent, reduce_stride_x, correction,
589 core::ptr::null_mut(), 0, stream_ptr,
590 )
591 },
592 (ReduceKind::Std, ElementKind::Bf16) => unsafe {
593 baracuda_kernels_sys::baracuda_kernels_reduce_std_backward_bf16_run(
594 numel, rank, shape.as_ptr(),
595 stride_dy.as_ptr(), stride_x.as_ptr(),
596 stride_y.as_ptr(), stride_dx.as_ptr(),
597 dy_ptr, x_ptr, y_ptr, dx_ptr,
598 reduce_axis_i32, reduce_extent, reduce_stride_x, correction,
599 core::ptr::null_mut(), 0, stream_ptr,
600 )
601 },
602 (ReduceKind::Std, ElementKind::F64) => unsafe {
603 baracuda_kernels_sys::baracuda_kernels_reduce_std_backward_f64_run(
604 numel, rank, shape.as_ptr(),
605 stride_dy.as_ptr(), stride_x.as_ptr(),
606 stride_y.as_ptr(), stride_dx.as_ptr(),
607 dy_ptr, x_ptr, y_ptr, dx_ptr,
608 reduce_axis_i32, reduce_extent, reduce_stride_x, correction,
609 core::ptr::null_mut(), 0, stream_ptr,
610 )
611 },
612 _ => return Err(Error::Unsupported(
613 "baracuda-kernels::ReduceBackwardPlan::run: Var/Std BW reached an \
614 unimplemented dtype — select() should have caught this",
615 )),
616 }
617 }
618 (ReduceKind::Norm2, _) => {
619 let x = args.x.as_ref().expect("Norm2 BW require saved x");
621 let y = args.y.as_ref().expect("Norm2 BW require saved y");
622 let x_ptr = x.data.as_raw().0 as *const c_void;
623 let y_ptr = y.data.as_raw().0 as *const c_void;
624 let stride_x = x.stride;
625 let mut stride_y = y.stride;
626 stride_y[axis] = 0;
627 match T::KIND {
628 ElementKind::F32 => unsafe {
629 baracuda_kernels_sys::baracuda_kernels_reduce_norm2_backward_f32_run(
630 numel, rank, shape.as_ptr(),
631 stride_dy.as_ptr(), stride_x.as_ptr(),
632 stride_y.as_ptr(), stride_dx.as_ptr(),
633 dy_ptr, x_ptr, y_ptr, dx_ptr,
634 core::ptr::null_mut(), 0, stream_ptr,
635 )
636 },
637 ElementKind::F16 => unsafe {
638 baracuda_kernels_sys::baracuda_kernels_reduce_norm2_backward_f16_run(
639 numel, rank, shape.as_ptr(),
640 stride_dy.as_ptr(), stride_x.as_ptr(),
641 stride_y.as_ptr(), stride_dx.as_ptr(),
642 dy_ptr, x_ptr, y_ptr, dx_ptr,
643 core::ptr::null_mut(), 0, stream_ptr,
644 )
645 },
646 ElementKind::Bf16 => unsafe {
647 baracuda_kernels_sys::baracuda_kernels_reduce_norm2_backward_bf16_run(
648 numel, rank, shape.as_ptr(),
649 stride_dy.as_ptr(), stride_x.as_ptr(),
650 stride_y.as_ptr(), stride_dx.as_ptr(),
651 dy_ptr, x_ptr, y_ptr, dx_ptr,
652 core::ptr::null_mut(), 0, stream_ptr,
653 )
654 },
655 ElementKind::F64 => unsafe {
656 baracuda_kernels_sys::baracuda_kernels_reduce_norm2_backward_f64_run(
657 numel, rank, shape.as_ptr(),
658 stride_dy.as_ptr(), stride_x.as_ptr(),
659 stride_y.as_ptr(), stride_dx.as_ptr(),
660 dy_ptr, x_ptr, y_ptr, dx_ptr,
661 core::ptr::null_mut(), 0, stream_ptr,
662 )
663 },
664 _ => return Err(Error::Unsupported(
665 "baracuda-kernels::ReduceBackwardPlan::run: Norm2 BW reached an \
666 unimplemented dtype — select() should have caught this",
667 )),
668 }
669 }
670 (ReduceKind::LogSumExp, _) => {
671 let x = args.x.as_ref().expect("LogSumExp BW require saved x");
676 let y = args.y.as_ref().expect("LogSumExp BW require saved y");
677 let x_ptr = x.data.as_raw().0 as *const c_void;
678 let y_ptr = y.data.as_raw().0 as *const c_void;
679 let stride_x = x.stride;
680 let mut stride_y = y.stride;
681 stride_y[axis] = 0;
682 match T::KIND {
683 ElementKind::F32 => unsafe {
684 baracuda_kernels_sys::baracuda_kernels_reduce_logsumexp_backward_f32_run(
685 numel, rank, shape.as_ptr(),
686 stride_dy.as_ptr(), stride_x.as_ptr(),
687 stride_y.as_ptr(), stride_dx.as_ptr(),
688 dy_ptr, x_ptr, y_ptr, dx_ptr,
689 core::ptr::null_mut(), 0, stream_ptr,
690 )
691 },
692 ElementKind::F16 => unsafe {
693 baracuda_kernels_sys::baracuda_kernels_reduce_logsumexp_backward_f16_run(
694 numel, rank, shape.as_ptr(),
695 stride_dy.as_ptr(), stride_x.as_ptr(),
696 stride_y.as_ptr(), stride_dx.as_ptr(),
697 dy_ptr, x_ptr, y_ptr, dx_ptr,
698 core::ptr::null_mut(), 0, stream_ptr,
699 )
700 },
701 ElementKind::Bf16 => unsafe {
702 baracuda_kernels_sys::baracuda_kernels_reduce_logsumexp_backward_bf16_run(
703 numel, rank, shape.as_ptr(),
704 stride_dy.as_ptr(), stride_x.as_ptr(),
705 stride_y.as_ptr(), stride_dx.as_ptr(),
706 dy_ptr, x_ptr, y_ptr, dx_ptr,
707 core::ptr::null_mut(), 0, stream_ptr,
708 )
709 },
710 ElementKind::F64 => unsafe {
711 baracuda_kernels_sys::baracuda_kernels_reduce_logsumexp_backward_f64_run(
712 numel, rank, shape.as_ptr(),
713 stride_dy.as_ptr(), stride_x.as_ptr(),
714 stride_y.as_ptr(), stride_dx.as_ptr(),
715 dy_ptr, x_ptr, y_ptr, dx_ptr,
716 core::ptr::null_mut(), 0, stream_ptr,
717 )
718 },
719 _ => return Err(Error::Unsupported(
720 "baracuda-kernels::ReduceBackwardPlan::run: LogSumExp BW reached an \
721 unimplemented dtype — select() should have caught this",
722 )),
723 }
724 }
725 _ => {
726 return Err(Error::Unsupported(
727 "baracuda-kernels::ReduceBackwardPlan::run reached an unimplemented \
728 (kind, dtype) pair — select() should have caught this",
729 ));
730 }
731 };
732 map_status(status)
733 }
734}
735
736fn map_status(code: i32) -> Result<()> {
737 match code {
738 0 => Ok(()),
739 1 => Err(Error::MisalignedOperand),
740 2 => Err(Error::InvalidProblem(
741 "baracuda-kernels-sys reported invalid problem",
742 )),
743 3 => Err(Error::Unsupported(
744 "baracuda-kernels-sys reported unsupported configuration",
745 )),
746 4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
747 n => Err(Error::CutlassInternal(n)),
748 }
749}