Skip to main content

baracuda_kernels/elementwise/
binary.rs

1//! Binary elementwise plan.
2//!
3//! Phase 3 trailblazer surface for the baracuda-kernels elementwise op
4//! family (category C from the comprehensive plan). Mirrors the shape
5//! of [`crate::IntGemmPlan`] (descriptor + args + select/can_implement/
6//! run/sku/precision_guarantee) but for arbitrary-rank tensors with no
7//! GEMM-style accumulator / epilogue chain.
8//!
9//! Today only the `Add` op on `f32` over fully-contiguous tensors of
10//! matching shape is wired — this is the Phase 3 trailblazer SKU. Other
11//! binary ops ([`BinaryKind::Sub`], `Mul`, `Div`, …) and other dtypes
12//! (`f16`, `bf16`, `f64`, integer family) join in fanout sessions; the
13//! `Add` instantiation in `baracuda-kernels-sys` is the template
14//! pattern they follow.
15//!
16//! Broadcasting is supported: operands with `stride[d] = 0` on a
17//! broadcast axis route through a strided kernel path that handles
18//! arbitrary per-axis stride (broadcast, transposed views, arbitrary
19//! strided slices). The dispatcher picks contig vs strided at run
20//! time based on `is_contiguous()` of all three operands.
21
22use core::ffi::c_void;
23use core::marker::PhantomData;
24
25use baracuda_cutlass::{Error, Result};
26use baracuda_driver::Stream;
27use baracuda_kernels_types::{
28    ArchSku, BackendKind, BinaryKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
29    PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
30};
31
32/// Descriptor for a binary elementwise op.
33///
34/// `shape` describes the **output** tensor shape (== both input shapes
35/// after the caller-side rank-normalization convention — see the crate
36/// docs for the broadcasting contract). `element` must match `T::KIND`
37/// at `select` time.
38#[derive(Copy, Clone, Debug)]
39pub struct BinaryDescriptor<const N: usize> {
40    /// Which binary op to apply.
41    pub kind: BinaryKind,
42    /// Output tensor shape (`= a.shape = b.shape` for the contig case).
43    pub shape: [i32; N],
44    /// Primary element type. Must match the type parameter `T` of the
45    /// containing plan.
46    pub element: ElementKind,
47}
48
49/// Args bundle for a binary elementwise launch.
50///
51/// Lifetime `'a` and rank `N` are shared across all three tensors; the
52/// element type `T` is shared too (heterogeneous-dtype ops like
53/// `compare(f32, f32) -> bool` use a future `BinaryWithOutDtypePlan`,
54/// not this one).
55pub struct BinaryArgs<'a, T: Element, const N: usize> {
56    /// First input.
57    pub a: TensorRef<'a, T, N>,
58    /// Second input.
59    pub b: TensorRef<'a, T, N>,
60    /// Output. Aliasing with either input is allowed (in-place add).
61    pub y: TensorMut<'a, T, N>,
62}
63
64/// Binary elementwise plan.
65///
66/// `T: Element` is the kernel's element type (today: must be `f32`).
67/// `const N: usize` is the tensor rank — fixed at compile time to keep
68/// the descriptor heap-free and the rank invariants type-checked.
69pub struct BinaryPlan<T: Element, const N: usize> {
70    desc: BinaryDescriptor<N>,
71    sku: KernelSku,
72    _marker: PhantomData<T>,
73}
74
75impl<T: Element, const N: usize> BinaryPlan<T, N> {
76    /// Pick a kernel for `desc`. Returns [`Error::Unsupported`] if the
77    /// `(kind, T::KIND)` pair isn't wired today.
78    pub fn select(
79        _stream: &Stream,
80        desc: &BinaryDescriptor<N>,
81        _pref: PlanPreference,
82    ) -> Result<Self> {
83        if desc.element != T::KIND {
84            return Err(Error::Unsupported(
85                "baracuda-kernels::BinaryPlan: descriptor element != type parameter T",
86            ));
87        }
88        for &d in desc.shape.iter() {
89            if d < 0 {
90                return Err(Error::InvalidProblem(
91                    "baracuda-kernels::BinaryPlan: shape dims must be non-negative",
92                ));
93            }
94        }
95
96        // Trailblazer + op-fanout + dtype-fanout matrix: {Add, Sub, Mul,
97        // Div, Pow, Atan2, Hypot} × {F32, F16, Bf16, F64}. Other (kind,
98        // dtype) combinations are reserved discriminants today (Eq /
99        // comparison family, Lerp, ... and the integer dtype family).
100        //
101        // Lerp is reserved-but-deferred: it takes a scalar `weight: f32`
102        // alongside its two tensor inputs, which doesn't fit this Plan's
103        // `BinaryArgs<a, b, y>` shape. A parameterized-binary plan shape
104        // (analogous to the ternary `addcmul`/`addcdiv` family) will
105        // host Lerp in a later milestone.
106        let supported = matches!(
107            (desc.kind, T::KIND),
108            (BinaryKind::Add, ElementKind::F32)
109                | (BinaryKind::Add, ElementKind::F16)
110                | (BinaryKind::Add, ElementKind::Bf16)
111                | (BinaryKind::Add, ElementKind::F64)
112                | (BinaryKind::Sub, ElementKind::F32)
113                | (BinaryKind::Sub, ElementKind::F16)
114                | (BinaryKind::Sub, ElementKind::Bf16)
115                | (BinaryKind::Sub, ElementKind::F64)
116                | (BinaryKind::Mul, ElementKind::F32)
117                | (BinaryKind::Mul, ElementKind::F16)
118                | (BinaryKind::Mul, ElementKind::Bf16)
119                | (BinaryKind::Mul, ElementKind::F64)
120                | (BinaryKind::Div, ElementKind::F32)
121                | (BinaryKind::Div, ElementKind::F16)
122                | (BinaryKind::Div, ElementKind::Bf16)
123                | (BinaryKind::Div, ElementKind::F64)
124                | (BinaryKind::Pow, ElementKind::F32)
125                | (BinaryKind::Pow, ElementKind::F16)
126                | (BinaryKind::Pow, ElementKind::Bf16)
127                | (BinaryKind::Pow, ElementKind::F64)
128                | (BinaryKind::Atan2, ElementKind::F32)
129                | (BinaryKind::Atan2, ElementKind::F16)
130                | (BinaryKind::Atan2, ElementKind::Bf16)
131                | (BinaryKind::Atan2, ElementKind::F64)
132                | (BinaryKind::Hypot, ElementKind::F32)
133                | (BinaryKind::Hypot, ElementKind::F16)
134                | (BinaryKind::Hypot, ElementKind::Bf16)
135                | (BinaryKind::Hypot, ElementKind::F64)
136                | (BinaryKind::Copysign, ElementKind::F32)
137                | (BinaryKind::Copysign, ElementKind::F16)
138                | (BinaryKind::Copysign, ElementKind::Bf16)
139                | (BinaryKind::Copysign, ElementKind::F64)
140                | (BinaryKind::Nextafter, ElementKind::F32)
141                | (BinaryKind::Nextafter, ElementKind::F16)
142                | (BinaryKind::Nextafter, ElementKind::Bf16)
143                | (BinaryKind::Nextafter, ElementKind::F64)
144                | (BinaryKind::Fmin, ElementKind::F32)
145                | (BinaryKind::Fmin, ElementKind::F16)
146                | (BinaryKind::Fmin, ElementKind::Bf16)
147                | (BinaryKind::Fmin, ElementKind::F64)
148                | (BinaryKind::Fmax, ElementKind::F32)
149                | (BinaryKind::Fmax, ElementKind::F16)
150                | (BinaryKind::Fmax, ElementKind::Bf16)
151                | (BinaryKind::Fmax, ElementKind::F64)
152                | (BinaryKind::Maximum, ElementKind::F32)
153                | (BinaryKind::Maximum, ElementKind::F16)
154                | (BinaryKind::Maximum, ElementKind::Bf16)
155                | (BinaryKind::Maximum, ElementKind::F64)
156                | (BinaryKind::Minimum, ElementKind::F32)
157                | (BinaryKind::Minimum, ElementKind::F16)
158                | (BinaryKind::Minimum, ElementKind::Bf16)
159                | (BinaryKind::Minimum, ElementKind::F64)
160                | (BinaryKind::FloorDivide, ElementKind::F32)
161                | (BinaryKind::FloorDivide, ElementKind::F16)
162                | (BinaryKind::FloorDivide, ElementKind::Bf16)
163                | (BinaryKind::FloorDivide, ElementKind::F64)
164                | (BinaryKind::Mod, ElementKind::F32)
165                | (BinaryKind::Mod, ElementKind::F16)
166                | (BinaryKind::Mod, ElementKind::Bf16)
167                | (BinaryKind::Mod, ElementKind::F64)
168                | (BinaryKind::Remainder, ElementKind::F32)
169                | (BinaryKind::Remainder, ElementKind::F16)
170                | (BinaryKind::Remainder, ElementKind::Bf16)
171                | (BinaryKind::Remainder, ElementKind::F64)
172                // Phase 3.3 integer + bool fanout. Five bitwise ops
173                // across {i32, i64} + three logical ops across Bool.
174                // Contig only — strided / broadcast deferred.
175                | (BinaryKind::BitwiseAnd, ElementKind::I32)
176                | (BinaryKind::BitwiseAnd, ElementKind::I64)
177                | (BinaryKind::BitwiseOr, ElementKind::I32)
178                | (BinaryKind::BitwiseOr, ElementKind::I64)
179                | (BinaryKind::BitwiseXor, ElementKind::I32)
180                | (BinaryKind::BitwiseXor, ElementKind::I64)
181                | (BinaryKind::BitwiseLeftShift, ElementKind::I32)
182                | (BinaryKind::BitwiseLeftShift, ElementKind::I64)
183                | (BinaryKind::BitwiseRightShift, ElementKind::I32)
184                | (BinaryKind::BitwiseRightShift, ElementKind::I64)
185                | (BinaryKind::LogicalAnd, ElementKind::Bool)
186                | (BinaryKind::LogicalOr, ElementKind::Bool)
187                | (BinaryKind::LogicalXor, ElementKind::Bool)
188        );
189        if !supported {
190            return Err(Error::Unsupported(
191                "baracuda-kernels::BinaryPlan: today only \
192                 `{Add,Sub,Mul,Div,Pow,Atan2,Hypot,Copysign,Nextafter,Fmin,Fmax,\
193                 Maximum,Minimum,FloorDivide,Mod,Remainder}` \
194                 × `{f32, f16, bf16, f64}` + Phase 3.3 integer / bool fanout \
195                 (`{BitwiseAnd,BitwiseOr,BitwiseXor,BitwiseLeftShift,\
196                 BitwiseRightShift}` × `{i32, i64}` and \
197                 `{LogicalAnd,LogicalOr,LogicalXor}` × Bool — contig only); \
198                 other (kind, dtype) pairs land in fanout sessions. Lerp is \
199                 reserved-but-deferred pending a parameterized-binary plan \
200                 shape.",
201            ));
202        }
203
204        // The chosen kernel is arch-agnostic SIMT (CUDA cores, no tensor
205        // cores). PrecisionGuarantee mirrors what `F32Strict` GEMM
206        // reports: full IEEE 754 binary32, bit-stable on the same
207        // hardware, deterministic across runs (no atomic accumulation,
208        // no warp reduction, no random tile schedule).
209        let precision_guarantee = PrecisionGuarantee {
210            math_precision: MathPrecision::F32,
211            accumulator: ElementKind::F32,
212            bit_stable_on_same_hardware: true,
213            deterministic: true,
214        };
215        let sku = KernelSku {
216            category: OpCategory::BinaryElementwise,
217            op: desc.kind as u16,
218            element: T::KIND,
219            aux_element: None,
220            layout: None,
221            epilogue: None,
222            arch: ArchSku::Sm80,
223            backend: BackendKind::Bespoke,
224            precision_guarantee,
225        };
226        Ok(Self {
227            desc: *desc,
228            sku,
229            _marker: PhantomData,
230        })
231    }
232
233    /// Validate that this plan can launch with `args`.
234    ///
235    /// Accepts both fully-contiguous operands (which take the contig
236    /// fast path) and broadcast / strided operands (which take the
237    /// strided path). For each axis `d`, each input operand must
238    /// satisfy `shape[d] == y.shape[d]` (no broadcast on that axis) or
239    /// `shape[d] == 1 && stride[d] == 0` (broadcast on that axis). The
240    /// output must be exactly `desc.shape` and is conventionally
241    /// contiguous, though the strided kernel accepts arbitrary `y`
242    /// strides too.
243    pub fn can_implement(&self, args: &BinaryArgs<'_, T, N>) -> Result<()> {
244        // Output must match the descriptor exactly. No broadcast on the
245        // output side — `y` is the destination of the broadcast, not
246        // a participant in it.
247        if args.y.shape != self.desc.shape {
248            return Err(Error::InvalidProblem(
249                "baracuda-kernels::BinaryPlan: Y shape mismatch with descriptor",
250            ));
251        }
252
253        // Per-axis broadcast compatibility check. An input axis must
254        // either match the output's axis exactly, or be 1 with stride 0.
255        for d in 0..N {
256            let y_dim = self.desc.shape[d];
257            let a_dim = args.a.shape[d];
258            let b_dim = args.b.shape[d];
259            if a_dim != y_dim && !(a_dim == 1 && args.a.stride[d] == 0) {
260                return Err(Error::InvalidProblem(
261                    "baracuda-kernels::BinaryPlan: A axis is not broadcast-compatible \
262                     with output (require shape[d] == y.shape[d], OR \
263                     shape[d] == 1 AND stride[d] == 0)",
264                ));
265            }
266            if b_dim != y_dim && !(b_dim == 1 && args.b.stride[d] == 0) {
267                return Err(Error::InvalidProblem(
268                    "baracuda-kernels::BinaryPlan: B axis is not broadcast-compatible \
269                     with output",
270                ));
271            }
272        }
273
274        // Strided kernel handles up to MAX_RANK = 8 axes. Reject
275        // larger ranks here so callers see a clean Unsupported instead
276        // of silent truncation in the kernel.
277        if N > 8 {
278            return Err(Error::Unsupported(
279                "baracuda-kernels::BinaryPlan: tensor rank > 8 not supported \
280                 (kernel param block fixes MAX_RANK = 8)",
281            ));
282        }
283
284        // Buffer sizing: the output must cover its own numel; each
285        // input must cover at least the largest gmem offset reachable
286        // by its strides. For broadcast (stride 0) the reachable
287        // offset is 0 along that axis — the simplest safe bound is
288        // `numel(input) = product(input.shape)`, treating broadcast
289        // dims as size-1.
290        let y_numel = args.y.numel();
291        let a_numel = args.a.numel();
292        let b_numel = args.b.numel();
293        let a_len = args.a.data.len() as i64;
294        let b_len = args.b.data.len() as i64;
295        let y_len = args.y.data.len() as i64;
296        if y_len < y_numel {
297            return Err(Error::BufferTooSmall {
298                needed: y_numel as usize,
299                got: y_len as usize,
300            });
301        }
302        if a_len < a_numel {
303            return Err(Error::BufferTooSmall {
304                needed: a_numel as usize,
305                got: a_len as usize,
306            });
307        }
308        if b_len < b_numel {
309            return Err(Error::BufferTooSmall {
310                needed: b_numel as usize,
311                got: b_len as usize,
312            });
313        }
314        Ok(())
315    }
316
317    /// Workspace size in bytes. Always `0` for the trailblazer SKU.
318    #[inline]
319    pub fn workspace_size(&self) -> usize {
320        0
321    }
322
323    /// Identity of the kernel this plan picked.
324    #[inline]
325    pub fn sku(&self) -> KernelSku {
326        self.sku
327    }
328
329    /// Numerical guarantees for this plan's kernel.
330    #[inline]
331    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
332        self.sku.precision_guarantee
333    }
334
335    /// Launch.
336    pub fn run(
337        &self,
338        stream: &Stream,
339        _workspace: Workspace<'_>,
340        args: BinaryArgs<'_, T, N>,
341    ) -> Result<()> {
342        self.can_implement(&args)?;
343        let numel = args.y.numel();
344        if numel == 0 {
345            return Ok(());
346        }
347        let a_ptr = args.a.data.as_raw().0 as *const c_void;
348        let b_ptr = args.b.data.as_raw().0 as *const c_void;
349        let y_ptr = args.y.data.as_raw().0 as *mut c_void;
350        let stream_ptr = stream.as_raw() as *mut c_void;
351
352        // Contig fast path requires all three operands to be fully
353        // contiguous AND have identical shape (no broadcast). Any other
354        // case (broadcast, transposed, strided slice) goes through the
355        // strided kernel.
356        let all_contig_same_shape = args.a.shape == args.y.shape
357            && args.b.shape == args.y.shape
358            && args.a.is_contiguous()
359            && args.b.is_contiguous()
360            && args.y.is_contiguous();
361
362        if !all_contig_same_shape {
363            return self.run_strided(stream_ptr, a_ptr, b_ptr, y_ptr, numel, &args);
364        }
365
366        let status = match (self.desc.kind, T::KIND) {
367            (BinaryKind::Add, ElementKind::F32) => unsafe {
368                baracuda_kernels_sys::baracuda_kernels_binary_add_f32_run(
369                    numel,
370                    a_ptr,
371                    b_ptr,
372                    y_ptr,
373                    core::ptr::null_mut(),
374                    0,
375                    stream_ptr,
376                )
377            },
378            (BinaryKind::Sub, ElementKind::F32) => unsafe {
379                baracuda_kernels_sys::baracuda_kernels_binary_sub_f32_run(
380                    numel,
381                    a_ptr,
382                    b_ptr,
383                    y_ptr,
384                    core::ptr::null_mut(),
385                    0,
386                    stream_ptr,
387                )
388            },
389            (BinaryKind::Mul, ElementKind::F32) => unsafe {
390                baracuda_kernels_sys::baracuda_kernels_binary_mul_f32_run(
391                    numel,
392                    a_ptr,
393                    b_ptr,
394                    y_ptr,
395                    core::ptr::null_mut(),
396                    0,
397                    stream_ptr,
398                )
399            },
400            (BinaryKind::Div, ElementKind::F32) => unsafe {
401                baracuda_kernels_sys::baracuda_kernels_binary_div_f32_run(
402                    numel,
403                    a_ptr,
404                    b_ptr,
405                    y_ptr,
406                    core::ptr::null_mut(),
407                    0,
408                    stream_ptr,
409                )
410            },
411            (BinaryKind::Add, ElementKind::F16) => unsafe {
412                baracuda_kernels_sys::baracuda_kernels_binary_add_f16_run(
413                    numel,
414                    a_ptr,
415                    b_ptr,
416                    y_ptr,
417                    core::ptr::null_mut(),
418                    0,
419                    stream_ptr,
420                )
421            },
422            (BinaryKind::Add, ElementKind::Bf16) => unsafe {
423                baracuda_kernels_sys::baracuda_kernels_binary_add_bf16_run(
424                    numel,
425                    a_ptr,
426                    b_ptr,
427                    y_ptr,
428                    core::ptr::null_mut(),
429                    0,
430                    stream_ptr,
431                )
432            },
433            (BinaryKind::Add, ElementKind::F64) => unsafe {
434                baracuda_kernels_sys::baracuda_kernels_binary_add_f64_run(
435                    numel,
436                    a_ptr,
437                    b_ptr,
438                    y_ptr,
439                    core::ptr::null_mut(),
440                    0,
441                    stream_ptr,
442                )
443            },
444            (BinaryKind::Sub, ElementKind::F16) => unsafe {
445                baracuda_kernels_sys::baracuda_kernels_binary_sub_f16_run(
446                    numel,
447                    a_ptr,
448                    b_ptr,
449                    y_ptr,
450                    core::ptr::null_mut(),
451                    0,
452                    stream_ptr,
453                )
454            },
455            (BinaryKind::Sub, ElementKind::Bf16) => unsafe {
456                baracuda_kernels_sys::baracuda_kernels_binary_sub_bf16_run(
457                    numel,
458                    a_ptr,
459                    b_ptr,
460                    y_ptr,
461                    core::ptr::null_mut(),
462                    0,
463                    stream_ptr,
464                )
465            },
466            (BinaryKind::Sub, ElementKind::F64) => unsafe {
467                baracuda_kernels_sys::baracuda_kernels_binary_sub_f64_run(
468                    numel,
469                    a_ptr,
470                    b_ptr,
471                    y_ptr,
472                    core::ptr::null_mut(),
473                    0,
474                    stream_ptr,
475                )
476            },
477            (BinaryKind::Mul, ElementKind::F16) => unsafe {
478                baracuda_kernels_sys::baracuda_kernels_binary_mul_f16_run(
479                    numel,
480                    a_ptr,
481                    b_ptr,
482                    y_ptr,
483                    core::ptr::null_mut(),
484                    0,
485                    stream_ptr,
486                )
487            },
488            (BinaryKind::Mul, ElementKind::Bf16) => unsafe {
489                baracuda_kernels_sys::baracuda_kernels_binary_mul_bf16_run(
490                    numel,
491                    a_ptr,
492                    b_ptr,
493                    y_ptr,
494                    core::ptr::null_mut(),
495                    0,
496                    stream_ptr,
497                )
498            },
499            (BinaryKind::Mul, ElementKind::F64) => unsafe {
500                baracuda_kernels_sys::baracuda_kernels_binary_mul_f64_run(
501                    numel,
502                    a_ptr,
503                    b_ptr,
504                    y_ptr,
505                    core::ptr::null_mut(),
506                    0,
507                    stream_ptr,
508                )
509            },
510            (BinaryKind::Div, ElementKind::F16) => unsafe {
511                baracuda_kernels_sys::baracuda_kernels_binary_div_f16_run(
512                    numel,
513                    a_ptr,
514                    b_ptr,
515                    y_ptr,
516                    core::ptr::null_mut(),
517                    0,
518                    stream_ptr,
519                )
520            },
521            (BinaryKind::Div, ElementKind::Bf16) => unsafe {
522                baracuda_kernels_sys::baracuda_kernels_binary_div_bf16_run(
523                    numel,
524                    a_ptr,
525                    b_ptr,
526                    y_ptr,
527                    core::ptr::null_mut(),
528                    0,
529                    stream_ptr,
530                )
531            },
532            (BinaryKind::Div, ElementKind::F64) => unsafe {
533                baracuda_kernels_sys::baracuda_kernels_binary_div_f64_run(
534                    numel,
535                    a_ptr,
536                    b_ptr,
537                    y_ptr,
538                    core::ptr::null_mut(),
539                    0,
540                    stream_ptr,
541                )
542            },
543            (BinaryKind::Pow, ElementKind::F32) => unsafe {
544                baracuda_kernels_sys::baracuda_kernels_binary_pow_f32_run(
545                    numel, a_ptr, b_ptr, y_ptr,
546                    core::ptr::null_mut(), 0, stream_ptr,
547                )
548            },
549            (BinaryKind::Pow, ElementKind::F16) => unsafe {
550                baracuda_kernels_sys::baracuda_kernels_binary_pow_f16_run(
551                    numel, a_ptr, b_ptr, y_ptr,
552                    core::ptr::null_mut(), 0, stream_ptr,
553                )
554            },
555            (BinaryKind::Pow, ElementKind::Bf16) => unsafe {
556                baracuda_kernels_sys::baracuda_kernels_binary_pow_bf16_run(
557                    numel, a_ptr, b_ptr, y_ptr,
558                    core::ptr::null_mut(), 0, stream_ptr,
559                )
560            },
561            (BinaryKind::Pow, ElementKind::F64) => unsafe {
562                baracuda_kernels_sys::baracuda_kernels_binary_pow_f64_run(
563                    numel, a_ptr, b_ptr, y_ptr,
564                    core::ptr::null_mut(), 0, stream_ptr,
565                )
566            },
567            (BinaryKind::Atan2, ElementKind::F32) => unsafe {
568                baracuda_kernels_sys::baracuda_kernels_binary_atan2_f32_run(
569                    numel, a_ptr, b_ptr, y_ptr,
570                    core::ptr::null_mut(), 0, stream_ptr,
571                )
572            },
573            (BinaryKind::Atan2, ElementKind::F16) => unsafe {
574                baracuda_kernels_sys::baracuda_kernels_binary_atan2_f16_run(
575                    numel, a_ptr, b_ptr, y_ptr,
576                    core::ptr::null_mut(), 0, stream_ptr,
577                )
578            },
579            (BinaryKind::Atan2, ElementKind::Bf16) => unsafe {
580                baracuda_kernels_sys::baracuda_kernels_binary_atan2_bf16_run(
581                    numel, a_ptr, b_ptr, y_ptr,
582                    core::ptr::null_mut(), 0, stream_ptr,
583                )
584            },
585            (BinaryKind::Atan2, ElementKind::F64) => unsafe {
586                baracuda_kernels_sys::baracuda_kernels_binary_atan2_f64_run(
587                    numel, a_ptr, b_ptr, y_ptr,
588                    core::ptr::null_mut(), 0, stream_ptr,
589                )
590            },
591            (BinaryKind::Hypot, ElementKind::F32) => unsafe {
592                baracuda_kernels_sys::baracuda_kernels_binary_hypot_f32_run(
593                    numel, a_ptr, b_ptr, y_ptr,
594                    core::ptr::null_mut(), 0, stream_ptr,
595                )
596            },
597            (BinaryKind::Hypot, ElementKind::F16) => unsafe {
598                baracuda_kernels_sys::baracuda_kernels_binary_hypot_f16_run(
599                    numel, a_ptr, b_ptr, y_ptr,
600                    core::ptr::null_mut(), 0, stream_ptr,
601                )
602            },
603            (BinaryKind::Hypot, ElementKind::Bf16) => unsafe {
604                baracuda_kernels_sys::baracuda_kernels_binary_hypot_bf16_run(
605                    numel, a_ptr, b_ptr, y_ptr,
606                    core::ptr::null_mut(), 0, stream_ptr,
607                )
608            },
609            (BinaryKind::Hypot, ElementKind::F64) => unsafe {
610                baracuda_kernels_sys::baracuda_kernels_binary_hypot_f64_run(
611                    numel, a_ptr, b_ptr, y_ptr,
612                    core::ptr::null_mut(), 0, stream_ptr,
613                )
614            },
615            (BinaryKind::Copysign, ElementKind::F32) => unsafe {
616                baracuda_kernels_sys::baracuda_kernels_binary_copysign_f32_run(
617                    numel, a_ptr, b_ptr, y_ptr,
618                    core::ptr::null_mut(), 0, stream_ptr,
619                )
620            },
621            (BinaryKind::Copysign, ElementKind::F16) => unsafe {
622                baracuda_kernels_sys::baracuda_kernels_binary_copysign_f16_run(
623                    numel, a_ptr, b_ptr, y_ptr,
624                    core::ptr::null_mut(), 0, stream_ptr,
625                )
626            },
627            (BinaryKind::Copysign, ElementKind::Bf16) => unsafe {
628                baracuda_kernels_sys::baracuda_kernels_binary_copysign_bf16_run(
629                    numel, a_ptr, b_ptr, y_ptr,
630                    core::ptr::null_mut(), 0, stream_ptr,
631                )
632            },
633            (BinaryKind::Copysign, ElementKind::F64) => unsafe {
634                baracuda_kernels_sys::baracuda_kernels_binary_copysign_f64_run(
635                    numel, a_ptr, b_ptr, y_ptr,
636                    core::ptr::null_mut(), 0, stream_ptr,
637                )
638            },
639            (BinaryKind::Nextafter, ElementKind::F32) => unsafe {
640                baracuda_kernels_sys::baracuda_kernels_binary_nextafter_f32_run(
641                    numel, a_ptr, b_ptr, y_ptr,
642                    core::ptr::null_mut(), 0, stream_ptr,
643                )
644            },
645            (BinaryKind::Nextafter, ElementKind::F16) => unsafe {
646                baracuda_kernels_sys::baracuda_kernels_binary_nextafter_f16_run(
647                    numel, a_ptr, b_ptr, y_ptr,
648                    core::ptr::null_mut(), 0, stream_ptr,
649                )
650            },
651            (BinaryKind::Nextafter, ElementKind::Bf16) => unsafe {
652                baracuda_kernels_sys::baracuda_kernels_binary_nextafter_bf16_run(
653                    numel, a_ptr, b_ptr, y_ptr,
654                    core::ptr::null_mut(), 0, stream_ptr,
655                )
656            },
657            (BinaryKind::Nextafter, ElementKind::F64) => unsafe {
658                baracuda_kernels_sys::baracuda_kernels_binary_nextafter_f64_run(
659                    numel, a_ptr, b_ptr, y_ptr,
660                    core::ptr::null_mut(), 0, stream_ptr,
661                )
662            },
663            (BinaryKind::Fmin, ElementKind::F32) => unsafe {
664                baracuda_kernels_sys::baracuda_kernels_binary_fmin_f32_run(
665                    numel, a_ptr, b_ptr, y_ptr,
666                    core::ptr::null_mut(), 0, stream_ptr,
667                )
668            },
669            (BinaryKind::Fmin, ElementKind::F16) => unsafe {
670                baracuda_kernels_sys::baracuda_kernels_binary_fmin_f16_run(
671                    numel, a_ptr, b_ptr, y_ptr,
672                    core::ptr::null_mut(), 0, stream_ptr,
673                )
674            },
675            (BinaryKind::Fmin, ElementKind::Bf16) => unsafe {
676                baracuda_kernels_sys::baracuda_kernels_binary_fmin_bf16_run(
677                    numel, a_ptr, b_ptr, y_ptr,
678                    core::ptr::null_mut(), 0, stream_ptr,
679                )
680            },
681            (BinaryKind::Fmin, ElementKind::F64) => unsafe {
682                baracuda_kernels_sys::baracuda_kernels_binary_fmin_f64_run(
683                    numel, a_ptr, b_ptr, y_ptr,
684                    core::ptr::null_mut(), 0, stream_ptr,
685                )
686            },
687            (BinaryKind::Fmax, ElementKind::F32) => unsafe {
688                baracuda_kernels_sys::baracuda_kernels_binary_fmax_f32_run(
689                    numel, a_ptr, b_ptr, y_ptr,
690                    core::ptr::null_mut(), 0, stream_ptr,
691                )
692            },
693            (BinaryKind::Fmax, ElementKind::F16) => unsafe {
694                baracuda_kernels_sys::baracuda_kernels_binary_fmax_f16_run(
695                    numel, a_ptr, b_ptr, y_ptr,
696                    core::ptr::null_mut(), 0, stream_ptr,
697                )
698            },
699            (BinaryKind::Fmax, ElementKind::Bf16) => unsafe {
700                baracuda_kernels_sys::baracuda_kernels_binary_fmax_bf16_run(
701                    numel, a_ptr, b_ptr, y_ptr,
702                    core::ptr::null_mut(), 0, stream_ptr,
703                )
704            },
705            (BinaryKind::Fmax, ElementKind::F64) => unsafe {
706                baracuda_kernels_sys::baracuda_kernels_binary_fmax_f64_run(
707                    numel, a_ptr, b_ptr, y_ptr,
708                    core::ptr::null_mut(), 0, stream_ptr,
709                )
710            },
711            (BinaryKind::Maximum, ElementKind::F32) => unsafe {
712                baracuda_kernels_sys::baracuda_kernels_binary_maximum_f32_run(
713                    numel, a_ptr, b_ptr, y_ptr,
714                    core::ptr::null_mut(), 0, stream_ptr,
715                )
716            },
717            (BinaryKind::Maximum, ElementKind::F16) => unsafe {
718                baracuda_kernels_sys::baracuda_kernels_binary_maximum_f16_run(
719                    numel, a_ptr, b_ptr, y_ptr,
720                    core::ptr::null_mut(), 0, stream_ptr,
721                )
722            },
723            (BinaryKind::Maximum, ElementKind::Bf16) => unsafe {
724                baracuda_kernels_sys::baracuda_kernels_binary_maximum_bf16_run(
725                    numel, a_ptr, b_ptr, y_ptr,
726                    core::ptr::null_mut(), 0, stream_ptr,
727                )
728            },
729            (BinaryKind::Maximum, ElementKind::F64) => unsafe {
730                baracuda_kernels_sys::baracuda_kernels_binary_maximum_f64_run(
731                    numel, a_ptr, b_ptr, y_ptr,
732                    core::ptr::null_mut(), 0, stream_ptr,
733                )
734            },
735            (BinaryKind::Minimum, ElementKind::F32) => unsafe {
736                baracuda_kernels_sys::baracuda_kernels_binary_minimum_f32_run(
737                    numel, a_ptr, b_ptr, y_ptr,
738                    core::ptr::null_mut(), 0, stream_ptr,
739                )
740            },
741            (BinaryKind::Minimum, ElementKind::F16) => unsafe {
742                baracuda_kernels_sys::baracuda_kernels_binary_minimum_f16_run(
743                    numel, a_ptr, b_ptr, y_ptr,
744                    core::ptr::null_mut(), 0, stream_ptr,
745                )
746            },
747            (BinaryKind::Minimum, ElementKind::Bf16) => unsafe {
748                baracuda_kernels_sys::baracuda_kernels_binary_minimum_bf16_run(
749                    numel, a_ptr, b_ptr, y_ptr,
750                    core::ptr::null_mut(), 0, stream_ptr,
751                )
752            },
753            (BinaryKind::Minimum, ElementKind::F64) => unsafe {
754                baracuda_kernels_sys::baracuda_kernels_binary_minimum_f64_run(
755                    numel, a_ptr, b_ptr, y_ptr,
756                    core::ptr::null_mut(), 0, stream_ptr,
757                )
758            },
759            (BinaryKind::FloorDivide, ElementKind::F32) => unsafe {
760                baracuda_kernels_sys::baracuda_kernels_binary_floor_divide_f32_run(
761                    numel, a_ptr, b_ptr, y_ptr,
762                    core::ptr::null_mut(), 0, stream_ptr,
763                )
764            },
765            (BinaryKind::FloorDivide, ElementKind::F16) => unsafe {
766                baracuda_kernels_sys::baracuda_kernels_binary_floor_divide_f16_run(
767                    numel, a_ptr, b_ptr, y_ptr,
768                    core::ptr::null_mut(), 0, stream_ptr,
769                )
770            },
771            (BinaryKind::FloorDivide, ElementKind::Bf16) => unsafe {
772                baracuda_kernels_sys::baracuda_kernels_binary_floor_divide_bf16_run(
773                    numel, a_ptr, b_ptr, y_ptr,
774                    core::ptr::null_mut(), 0, stream_ptr,
775                )
776            },
777            (BinaryKind::FloorDivide, ElementKind::F64) => unsafe {
778                baracuda_kernels_sys::baracuda_kernels_binary_floor_divide_f64_run(
779                    numel, a_ptr, b_ptr, y_ptr,
780                    core::ptr::null_mut(), 0, stream_ptr,
781                )
782            },
783            (BinaryKind::Mod, ElementKind::F32) => unsafe {
784                baracuda_kernels_sys::baracuda_kernels_binary_mod_f32_run(
785                    numel, a_ptr, b_ptr, y_ptr,
786                    core::ptr::null_mut(), 0, stream_ptr,
787                )
788            },
789            (BinaryKind::Mod, ElementKind::F16) => unsafe {
790                baracuda_kernels_sys::baracuda_kernels_binary_mod_f16_run(
791                    numel, a_ptr, b_ptr, y_ptr,
792                    core::ptr::null_mut(), 0, stream_ptr,
793                )
794            },
795            (BinaryKind::Mod, ElementKind::Bf16) => unsafe {
796                baracuda_kernels_sys::baracuda_kernels_binary_mod_bf16_run(
797                    numel, a_ptr, b_ptr, y_ptr,
798                    core::ptr::null_mut(), 0, stream_ptr,
799                )
800            },
801            (BinaryKind::Mod, ElementKind::F64) => unsafe {
802                baracuda_kernels_sys::baracuda_kernels_binary_mod_f64_run(
803                    numel, a_ptr, b_ptr, y_ptr,
804                    core::ptr::null_mut(), 0, stream_ptr,
805                )
806            },
807            (BinaryKind::Remainder, ElementKind::F32) => unsafe {
808                baracuda_kernels_sys::baracuda_kernels_binary_remainder_f32_run(
809                    numel, a_ptr, b_ptr, y_ptr,
810                    core::ptr::null_mut(), 0, stream_ptr,
811                )
812            },
813            (BinaryKind::Remainder, ElementKind::F16) => unsafe {
814                baracuda_kernels_sys::baracuda_kernels_binary_remainder_f16_run(
815                    numel, a_ptr, b_ptr, y_ptr,
816                    core::ptr::null_mut(), 0, stream_ptr,
817                )
818            },
819            (BinaryKind::Remainder, ElementKind::Bf16) => unsafe {
820                baracuda_kernels_sys::baracuda_kernels_binary_remainder_bf16_run(
821                    numel, a_ptr, b_ptr, y_ptr,
822                    core::ptr::null_mut(), 0, stream_ptr,
823                )
824            },
825            (BinaryKind::Remainder, ElementKind::F64) => unsafe {
826                baracuda_kernels_sys::baracuda_kernels_binary_remainder_f64_run(
827                    numel, a_ptr, b_ptr, y_ptr,
828                    core::ptr::null_mut(), 0, stream_ptr,
829                )
830            },
831            // ---- Phase 3.3 integer + bool fanout (contig only) ----
832            (BinaryKind::BitwiseAnd, ElementKind::I32) => unsafe {
833                baracuda_kernels_sys::baracuda_kernels_binary_bitwise_and_i32_run(
834                    numel, a_ptr, b_ptr, y_ptr,
835                    core::ptr::null_mut(), 0, stream_ptr,
836                )
837            },
838            (BinaryKind::BitwiseAnd, ElementKind::I64) => unsafe {
839                baracuda_kernels_sys::baracuda_kernels_binary_bitwise_and_i64_run(
840                    numel, a_ptr, b_ptr, y_ptr,
841                    core::ptr::null_mut(), 0, stream_ptr,
842                )
843            },
844            (BinaryKind::BitwiseOr, ElementKind::I32) => unsafe {
845                baracuda_kernels_sys::baracuda_kernels_binary_bitwise_or_i32_run(
846                    numel, a_ptr, b_ptr, y_ptr,
847                    core::ptr::null_mut(), 0, stream_ptr,
848                )
849            },
850            (BinaryKind::BitwiseOr, ElementKind::I64) => unsafe {
851                baracuda_kernels_sys::baracuda_kernels_binary_bitwise_or_i64_run(
852                    numel, a_ptr, b_ptr, y_ptr,
853                    core::ptr::null_mut(), 0, stream_ptr,
854                )
855            },
856            (BinaryKind::BitwiseXor, ElementKind::I32) => unsafe {
857                baracuda_kernels_sys::baracuda_kernels_binary_bitwise_xor_i32_run(
858                    numel, a_ptr, b_ptr, y_ptr,
859                    core::ptr::null_mut(), 0, stream_ptr,
860                )
861            },
862            (BinaryKind::BitwiseXor, ElementKind::I64) => unsafe {
863                baracuda_kernels_sys::baracuda_kernels_binary_bitwise_xor_i64_run(
864                    numel, a_ptr, b_ptr, y_ptr,
865                    core::ptr::null_mut(), 0, stream_ptr,
866                )
867            },
868            (BinaryKind::BitwiseLeftShift, ElementKind::I32) => unsafe {
869                baracuda_kernels_sys::baracuda_kernels_binary_bitwise_left_shift_i32_run(
870                    numel, a_ptr, b_ptr, y_ptr,
871                    core::ptr::null_mut(), 0, stream_ptr,
872                )
873            },
874            (BinaryKind::BitwiseLeftShift, ElementKind::I64) => unsafe {
875                baracuda_kernels_sys::baracuda_kernels_binary_bitwise_left_shift_i64_run(
876                    numel, a_ptr, b_ptr, y_ptr,
877                    core::ptr::null_mut(), 0, stream_ptr,
878                )
879            },
880            (BinaryKind::BitwiseRightShift, ElementKind::I32) => unsafe {
881                baracuda_kernels_sys::baracuda_kernels_binary_bitwise_right_shift_i32_run(
882                    numel, a_ptr, b_ptr, y_ptr,
883                    core::ptr::null_mut(), 0, stream_ptr,
884                )
885            },
886            (BinaryKind::BitwiseRightShift, ElementKind::I64) => unsafe {
887                baracuda_kernels_sys::baracuda_kernels_binary_bitwise_right_shift_i64_run(
888                    numel, a_ptr, b_ptr, y_ptr,
889                    core::ptr::null_mut(), 0, stream_ptr,
890                )
891            },
892            (BinaryKind::LogicalAnd, ElementKind::Bool) => unsafe {
893                baracuda_kernels_sys::baracuda_kernels_binary_logical_and_bool_run(
894                    numel, a_ptr, b_ptr, y_ptr,
895                    core::ptr::null_mut(), 0, stream_ptr,
896                )
897            },
898            (BinaryKind::LogicalOr, ElementKind::Bool) => unsafe {
899                baracuda_kernels_sys::baracuda_kernels_binary_logical_or_bool_run(
900                    numel, a_ptr, b_ptr, y_ptr,
901                    core::ptr::null_mut(), 0, stream_ptr,
902                )
903            },
904            (BinaryKind::LogicalXor, ElementKind::Bool) => unsafe {
905                baracuda_kernels_sys::baracuda_kernels_binary_logical_xor_bool_run(
906                    numel, a_ptr, b_ptr, y_ptr,
907                    core::ptr::null_mut(), 0, stream_ptr,
908                )
909            },
910            _ => {
911                return Err(Error::Unsupported(
912                    "baracuda-kernels::BinaryPlan::run reached an unimplemented \
913                     (kind, dtype) pair — select() should have caught this",
914                ))
915            }
916        };
917        map_status(status)
918    }
919}
920
921impl<T: Element, const N: usize> BinaryPlan<T, N> {
922    /// Launch the strided / broadcast kernel path.
923    ///
924    /// Called by [`Self::run`] when at least one operand isn't
925    /// contiguous (broadcast, transposed view, arbitrary strided
926    /// slice). The kernel reads each output coord c via the per-operand
927    /// strides — a stride of 0 along axis d collapses that axis to
928    /// element 0, which is the broadcast semantic.
929    fn run_strided(
930        &self,
931        stream_ptr: *mut c_void,
932        a_ptr: *const c_void,
933        b_ptr: *const c_void,
934        y_ptr: *mut c_void,
935        numel: i64,
936        args: &BinaryArgs<'_, T, N>,
937    ) -> Result<()> {
938        // Output shape (== descriptor shape) drives the kernel's coord
939        // loop. Stride arrays come from each operand.
940        let shape = args.y.shape;
941        let stride_a = args.a.stride;
942        let stride_b = args.b.stride;
943        let stride_y = args.y.stride;
944        let rank = N as i32;
945
946        let status = match (self.desc.kind, T::KIND) {
947            (BinaryKind::Add, ElementKind::F32) => unsafe {
948                baracuda_kernels_sys::baracuda_kernels_binary_add_f32_strided_run(
949                    numel,
950                    rank,
951                    shape.as_ptr(),
952                    stride_a.as_ptr(),
953                    stride_b.as_ptr(),
954                    stride_y.as_ptr(),
955                    a_ptr,
956                    b_ptr,
957                    y_ptr,
958                    core::ptr::null_mut(),
959                    0,
960                    stream_ptr,
961                )
962            },
963            (BinaryKind::Add, ElementKind::F16) => unsafe {
964                baracuda_kernels_sys::baracuda_kernels_binary_add_f16_strided_run(
965                    numel,
966                    rank,
967                    shape.as_ptr(),
968                    stride_a.as_ptr(),
969                    stride_b.as_ptr(),
970                    stride_y.as_ptr(),
971                    a_ptr,
972                    b_ptr,
973                    y_ptr,
974                    core::ptr::null_mut(),
975                    0,
976                    stream_ptr,
977                )
978            },
979            (BinaryKind::Add, ElementKind::Bf16) => unsafe {
980                baracuda_kernels_sys::baracuda_kernels_binary_add_bf16_strided_run(
981                    numel,
982                    rank,
983                    shape.as_ptr(),
984                    stride_a.as_ptr(),
985                    stride_b.as_ptr(),
986                    stride_y.as_ptr(),
987                    a_ptr,
988                    b_ptr,
989                    y_ptr,
990                    core::ptr::null_mut(),
991                    0,
992                    stream_ptr,
993                )
994            },
995            (BinaryKind::Add, ElementKind::F64) => unsafe {
996                baracuda_kernels_sys::baracuda_kernels_binary_add_f64_strided_run(
997                    numel,
998                    rank,
999                    shape.as_ptr(),
1000                    stride_a.as_ptr(),
1001                    stride_b.as_ptr(),
1002                    stride_y.as_ptr(),
1003                    a_ptr,
1004                    b_ptr,
1005                    y_ptr,
1006                    core::ptr::null_mut(),
1007                    0,
1008                    stream_ptr,
1009                )
1010            },
1011            (BinaryKind::Sub, ElementKind::F32) => unsafe {
1012                baracuda_kernels_sys::baracuda_kernels_binary_sub_f32_strided_run(
1013                    numel,
1014                    rank,
1015                    shape.as_ptr(),
1016                    stride_a.as_ptr(),
1017                    stride_b.as_ptr(),
1018                    stride_y.as_ptr(),
1019                    a_ptr,
1020                    b_ptr,
1021                    y_ptr,
1022                    core::ptr::null_mut(),
1023                    0,
1024                    stream_ptr,
1025                )
1026            },
1027            (BinaryKind::Sub, ElementKind::F16) => unsafe {
1028                baracuda_kernels_sys::baracuda_kernels_binary_sub_f16_strided_run(
1029                    numel,
1030                    rank,
1031                    shape.as_ptr(),
1032                    stride_a.as_ptr(),
1033                    stride_b.as_ptr(),
1034                    stride_y.as_ptr(),
1035                    a_ptr,
1036                    b_ptr,
1037                    y_ptr,
1038                    core::ptr::null_mut(),
1039                    0,
1040                    stream_ptr,
1041                )
1042            },
1043            (BinaryKind::Sub, ElementKind::Bf16) => unsafe {
1044                baracuda_kernels_sys::baracuda_kernels_binary_sub_bf16_strided_run(
1045                    numel,
1046                    rank,
1047                    shape.as_ptr(),
1048                    stride_a.as_ptr(),
1049                    stride_b.as_ptr(),
1050                    stride_y.as_ptr(),
1051                    a_ptr,
1052                    b_ptr,
1053                    y_ptr,
1054                    core::ptr::null_mut(),
1055                    0,
1056                    stream_ptr,
1057                )
1058            },
1059            (BinaryKind::Sub, ElementKind::F64) => unsafe {
1060                baracuda_kernels_sys::baracuda_kernels_binary_sub_f64_strided_run(
1061                    numel,
1062                    rank,
1063                    shape.as_ptr(),
1064                    stride_a.as_ptr(),
1065                    stride_b.as_ptr(),
1066                    stride_y.as_ptr(),
1067                    a_ptr,
1068                    b_ptr,
1069                    y_ptr,
1070                    core::ptr::null_mut(),
1071                    0,
1072                    stream_ptr,
1073                )
1074            },
1075            (BinaryKind::Mul, ElementKind::F32) => unsafe {
1076                baracuda_kernels_sys::baracuda_kernels_binary_mul_f32_strided_run(
1077                    numel,
1078                    rank,
1079                    shape.as_ptr(),
1080                    stride_a.as_ptr(),
1081                    stride_b.as_ptr(),
1082                    stride_y.as_ptr(),
1083                    a_ptr,
1084                    b_ptr,
1085                    y_ptr,
1086                    core::ptr::null_mut(),
1087                    0,
1088                    stream_ptr,
1089                )
1090            },
1091            (BinaryKind::Mul, ElementKind::F16) => unsafe {
1092                baracuda_kernels_sys::baracuda_kernels_binary_mul_f16_strided_run(
1093                    numel,
1094                    rank,
1095                    shape.as_ptr(),
1096                    stride_a.as_ptr(),
1097                    stride_b.as_ptr(),
1098                    stride_y.as_ptr(),
1099                    a_ptr,
1100                    b_ptr,
1101                    y_ptr,
1102                    core::ptr::null_mut(),
1103                    0,
1104                    stream_ptr,
1105                )
1106            },
1107            (BinaryKind::Mul, ElementKind::Bf16) => unsafe {
1108                baracuda_kernels_sys::baracuda_kernels_binary_mul_bf16_strided_run(
1109                    numel,
1110                    rank,
1111                    shape.as_ptr(),
1112                    stride_a.as_ptr(),
1113                    stride_b.as_ptr(),
1114                    stride_y.as_ptr(),
1115                    a_ptr,
1116                    b_ptr,
1117                    y_ptr,
1118                    core::ptr::null_mut(),
1119                    0,
1120                    stream_ptr,
1121                )
1122            },
1123            (BinaryKind::Mul, ElementKind::F64) => unsafe {
1124                baracuda_kernels_sys::baracuda_kernels_binary_mul_f64_strided_run(
1125                    numel,
1126                    rank,
1127                    shape.as_ptr(),
1128                    stride_a.as_ptr(),
1129                    stride_b.as_ptr(),
1130                    stride_y.as_ptr(),
1131                    a_ptr,
1132                    b_ptr,
1133                    y_ptr,
1134                    core::ptr::null_mut(),
1135                    0,
1136                    stream_ptr,
1137                )
1138            },
1139            (BinaryKind::Div, ElementKind::F32) => unsafe {
1140                baracuda_kernels_sys::baracuda_kernels_binary_div_f32_strided_run(
1141                    numel,
1142                    rank,
1143                    shape.as_ptr(),
1144                    stride_a.as_ptr(),
1145                    stride_b.as_ptr(),
1146                    stride_y.as_ptr(),
1147                    a_ptr,
1148                    b_ptr,
1149                    y_ptr,
1150                    core::ptr::null_mut(),
1151                    0,
1152                    stream_ptr,
1153                )
1154            },
1155            (BinaryKind::Div, ElementKind::F16) => unsafe {
1156                baracuda_kernels_sys::baracuda_kernels_binary_div_f16_strided_run(
1157                    numel,
1158                    rank,
1159                    shape.as_ptr(),
1160                    stride_a.as_ptr(),
1161                    stride_b.as_ptr(),
1162                    stride_y.as_ptr(),
1163                    a_ptr,
1164                    b_ptr,
1165                    y_ptr,
1166                    core::ptr::null_mut(),
1167                    0,
1168                    stream_ptr,
1169                )
1170            },
1171            (BinaryKind::Div, ElementKind::Bf16) => unsafe {
1172                baracuda_kernels_sys::baracuda_kernels_binary_div_bf16_strided_run(
1173                    numel,
1174                    rank,
1175                    shape.as_ptr(),
1176                    stride_a.as_ptr(),
1177                    stride_b.as_ptr(),
1178                    stride_y.as_ptr(),
1179                    a_ptr,
1180                    b_ptr,
1181                    y_ptr,
1182                    core::ptr::null_mut(),
1183                    0,
1184                    stream_ptr,
1185                )
1186            },
1187            (BinaryKind::Div, ElementKind::F64) => unsafe {
1188                baracuda_kernels_sys::baracuda_kernels_binary_div_f64_strided_run(
1189                    numel,
1190                    rank,
1191                    shape.as_ptr(),
1192                    stride_a.as_ptr(),
1193                    stride_b.as_ptr(),
1194                    stride_y.as_ptr(),
1195                    a_ptr,
1196                    b_ptr,
1197                    y_ptr,
1198                    core::ptr::null_mut(),
1199                    0,
1200                    stream_ptr,
1201                )
1202            },
1203            (BinaryKind::Pow, ElementKind::F32) => unsafe {
1204                baracuda_kernels_sys::baracuda_kernels_binary_pow_f32_strided_run(
1205                    numel, rank, shape.as_ptr(),
1206                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1207                    a_ptr, b_ptr, y_ptr,
1208                    core::ptr::null_mut(), 0, stream_ptr,
1209                )
1210            },
1211            (BinaryKind::Pow, ElementKind::F16) => unsafe {
1212                baracuda_kernels_sys::baracuda_kernels_binary_pow_f16_strided_run(
1213                    numel, rank, shape.as_ptr(),
1214                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1215                    a_ptr, b_ptr, y_ptr,
1216                    core::ptr::null_mut(), 0, stream_ptr,
1217                )
1218            },
1219            (BinaryKind::Pow, ElementKind::Bf16) => unsafe {
1220                baracuda_kernels_sys::baracuda_kernels_binary_pow_bf16_strided_run(
1221                    numel, rank, shape.as_ptr(),
1222                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1223                    a_ptr, b_ptr, y_ptr,
1224                    core::ptr::null_mut(), 0, stream_ptr,
1225                )
1226            },
1227            (BinaryKind::Pow, ElementKind::F64) => unsafe {
1228                baracuda_kernels_sys::baracuda_kernels_binary_pow_f64_strided_run(
1229                    numel, rank, shape.as_ptr(),
1230                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1231                    a_ptr, b_ptr, y_ptr,
1232                    core::ptr::null_mut(), 0, stream_ptr,
1233                )
1234            },
1235            (BinaryKind::Atan2, ElementKind::F32) => unsafe {
1236                baracuda_kernels_sys::baracuda_kernels_binary_atan2_f32_strided_run(
1237                    numel, rank, shape.as_ptr(),
1238                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1239                    a_ptr, b_ptr, y_ptr,
1240                    core::ptr::null_mut(), 0, stream_ptr,
1241                )
1242            },
1243            (BinaryKind::Atan2, ElementKind::F16) => unsafe {
1244                baracuda_kernels_sys::baracuda_kernels_binary_atan2_f16_strided_run(
1245                    numel, rank, shape.as_ptr(),
1246                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1247                    a_ptr, b_ptr, y_ptr,
1248                    core::ptr::null_mut(), 0, stream_ptr,
1249                )
1250            },
1251            (BinaryKind::Atan2, ElementKind::Bf16) => unsafe {
1252                baracuda_kernels_sys::baracuda_kernels_binary_atan2_bf16_strided_run(
1253                    numel, rank, shape.as_ptr(),
1254                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1255                    a_ptr, b_ptr, y_ptr,
1256                    core::ptr::null_mut(), 0, stream_ptr,
1257                )
1258            },
1259            (BinaryKind::Atan2, ElementKind::F64) => unsafe {
1260                baracuda_kernels_sys::baracuda_kernels_binary_atan2_f64_strided_run(
1261                    numel, rank, shape.as_ptr(),
1262                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1263                    a_ptr, b_ptr, y_ptr,
1264                    core::ptr::null_mut(), 0, stream_ptr,
1265                )
1266            },
1267            (BinaryKind::Hypot, ElementKind::F32) => unsafe {
1268                baracuda_kernels_sys::baracuda_kernels_binary_hypot_f32_strided_run(
1269                    numel, rank, shape.as_ptr(),
1270                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1271                    a_ptr, b_ptr, y_ptr,
1272                    core::ptr::null_mut(), 0, stream_ptr,
1273                )
1274            },
1275            (BinaryKind::Hypot, ElementKind::F16) => unsafe {
1276                baracuda_kernels_sys::baracuda_kernels_binary_hypot_f16_strided_run(
1277                    numel, rank, shape.as_ptr(),
1278                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1279                    a_ptr, b_ptr, y_ptr,
1280                    core::ptr::null_mut(), 0, stream_ptr,
1281                )
1282            },
1283            (BinaryKind::Hypot, ElementKind::Bf16) => unsafe {
1284                baracuda_kernels_sys::baracuda_kernels_binary_hypot_bf16_strided_run(
1285                    numel, rank, shape.as_ptr(),
1286                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1287                    a_ptr, b_ptr, y_ptr,
1288                    core::ptr::null_mut(), 0, stream_ptr,
1289                )
1290            },
1291            (BinaryKind::Hypot, ElementKind::F64) => unsafe {
1292                baracuda_kernels_sys::baracuda_kernels_binary_hypot_f64_strided_run(
1293                    numel, rank, shape.as_ptr(),
1294                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1295                    a_ptr, b_ptr, y_ptr,
1296                    core::ptr::null_mut(), 0, stream_ptr,
1297                )
1298            },
1299            (BinaryKind::Copysign, ElementKind::F32) => unsafe {
1300                baracuda_kernels_sys::baracuda_kernels_binary_copysign_f32_strided_run(
1301                    numel, rank, shape.as_ptr(),
1302                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1303                    a_ptr, b_ptr, y_ptr,
1304                    core::ptr::null_mut(), 0, stream_ptr,
1305                )
1306            },
1307            (BinaryKind::Copysign, ElementKind::F16) => unsafe {
1308                baracuda_kernels_sys::baracuda_kernels_binary_copysign_f16_strided_run(
1309                    numel, rank, shape.as_ptr(),
1310                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1311                    a_ptr, b_ptr, y_ptr,
1312                    core::ptr::null_mut(), 0, stream_ptr,
1313                )
1314            },
1315            (BinaryKind::Copysign, ElementKind::Bf16) => unsafe {
1316                baracuda_kernels_sys::baracuda_kernels_binary_copysign_bf16_strided_run(
1317                    numel, rank, shape.as_ptr(),
1318                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1319                    a_ptr, b_ptr, y_ptr,
1320                    core::ptr::null_mut(), 0, stream_ptr,
1321                )
1322            },
1323            (BinaryKind::Copysign, ElementKind::F64) => unsafe {
1324                baracuda_kernels_sys::baracuda_kernels_binary_copysign_f64_strided_run(
1325                    numel, rank, shape.as_ptr(),
1326                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1327                    a_ptr, b_ptr, y_ptr,
1328                    core::ptr::null_mut(), 0, stream_ptr,
1329                )
1330            },
1331            (BinaryKind::Nextafter, ElementKind::F32) => unsafe {
1332                baracuda_kernels_sys::baracuda_kernels_binary_nextafter_f32_strided_run(
1333                    numel, rank, shape.as_ptr(),
1334                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1335                    a_ptr, b_ptr, y_ptr,
1336                    core::ptr::null_mut(), 0, stream_ptr,
1337                )
1338            },
1339            (BinaryKind::Nextafter, ElementKind::F16) => unsafe {
1340                baracuda_kernels_sys::baracuda_kernels_binary_nextafter_f16_strided_run(
1341                    numel, rank, shape.as_ptr(),
1342                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1343                    a_ptr, b_ptr, y_ptr,
1344                    core::ptr::null_mut(), 0, stream_ptr,
1345                )
1346            },
1347            (BinaryKind::Nextafter, ElementKind::Bf16) => unsafe {
1348                baracuda_kernels_sys::baracuda_kernels_binary_nextafter_bf16_strided_run(
1349                    numel, rank, shape.as_ptr(),
1350                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1351                    a_ptr, b_ptr, y_ptr,
1352                    core::ptr::null_mut(), 0, stream_ptr,
1353                )
1354            },
1355            (BinaryKind::Nextafter, ElementKind::F64) => unsafe {
1356                baracuda_kernels_sys::baracuda_kernels_binary_nextafter_f64_strided_run(
1357                    numel, rank, shape.as_ptr(),
1358                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1359                    a_ptr, b_ptr, y_ptr,
1360                    core::ptr::null_mut(), 0, stream_ptr,
1361                )
1362            },
1363            (BinaryKind::Fmin, ElementKind::F32) => unsafe {
1364                baracuda_kernels_sys::baracuda_kernels_binary_fmin_f32_strided_run(
1365                    numel, rank, shape.as_ptr(),
1366                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1367                    a_ptr, b_ptr, y_ptr,
1368                    core::ptr::null_mut(), 0, stream_ptr,
1369                )
1370            },
1371            (BinaryKind::Fmin, ElementKind::F16) => unsafe {
1372                baracuda_kernels_sys::baracuda_kernels_binary_fmin_f16_strided_run(
1373                    numel, rank, shape.as_ptr(),
1374                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1375                    a_ptr, b_ptr, y_ptr,
1376                    core::ptr::null_mut(), 0, stream_ptr,
1377                )
1378            },
1379            (BinaryKind::Fmin, ElementKind::Bf16) => unsafe {
1380                baracuda_kernels_sys::baracuda_kernels_binary_fmin_bf16_strided_run(
1381                    numel, rank, shape.as_ptr(),
1382                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1383                    a_ptr, b_ptr, y_ptr,
1384                    core::ptr::null_mut(), 0, stream_ptr,
1385                )
1386            },
1387            (BinaryKind::Fmin, ElementKind::F64) => unsafe {
1388                baracuda_kernels_sys::baracuda_kernels_binary_fmin_f64_strided_run(
1389                    numel, rank, shape.as_ptr(),
1390                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1391                    a_ptr, b_ptr, y_ptr,
1392                    core::ptr::null_mut(), 0, stream_ptr,
1393                )
1394            },
1395            (BinaryKind::Fmax, ElementKind::F32) => unsafe {
1396                baracuda_kernels_sys::baracuda_kernels_binary_fmax_f32_strided_run(
1397                    numel, rank, shape.as_ptr(),
1398                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1399                    a_ptr, b_ptr, y_ptr,
1400                    core::ptr::null_mut(), 0, stream_ptr,
1401                )
1402            },
1403            (BinaryKind::Fmax, ElementKind::F16) => unsafe {
1404                baracuda_kernels_sys::baracuda_kernels_binary_fmax_f16_strided_run(
1405                    numel, rank, shape.as_ptr(),
1406                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1407                    a_ptr, b_ptr, y_ptr,
1408                    core::ptr::null_mut(), 0, stream_ptr,
1409                )
1410            },
1411            (BinaryKind::Fmax, ElementKind::Bf16) => unsafe {
1412                baracuda_kernels_sys::baracuda_kernels_binary_fmax_bf16_strided_run(
1413                    numel, rank, shape.as_ptr(),
1414                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1415                    a_ptr, b_ptr, y_ptr,
1416                    core::ptr::null_mut(), 0, stream_ptr,
1417                )
1418            },
1419            (BinaryKind::Fmax, ElementKind::F64) => unsafe {
1420                baracuda_kernels_sys::baracuda_kernels_binary_fmax_f64_strided_run(
1421                    numel, rank, shape.as_ptr(),
1422                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1423                    a_ptr, b_ptr, y_ptr,
1424                    core::ptr::null_mut(), 0, stream_ptr,
1425                )
1426            },
1427            (BinaryKind::Maximum, ElementKind::F32) => unsafe {
1428                baracuda_kernels_sys::baracuda_kernels_binary_maximum_f32_strided_run(
1429                    numel, rank, shape.as_ptr(),
1430                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1431                    a_ptr, b_ptr, y_ptr,
1432                    core::ptr::null_mut(), 0, stream_ptr,
1433                )
1434            },
1435            (BinaryKind::Maximum, ElementKind::F16) => unsafe {
1436                baracuda_kernels_sys::baracuda_kernels_binary_maximum_f16_strided_run(
1437                    numel, rank, shape.as_ptr(),
1438                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1439                    a_ptr, b_ptr, y_ptr,
1440                    core::ptr::null_mut(), 0, stream_ptr,
1441                )
1442            },
1443            (BinaryKind::Maximum, ElementKind::Bf16) => unsafe {
1444                baracuda_kernels_sys::baracuda_kernels_binary_maximum_bf16_strided_run(
1445                    numel, rank, shape.as_ptr(),
1446                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1447                    a_ptr, b_ptr, y_ptr,
1448                    core::ptr::null_mut(), 0, stream_ptr,
1449                )
1450            },
1451            (BinaryKind::Maximum, ElementKind::F64) => unsafe {
1452                baracuda_kernels_sys::baracuda_kernels_binary_maximum_f64_strided_run(
1453                    numel, rank, shape.as_ptr(),
1454                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1455                    a_ptr, b_ptr, y_ptr,
1456                    core::ptr::null_mut(), 0, stream_ptr,
1457                )
1458            },
1459            (BinaryKind::Minimum, ElementKind::F32) => unsafe {
1460                baracuda_kernels_sys::baracuda_kernels_binary_minimum_f32_strided_run(
1461                    numel, rank, shape.as_ptr(),
1462                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1463                    a_ptr, b_ptr, y_ptr,
1464                    core::ptr::null_mut(), 0, stream_ptr,
1465                )
1466            },
1467            (BinaryKind::Minimum, ElementKind::F16) => unsafe {
1468                baracuda_kernels_sys::baracuda_kernels_binary_minimum_f16_strided_run(
1469                    numel, rank, shape.as_ptr(),
1470                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1471                    a_ptr, b_ptr, y_ptr,
1472                    core::ptr::null_mut(), 0, stream_ptr,
1473                )
1474            },
1475            (BinaryKind::Minimum, ElementKind::Bf16) => unsafe {
1476                baracuda_kernels_sys::baracuda_kernels_binary_minimum_bf16_strided_run(
1477                    numel, rank, shape.as_ptr(),
1478                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1479                    a_ptr, b_ptr, y_ptr,
1480                    core::ptr::null_mut(), 0, stream_ptr,
1481                )
1482            },
1483            (BinaryKind::Minimum, ElementKind::F64) => unsafe {
1484                baracuda_kernels_sys::baracuda_kernels_binary_minimum_f64_strided_run(
1485                    numel, rank, shape.as_ptr(),
1486                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1487                    a_ptr, b_ptr, y_ptr,
1488                    core::ptr::null_mut(), 0, stream_ptr,
1489                )
1490            },
1491            (BinaryKind::FloorDivide, ElementKind::F32) => unsafe {
1492                baracuda_kernels_sys::baracuda_kernels_binary_floor_divide_f32_strided_run(
1493                    numel, rank, shape.as_ptr(),
1494                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1495                    a_ptr, b_ptr, y_ptr,
1496                    core::ptr::null_mut(), 0, stream_ptr,
1497                )
1498            },
1499            (BinaryKind::FloorDivide, ElementKind::F16) => unsafe {
1500                baracuda_kernels_sys::baracuda_kernels_binary_floor_divide_f16_strided_run(
1501                    numel, rank, shape.as_ptr(),
1502                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1503                    a_ptr, b_ptr, y_ptr,
1504                    core::ptr::null_mut(), 0, stream_ptr,
1505                )
1506            },
1507            (BinaryKind::FloorDivide, ElementKind::Bf16) => unsafe {
1508                baracuda_kernels_sys::baracuda_kernels_binary_floor_divide_bf16_strided_run(
1509                    numel, rank, shape.as_ptr(),
1510                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1511                    a_ptr, b_ptr, y_ptr,
1512                    core::ptr::null_mut(), 0, stream_ptr,
1513                )
1514            },
1515            (BinaryKind::FloorDivide, ElementKind::F64) => unsafe {
1516                baracuda_kernels_sys::baracuda_kernels_binary_floor_divide_f64_strided_run(
1517                    numel, rank, shape.as_ptr(),
1518                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1519                    a_ptr, b_ptr, y_ptr,
1520                    core::ptr::null_mut(), 0, stream_ptr,
1521                )
1522            },
1523            (BinaryKind::Mod, ElementKind::F32) => unsafe {
1524                baracuda_kernels_sys::baracuda_kernels_binary_mod_f32_strided_run(
1525                    numel, rank, shape.as_ptr(),
1526                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1527                    a_ptr, b_ptr, y_ptr,
1528                    core::ptr::null_mut(), 0, stream_ptr,
1529                )
1530            },
1531            (BinaryKind::Mod, ElementKind::F16) => unsafe {
1532                baracuda_kernels_sys::baracuda_kernels_binary_mod_f16_strided_run(
1533                    numel, rank, shape.as_ptr(),
1534                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1535                    a_ptr, b_ptr, y_ptr,
1536                    core::ptr::null_mut(), 0, stream_ptr,
1537                )
1538            },
1539            (BinaryKind::Mod, ElementKind::Bf16) => unsafe {
1540                baracuda_kernels_sys::baracuda_kernels_binary_mod_bf16_strided_run(
1541                    numel, rank, shape.as_ptr(),
1542                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1543                    a_ptr, b_ptr, y_ptr,
1544                    core::ptr::null_mut(), 0, stream_ptr,
1545                )
1546            },
1547            (BinaryKind::Mod, ElementKind::F64) => unsafe {
1548                baracuda_kernels_sys::baracuda_kernels_binary_mod_f64_strided_run(
1549                    numel, rank, shape.as_ptr(),
1550                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1551                    a_ptr, b_ptr, y_ptr,
1552                    core::ptr::null_mut(), 0, stream_ptr,
1553                )
1554            },
1555            (BinaryKind::Remainder, ElementKind::F32) => unsafe {
1556                baracuda_kernels_sys::baracuda_kernels_binary_remainder_f32_strided_run(
1557                    numel, rank, shape.as_ptr(),
1558                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1559                    a_ptr, b_ptr, y_ptr,
1560                    core::ptr::null_mut(), 0, stream_ptr,
1561                )
1562            },
1563            (BinaryKind::Remainder, ElementKind::F16) => unsafe {
1564                baracuda_kernels_sys::baracuda_kernels_binary_remainder_f16_strided_run(
1565                    numel, rank, shape.as_ptr(),
1566                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1567                    a_ptr, b_ptr, y_ptr,
1568                    core::ptr::null_mut(), 0, stream_ptr,
1569                )
1570            },
1571            (BinaryKind::Remainder, ElementKind::Bf16) => unsafe {
1572                baracuda_kernels_sys::baracuda_kernels_binary_remainder_bf16_strided_run(
1573                    numel, rank, shape.as_ptr(),
1574                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1575                    a_ptr, b_ptr, y_ptr,
1576                    core::ptr::null_mut(), 0, stream_ptr,
1577                )
1578            },
1579            (BinaryKind::Remainder, ElementKind::F64) => unsafe {
1580                baracuda_kernels_sys::baracuda_kernels_binary_remainder_f64_strided_run(
1581                    numel, rank, shape.as_ptr(),
1582                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
1583                    a_ptr, b_ptr, y_ptr,
1584                    core::ptr::null_mut(), 0, stream_ptr,
1585                )
1586            },
1587            _ => {
1588                return Err(Error::Unsupported(
1589                    "baracuda-kernels::BinaryPlan::run_strided reached an \
1590                     unimplemented (kind, dtype) pair — select() should \
1591                     have caught this",
1592                ));
1593            }
1594        };
1595        map_status(status)
1596    }
1597}
1598
1599fn map_status(code: i32) -> Result<()> {
1600    match code {
1601        0 => Ok(()),
1602        1 => Err(Error::MisalignedOperand),
1603        2 => Err(Error::InvalidProblem(
1604            "baracuda-kernels-sys reported invalid problem",
1605        )),
1606        3 => Err(Error::Unsupported(
1607            "baracuda-kernels-sys reported unsupported configuration",
1608        )),
1609        4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
1610        n => Err(Error::CutlassInternal(n)),
1611    }
1612}