1use core::ffi::c_void;
23use core::marker::PhantomData;
24
25use baracuda_cutlass::{Error, Result};
26use baracuda_driver::Stream;
27use baracuda_kernels_types::{
28 ArchSku, BackendKind, BinaryKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
29 PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
30};
31
32#[derive(Copy, Clone, Debug)]
39pub struct BinaryDescriptor<const N: usize> {
40 pub kind: BinaryKind,
42 pub shape: [i32; N],
44 pub element: ElementKind,
47}
48
49pub struct BinaryArgs<'a, T: Element, const N: usize> {
56 pub a: TensorRef<'a, T, N>,
58 pub b: TensorRef<'a, T, N>,
60 pub y: TensorMut<'a, T, N>,
62}
63
64pub struct BinaryPlan<T: Element, const N: usize> {
70 desc: BinaryDescriptor<N>,
71 sku: KernelSku,
72 _marker: PhantomData<T>,
73}
74
75impl<T: Element, const N: usize> BinaryPlan<T, N> {
76 pub fn select(
79 _stream: &Stream,
80 desc: &BinaryDescriptor<N>,
81 _pref: PlanPreference,
82 ) -> Result<Self> {
83 if desc.element != T::KIND {
84 return Err(Error::Unsupported(
85 "baracuda-kernels::BinaryPlan: descriptor element != type parameter T",
86 ));
87 }
88 for &d in desc.shape.iter() {
89 if d < 0 {
90 return Err(Error::InvalidProblem(
91 "baracuda-kernels::BinaryPlan: shape dims must be non-negative",
92 ));
93 }
94 }
95
96 let supported = matches!(
107 (desc.kind, T::KIND),
108 (BinaryKind::Add, ElementKind::F32)
109 | (BinaryKind::Add, ElementKind::F16)
110 | (BinaryKind::Add, ElementKind::Bf16)
111 | (BinaryKind::Add, ElementKind::F64)
112 | (BinaryKind::Sub, ElementKind::F32)
113 | (BinaryKind::Sub, ElementKind::F16)
114 | (BinaryKind::Sub, ElementKind::Bf16)
115 | (BinaryKind::Sub, ElementKind::F64)
116 | (BinaryKind::Mul, ElementKind::F32)
117 | (BinaryKind::Mul, ElementKind::F16)
118 | (BinaryKind::Mul, ElementKind::Bf16)
119 | (BinaryKind::Mul, ElementKind::F64)
120 | (BinaryKind::Div, ElementKind::F32)
121 | (BinaryKind::Div, ElementKind::F16)
122 | (BinaryKind::Div, ElementKind::Bf16)
123 | (BinaryKind::Div, ElementKind::F64)
124 | (BinaryKind::Pow, ElementKind::F32)
125 | (BinaryKind::Pow, ElementKind::F16)
126 | (BinaryKind::Pow, ElementKind::Bf16)
127 | (BinaryKind::Pow, ElementKind::F64)
128 | (BinaryKind::Atan2, ElementKind::F32)
129 | (BinaryKind::Atan2, ElementKind::F16)
130 | (BinaryKind::Atan2, ElementKind::Bf16)
131 | (BinaryKind::Atan2, ElementKind::F64)
132 | (BinaryKind::Hypot, ElementKind::F32)
133 | (BinaryKind::Hypot, ElementKind::F16)
134 | (BinaryKind::Hypot, ElementKind::Bf16)
135 | (BinaryKind::Hypot, ElementKind::F64)
136 | (BinaryKind::Copysign, ElementKind::F32)
137 | (BinaryKind::Copysign, ElementKind::F16)
138 | (BinaryKind::Copysign, ElementKind::Bf16)
139 | (BinaryKind::Copysign, ElementKind::F64)
140 | (BinaryKind::Nextafter, ElementKind::F32)
141 | (BinaryKind::Nextafter, ElementKind::F16)
142 | (BinaryKind::Nextafter, ElementKind::Bf16)
143 | (BinaryKind::Nextafter, ElementKind::F64)
144 | (BinaryKind::Fmin, ElementKind::F32)
145 | (BinaryKind::Fmin, ElementKind::F16)
146 | (BinaryKind::Fmin, ElementKind::Bf16)
147 | (BinaryKind::Fmin, ElementKind::F64)
148 | (BinaryKind::Fmax, ElementKind::F32)
149 | (BinaryKind::Fmax, ElementKind::F16)
150 | (BinaryKind::Fmax, ElementKind::Bf16)
151 | (BinaryKind::Fmax, ElementKind::F64)
152 | (BinaryKind::Maximum, ElementKind::F32)
153 | (BinaryKind::Maximum, ElementKind::F16)
154 | (BinaryKind::Maximum, ElementKind::Bf16)
155 | (BinaryKind::Maximum, ElementKind::F64)
156 | (BinaryKind::Minimum, ElementKind::F32)
157 | (BinaryKind::Minimum, ElementKind::F16)
158 | (BinaryKind::Minimum, ElementKind::Bf16)
159 | (BinaryKind::Minimum, ElementKind::F64)
160 | (BinaryKind::FloorDivide, ElementKind::F32)
161 | (BinaryKind::FloorDivide, ElementKind::F16)
162 | (BinaryKind::FloorDivide, ElementKind::Bf16)
163 | (BinaryKind::FloorDivide, ElementKind::F64)
164 | (BinaryKind::Mod, ElementKind::F32)
165 | (BinaryKind::Mod, ElementKind::F16)
166 | (BinaryKind::Mod, ElementKind::Bf16)
167 | (BinaryKind::Mod, ElementKind::F64)
168 | (BinaryKind::Remainder, ElementKind::F32)
169 | (BinaryKind::Remainder, ElementKind::F16)
170 | (BinaryKind::Remainder, ElementKind::Bf16)
171 | (BinaryKind::Remainder, ElementKind::F64)
172 | (BinaryKind::BitwiseAnd, ElementKind::I32)
176 | (BinaryKind::BitwiseAnd, ElementKind::I64)
177 | (BinaryKind::BitwiseOr, ElementKind::I32)
178 | (BinaryKind::BitwiseOr, ElementKind::I64)
179 | (BinaryKind::BitwiseXor, ElementKind::I32)
180 | (BinaryKind::BitwiseXor, ElementKind::I64)
181 | (BinaryKind::BitwiseLeftShift, ElementKind::I32)
182 | (BinaryKind::BitwiseLeftShift, ElementKind::I64)
183 | (BinaryKind::BitwiseRightShift, ElementKind::I32)
184 | (BinaryKind::BitwiseRightShift, ElementKind::I64)
185 | (BinaryKind::LogicalAnd, ElementKind::Bool)
186 | (BinaryKind::LogicalOr, ElementKind::Bool)
187 | (BinaryKind::LogicalXor, ElementKind::Bool)
188 );
189 if !supported {
190 return Err(Error::Unsupported(
191 "baracuda-kernels::BinaryPlan: today only \
192 `{Add,Sub,Mul,Div,Pow,Atan2,Hypot,Copysign,Nextafter,Fmin,Fmax,\
193 Maximum,Minimum,FloorDivide,Mod,Remainder}` \
194 × `{f32, f16, bf16, f64}` + Phase 3.3 integer / bool fanout \
195 (`{BitwiseAnd,BitwiseOr,BitwiseXor,BitwiseLeftShift,\
196 BitwiseRightShift}` × `{i32, i64}` and \
197 `{LogicalAnd,LogicalOr,LogicalXor}` × Bool — contig only); \
198 other (kind, dtype) pairs land in fanout sessions. Lerp is \
199 reserved-but-deferred pending a parameterized-binary plan \
200 shape.",
201 ));
202 }
203
204 let precision_guarantee = PrecisionGuarantee {
210 math_precision: MathPrecision::F32,
211 accumulator: ElementKind::F32,
212 bit_stable_on_same_hardware: true,
213 deterministic: true,
214 };
215 let sku = KernelSku {
216 category: OpCategory::BinaryElementwise,
217 op: desc.kind as u16,
218 element: T::KIND,
219 aux_element: None,
220 layout: None,
221 epilogue: None,
222 arch: ArchSku::Sm80,
223 backend: BackendKind::Bespoke,
224 precision_guarantee,
225 };
226 Ok(Self {
227 desc: *desc,
228 sku,
229 _marker: PhantomData,
230 })
231 }
232
233 pub fn can_implement(&self, args: &BinaryArgs<'_, T, N>) -> Result<()> {
244 if args.y.shape != self.desc.shape {
248 return Err(Error::InvalidProblem(
249 "baracuda-kernels::BinaryPlan: Y shape mismatch with descriptor",
250 ));
251 }
252
253 for d in 0..N {
256 let y_dim = self.desc.shape[d];
257 let a_dim = args.a.shape[d];
258 let b_dim = args.b.shape[d];
259 if a_dim != y_dim && !(a_dim == 1 && args.a.stride[d] == 0) {
260 return Err(Error::InvalidProblem(
261 "baracuda-kernels::BinaryPlan: A axis is not broadcast-compatible \
262 with output (require shape[d] == y.shape[d], OR \
263 shape[d] == 1 AND stride[d] == 0)",
264 ));
265 }
266 if b_dim != y_dim && !(b_dim == 1 && args.b.stride[d] == 0) {
267 return Err(Error::InvalidProblem(
268 "baracuda-kernels::BinaryPlan: B axis is not broadcast-compatible \
269 with output",
270 ));
271 }
272 }
273
274 if N > 8 {
278 return Err(Error::Unsupported(
279 "baracuda-kernels::BinaryPlan: tensor rank > 8 not supported \
280 (kernel param block fixes MAX_RANK = 8)",
281 ));
282 }
283
284 let y_numel = args.y.numel();
291 let a_numel = args.a.numel();
292 let b_numel = args.b.numel();
293 let a_len = args.a.data.len() as i64;
294 let b_len = args.b.data.len() as i64;
295 let y_len = args.y.data.len() as i64;
296 if y_len < y_numel {
297 return Err(Error::BufferTooSmall {
298 needed: y_numel as usize,
299 got: y_len as usize,
300 });
301 }
302 if a_len < a_numel {
303 return Err(Error::BufferTooSmall {
304 needed: a_numel as usize,
305 got: a_len as usize,
306 });
307 }
308 if b_len < b_numel {
309 return Err(Error::BufferTooSmall {
310 needed: b_numel as usize,
311 got: b_len as usize,
312 });
313 }
314 Ok(())
315 }
316
317 #[inline]
319 pub fn workspace_size(&self) -> usize {
320 0
321 }
322
323 #[inline]
325 pub fn sku(&self) -> KernelSku {
326 self.sku
327 }
328
329 #[inline]
331 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
332 self.sku.precision_guarantee
333 }
334
335 pub fn run(
337 &self,
338 stream: &Stream,
339 _workspace: Workspace<'_>,
340 args: BinaryArgs<'_, T, N>,
341 ) -> Result<()> {
342 self.can_implement(&args)?;
343 let numel = args.y.numel();
344 if numel == 0 {
345 return Ok(());
346 }
347 let a_ptr = args.a.data.as_raw().0 as *const c_void;
348 let b_ptr = args.b.data.as_raw().0 as *const c_void;
349 let y_ptr = args.y.data.as_raw().0 as *mut c_void;
350 let stream_ptr = stream.as_raw() as *mut c_void;
351
352 let all_contig_same_shape = args.a.shape == args.y.shape
357 && args.b.shape == args.y.shape
358 && args.a.is_contiguous()
359 && args.b.is_contiguous()
360 && args.y.is_contiguous();
361
362 if !all_contig_same_shape {
363 return self.run_strided(stream_ptr, a_ptr, b_ptr, y_ptr, numel, &args);
364 }
365
366 let status = match (self.desc.kind, T::KIND) {
367 (BinaryKind::Add, ElementKind::F32) => unsafe {
368 baracuda_kernels_sys::baracuda_kernels_binary_add_f32_run(
369 numel,
370 a_ptr,
371 b_ptr,
372 y_ptr,
373 core::ptr::null_mut(),
374 0,
375 stream_ptr,
376 )
377 },
378 (BinaryKind::Sub, ElementKind::F32) => unsafe {
379 baracuda_kernels_sys::baracuda_kernels_binary_sub_f32_run(
380 numel,
381 a_ptr,
382 b_ptr,
383 y_ptr,
384 core::ptr::null_mut(),
385 0,
386 stream_ptr,
387 )
388 },
389 (BinaryKind::Mul, ElementKind::F32) => unsafe {
390 baracuda_kernels_sys::baracuda_kernels_binary_mul_f32_run(
391 numel,
392 a_ptr,
393 b_ptr,
394 y_ptr,
395 core::ptr::null_mut(),
396 0,
397 stream_ptr,
398 )
399 },
400 (BinaryKind::Div, ElementKind::F32) => unsafe {
401 baracuda_kernels_sys::baracuda_kernels_binary_div_f32_run(
402 numel,
403 a_ptr,
404 b_ptr,
405 y_ptr,
406 core::ptr::null_mut(),
407 0,
408 stream_ptr,
409 )
410 },
411 (BinaryKind::Add, ElementKind::F16) => unsafe {
412 baracuda_kernels_sys::baracuda_kernels_binary_add_f16_run(
413 numel,
414 a_ptr,
415 b_ptr,
416 y_ptr,
417 core::ptr::null_mut(),
418 0,
419 stream_ptr,
420 )
421 },
422 (BinaryKind::Add, ElementKind::Bf16) => unsafe {
423 baracuda_kernels_sys::baracuda_kernels_binary_add_bf16_run(
424 numel,
425 a_ptr,
426 b_ptr,
427 y_ptr,
428 core::ptr::null_mut(),
429 0,
430 stream_ptr,
431 )
432 },
433 (BinaryKind::Add, ElementKind::F64) => unsafe {
434 baracuda_kernels_sys::baracuda_kernels_binary_add_f64_run(
435 numel,
436 a_ptr,
437 b_ptr,
438 y_ptr,
439 core::ptr::null_mut(),
440 0,
441 stream_ptr,
442 )
443 },
444 (BinaryKind::Sub, ElementKind::F16) => unsafe {
445 baracuda_kernels_sys::baracuda_kernels_binary_sub_f16_run(
446 numel,
447 a_ptr,
448 b_ptr,
449 y_ptr,
450 core::ptr::null_mut(),
451 0,
452 stream_ptr,
453 )
454 },
455 (BinaryKind::Sub, ElementKind::Bf16) => unsafe {
456 baracuda_kernels_sys::baracuda_kernels_binary_sub_bf16_run(
457 numel,
458 a_ptr,
459 b_ptr,
460 y_ptr,
461 core::ptr::null_mut(),
462 0,
463 stream_ptr,
464 )
465 },
466 (BinaryKind::Sub, ElementKind::F64) => unsafe {
467 baracuda_kernels_sys::baracuda_kernels_binary_sub_f64_run(
468 numel,
469 a_ptr,
470 b_ptr,
471 y_ptr,
472 core::ptr::null_mut(),
473 0,
474 stream_ptr,
475 )
476 },
477 (BinaryKind::Mul, ElementKind::F16) => unsafe {
478 baracuda_kernels_sys::baracuda_kernels_binary_mul_f16_run(
479 numel,
480 a_ptr,
481 b_ptr,
482 y_ptr,
483 core::ptr::null_mut(),
484 0,
485 stream_ptr,
486 )
487 },
488 (BinaryKind::Mul, ElementKind::Bf16) => unsafe {
489 baracuda_kernels_sys::baracuda_kernels_binary_mul_bf16_run(
490 numel,
491 a_ptr,
492 b_ptr,
493 y_ptr,
494 core::ptr::null_mut(),
495 0,
496 stream_ptr,
497 )
498 },
499 (BinaryKind::Mul, ElementKind::F64) => unsafe {
500 baracuda_kernels_sys::baracuda_kernels_binary_mul_f64_run(
501 numel,
502 a_ptr,
503 b_ptr,
504 y_ptr,
505 core::ptr::null_mut(),
506 0,
507 stream_ptr,
508 )
509 },
510 (BinaryKind::Div, ElementKind::F16) => unsafe {
511 baracuda_kernels_sys::baracuda_kernels_binary_div_f16_run(
512 numel,
513 a_ptr,
514 b_ptr,
515 y_ptr,
516 core::ptr::null_mut(),
517 0,
518 stream_ptr,
519 )
520 },
521 (BinaryKind::Div, ElementKind::Bf16) => unsafe {
522 baracuda_kernels_sys::baracuda_kernels_binary_div_bf16_run(
523 numel,
524 a_ptr,
525 b_ptr,
526 y_ptr,
527 core::ptr::null_mut(),
528 0,
529 stream_ptr,
530 )
531 },
532 (BinaryKind::Div, ElementKind::F64) => unsafe {
533 baracuda_kernels_sys::baracuda_kernels_binary_div_f64_run(
534 numel,
535 a_ptr,
536 b_ptr,
537 y_ptr,
538 core::ptr::null_mut(),
539 0,
540 stream_ptr,
541 )
542 },
543 (BinaryKind::Pow, ElementKind::F32) => unsafe {
544 baracuda_kernels_sys::baracuda_kernels_binary_pow_f32_run(
545 numel, a_ptr, b_ptr, y_ptr,
546 core::ptr::null_mut(), 0, stream_ptr,
547 )
548 },
549 (BinaryKind::Pow, ElementKind::F16) => unsafe {
550 baracuda_kernels_sys::baracuda_kernels_binary_pow_f16_run(
551 numel, a_ptr, b_ptr, y_ptr,
552 core::ptr::null_mut(), 0, stream_ptr,
553 )
554 },
555 (BinaryKind::Pow, ElementKind::Bf16) => unsafe {
556 baracuda_kernels_sys::baracuda_kernels_binary_pow_bf16_run(
557 numel, a_ptr, b_ptr, y_ptr,
558 core::ptr::null_mut(), 0, stream_ptr,
559 )
560 },
561 (BinaryKind::Pow, ElementKind::F64) => unsafe {
562 baracuda_kernels_sys::baracuda_kernels_binary_pow_f64_run(
563 numel, a_ptr, b_ptr, y_ptr,
564 core::ptr::null_mut(), 0, stream_ptr,
565 )
566 },
567 (BinaryKind::Atan2, ElementKind::F32) => unsafe {
568 baracuda_kernels_sys::baracuda_kernels_binary_atan2_f32_run(
569 numel, a_ptr, b_ptr, y_ptr,
570 core::ptr::null_mut(), 0, stream_ptr,
571 )
572 },
573 (BinaryKind::Atan2, ElementKind::F16) => unsafe {
574 baracuda_kernels_sys::baracuda_kernels_binary_atan2_f16_run(
575 numel, a_ptr, b_ptr, y_ptr,
576 core::ptr::null_mut(), 0, stream_ptr,
577 )
578 },
579 (BinaryKind::Atan2, ElementKind::Bf16) => unsafe {
580 baracuda_kernels_sys::baracuda_kernels_binary_atan2_bf16_run(
581 numel, a_ptr, b_ptr, y_ptr,
582 core::ptr::null_mut(), 0, stream_ptr,
583 )
584 },
585 (BinaryKind::Atan2, ElementKind::F64) => unsafe {
586 baracuda_kernels_sys::baracuda_kernels_binary_atan2_f64_run(
587 numel, a_ptr, b_ptr, y_ptr,
588 core::ptr::null_mut(), 0, stream_ptr,
589 )
590 },
591 (BinaryKind::Hypot, ElementKind::F32) => unsafe {
592 baracuda_kernels_sys::baracuda_kernels_binary_hypot_f32_run(
593 numel, a_ptr, b_ptr, y_ptr,
594 core::ptr::null_mut(), 0, stream_ptr,
595 )
596 },
597 (BinaryKind::Hypot, ElementKind::F16) => unsafe {
598 baracuda_kernels_sys::baracuda_kernels_binary_hypot_f16_run(
599 numel, a_ptr, b_ptr, y_ptr,
600 core::ptr::null_mut(), 0, stream_ptr,
601 )
602 },
603 (BinaryKind::Hypot, ElementKind::Bf16) => unsafe {
604 baracuda_kernels_sys::baracuda_kernels_binary_hypot_bf16_run(
605 numel, a_ptr, b_ptr, y_ptr,
606 core::ptr::null_mut(), 0, stream_ptr,
607 )
608 },
609 (BinaryKind::Hypot, ElementKind::F64) => unsafe {
610 baracuda_kernels_sys::baracuda_kernels_binary_hypot_f64_run(
611 numel, a_ptr, b_ptr, y_ptr,
612 core::ptr::null_mut(), 0, stream_ptr,
613 )
614 },
615 (BinaryKind::Copysign, ElementKind::F32) => unsafe {
616 baracuda_kernels_sys::baracuda_kernels_binary_copysign_f32_run(
617 numel, a_ptr, b_ptr, y_ptr,
618 core::ptr::null_mut(), 0, stream_ptr,
619 )
620 },
621 (BinaryKind::Copysign, ElementKind::F16) => unsafe {
622 baracuda_kernels_sys::baracuda_kernels_binary_copysign_f16_run(
623 numel, a_ptr, b_ptr, y_ptr,
624 core::ptr::null_mut(), 0, stream_ptr,
625 )
626 },
627 (BinaryKind::Copysign, ElementKind::Bf16) => unsafe {
628 baracuda_kernels_sys::baracuda_kernels_binary_copysign_bf16_run(
629 numel, a_ptr, b_ptr, y_ptr,
630 core::ptr::null_mut(), 0, stream_ptr,
631 )
632 },
633 (BinaryKind::Copysign, ElementKind::F64) => unsafe {
634 baracuda_kernels_sys::baracuda_kernels_binary_copysign_f64_run(
635 numel, a_ptr, b_ptr, y_ptr,
636 core::ptr::null_mut(), 0, stream_ptr,
637 )
638 },
639 (BinaryKind::Nextafter, ElementKind::F32) => unsafe {
640 baracuda_kernels_sys::baracuda_kernels_binary_nextafter_f32_run(
641 numel, a_ptr, b_ptr, y_ptr,
642 core::ptr::null_mut(), 0, stream_ptr,
643 )
644 },
645 (BinaryKind::Nextafter, ElementKind::F16) => unsafe {
646 baracuda_kernels_sys::baracuda_kernels_binary_nextafter_f16_run(
647 numel, a_ptr, b_ptr, y_ptr,
648 core::ptr::null_mut(), 0, stream_ptr,
649 )
650 },
651 (BinaryKind::Nextafter, ElementKind::Bf16) => unsafe {
652 baracuda_kernels_sys::baracuda_kernels_binary_nextafter_bf16_run(
653 numel, a_ptr, b_ptr, y_ptr,
654 core::ptr::null_mut(), 0, stream_ptr,
655 )
656 },
657 (BinaryKind::Nextafter, ElementKind::F64) => unsafe {
658 baracuda_kernels_sys::baracuda_kernels_binary_nextafter_f64_run(
659 numel, a_ptr, b_ptr, y_ptr,
660 core::ptr::null_mut(), 0, stream_ptr,
661 )
662 },
663 (BinaryKind::Fmin, ElementKind::F32) => unsafe {
664 baracuda_kernels_sys::baracuda_kernels_binary_fmin_f32_run(
665 numel, a_ptr, b_ptr, y_ptr,
666 core::ptr::null_mut(), 0, stream_ptr,
667 )
668 },
669 (BinaryKind::Fmin, ElementKind::F16) => unsafe {
670 baracuda_kernels_sys::baracuda_kernels_binary_fmin_f16_run(
671 numel, a_ptr, b_ptr, y_ptr,
672 core::ptr::null_mut(), 0, stream_ptr,
673 )
674 },
675 (BinaryKind::Fmin, ElementKind::Bf16) => unsafe {
676 baracuda_kernels_sys::baracuda_kernels_binary_fmin_bf16_run(
677 numel, a_ptr, b_ptr, y_ptr,
678 core::ptr::null_mut(), 0, stream_ptr,
679 )
680 },
681 (BinaryKind::Fmin, ElementKind::F64) => unsafe {
682 baracuda_kernels_sys::baracuda_kernels_binary_fmin_f64_run(
683 numel, a_ptr, b_ptr, y_ptr,
684 core::ptr::null_mut(), 0, stream_ptr,
685 )
686 },
687 (BinaryKind::Fmax, ElementKind::F32) => unsafe {
688 baracuda_kernels_sys::baracuda_kernels_binary_fmax_f32_run(
689 numel, a_ptr, b_ptr, y_ptr,
690 core::ptr::null_mut(), 0, stream_ptr,
691 )
692 },
693 (BinaryKind::Fmax, ElementKind::F16) => unsafe {
694 baracuda_kernels_sys::baracuda_kernels_binary_fmax_f16_run(
695 numel, a_ptr, b_ptr, y_ptr,
696 core::ptr::null_mut(), 0, stream_ptr,
697 )
698 },
699 (BinaryKind::Fmax, ElementKind::Bf16) => unsafe {
700 baracuda_kernels_sys::baracuda_kernels_binary_fmax_bf16_run(
701 numel, a_ptr, b_ptr, y_ptr,
702 core::ptr::null_mut(), 0, stream_ptr,
703 )
704 },
705 (BinaryKind::Fmax, ElementKind::F64) => unsafe {
706 baracuda_kernels_sys::baracuda_kernels_binary_fmax_f64_run(
707 numel, a_ptr, b_ptr, y_ptr,
708 core::ptr::null_mut(), 0, stream_ptr,
709 )
710 },
711 (BinaryKind::Maximum, ElementKind::F32) => unsafe {
712 baracuda_kernels_sys::baracuda_kernels_binary_maximum_f32_run(
713 numel, a_ptr, b_ptr, y_ptr,
714 core::ptr::null_mut(), 0, stream_ptr,
715 )
716 },
717 (BinaryKind::Maximum, ElementKind::F16) => unsafe {
718 baracuda_kernels_sys::baracuda_kernels_binary_maximum_f16_run(
719 numel, a_ptr, b_ptr, y_ptr,
720 core::ptr::null_mut(), 0, stream_ptr,
721 )
722 },
723 (BinaryKind::Maximum, ElementKind::Bf16) => unsafe {
724 baracuda_kernels_sys::baracuda_kernels_binary_maximum_bf16_run(
725 numel, a_ptr, b_ptr, y_ptr,
726 core::ptr::null_mut(), 0, stream_ptr,
727 )
728 },
729 (BinaryKind::Maximum, ElementKind::F64) => unsafe {
730 baracuda_kernels_sys::baracuda_kernels_binary_maximum_f64_run(
731 numel, a_ptr, b_ptr, y_ptr,
732 core::ptr::null_mut(), 0, stream_ptr,
733 )
734 },
735 (BinaryKind::Minimum, ElementKind::F32) => unsafe {
736 baracuda_kernels_sys::baracuda_kernels_binary_minimum_f32_run(
737 numel, a_ptr, b_ptr, y_ptr,
738 core::ptr::null_mut(), 0, stream_ptr,
739 )
740 },
741 (BinaryKind::Minimum, ElementKind::F16) => unsafe {
742 baracuda_kernels_sys::baracuda_kernels_binary_minimum_f16_run(
743 numel, a_ptr, b_ptr, y_ptr,
744 core::ptr::null_mut(), 0, stream_ptr,
745 )
746 },
747 (BinaryKind::Minimum, ElementKind::Bf16) => unsafe {
748 baracuda_kernels_sys::baracuda_kernels_binary_minimum_bf16_run(
749 numel, a_ptr, b_ptr, y_ptr,
750 core::ptr::null_mut(), 0, stream_ptr,
751 )
752 },
753 (BinaryKind::Minimum, ElementKind::F64) => unsafe {
754 baracuda_kernels_sys::baracuda_kernels_binary_minimum_f64_run(
755 numel, a_ptr, b_ptr, y_ptr,
756 core::ptr::null_mut(), 0, stream_ptr,
757 )
758 },
759 (BinaryKind::FloorDivide, ElementKind::F32) => unsafe {
760 baracuda_kernels_sys::baracuda_kernels_binary_floor_divide_f32_run(
761 numel, a_ptr, b_ptr, y_ptr,
762 core::ptr::null_mut(), 0, stream_ptr,
763 )
764 },
765 (BinaryKind::FloorDivide, ElementKind::F16) => unsafe {
766 baracuda_kernels_sys::baracuda_kernels_binary_floor_divide_f16_run(
767 numel, a_ptr, b_ptr, y_ptr,
768 core::ptr::null_mut(), 0, stream_ptr,
769 )
770 },
771 (BinaryKind::FloorDivide, ElementKind::Bf16) => unsafe {
772 baracuda_kernels_sys::baracuda_kernels_binary_floor_divide_bf16_run(
773 numel, a_ptr, b_ptr, y_ptr,
774 core::ptr::null_mut(), 0, stream_ptr,
775 )
776 },
777 (BinaryKind::FloorDivide, ElementKind::F64) => unsafe {
778 baracuda_kernels_sys::baracuda_kernels_binary_floor_divide_f64_run(
779 numel, a_ptr, b_ptr, y_ptr,
780 core::ptr::null_mut(), 0, stream_ptr,
781 )
782 },
783 (BinaryKind::Mod, ElementKind::F32) => unsafe {
784 baracuda_kernels_sys::baracuda_kernels_binary_mod_f32_run(
785 numel, a_ptr, b_ptr, y_ptr,
786 core::ptr::null_mut(), 0, stream_ptr,
787 )
788 },
789 (BinaryKind::Mod, ElementKind::F16) => unsafe {
790 baracuda_kernels_sys::baracuda_kernels_binary_mod_f16_run(
791 numel, a_ptr, b_ptr, y_ptr,
792 core::ptr::null_mut(), 0, stream_ptr,
793 )
794 },
795 (BinaryKind::Mod, ElementKind::Bf16) => unsafe {
796 baracuda_kernels_sys::baracuda_kernels_binary_mod_bf16_run(
797 numel, a_ptr, b_ptr, y_ptr,
798 core::ptr::null_mut(), 0, stream_ptr,
799 )
800 },
801 (BinaryKind::Mod, ElementKind::F64) => unsafe {
802 baracuda_kernels_sys::baracuda_kernels_binary_mod_f64_run(
803 numel, a_ptr, b_ptr, y_ptr,
804 core::ptr::null_mut(), 0, stream_ptr,
805 )
806 },
807 (BinaryKind::Remainder, ElementKind::F32) => unsafe {
808 baracuda_kernels_sys::baracuda_kernels_binary_remainder_f32_run(
809 numel, a_ptr, b_ptr, y_ptr,
810 core::ptr::null_mut(), 0, stream_ptr,
811 )
812 },
813 (BinaryKind::Remainder, ElementKind::F16) => unsafe {
814 baracuda_kernels_sys::baracuda_kernels_binary_remainder_f16_run(
815 numel, a_ptr, b_ptr, y_ptr,
816 core::ptr::null_mut(), 0, stream_ptr,
817 )
818 },
819 (BinaryKind::Remainder, ElementKind::Bf16) => unsafe {
820 baracuda_kernels_sys::baracuda_kernels_binary_remainder_bf16_run(
821 numel, a_ptr, b_ptr, y_ptr,
822 core::ptr::null_mut(), 0, stream_ptr,
823 )
824 },
825 (BinaryKind::Remainder, ElementKind::F64) => unsafe {
826 baracuda_kernels_sys::baracuda_kernels_binary_remainder_f64_run(
827 numel, a_ptr, b_ptr, y_ptr,
828 core::ptr::null_mut(), 0, stream_ptr,
829 )
830 },
831 (BinaryKind::BitwiseAnd, ElementKind::I32) => unsafe {
833 baracuda_kernels_sys::baracuda_kernels_binary_bitwise_and_i32_run(
834 numel, a_ptr, b_ptr, y_ptr,
835 core::ptr::null_mut(), 0, stream_ptr,
836 )
837 },
838 (BinaryKind::BitwiseAnd, ElementKind::I64) => unsafe {
839 baracuda_kernels_sys::baracuda_kernels_binary_bitwise_and_i64_run(
840 numel, a_ptr, b_ptr, y_ptr,
841 core::ptr::null_mut(), 0, stream_ptr,
842 )
843 },
844 (BinaryKind::BitwiseOr, ElementKind::I32) => unsafe {
845 baracuda_kernels_sys::baracuda_kernels_binary_bitwise_or_i32_run(
846 numel, a_ptr, b_ptr, y_ptr,
847 core::ptr::null_mut(), 0, stream_ptr,
848 )
849 },
850 (BinaryKind::BitwiseOr, ElementKind::I64) => unsafe {
851 baracuda_kernels_sys::baracuda_kernels_binary_bitwise_or_i64_run(
852 numel, a_ptr, b_ptr, y_ptr,
853 core::ptr::null_mut(), 0, stream_ptr,
854 )
855 },
856 (BinaryKind::BitwiseXor, ElementKind::I32) => unsafe {
857 baracuda_kernels_sys::baracuda_kernels_binary_bitwise_xor_i32_run(
858 numel, a_ptr, b_ptr, y_ptr,
859 core::ptr::null_mut(), 0, stream_ptr,
860 )
861 },
862 (BinaryKind::BitwiseXor, ElementKind::I64) => unsafe {
863 baracuda_kernels_sys::baracuda_kernels_binary_bitwise_xor_i64_run(
864 numel, a_ptr, b_ptr, y_ptr,
865 core::ptr::null_mut(), 0, stream_ptr,
866 )
867 },
868 (BinaryKind::BitwiseLeftShift, ElementKind::I32) => unsafe {
869 baracuda_kernels_sys::baracuda_kernels_binary_bitwise_left_shift_i32_run(
870 numel, a_ptr, b_ptr, y_ptr,
871 core::ptr::null_mut(), 0, stream_ptr,
872 )
873 },
874 (BinaryKind::BitwiseLeftShift, ElementKind::I64) => unsafe {
875 baracuda_kernels_sys::baracuda_kernels_binary_bitwise_left_shift_i64_run(
876 numel, a_ptr, b_ptr, y_ptr,
877 core::ptr::null_mut(), 0, stream_ptr,
878 )
879 },
880 (BinaryKind::BitwiseRightShift, ElementKind::I32) => unsafe {
881 baracuda_kernels_sys::baracuda_kernels_binary_bitwise_right_shift_i32_run(
882 numel, a_ptr, b_ptr, y_ptr,
883 core::ptr::null_mut(), 0, stream_ptr,
884 )
885 },
886 (BinaryKind::BitwiseRightShift, ElementKind::I64) => unsafe {
887 baracuda_kernels_sys::baracuda_kernels_binary_bitwise_right_shift_i64_run(
888 numel, a_ptr, b_ptr, y_ptr,
889 core::ptr::null_mut(), 0, stream_ptr,
890 )
891 },
892 (BinaryKind::LogicalAnd, ElementKind::Bool) => unsafe {
893 baracuda_kernels_sys::baracuda_kernels_binary_logical_and_bool_run(
894 numel, a_ptr, b_ptr, y_ptr,
895 core::ptr::null_mut(), 0, stream_ptr,
896 )
897 },
898 (BinaryKind::LogicalOr, ElementKind::Bool) => unsafe {
899 baracuda_kernels_sys::baracuda_kernels_binary_logical_or_bool_run(
900 numel, a_ptr, b_ptr, y_ptr,
901 core::ptr::null_mut(), 0, stream_ptr,
902 )
903 },
904 (BinaryKind::LogicalXor, ElementKind::Bool) => unsafe {
905 baracuda_kernels_sys::baracuda_kernels_binary_logical_xor_bool_run(
906 numel, a_ptr, b_ptr, y_ptr,
907 core::ptr::null_mut(), 0, stream_ptr,
908 )
909 },
910 _ => {
911 return Err(Error::Unsupported(
912 "baracuda-kernels::BinaryPlan::run reached an unimplemented \
913 (kind, dtype) pair — select() should have caught this",
914 ))
915 }
916 };
917 map_status(status)
918 }
919}
920
921impl<T: Element, const N: usize> BinaryPlan<T, N> {
922 fn run_strided(
930 &self,
931 stream_ptr: *mut c_void,
932 a_ptr: *const c_void,
933 b_ptr: *const c_void,
934 y_ptr: *mut c_void,
935 numel: i64,
936 args: &BinaryArgs<'_, T, N>,
937 ) -> Result<()> {
938 let shape = args.y.shape;
941 let stride_a = args.a.stride;
942 let stride_b = args.b.stride;
943 let stride_y = args.y.stride;
944 let rank = N as i32;
945
946 let status = match (self.desc.kind, T::KIND) {
947 (BinaryKind::Add, ElementKind::F32) => unsafe {
948 baracuda_kernels_sys::baracuda_kernels_binary_add_f32_strided_run(
949 numel,
950 rank,
951 shape.as_ptr(),
952 stride_a.as_ptr(),
953 stride_b.as_ptr(),
954 stride_y.as_ptr(),
955 a_ptr,
956 b_ptr,
957 y_ptr,
958 core::ptr::null_mut(),
959 0,
960 stream_ptr,
961 )
962 },
963 (BinaryKind::Add, ElementKind::F16) => unsafe {
964 baracuda_kernels_sys::baracuda_kernels_binary_add_f16_strided_run(
965 numel,
966 rank,
967 shape.as_ptr(),
968 stride_a.as_ptr(),
969 stride_b.as_ptr(),
970 stride_y.as_ptr(),
971 a_ptr,
972 b_ptr,
973 y_ptr,
974 core::ptr::null_mut(),
975 0,
976 stream_ptr,
977 )
978 },
979 (BinaryKind::Add, ElementKind::Bf16) => unsafe {
980 baracuda_kernels_sys::baracuda_kernels_binary_add_bf16_strided_run(
981 numel,
982 rank,
983 shape.as_ptr(),
984 stride_a.as_ptr(),
985 stride_b.as_ptr(),
986 stride_y.as_ptr(),
987 a_ptr,
988 b_ptr,
989 y_ptr,
990 core::ptr::null_mut(),
991 0,
992 stream_ptr,
993 )
994 },
995 (BinaryKind::Add, ElementKind::F64) => unsafe {
996 baracuda_kernels_sys::baracuda_kernels_binary_add_f64_strided_run(
997 numel,
998 rank,
999 shape.as_ptr(),
1000 stride_a.as_ptr(),
1001 stride_b.as_ptr(),
1002 stride_y.as_ptr(),
1003 a_ptr,
1004 b_ptr,
1005 y_ptr,
1006 core::ptr::null_mut(),
1007 0,
1008 stream_ptr,
1009 )
1010 },
1011 (BinaryKind::Sub, ElementKind::F32) => unsafe {
1012 baracuda_kernels_sys::baracuda_kernels_binary_sub_f32_strided_run(
1013 numel,
1014 rank,
1015 shape.as_ptr(),
1016 stride_a.as_ptr(),
1017 stride_b.as_ptr(),
1018 stride_y.as_ptr(),
1019 a_ptr,
1020 b_ptr,
1021 y_ptr,
1022 core::ptr::null_mut(),
1023 0,
1024 stream_ptr,
1025 )
1026 },
1027 (BinaryKind::Sub, ElementKind::F16) => unsafe {
1028 baracuda_kernels_sys::baracuda_kernels_binary_sub_f16_strided_run(
1029 numel,
1030 rank,
1031 shape.as_ptr(),
1032 stride_a.as_ptr(),
1033 stride_b.as_ptr(),
1034 stride_y.as_ptr(),
1035 a_ptr,
1036 b_ptr,
1037 y_ptr,
1038 core::ptr::null_mut(),
1039 0,
1040 stream_ptr,
1041 )
1042 },
1043 (BinaryKind::Sub, ElementKind::Bf16) => unsafe {
1044 baracuda_kernels_sys::baracuda_kernels_binary_sub_bf16_strided_run(
1045 numel,
1046 rank,
1047 shape.as_ptr(),
1048 stride_a.as_ptr(),
1049 stride_b.as_ptr(),
1050 stride_y.as_ptr(),
1051 a_ptr,
1052 b_ptr,
1053 y_ptr,
1054 core::ptr::null_mut(),
1055 0,
1056 stream_ptr,
1057 )
1058 },
1059 (BinaryKind::Sub, ElementKind::F64) => unsafe {
1060 baracuda_kernels_sys::baracuda_kernels_binary_sub_f64_strided_run(
1061 numel,
1062 rank,
1063 shape.as_ptr(),
1064 stride_a.as_ptr(),
1065 stride_b.as_ptr(),
1066 stride_y.as_ptr(),
1067 a_ptr,
1068 b_ptr,
1069 y_ptr,
1070 core::ptr::null_mut(),
1071 0,
1072 stream_ptr,
1073 )
1074 },
1075 (BinaryKind::Mul, ElementKind::F32) => unsafe {
1076 baracuda_kernels_sys::baracuda_kernels_binary_mul_f32_strided_run(
1077 numel,
1078 rank,
1079 shape.as_ptr(),
1080 stride_a.as_ptr(),
1081 stride_b.as_ptr(),
1082 stride_y.as_ptr(),
1083 a_ptr,
1084 b_ptr,
1085 y_ptr,
1086 core::ptr::null_mut(),
1087 0,
1088 stream_ptr,
1089 )
1090 },
1091 (BinaryKind::Mul, ElementKind::F16) => unsafe {
1092 baracuda_kernels_sys::baracuda_kernels_binary_mul_f16_strided_run(
1093 numel,
1094 rank,
1095 shape.as_ptr(),
1096 stride_a.as_ptr(),
1097 stride_b.as_ptr(),
1098 stride_y.as_ptr(),
1099 a_ptr,
1100 b_ptr,
1101 y_ptr,
1102 core::ptr::null_mut(),
1103 0,
1104 stream_ptr,
1105 )
1106 },
1107 (BinaryKind::Mul, ElementKind::Bf16) => unsafe {
1108 baracuda_kernels_sys::baracuda_kernels_binary_mul_bf16_strided_run(
1109 numel,
1110 rank,
1111 shape.as_ptr(),
1112 stride_a.as_ptr(),
1113 stride_b.as_ptr(),
1114 stride_y.as_ptr(),
1115 a_ptr,
1116 b_ptr,
1117 y_ptr,
1118 core::ptr::null_mut(),
1119 0,
1120 stream_ptr,
1121 )
1122 },
1123 (BinaryKind::Mul, ElementKind::F64) => unsafe {
1124 baracuda_kernels_sys::baracuda_kernels_binary_mul_f64_strided_run(
1125 numel,
1126 rank,
1127 shape.as_ptr(),
1128 stride_a.as_ptr(),
1129 stride_b.as_ptr(),
1130 stride_y.as_ptr(),
1131 a_ptr,
1132 b_ptr,
1133 y_ptr,
1134 core::ptr::null_mut(),
1135 0,
1136 stream_ptr,
1137 )
1138 },
1139 (BinaryKind::Div, ElementKind::F32) => unsafe {
1140 baracuda_kernels_sys::baracuda_kernels_binary_div_f32_strided_run(
1141 numel,
1142 rank,
1143 shape.as_ptr(),
1144 stride_a.as_ptr(),
1145 stride_b.as_ptr(),
1146 stride_y.as_ptr(),
1147 a_ptr,
1148 b_ptr,
1149 y_ptr,
1150 core::ptr::null_mut(),
1151 0,
1152 stream_ptr,
1153 )
1154 },
1155 (BinaryKind::Div, ElementKind::F16) => unsafe {
1156 baracuda_kernels_sys::baracuda_kernels_binary_div_f16_strided_run(
1157 numel,
1158 rank,
1159 shape.as_ptr(),
1160 stride_a.as_ptr(),
1161 stride_b.as_ptr(),
1162 stride_y.as_ptr(),
1163 a_ptr,
1164 b_ptr,
1165 y_ptr,
1166 core::ptr::null_mut(),
1167 0,
1168 stream_ptr,
1169 )
1170 },
1171 (BinaryKind::Div, ElementKind::Bf16) => unsafe {
1172 baracuda_kernels_sys::baracuda_kernels_binary_div_bf16_strided_run(
1173 numel,
1174 rank,
1175 shape.as_ptr(),
1176 stride_a.as_ptr(),
1177 stride_b.as_ptr(),
1178 stride_y.as_ptr(),
1179 a_ptr,
1180 b_ptr,
1181 y_ptr,
1182 core::ptr::null_mut(),
1183 0,
1184 stream_ptr,
1185 )
1186 },
1187 (BinaryKind::Div, ElementKind::F64) => unsafe {
1188 baracuda_kernels_sys::baracuda_kernels_binary_div_f64_strided_run(
1189 numel,
1190 rank,
1191 shape.as_ptr(),
1192 stride_a.as_ptr(),
1193 stride_b.as_ptr(),
1194 stride_y.as_ptr(),
1195 a_ptr,
1196 b_ptr,
1197 y_ptr,
1198 core::ptr::null_mut(),
1199 0,
1200 stream_ptr,
1201 )
1202 },
1203 (BinaryKind::Pow, ElementKind::F32) => unsafe {
1204 baracuda_kernels_sys::baracuda_kernels_binary_pow_f32_strided_run(
1205 numel, rank, shape.as_ptr(),
1206 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1207 a_ptr, b_ptr, y_ptr,
1208 core::ptr::null_mut(), 0, stream_ptr,
1209 )
1210 },
1211 (BinaryKind::Pow, ElementKind::F16) => unsafe {
1212 baracuda_kernels_sys::baracuda_kernels_binary_pow_f16_strided_run(
1213 numel, rank, shape.as_ptr(),
1214 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1215 a_ptr, b_ptr, y_ptr,
1216 core::ptr::null_mut(), 0, stream_ptr,
1217 )
1218 },
1219 (BinaryKind::Pow, ElementKind::Bf16) => unsafe {
1220 baracuda_kernels_sys::baracuda_kernels_binary_pow_bf16_strided_run(
1221 numel, rank, shape.as_ptr(),
1222 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1223 a_ptr, b_ptr, y_ptr,
1224 core::ptr::null_mut(), 0, stream_ptr,
1225 )
1226 },
1227 (BinaryKind::Pow, ElementKind::F64) => unsafe {
1228 baracuda_kernels_sys::baracuda_kernels_binary_pow_f64_strided_run(
1229 numel, rank, shape.as_ptr(),
1230 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1231 a_ptr, b_ptr, y_ptr,
1232 core::ptr::null_mut(), 0, stream_ptr,
1233 )
1234 },
1235 (BinaryKind::Atan2, ElementKind::F32) => unsafe {
1236 baracuda_kernels_sys::baracuda_kernels_binary_atan2_f32_strided_run(
1237 numel, rank, shape.as_ptr(),
1238 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1239 a_ptr, b_ptr, y_ptr,
1240 core::ptr::null_mut(), 0, stream_ptr,
1241 )
1242 },
1243 (BinaryKind::Atan2, ElementKind::F16) => unsafe {
1244 baracuda_kernels_sys::baracuda_kernels_binary_atan2_f16_strided_run(
1245 numel, rank, shape.as_ptr(),
1246 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1247 a_ptr, b_ptr, y_ptr,
1248 core::ptr::null_mut(), 0, stream_ptr,
1249 )
1250 },
1251 (BinaryKind::Atan2, ElementKind::Bf16) => unsafe {
1252 baracuda_kernels_sys::baracuda_kernels_binary_atan2_bf16_strided_run(
1253 numel, rank, shape.as_ptr(),
1254 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1255 a_ptr, b_ptr, y_ptr,
1256 core::ptr::null_mut(), 0, stream_ptr,
1257 )
1258 },
1259 (BinaryKind::Atan2, ElementKind::F64) => unsafe {
1260 baracuda_kernels_sys::baracuda_kernels_binary_atan2_f64_strided_run(
1261 numel, rank, shape.as_ptr(),
1262 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1263 a_ptr, b_ptr, y_ptr,
1264 core::ptr::null_mut(), 0, stream_ptr,
1265 )
1266 },
1267 (BinaryKind::Hypot, ElementKind::F32) => unsafe {
1268 baracuda_kernels_sys::baracuda_kernels_binary_hypot_f32_strided_run(
1269 numel, rank, shape.as_ptr(),
1270 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1271 a_ptr, b_ptr, y_ptr,
1272 core::ptr::null_mut(), 0, stream_ptr,
1273 )
1274 },
1275 (BinaryKind::Hypot, ElementKind::F16) => unsafe {
1276 baracuda_kernels_sys::baracuda_kernels_binary_hypot_f16_strided_run(
1277 numel, rank, shape.as_ptr(),
1278 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1279 a_ptr, b_ptr, y_ptr,
1280 core::ptr::null_mut(), 0, stream_ptr,
1281 )
1282 },
1283 (BinaryKind::Hypot, ElementKind::Bf16) => unsafe {
1284 baracuda_kernels_sys::baracuda_kernels_binary_hypot_bf16_strided_run(
1285 numel, rank, shape.as_ptr(),
1286 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1287 a_ptr, b_ptr, y_ptr,
1288 core::ptr::null_mut(), 0, stream_ptr,
1289 )
1290 },
1291 (BinaryKind::Hypot, ElementKind::F64) => unsafe {
1292 baracuda_kernels_sys::baracuda_kernels_binary_hypot_f64_strided_run(
1293 numel, rank, shape.as_ptr(),
1294 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1295 a_ptr, b_ptr, y_ptr,
1296 core::ptr::null_mut(), 0, stream_ptr,
1297 )
1298 },
1299 (BinaryKind::Copysign, ElementKind::F32) => unsafe {
1300 baracuda_kernels_sys::baracuda_kernels_binary_copysign_f32_strided_run(
1301 numel, rank, shape.as_ptr(),
1302 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1303 a_ptr, b_ptr, y_ptr,
1304 core::ptr::null_mut(), 0, stream_ptr,
1305 )
1306 },
1307 (BinaryKind::Copysign, ElementKind::F16) => unsafe {
1308 baracuda_kernels_sys::baracuda_kernels_binary_copysign_f16_strided_run(
1309 numel, rank, shape.as_ptr(),
1310 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1311 a_ptr, b_ptr, y_ptr,
1312 core::ptr::null_mut(), 0, stream_ptr,
1313 )
1314 },
1315 (BinaryKind::Copysign, ElementKind::Bf16) => unsafe {
1316 baracuda_kernels_sys::baracuda_kernels_binary_copysign_bf16_strided_run(
1317 numel, rank, shape.as_ptr(),
1318 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1319 a_ptr, b_ptr, y_ptr,
1320 core::ptr::null_mut(), 0, stream_ptr,
1321 )
1322 },
1323 (BinaryKind::Copysign, ElementKind::F64) => unsafe {
1324 baracuda_kernels_sys::baracuda_kernels_binary_copysign_f64_strided_run(
1325 numel, rank, shape.as_ptr(),
1326 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1327 a_ptr, b_ptr, y_ptr,
1328 core::ptr::null_mut(), 0, stream_ptr,
1329 )
1330 },
1331 (BinaryKind::Nextafter, ElementKind::F32) => unsafe {
1332 baracuda_kernels_sys::baracuda_kernels_binary_nextafter_f32_strided_run(
1333 numel, rank, shape.as_ptr(),
1334 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1335 a_ptr, b_ptr, y_ptr,
1336 core::ptr::null_mut(), 0, stream_ptr,
1337 )
1338 },
1339 (BinaryKind::Nextafter, ElementKind::F16) => unsafe {
1340 baracuda_kernels_sys::baracuda_kernels_binary_nextafter_f16_strided_run(
1341 numel, rank, shape.as_ptr(),
1342 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1343 a_ptr, b_ptr, y_ptr,
1344 core::ptr::null_mut(), 0, stream_ptr,
1345 )
1346 },
1347 (BinaryKind::Nextafter, ElementKind::Bf16) => unsafe {
1348 baracuda_kernels_sys::baracuda_kernels_binary_nextafter_bf16_strided_run(
1349 numel, rank, shape.as_ptr(),
1350 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1351 a_ptr, b_ptr, y_ptr,
1352 core::ptr::null_mut(), 0, stream_ptr,
1353 )
1354 },
1355 (BinaryKind::Nextafter, ElementKind::F64) => unsafe {
1356 baracuda_kernels_sys::baracuda_kernels_binary_nextafter_f64_strided_run(
1357 numel, rank, shape.as_ptr(),
1358 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1359 a_ptr, b_ptr, y_ptr,
1360 core::ptr::null_mut(), 0, stream_ptr,
1361 )
1362 },
1363 (BinaryKind::Fmin, ElementKind::F32) => unsafe {
1364 baracuda_kernels_sys::baracuda_kernels_binary_fmin_f32_strided_run(
1365 numel, rank, shape.as_ptr(),
1366 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1367 a_ptr, b_ptr, y_ptr,
1368 core::ptr::null_mut(), 0, stream_ptr,
1369 )
1370 },
1371 (BinaryKind::Fmin, ElementKind::F16) => unsafe {
1372 baracuda_kernels_sys::baracuda_kernels_binary_fmin_f16_strided_run(
1373 numel, rank, shape.as_ptr(),
1374 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1375 a_ptr, b_ptr, y_ptr,
1376 core::ptr::null_mut(), 0, stream_ptr,
1377 )
1378 },
1379 (BinaryKind::Fmin, ElementKind::Bf16) => unsafe {
1380 baracuda_kernels_sys::baracuda_kernels_binary_fmin_bf16_strided_run(
1381 numel, rank, shape.as_ptr(),
1382 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1383 a_ptr, b_ptr, y_ptr,
1384 core::ptr::null_mut(), 0, stream_ptr,
1385 )
1386 },
1387 (BinaryKind::Fmin, ElementKind::F64) => unsafe {
1388 baracuda_kernels_sys::baracuda_kernels_binary_fmin_f64_strided_run(
1389 numel, rank, shape.as_ptr(),
1390 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1391 a_ptr, b_ptr, y_ptr,
1392 core::ptr::null_mut(), 0, stream_ptr,
1393 )
1394 },
1395 (BinaryKind::Fmax, ElementKind::F32) => unsafe {
1396 baracuda_kernels_sys::baracuda_kernels_binary_fmax_f32_strided_run(
1397 numel, rank, shape.as_ptr(),
1398 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1399 a_ptr, b_ptr, y_ptr,
1400 core::ptr::null_mut(), 0, stream_ptr,
1401 )
1402 },
1403 (BinaryKind::Fmax, ElementKind::F16) => unsafe {
1404 baracuda_kernels_sys::baracuda_kernels_binary_fmax_f16_strided_run(
1405 numel, rank, shape.as_ptr(),
1406 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1407 a_ptr, b_ptr, y_ptr,
1408 core::ptr::null_mut(), 0, stream_ptr,
1409 )
1410 },
1411 (BinaryKind::Fmax, ElementKind::Bf16) => unsafe {
1412 baracuda_kernels_sys::baracuda_kernels_binary_fmax_bf16_strided_run(
1413 numel, rank, shape.as_ptr(),
1414 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1415 a_ptr, b_ptr, y_ptr,
1416 core::ptr::null_mut(), 0, stream_ptr,
1417 )
1418 },
1419 (BinaryKind::Fmax, ElementKind::F64) => unsafe {
1420 baracuda_kernels_sys::baracuda_kernels_binary_fmax_f64_strided_run(
1421 numel, rank, shape.as_ptr(),
1422 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1423 a_ptr, b_ptr, y_ptr,
1424 core::ptr::null_mut(), 0, stream_ptr,
1425 )
1426 },
1427 (BinaryKind::Maximum, ElementKind::F32) => unsafe {
1428 baracuda_kernels_sys::baracuda_kernels_binary_maximum_f32_strided_run(
1429 numel, rank, shape.as_ptr(),
1430 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1431 a_ptr, b_ptr, y_ptr,
1432 core::ptr::null_mut(), 0, stream_ptr,
1433 )
1434 },
1435 (BinaryKind::Maximum, ElementKind::F16) => unsafe {
1436 baracuda_kernels_sys::baracuda_kernels_binary_maximum_f16_strided_run(
1437 numel, rank, shape.as_ptr(),
1438 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1439 a_ptr, b_ptr, y_ptr,
1440 core::ptr::null_mut(), 0, stream_ptr,
1441 )
1442 },
1443 (BinaryKind::Maximum, ElementKind::Bf16) => unsafe {
1444 baracuda_kernels_sys::baracuda_kernels_binary_maximum_bf16_strided_run(
1445 numel, rank, shape.as_ptr(),
1446 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1447 a_ptr, b_ptr, y_ptr,
1448 core::ptr::null_mut(), 0, stream_ptr,
1449 )
1450 },
1451 (BinaryKind::Maximum, ElementKind::F64) => unsafe {
1452 baracuda_kernels_sys::baracuda_kernels_binary_maximum_f64_strided_run(
1453 numel, rank, shape.as_ptr(),
1454 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1455 a_ptr, b_ptr, y_ptr,
1456 core::ptr::null_mut(), 0, stream_ptr,
1457 )
1458 },
1459 (BinaryKind::Minimum, ElementKind::F32) => unsafe {
1460 baracuda_kernels_sys::baracuda_kernels_binary_minimum_f32_strided_run(
1461 numel, rank, shape.as_ptr(),
1462 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1463 a_ptr, b_ptr, y_ptr,
1464 core::ptr::null_mut(), 0, stream_ptr,
1465 )
1466 },
1467 (BinaryKind::Minimum, ElementKind::F16) => unsafe {
1468 baracuda_kernels_sys::baracuda_kernels_binary_minimum_f16_strided_run(
1469 numel, rank, shape.as_ptr(),
1470 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1471 a_ptr, b_ptr, y_ptr,
1472 core::ptr::null_mut(), 0, stream_ptr,
1473 )
1474 },
1475 (BinaryKind::Minimum, ElementKind::Bf16) => unsafe {
1476 baracuda_kernels_sys::baracuda_kernels_binary_minimum_bf16_strided_run(
1477 numel, rank, shape.as_ptr(),
1478 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1479 a_ptr, b_ptr, y_ptr,
1480 core::ptr::null_mut(), 0, stream_ptr,
1481 )
1482 },
1483 (BinaryKind::Minimum, ElementKind::F64) => unsafe {
1484 baracuda_kernels_sys::baracuda_kernels_binary_minimum_f64_strided_run(
1485 numel, rank, shape.as_ptr(),
1486 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1487 a_ptr, b_ptr, y_ptr,
1488 core::ptr::null_mut(), 0, stream_ptr,
1489 )
1490 },
1491 (BinaryKind::FloorDivide, ElementKind::F32) => unsafe {
1492 baracuda_kernels_sys::baracuda_kernels_binary_floor_divide_f32_strided_run(
1493 numel, rank, shape.as_ptr(),
1494 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1495 a_ptr, b_ptr, y_ptr,
1496 core::ptr::null_mut(), 0, stream_ptr,
1497 )
1498 },
1499 (BinaryKind::FloorDivide, ElementKind::F16) => unsafe {
1500 baracuda_kernels_sys::baracuda_kernels_binary_floor_divide_f16_strided_run(
1501 numel, rank, shape.as_ptr(),
1502 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1503 a_ptr, b_ptr, y_ptr,
1504 core::ptr::null_mut(), 0, stream_ptr,
1505 )
1506 },
1507 (BinaryKind::FloorDivide, ElementKind::Bf16) => unsafe {
1508 baracuda_kernels_sys::baracuda_kernels_binary_floor_divide_bf16_strided_run(
1509 numel, rank, shape.as_ptr(),
1510 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1511 a_ptr, b_ptr, y_ptr,
1512 core::ptr::null_mut(), 0, stream_ptr,
1513 )
1514 },
1515 (BinaryKind::FloorDivide, ElementKind::F64) => unsafe {
1516 baracuda_kernels_sys::baracuda_kernels_binary_floor_divide_f64_strided_run(
1517 numel, rank, shape.as_ptr(),
1518 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1519 a_ptr, b_ptr, y_ptr,
1520 core::ptr::null_mut(), 0, stream_ptr,
1521 )
1522 },
1523 (BinaryKind::Mod, ElementKind::F32) => unsafe {
1524 baracuda_kernels_sys::baracuda_kernels_binary_mod_f32_strided_run(
1525 numel, rank, shape.as_ptr(),
1526 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1527 a_ptr, b_ptr, y_ptr,
1528 core::ptr::null_mut(), 0, stream_ptr,
1529 )
1530 },
1531 (BinaryKind::Mod, ElementKind::F16) => unsafe {
1532 baracuda_kernels_sys::baracuda_kernels_binary_mod_f16_strided_run(
1533 numel, rank, shape.as_ptr(),
1534 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1535 a_ptr, b_ptr, y_ptr,
1536 core::ptr::null_mut(), 0, stream_ptr,
1537 )
1538 },
1539 (BinaryKind::Mod, ElementKind::Bf16) => unsafe {
1540 baracuda_kernels_sys::baracuda_kernels_binary_mod_bf16_strided_run(
1541 numel, rank, shape.as_ptr(),
1542 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1543 a_ptr, b_ptr, y_ptr,
1544 core::ptr::null_mut(), 0, stream_ptr,
1545 )
1546 },
1547 (BinaryKind::Mod, ElementKind::F64) => unsafe {
1548 baracuda_kernels_sys::baracuda_kernels_binary_mod_f64_strided_run(
1549 numel, rank, shape.as_ptr(),
1550 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1551 a_ptr, b_ptr, y_ptr,
1552 core::ptr::null_mut(), 0, stream_ptr,
1553 )
1554 },
1555 (BinaryKind::Remainder, ElementKind::F32) => unsafe {
1556 baracuda_kernels_sys::baracuda_kernels_binary_remainder_f32_strided_run(
1557 numel, rank, shape.as_ptr(),
1558 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1559 a_ptr, b_ptr, y_ptr,
1560 core::ptr::null_mut(), 0, stream_ptr,
1561 )
1562 },
1563 (BinaryKind::Remainder, ElementKind::F16) => unsafe {
1564 baracuda_kernels_sys::baracuda_kernels_binary_remainder_f16_strided_run(
1565 numel, rank, shape.as_ptr(),
1566 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1567 a_ptr, b_ptr, y_ptr,
1568 core::ptr::null_mut(), 0, stream_ptr,
1569 )
1570 },
1571 (BinaryKind::Remainder, ElementKind::Bf16) => unsafe {
1572 baracuda_kernels_sys::baracuda_kernels_binary_remainder_bf16_strided_run(
1573 numel, rank, shape.as_ptr(),
1574 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1575 a_ptr, b_ptr, y_ptr,
1576 core::ptr::null_mut(), 0, stream_ptr,
1577 )
1578 },
1579 (BinaryKind::Remainder, ElementKind::F64) => unsafe {
1580 baracuda_kernels_sys::baracuda_kernels_binary_remainder_f64_strided_run(
1581 numel, rank, shape.as_ptr(),
1582 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1583 a_ptr, b_ptr, y_ptr,
1584 core::ptr::null_mut(), 0, stream_ptr,
1585 )
1586 },
1587 _ => {
1588 return Err(Error::Unsupported(
1589 "baracuda-kernels::BinaryPlan::run_strided reached an \
1590 unimplemented (kind, dtype) pair — select() should \
1591 have caught this",
1592 ));
1593 }
1594 };
1595 map_status(status)
1596 }
1597}
1598
1599fn map_status(code: i32) -> Result<()> {
1600 match code {
1601 0 => Ok(()),
1602 1 => Err(Error::MisalignedOperand),
1603 2 => Err(Error::InvalidProblem(
1604 "baracuda-kernels-sys reported invalid problem",
1605 )),
1606 3 => Err(Error::Unsupported(
1607 "baracuda-kernels-sys reported unsupported configuration",
1608 )),
1609 4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
1610 n => Err(Error::CutlassInternal(n)),
1611 }
1612}