1use core::ffi::c_void;
29use core::marker::PhantomData;
30
31use baracuda_cutlass::{Error, Result};
32use baracuda_driver::Stream;
33use baracuda_kernels_types::{
34 ArchSku, BackendKind, BinaryKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
35 PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
36};
37
38#[derive(Copy, Clone, Debug)]
40pub struct BinaryBackwardDescriptor<const N: usize> {
41 pub kind: BinaryKind,
43 pub shape: [i32; N],
45 pub element: ElementKind,
47}
48
49pub struct BinaryBackwardArgs<'a, T: Element, const N: usize> {
56 pub dy: TensorRef<'a, T, N>,
58 pub a: Option<TensorRef<'a, T, N>>,
60 pub b: Option<TensorRef<'a, T, N>>,
62 pub da: TensorMut<'a, T, N>,
64 pub db: TensorMut<'a, T, N>,
66}
67
68pub struct BinaryBackwardPlan<T: Element, const N: usize> {
70 desc: BinaryBackwardDescriptor<N>,
71 sku: KernelSku,
72 _marker: PhantomData<T>,
73}
74
75#[inline]
76fn op_needs_saves(kind: BinaryKind) -> bool {
77 matches!(
78 kind,
79 BinaryKind::Mul
80 | BinaryKind::Div
81 | BinaryKind::Pow
82 | BinaryKind::Maximum
83 | BinaryKind::Minimum
84 | BinaryKind::Atan2
85 | BinaryKind::Hypot
86 )
87}
88
89impl<T: Element, const N: usize> BinaryBackwardPlan<T, N> {
90 pub fn select(
92 _stream: &Stream,
93 desc: &BinaryBackwardDescriptor<N>,
94 _pref: PlanPreference,
95 ) -> Result<Self> {
96 if desc.element != T::KIND {
97 return Err(Error::Unsupported(
98 "baracuda-kernels::BinaryBackwardPlan: descriptor element != T",
99 ));
100 }
101 for &d in desc.shape.iter() {
102 if d < 0 {
103 return Err(Error::InvalidProblem(
104 "baracuda-kernels::BinaryBackwardPlan: shape dims must be non-negative",
105 ));
106 }
107 }
108 let supported = matches!(
111 (desc.kind, T::KIND),
112 (BinaryKind::Add, ElementKind::F32)
113 | (BinaryKind::Add, ElementKind::F16)
114 | (BinaryKind::Add, ElementKind::Bf16)
115 | (BinaryKind::Add, ElementKind::F64)
116 | (BinaryKind::Sub, ElementKind::F32)
117 | (BinaryKind::Sub, ElementKind::F16)
118 | (BinaryKind::Sub, ElementKind::Bf16)
119 | (BinaryKind::Sub, ElementKind::F64)
120 | (BinaryKind::Mul, ElementKind::F32)
121 | (BinaryKind::Mul, ElementKind::F16)
122 | (BinaryKind::Mul, ElementKind::Bf16)
123 | (BinaryKind::Mul, ElementKind::F64)
124 | (BinaryKind::Div, ElementKind::F32)
125 | (BinaryKind::Div, ElementKind::F16)
126 | (BinaryKind::Div, ElementKind::Bf16)
127 | (BinaryKind::Div, ElementKind::F64)
128 | (BinaryKind::Maximum, ElementKind::F32)
129 | (BinaryKind::Maximum, ElementKind::F16)
130 | (BinaryKind::Maximum, ElementKind::Bf16)
131 | (BinaryKind::Maximum, ElementKind::F64)
132 | (BinaryKind::Minimum, ElementKind::F32)
133 | (BinaryKind::Minimum, ElementKind::F16)
134 | (BinaryKind::Minimum, ElementKind::Bf16)
135 | (BinaryKind::Minimum, ElementKind::F64)
136 | (BinaryKind::Pow, ElementKind::F32)
137 | (BinaryKind::Pow, ElementKind::F16)
138 | (BinaryKind::Pow, ElementKind::Bf16)
139 | (BinaryKind::Pow, ElementKind::F64)
140 | (BinaryKind::Atan2, ElementKind::F32)
141 | (BinaryKind::Atan2, ElementKind::F16)
142 | (BinaryKind::Atan2, ElementKind::Bf16)
143 | (BinaryKind::Atan2, ElementKind::F64)
144 | (BinaryKind::Hypot, ElementKind::F32)
145 | (BinaryKind::Hypot, ElementKind::F16)
146 | (BinaryKind::Hypot, ElementKind::Bf16)
147 | (BinaryKind::Hypot, ElementKind::F64)
148 );
149 if !supported {
150 return Err(Error::Unsupported(
151 "baracuda-kernels::BinaryBackwardPlan: only \
152 `{Add,Sub,Mul,Div,Maximum,Minimum,Pow,Atan2,Hypot}` × \
153 `{f32, f16, bf16, f64}` are wired today; other (kind, dtype) \
154 pairs (e.g. integer family, Lerp) land in later fanout. Lerp \
155 is reserved-but-deferred pending a parameterized-binary plan \
156 shape.",
157 ));
158 }
159
160 let precision_guarantee = PrecisionGuarantee {
161 math_precision: MathPrecision::F32,
162 accumulator: ElementKind::F32,
163 bit_stable_on_same_hardware: true,
164 deterministic: true,
165 };
166 let sku = KernelSku {
167 category: OpCategory::BinaryElementwise,
168 op: desc.kind as u16,
171 element: T::KIND,
172 aux_element: None,
173 layout: None,
174 epilogue: None,
175 arch: ArchSku::Sm80,
176 backend: BackendKind::Bespoke,
177 precision_guarantee,
178 };
179 Ok(Self {
180 desc: *desc,
181 sku,
182 _marker: PhantomData,
183 })
184 }
185
186 pub fn can_implement(&self, args: &BinaryBackwardArgs<'_, T, N>) -> Result<()> {
188 if args.dy.shape != self.desc.shape {
189 return Err(Error::InvalidProblem(
190 "baracuda-kernels::BinaryBackwardPlan: dy shape mismatch",
191 ));
192 }
193 if args.da.shape != self.desc.shape {
194 return Err(Error::InvalidProblem(
195 "baracuda-kernels::BinaryBackwardPlan: da shape mismatch",
196 ));
197 }
198 if args.db.shape != self.desc.shape {
199 return Err(Error::InvalidProblem(
200 "baracuda-kernels::BinaryBackwardPlan: db shape mismatch",
201 ));
202 }
203 if !args.dy.is_contiguous() || !args.da.is_contiguous() || !args.db.is_contiguous() {
205 return Err(Error::Unsupported(
206 "baracuda-kernels::BinaryBackwardPlan: trailblazer requires contiguous \
207 dy / da / db; strided fanout lands later",
208 ));
209 }
210 if op_needs_saves(self.desc.kind) {
212 let a = args.a.as_ref().ok_or(Error::InvalidProblem(
213 "baracuda-kernels::BinaryBackwardPlan: this op requires saved input `a`",
214 ))?;
215 let b = args.b.as_ref().ok_or(Error::InvalidProblem(
216 "baracuda-kernels::BinaryBackwardPlan: this op requires saved input `b`",
217 ))?;
218 if a.shape != self.desc.shape {
219 return Err(Error::InvalidProblem(
220 "baracuda-kernels::BinaryBackwardPlan: saved a shape mismatch",
221 ));
222 }
223 if b.shape != self.desc.shape {
224 return Err(Error::InvalidProblem(
225 "baracuda-kernels::BinaryBackwardPlan: saved b shape mismatch",
226 ));
227 }
228 if !a.is_contiguous() || !b.is_contiguous() {
229 return Err(Error::Unsupported(
230 "baracuda-kernels::BinaryBackwardPlan: saved a/b must be contiguous \
231 (strided fanout lands later)",
232 ));
233 }
234 let numel = args.dy.numel() as usize;
235 if a.data.len() < numel {
236 return Err(Error::BufferTooSmall {
237 needed: numel,
238 got: a.data.len(),
239 });
240 }
241 if b.data.len() < numel {
242 return Err(Error::BufferTooSmall {
243 needed: numel,
244 got: b.data.len(),
245 });
246 }
247 }
248 let numel = args.dy.numel();
249 let dy_len = args.dy.data.len() as i64;
250 let da_len = args.da.data.len() as i64;
251 let db_len = args.db.data.len() as i64;
252 if dy_len < numel || da_len < numel || db_len < numel {
253 return Err(Error::BufferTooSmall {
254 needed: numel as usize,
255 got: dy_len.min(da_len).min(db_len) as usize,
256 });
257 }
258 Ok(())
259 }
260
261 #[inline]
263 pub fn workspace_size(&self) -> usize {
264 0
265 }
266 #[inline]
268 pub fn sku(&self) -> KernelSku {
269 self.sku
270 }
271 #[inline]
273 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
274 self.sku.precision_guarantee
275 }
276
277 pub fn run(
279 &self,
280 stream: &Stream,
281 _workspace: Workspace<'_>,
282 args: BinaryBackwardArgs<'_, T, N>,
283 ) -> Result<()> {
284 self.can_implement(&args)?;
285 let numel = args.dy.numel();
286 if numel == 0 {
287 return Ok(());
288 }
289 let dy_ptr = args.dy.data.as_raw().0 as *const c_void;
290 let da_ptr = args.da.data.as_raw().0 as *mut c_void;
291 let db_ptr = args.db.data.as_raw().0 as *mut c_void;
292 let stream_ptr = stream.as_raw() as *mut c_void;
293
294 let status = match (self.desc.kind, T::KIND) {
295 (BinaryKind::Add, ElementKind::F32) => unsafe {
297 baracuda_kernels_sys::baracuda_kernels_binary_add_backward_f32_run(
298 numel, dy_ptr, da_ptr, db_ptr,
299 core::ptr::null_mut(), 0, stream_ptr,
300 )
301 },
302 (BinaryKind::Add, ElementKind::F16) => unsafe {
303 baracuda_kernels_sys::baracuda_kernels_binary_add_backward_f16_run(
304 numel, dy_ptr, da_ptr, db_ptr,
305 core::ptr::null_mut(), 0, stream_ptr,
306 )
307 },
308 (BinaryKind::Add, ElementKind::Bf16) => unsafe {
309 baracuda_kernels_sys::baracuda_kernels_binary_add_backward_bf16_run(
310 numel, dy_ptr, da_ptr, db_ptr,
311 core::ptr::null_mut(), 0, stream_ptr,
312 )
313 },
314 (BinaryKind::Add, ElementKind::F64) => unsafe {
315 baracuda_kernels_sys::baracuda_kernels_binary_add_backward_f64_run(
316 numel, dy_ptr, da_ptr, db_ptr,
317 core::ptr::null_mut(), 0, stream_ptr,
318 )
319 },
320 (BinaryKind::Sub, ElementKind::F32) => unsafe {
322 baracuda_kernels_sys::baracuda_kernels_binary_sub_backward_f32_run(
323 numel, dy_ptr, da_ptr, db_ptr,
324 core::ptr::null_mut(), 0, stream_ptr,
325 )
326 },
327 (BinaryKind::Sub, ElementKind::F16) => unsafe {
328 baracuda_kernels_sys::baracuda_kernels_binary_sub_backward_f16_run(
329 numel, dy_ptr, da_ptr, db_ptr,
330 core::ptr::null_mut(), 0, stream_ptr,
331 )
332 },
333 (BinaryKind::Sub, ElementKind::Bf16) => unsafe {
334 baracuda_kernels_sys::baracuda_kernels_binary_sub_backward_bf16_run(
335 numel, dy_ptr, da_ptr, db_ptr,
336 core::ptr::null_mut(), 0, stream_ptr,
337 )
338 },
339 (BinaryKind::Sub, ElementKind::F64) => unsafe {
340 baracuda_kernels_sys::baracuda_kernels_binary_sub_backward_f64_run(
341 numel, dy_ptr, da_ptr, db_ptr,
342 core::ptr::null_mut(), 0, stream_ptr,
343 )
344 },
345 (BinaryKind::Mul, ElementKind::F32) => {
347 let (a_ptr, b_ptr) = saved_ptrs(&args);
348 unsafe {
349 baracuda_kernels_sys::baracuda_kernels_binary_mul_backward_f32_run(
350 numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
351 core::ptr::null_mut(), 0, stream_ptr,
352 )
353 }
354 }
355 (BinaryKind::Mul, ElementKind::F16) => {
356 let (a_ptr, b_ptr) = saved_ptrs(&args);
357 unsafe {
358 baracuda_kernels_sys::baracuda_kernels_binary_mul_backward_f16_run(
359 numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
360 core::ptr::null_mut(), 0, stream_ptr,
361 )
362 }
363 }
364 (BinaryKind::Mul, ElementKind::Bf16) => {
365 let (a_ptr, b_ptr) = saved_ptrs(&args);
366 unsafe {
367 baracuda_kernels_sys::baracuda_kernels_binary_mul_backward_bf16_run(
368 numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
369 core::ptr::null_mut(), 0, stream_ptr,
370 )
371 }
372 }
373 (BinaryKind::Mul, ElementKind::F64) => {
374 let (a_ptr, b_ptr) = saved_ptrs(&args);
375 unsafe {
376 baracuda_kernels_sys::baracuda_kernels_binary_mul_backward_f64_run(
377 numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
378 core::ptr::null_mut(), 0, stream_ptr,
379 )
380 }
381 }
382 (BinaryKind::Div, ElementKind::F32) => {
384 let (a_ptr, b_ptr) = saved_ptrs(&args);
385 unsafe {
386 baracuda_kernels_sys::baracuda_kernels_binary_div_backward_f32_run(
387 numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
388 core::ptr::null_mut(), 0, stream_ptr,
389 )
390 }
391 }
392 (BinaryKind::Div, ElementKind::F16) => {
393 let (a_ptr, b_ptr) = saved_ptrs(&args);
394 unsafe {
395 baracuda_kernels_sys::baracuda_kernels_binary_div_backward_f16_run(
396 numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
397 core::ptr::null_mut(), 0, stream_ptr,
398 )
399 }
400 }
401 (BinaryKind::Div, ElementKind::Bf16) => {
402 let (a_ptr, b_ptr) = saved_ptrs(&args);
403 unsafe {
404 baracuda_kernels_sys::baracuda_kernels_binary_div_backward_bf16_run(
405 numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
406 core::ptr::null_mut(), 0, stream_ptr,
407 )
408 }
409 }
410 (BinaryKind::Div, ElementKind::F64) => {
411 let (a_ptr, b_ptr) = saved_ptrs(&args);
412 unsafe {
413 baracuda_kernels_sys::baracuda_kernels_binary_div_backward_f64_run(
414 numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
415 core::ptr::null_mut(), 0, stream_ptr,
416 )
417 }
418 }
419 (BinaryKind::Maximum, ElementKind::F32) => {
421 let (a_ptr, b_ptr) = saved_ptrs(&args);
422 unsafe {
423 baracuda_kernels_sys::baracuda_kernels_binary_maximum_backward_f32_run(
424 numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
425 core::ptr::null_mut(), 0, stream_ptr,
426 )
427 }
428 }
429 (BinaryKind::Maximum, ElementKind::F16) => {
430 let (a_ptr, b_ptr) = saved_ptrs(&args);
431 unsafe {
432 baracuda_kernels_sys::baracuda_kernels_binary_maximum_backward_f16_run(
433 numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
434 core::ptr::null_mut(), 0, stream_ptr,
435 )
436 }
437 }
438 (BinaryKind::Maximum, ElementKind::Bf16) => {
439 let (a_ptr, b_ptr) = saved_ptrs(&args);
440 unsafe {
441 baracuda_kernels_sys::baracuda_kernels_binary_maximum_backward_bf16_run(
442 numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
443 core::ptr::null_mut(), 0, stream_ptr,
444 )
445 }
446 }
447 (BinaryKind::Maximum, ElementKind::F64) => {
448 let (a_ptr, b_ptr) = saved_ptrs(&args);
449 unsafe {
450 baracuda_kernels_sys::baracuda_kernels_binary_maximum_backward_f64_run(
451 numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
452 core::ptr::null_mut(), 0, stream_ptr,
453 )
454 }
455 }
456 (BinaryKind::Minimum, ElementKind::F32) => {
458 let (a_ptr, b_ptr) = saved_ptrs(&args);
459 unsafe {
460 baracuda_kernels_sys::baracuda_kernels_binary_minimum_backward_f32_run(
461 numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
462 core::ptr::null_mut(), 0, stream_ptr,
463 )
464 }
465 }
466 (BinaryKind::Minimum, ElementKind::F16) => {
467 let (a_ptr, b_ptr) = saved_ptrs(&args);
468 unsafe {
469 baracuda_kernels_sys::baracuda_kernels_binary_minimum_backward_f16_run(
470 numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
471 core::ptr::null_mut(), 0, stream_ptr,
472 )
473 }
474 }
475 (BinaryKind::Minimum, ElementKind::Bf16) => {
476 let (a_ptr, b_ptr) = saved_ptrs(&args);
477 unsafe {
478 baracuda_kernels_sys::baracuda_kernels_binary_minimum_backward_bf16_run(
479 numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
480 core::ptr::null_mut(), 0, stream_ptr,
481 )
482 }
483 }
484 (BinaryKind::Minimum, ElementKind::F64) => {
485 let (a_ptr, b_ptr) = saved_ptrs(&args);
486 unsafe {
487 baracuda_kernels_sys::baracuda_kernels_binary_minimum_backward_f64_run(
488 numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
489 core::ptr::null_mut(), 0, stream_ptr,
490 )
491 }
492 }
493 (BinaryKind::Pow, ElementKind::F32) => {
495 let (a_ptr, b_ptr) = saved_ptrs(&args);
496 unsafe {
497 baracuda_kernels_sys::baracuda_kernels_binary_pow_backward_f32_run(
498 numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
499 core::ptr::null_mut(), 0, stream_ptr,
500 )
501 }
502 }
503 (BinaryKind::Pow, ElementKind::F16) => {
504 let (a_ptr, b_ptr) = saved_ptrs(&args);
505 unsafe {
506 baracuda_kernels_sys::baracuda_kernels_binary_pow_backward_f16_run(
507 numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
508 core::ptr::null_mut(), 0, stream_ptr,
509 )
510 }
511 }
512 (BinaryKind::Pow, ElementKind::Bf16) => {
513 let (a_ptr, b_ptr) = saved_ptrs(&args);
514 unsafe {
515 baracuda_kernels_sys::baracuda_kernels_binary_pow_backward_bf16_run(
516 numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
517 core::ptr::null_mut(), 0, stream_ptr,
518 )
519 }
520 }
521 (BinaryKind::Pow, ElementKind::F64) => {
522 let (a_ptr, b_ptr) = saved_ptrs(&args);
523 unsafe {
524 baracuda_kernels_sys::baracuda_kernels_binary_pow_backward_f64_run(
525 numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
526 core::ptr::null_mut(), 0, stream_ptr,
527 )
528 }
529 }
530 (BinaryKind::Atan2, ElementKind::F32) => {
532 let (a_ptr, b_ptr) = saved_ptrs(&args);
533 unsafe {
534 baracuda_kernels_sys::baracuda_kernels_binary_atan2_backward_f32_run(
535 numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
536 core::ptr::null_mut(), 0, stream_ptr,
537 )
538 }
539 }
540 (BinaryKind::Atan2, ElementKind::F16) => {
541 let (a_ptr, b_ptr) = saved_ptrs(&args);
542 unsafe {
543 baracuda_kernels_sys::baracuda_kernels_binary_atan2_backward_f16_run(
544 numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
545 core::ptr::null_mut(), 0, stream_ptr,
546 )
547 }
548 }
549 (BinaryKind::Atan2, ElementKind::Bf16) => {
550 let (a_ptr, b_ptr) = saved_ptrs(&args);
551 unsafe {
552 baracuda_kernels_sys::baracuda_kernels_binary_atan2_backward_bf16_run(
553 numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
554 core::ptr::null_mut(), 0, stream_ptr,
555 )
556 }
557 }
558 (BinaryKind::Atan2, ElementKind::F64) => {
559 let (a_ptr, b_ptr) = saved_ptrs(&args);
560 unsafe {
561 baracuda_kernels_sys::baracuda_kernels_binary_atan2_backward_f64_run(
562 numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
563 core::ptr::null_mut(), 0, stream_ptr,
564 )
565 }
566 }
567 (BinaryKind::Hypot, ElementKind::F32) => {
569 let (a_ptr, b_ptr) = saved_ptrs(&args);
570 unsafe {
571 baracuda_kernels_sys::baracuda_kernels_binary_hypot_backward_f32_run(
572 numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
573 core::ptr::null_mut(), 0, stream_ptr,
574 )
575 }
576 }
577 (BinaryKind::Hypot, ElementKind::F16) => {
578 let (a_ptr, b_ptr) = saved_ptrs(&args);
579 unsafe {
580 baracuda_kernels_sys::baracuda_kernels_binary_hypot_backward_f16_run(
581 numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
582 core::ptr::null_mut(), 0, stream_ptr,
583 )
584 }
585 }
586 (BinaryKind::Hypot, ElementKind::Bf16) => {
587 let (a_ptr, b_ptr) = saved_ptrs(&args);
588 unsafe {
589 baracuda_kernels_sys::baracuda_kernels_binary_hypot_backward_bf16_run(
590 numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
591 core::ptr::null_mut(), 0, stream_ptr,
592 )
593 }
594 }
595 (BinaryKind::Hypot, ElementKind::F64) => {
596 let (a_ptr, b_ptr) = saved_ptrs(&args);
597 unsafe {
598 baracuda_kernels_sys::baracuda_kernels_binary_hypot_backward_f64_run(
599 numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
600 core::ptr::null_mut(), 0, stream_ptr,
601 )
602 }
603 }
604 _ => {
605 return Err(Error::Unsupported(
606 "baracuda-kernels::BinaryBackwardPlan::run reached an unimplemented \
607 (kind, dtype) pair — select() should have caught this",
608 ));
609 }
610 };
611 map_status(status)
612 }
613}
614
615#[inline]
616fn saved_ptrs<T: Element, const N: usize>(
617 args: &BinaryBackwardArgs<'_, T, N>,
618) -> (*const c_void, *const c_void) {
619 let a = args
621 .a
622 .as_ref()
623 .expect("Mul/Div/Pow/Maximum/Minimum/Atan2/Hypot backward require saved a");
624 let b = args
625 .b
626 .as_ref()
627 .expect("Mul/Div/Pow/Maximum/Minimum/Atan2/Hypot backward require saved b");
628 (
629 a.data.as_raw().0 as *const c_void,
630 b.data.as_raw().0 as *const c_void,
631 )
632}
633
634fn map_status(code: i32) -> Result<()> {
635 match code {
636 0 => Ok(()),
637 1 => Err(Error::MisalignedOperand),
638 2 => Err(Error::InvalidProblem(
639 "baracuda-kernels-sys reported invalid problem",
640 )),
641 3 => Err(Error::Unsupported(
642 "baracuda-kernels-sys reported unsupported configuration",
643 )),
644 4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
645 n => Err(Error::CutlassInternal(n)),
646 }
647}