Skip to main content

baracuda_kernels/elementwise/
binary_cmp.rs

1//! Binary comparison plan.
2//!
3//! Sibling of [`crate::BinaryPlan`] for ops where the output dtype
4//! differs from the input dtype: comparisons produce `u8` (PyTorch /
5//! NumPy bool storage convention: 0 = false, 1 = true) regardless of
6//! the input element type.
7//!
8//! Fully wired matrix: {Eq, Ne, Gt, Ge, Lt, Le} × {f32, f16, bf16,
9//! f64} = 24 (kind, dtype) cells, each with both the contig fast path
10//! and the strided / broadcast path (48 launchers total). The
11//! dispatcher's supported check reduces to a straight cross product
12//! `kind_in_scope && dtype_in_scope`; the match arms in
13//! `run` / `run_strided` remain the authoritative dispatch table.
14
15use core::ffi::c_void;
16use core::marker::PhantomData;
17
18use baracuda_cutlass::{Error, Result};
19use baracuda_driver::Stream;
20use baracuda_kernels_types::{
21    ArchSku, BackendKind, BinaryCmpKind, Element, ElementKind, KernelSku, MathPrecision,
22    OpCategory, PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
23};
24
25/// Descriptor for a binary comparison op.
26///
27/// `shape` is the OUTPUT tensor shape. `element` is the INPUT dtype —
28/// the output is always `u8`. `element` must match the type parameter
29/// `T` of the containing plan at `select` time.
30#[derive(Copy, Clone, Debug)]
31pub struct BinaryCmpDescriptor<const N: usize> {
32    /// Which comparison op to apply.
33    pub kind: BinaryCmpKind,
34    /// Output tensor shape.
35    pub shape: [i32; N],
36    /// Input element type (output is always `u8`).
37    pub element: ElementKind,
38}
39
40/// Args bundle for a binary comparison launch.
41///
42/// Inputs are `T`; output is `u8` (0 / 1). Aliasing `y` with `a` or `b`
43/// is unsafe because `y` has a different element size than the inputs;
44/// the kernel does NOT alias-check.
45pub struct BinaryCmpArgs<'a, T: Element, const N: usize> {
46    /// First input.
47    pub a: TensorRef<'a, T, N>,
48    /// Second input.
49    pub b: TensorRef<'a, T, N>,
50    /// Output. `u8` storage: 0 = false, 1 = true.
51    pub y: TensorMut<'a, u8, N>,
52}
53
54/// Binary comparison plan.
55///
56/// `T: Element` is the input element type (today: must be `f32`).
57/// Output is always `u8`. `const N: usize` is the tensor rank.
58pub struct BinaryCmpPlan<T: Element, const N: usize> {
59    desc: BinaryCmpDescriptor<N>,
60    sku: KernelSku,
61    _marker: PhantomData<T>,
62}
63
64impl<T: Element, const N: usize> BinaryCmpPlan<T, N> {
65    /// Pick a kernel for `desc`. Returns [`Error::Unsupported`] if the
66    /// `(kind, T::KIND)` pair isn't wired today.
67    pub fn select(
68        _stream: &Stream,
69        desc: &BinaryCmpDescriptor<N>,
70        _pref: PlanPreference,
71    ) -> Result<Self> {
72        if desc.element != T::KIND {
73            return Err(Error::Unsupported(
74                "baracuda-kernels::BinaryCmpPlan: descriptor element != type parameter T",
75            ));
76        }
77        for &d in desc.shape.iter() {
78            if d < 0 {
79                return Err(Error::InvalidProblem(
80                    "baracuda-kernels::BinaryCmpPlan: shape dims must be non-negative",
81                ));
82            }
83        }
84
85        // Supported matrix: all 6 BinaryCmpKind variants across the
86        // 4 FP dtypes. The match arms in `run` / `run_strided` remain
87        // the authoritative dispatch table; the unreachable `_ =>` arm
88        // catches any future drift if new variants are added upstream.
89        let kind_in_scope = matches!(
90            desc.kind,
91            BinaryCmpKind::Eq
92                | BinaryCmpKind::Ne
93                | BinaryCmpKind::Gt
94                | BinaryCmpKind::Ge
95                | BinaryCmpKind::Lt
96                | BinaryCmpKind::Le
97        );
98        let dtype_in_scope = matches!(
99            T::KIND,
100            ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
101        );
102        let supported = kind_in_scope && dtype_in_scope;
103        if !supported {
104            return Err(Error::Unsupported(
105                "baracuda-kernels::BinaryCmpPlan: this (kind, dtype) cell is not yet \
106                 wired; see the dispatcher's kind / dtype scope for the supported set",
107            ));
108        }
109
110        // Comparisons are bit-stable + deterministic — no math, no
111        // ordering ambiguity. Output dtype is u8 but we tag the SKU's
112        // primary `element` as the INPUT dtype (matches the input
113        // tensor's dtype, drives the kernel selection); `aux_element`
114        // captures the output dtype.
115        let precision_guarantee = PrecisionGuarantee {
116            math_precision: MathPrecision::F32,
117            accumulator: ElementKind::F32,
118            bit_stable_on_same_hardware: true,
119            deterministic: true,
120        };
121        let sku = KernelSku {
122            category: OpCategory::BinaryElementwise,
123            op: desc.kind as u16,
124            element: T::KIND,
125            // u8 output isn't an ElementKind variant today; encode as
126            // None and rely on the kind tag to disambiguate from
127            // same-dtype binary ops.
128            aux_element: None,
129            layout: None,
130            epilogue: None,
131            arch: ArchSku::Sm80,
132            backend: BackendKind::Bespoke,
133            precision_guarantee,
134        };
135        Ok(Self {
136            desc: *desc,
137            sku,
138            _marker: PhantomData,
139        })
140    }
141
142    /// Validate that this plan can launch with `args`.
143    ///
144    /// Accepts contig and strided / broadcast operands. Same broadcast
145    /// rules as [`crate::BinaryPlan::can_implement`].
146    pub fn can_implement(&self, args: &BinaryCmpArgs<'_, T, N>) -> Result<()> {
147        if args.y.shape != self.desc.shape {
148            return Err(Error::InvalidProblem(
149                "baracuda-kernels::BinaryCmpPlan: Y shape mismatch with descriptor",
150            ));
151        }
152
153        for d in 0..N {
154            let y_dim = self.desc.shape[d];
155            let a_dim = args.a.shape[d];
156            let b_dim = args.b.shape[d];
157            if a_dim != y_dim && !(a_dim == 1 && args.a.stride[d] == 0) {
158                return Err(Error::InvalidProblem(
159                    "baracuda-kernels::BinaryCmpPlan: A axis not broadcast-compatible with output",
160                ));
161            }
162            if b_dim != y_dim && !(b_dim == 1 && args.b.stride[d] == 0) {
163                return Err(Error::InvalidProblem(
164                    "baracuda-kernels::BinaryCmpPlan: B axis not broadcast-compatible with output",
165                ));
166            }
167        }
168
169        if N > 8 {
170            return Err(Error::Unsupported(
171                "baracuda-kernels::BinaryCmpPlan: tensor rank > 8 not supported \
172                 (kernel param block fixes MAX_RANK = 8)",
173            ));
174        }
175
176        let y_numel = args.y.numel();
177        let a_numel = args.a.numel();
178        let b_numel = args.b.numel();
179        let a_len = args.a.data.len() as i64;
180        let b_len = args.b.data.len() as i64;
181        let y_len = args.y.data.len() as i64;
182        if y_len < y_numel {
183            return Err(Error::BufferTooSmall {
184                needed: y_numel as usize,
185                got: y_len as usize,
186            });
187        }
188        if a_len < a_numel {
189            return Err(Error::BufferTooSmall {
190                needed: a_numel as usize,
191                got: a_len as usize,
192            });
193        }
194        if b_len < b_numel {
195            return Err(Error::BufferTooSmall {
196                needed: b_numel as usize,
197                got: b_len as usize,
198            });
199        }
200        Ok(())
201    }
202
203    /// Workspace size in bytes. Always `0` for the trailblazer.
204    #[inline]
205    pub fn workspace_size(&self) -> usize {
206        0
207    }
208
209    /// Identity of the kernel this plan picked.
210    #[inline]
211    pub fn sku(&self) -> KernelSku {
212        self.sku
213    }
214
215    /// Numerical guarantees for this plan's kernel.
216    #[inline]
217    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
218        self.sku.precision_guarantee
219    }
220
221    /// Launch.
222    pub fn run(
223        &self,
224        stream: &Stream,
225        _workspace: Workspace<'_>,
226        args: BinaryCmpArgs<'_, T, N>,
227    ) -> Result<()> {
228        self.can_implement(&args)?;
229        let numel = args.y.numel();
230        if numel == 0 {
231            return Ok(());
232        }
233        let a_ptr = args.a.data.as_raw().0 as *const c_void;
234        let b_ptr = args.b.data.as_raw().0 as *const c_void;
235        let y_ptr = args.y.data.as_raw().0 as *mut c_void;
236        let stream_ptr = stream.as_raw() as *mut c_void;
237
238        let all_contig_same_shape = args.a.shape == args.y.shape
239            && args.b.shape == args.y.shape
240            && args.a.is_contiguous()
241            && args.b.is_contiguous()
242            && args.y.is_contiguous();
243
244        if !all_contig_same_shape {
245            return self.run_strided(stream_ptr, a_ptr, b_ptr, y_ptr, numel, &args);
246        }
247
248        let status = match (self.desc.kind, T::KIND) {
249            // --- Eq -----------------------------------------------------
250            (BinaryCmpKind::Eq, ElementKind::F32) => unsafe {
251                baracuda_kernels_sys::baracuda_kernels_binary_cmp_eq_f32_run(
252                    numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
253                )
254            },
255            (BinaryCmpKind::Eq, ElementKind::F16) => unsafe {
256                baracuda_kernels_sys::baracuda_kernels_binary_cmp_eq_f16_run(
257                    numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
258                )
259            },
260            (BinaryCmpKind::Eq, ElementKind::Bf16) => unsafe {
261                baracuda_kernels_sys::baracuda_kernels_binary_cmp_eq_bf16_run(
262                    numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
263                )
264            },
265            (BinaryCmpKind::Eq, ElementKind::F64) => unsafe {
266                baracuda_kernels_sys::baracuda_kernels_binary_cmp_eq_f64_run(
267                    numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
268                )
269            },
270            // --- Ne -----------------------------------------------------
271            (BinaryCmpKind::Ne, ElementKind::F32) => unsafe {
272                baracuda_kernels_sys::baracuda_kernels_binary_cmp_ne_f32_run(
273                    numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
274                )
275            },
276            (BinaryCmpKind::Ne, ElementKind::F16) => unsafe {
277                baracuda_kernels_sys::baracuda_kernels_binary_cmp_ne_f16_run(
278                    numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
279                )
280            },
281            (BinaryCmpKind::Ne, ElementKind::Bf16) => unsafe {
282                baracuda_kernels_sys::baracuda_kernels_binary_cmp_ne_bf16_run(
283                    numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
284                )
285            },
286            (BinaryCmpKind::Ne, ElementKind::F64) => unsafe {
287                baracuda_kernels_sys::baracuda_kernels_binary_cmp_ne_f64_run(
288                    numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
289                )
290            },
291            // --- Gt -----------------------------------------------------
292            (BinaryCmpKind::Gt, ElementKind::F32) => unsafe {
293                baracuda_kernels_sys::baracuda_kernels_binary_cmp_gt_f32_run(
294                    numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
295                )
296            },
297            (BinaryCmpKind::Gt, ElementKind::F16) => unsafe {
298                baracuda_kernels_sys::baracuda_kernels_binary_cmp_gt_f16_run(
299                    numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
300                )
301            },
302            (BinaryCmpKind::Gt, ElementKind::Bf16) => unsafe {
303                baracuda_kernels_sys::baracuda_kernels_binary_cmp_gt_bf16_run(
304                    numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
305                )
306            },
307            (BinaryCmpKind::Gt, ElementKind::F64) => unsafe {
308                baracuda_kernels_sys::baracuda_kernels_binary_cmp_gt_f64_run(
309                    numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
310                )
311            },
312            // --- Ge -----------------------------------------------------
313            (BinaryCmpKind::Ge, ElementKind::F32) => unsafe {
314                baracuda_kernels_sys::baracuda_kernels_binary_cmp_ge_f32_run(
315                    numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
316                )
317            },
318            (BinaryCmpKind::Ge, ElementKind::F16) => unsafe {
319                baracuda_kernels_sys::baracuda_kernels_binary_cmp_ge_f16_run(
320                    numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
321                )
322            },
323            (BinaryCmpKind::Ge, ElementKind::Bf16) => unsafe {
324                baracuda_kernels_sys::baracuda_kernels_binary_cmp_ge_bf16_run(
325                    numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
326                )
327            },
328            (BinaryCmpKind::Ge, ElementKind::F64) => unsafe {
329                baracuda_kernels_sys::baracuda_kernels_binary_cmp_ge_f64_run(
330                    numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
331                )
332            },
333            // --- Lt -----------------------------------------------------
334            (BinaryCmpKind::Lt, ElementKind::F32) => unsafe {
335                baracuda_kernels_sys::baracuda_kernels_binary_cmp_lt_f32_run(
336                    numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
337                )
338            },
339            (BinaryCmpKind::Lt, ElementKind::F16) => unsafe {
340                baracuda_kernels_sys::baracuda_kernels_binary_cmp_lt_f16_run(
341                    numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
342                )
343            },
344            (BinaryCmpKind::Lt, ElementKind::Bf16) => unsafe {
345                baracuda_kernels_sys::baracuda_kernels_binary_cmp_lt_bf16_run(
346                    numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
347                )
348            },
349            (BinaryCmpKind::Lt, ElementKind::F64) => unsafe {
350                baracuda_kernels_sys::baracuda_kernels_binary_cmp_lt_f64_run(
351                    numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
352                )
353            },
354            // --- Le -----------------------------------------------------
355            (BinaryCmpKind::Le, ElementKind::F32) => unsafe {
356                baracuda_kernels_sys::baracuda_kernels_binary_cmp_le_f32_run(
357                    numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
358                )
359            },
360            (BinaryCmpKind::Le, ElementKind::F16) => unsafe {
361                baracuda_kernels_sys::baracuda_kernels_binary_cmp_le_f16_run(
362                    numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
363                )
364            },
365            (BinaryCmpKind::Le, ElementKind::Bf16) => unsafe {
366                baracuda_kernels_sys::baracuda_kernels_binary_cmp_le_bf16_run(
367                    numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
368                )
369            },
370            (BinaryCmpKind::Le, ElementKind::F64) => unsafe {
371                baracuda_kernels_sys::baracuda_kernels_binary_cmp_le_f64_run(
372                    numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
373                )
374            },
375            _ => {
376                return Err(Error::Unsupported(
377                    "baracuda-kernels::BinaryCmpPlan::run reached an unimplemented \
378                     (kind, dtype) pair — select() should have caught this",
379                ));
380            }
381        };
382        map_status(status)
383    }
384
385    /// Strided / broadcast kernel path.
386    fn run_strided(
387        &self,
388        stream_ptr: *mut c_void,
389        a_ptr: *const c_void,
390        b_ptr: *const c_void,
391        y_ptr: *mut c_void,
392        numel: i64,
393        args: &BinaryCmpArgs<'_, T, N>,
394    ) -> Result<()> {
395        let shape = args.y.shape;
396        let stride_a = args.a.stride;
397        let stride_b = args.b.stride;
398        let stride_y = args.y.stride;
399        let rank = N as i32;
400
401        let status = match (self.desc.kind, T::KIND) {
402            // --- Eq -----------------------------------------------------
403            (BinaryCmpKind::Eq, ElementKind::F32) => unsafe {
404                baracuda_kernels_sys::baracuda_kernels_binary_cmp_eq_f32_strided_run(
405                    numel, rank, shape.as_ptr(),
406                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
407                    a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
408                )
409            },
410            (BinaryCmpKind::Eq, ElementKind::F16) => unsafe {
411                baracuda_kernels_sys::baracuda_kernels_binary_cmp_eq_f16_strided_run(
412                    numel, rank, shape.as_ptr(),
413                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
414                    a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
415                )
416            },
417            (BinaryCmpKind::Eq, ElementKind::Bf16) => unsafe {
418                baracuda_kernels_sys::baracuda_kernels_binary_cmp_eq_bf16_strided_run(
419                    numel, rank, shape.as_ptr(),
420                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
421                    a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
422                )
423            },
424            (BinaryCmpKind::Eq, ElementKind::F64) => unsafe {
425                baracuda_kernels_sys::baracuda_kernels_binary_cmp_eq_f64_strided_run(
426                    numel, rank, shape.as_ptr(),
427                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
428                    a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
429                )
430            },
431            // --- Ne -----------------------------------------------------
432            (BinaryCmpKind::Ne, ElementKind::F32) => unsafe {
433                baracuda_kernels_sys::baracuda_kernels_binary_cmp_ne_f32_strided_run(
434                    numel, rank, shape.as_ptr(),
435                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
436                    a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
437                )
438            },
439            (BinaryCmpKind::Ne, ElementKind::F16) => unsafe {
440                baracuda_kernels_sys::baracuda_kernels_binary_cmp_ne_f16_strided_run(
441                    numel, rank, shape.as_ptr(),
442                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
443                    a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
444                )
445            },
446            (BinaryCmpKind::Ne, ElementKind::Bf16) => unsafe {
447                baracuda_kernels_sys::baracuda_kernels_binary_cmp_ne_bf16_strided_run(
448                    numel, rank, shape.as_ptr(),
449                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
450                    a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
451                )
452            },
453            (BinaryCmpKind::Ne, ElementKind::F64) => unsafe {
454                baracuda_kernels_sys::baracuda_kernels_binary_cmp_ne_f64_strided_run(
455                    numel, rank, shape.as_ptr(),
456                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
457                    a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
458                )
459            },
460            // --- Gt -----------------------------------------------------
461            (BinaryCmpKind::Gt, ElementKind::F32) => unsafe {
462                baracuda_kernels_sys::baracuda_kernels_binary_cmp_gt_f32_strided_run(
463                    numel, rank, shape.as_ptr(),
464                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
465                    a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
466                )
467            },
468            (BinaryCmpKind::Gt, ElementKind::F16) => unsafe {
469                baracuda_kernels_sys::baracuda_kernels_binary_cmp_gt_f16_strided_run(
470                    numel, rank, shape.as_ptr(),
471                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
472                    a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
473                )
474            },
475            (BinaryCmpKind::Gt, ElementKind::Bf16) => unsafe {
476                baracuda_kernels_sys::baracuda_kernels_binary_cmp_gt_bf16_strided_run(
477                    numel, rank, shape.as_ptr(),
478                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
479                    a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
480                )
481            },
482            (BinaryCmpKind::Gt, ElementKind::F64) => unsafe {
483                baracuda_kernels_sys::baracuda_kernels_binary_cmp_gt_f64_strided_run(
484                    numel, rank, shape.as_ptr(),
485                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
486                    a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
487                )
488            },
489            // --- Ge -----------------------------------------------------
490            (BinaryCmpKind::Ge, ElementKind::F32) => unsafe {
491                baracuda_kernels_sys::baracuda_kernels_binary_cmp_ge_f32_strided_run(
492                    numel, rank, shape.as_ptr(),
493                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
494                    a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
495                )
496            },
497            (BinaryCmpKind::Ge, ElementKind::F16) => unsafe {
498                baracuda_kernels_sys::baracuda_kernels_binary_cmp_ge_f16_strided_run(
499                    numel, rank, shape.as_ptr(),
500                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
501                    a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
502                )
503            },
504            (BinaryCmpKind::Ge, ElementKind::Bf16) => unsafe {
505                baracuda_kernels_sys::baracuda_kernels_binary_cmp_ge_bf16_strided_run(
506                    numel, rank, shape.as_ptr(),
507                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
508                    a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
509                )
510            },
511            (BinaryCmpKind::Ge, ElementKind::F64) => unsafe {
512                baracuda_kernels_sys::baracuda_kernels_binary_cmp_ge_f64_strided_run(
513                    numel, rank, shape.as_ptr(),
514                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
515                    a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
516                )
517            },
518            // --- Lt -----------------------------------------------------
519            (BinaryCmpKind::Lt, ElementKind::F32) => unsafe {
520                baracuda_kernels_sys::baracuda_kernels_binary_cmp_lt_f32_strided_run(
521                    numel, rank, shape.as_ptr(),
522                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
523                    a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
524                )
525            },
526            (BinaryCmpKind::Lt, ElementKind::F16) => unsafe {
527                baracuda_kernels_sys::baracuda_kernels_binary_cmp_lt_f16_strided_run(
528                    numel, rank, shape.as_ptr(),
529                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
530                    a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
531                )
532            },
533            (BinaryCmpKind::Lt, ElementKind::Bf16) => unsafe {
534                baracuda_kernels_sys::baracuda_kernels_binary_cmp_lt_bf16_strided_run(
535                    numel, rank, shape.as_ptr(),
536                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
537                    a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
538                )
539            },
540            (BinaryCmpKind::Lt, ElementKind::F64) => unsafe {
541                baracuda_kernels_sys::baracuda_kernels_binary_cmp_lt_f64_strided_run(
542                    numel, rank, shape.as_ptr(),
543                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
544                    a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
545                )
546            },
547            // --- Le -----------------------------------------------------
548            (BinaryCmpKind::Le, ElementKind::F32) => unsafe {
549                baracuda_kernels_sys::baracuda_kernels_binary_cmp_le_f32_strided_run(
550                    numel, rank, shape.as_ptr(),
551                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
552                    a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
553                )
554            },
555            (BinaryCmpKind::Le, ElementKind::F16) => unsafe {
556                baracuda_kernels_sys::baracuda_kernels_binary_cmp_le_f16_strided_run(
557                    numel, rank, shape.as_ptr(),
558                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
559                    a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
560                )
561            },
562            (BinaryCmpKind::Le, ElementKind::Bf16) => unsafe {
563                baracuda_kernels_sys::baracuda_kernels_binary_cmp_le_bf16_strided_run(
564                    numel, rank, shape.as_ptr(),
565                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
566                    a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
567                )
568            },
569            (BinaryCmpKind::Le, ElementKind::F64) => unsafe {
570                baracuda_kernels_sys::baracuda_kernels_binary_cmp_le_f64_strided_run(
571                    numel, rank, shape.as_ptr(),
572                    stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
573                    a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
574                )
575            },
576            _ => {
577                return Err(Error::Unsupported(
578                    "baracuda-kernels::BinaryCmpPlan::run_strided reached an \
579                     unimplemented (kind, dtype) pair — select() should have caught this",
580                ));
581            }
582        };
583        map_status(status)
584    }
585}
586
587fn map_status(code: i32) -> Result<()> {
588    match code {
589        0 => Ok(()),
590        1 => Err(Error::MisalignedOperand),
591        2 => Err(Error::InvalidProblem(
592            "baracuda-kernels-sys reported invalid problem",
593        )),
594        3 => Err(Error::Unsupported(
595            "baracuda-kernels-sys reported unsupported configuration",
596        )),
597        4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
598        n => Err(Error::CutlassInternal(n)),
599    }
600}