1use core::ffi::c_void;
16use core::marker::PhantomData;
17
18use baracuda_cutlass::{Error, Result};
19use baracuda_driver::Stream;
20use baracuda_kernels_types::{
21 ArchSku, BackendKind, BinaryCmpKind, Element, ElementKind, KernelSku, MathPrecision,
22 OpCategory, PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
23};
24
25#[derive(Copy, Clone, Debug)]
31pub struct BinaryCmpDescriptor<const N: usize> {
32 pub kind: BinaryCmpKind,
34 pub shape: [i32; N],
36 pub element: ElementKind,
38}
39
40pub struct BinaryCmpArgs<'a, T: Element, const N: usize> {
46 pub a: TensorRef<'a, T, N>,
48 pub b: TensorRef<'a, T, N>,
50 pub y: TensorMut<'a, u8, N>,
52}
53
54pub struct BinaryCmpPlan<T: Element, const N: usize> {
59 desc: BinaryCmpDescriptor<N>,
60 sku: KernelSku,
61 _marker: PhantomData<T>,
62}
63
64impl<T: Element, const N: usize> BinaryCmpPlan<T, N> {
65 pub fn select(
68 _stream: &Stream,
69 desc: &BinaryCmpDescriptor<N>,
70 _pref: PlanPreference,
71 ) -> Result<Self> {
72 if desc.element != T::KIND {
73 return Err(Error::Unsupported(
74 "baracuda-kernels::BinaryCmpPlan: descriptor element != type parameter T",
75 ));
76 }
77 for &d in desc.shape.iter() {
78 if d < 0 {
79 return Err(Error::InvalidProblem(
80 "baracuda-kernels::BinaryCmpPlan: shape dims must be non-negative",
81 ));
82 }
83 }
84
85 let kind_in_scope = matches!(
90 desc.kind,
91 BinaryCmpKind::Eq
92 | BinaryCmpKind::Ne
93 | BinaryCmpKind::Gt
94 | BinaryCmpKind::Ge
95 | BinaryCmpKind::Lt
96 | BinaryCmpKind::Le
97 );
98 let dtype_in_scope = matches!(
99 T::KIND,
100 ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
101 );
102 let supported = kind_in_scope && dtype_in_scope;
103 if !supported {
104 return Err(Error::Unsupported(
105 "baracuda-kernels::BinaryCmpPlan: this (kind, dtype) cell is not yet \
106 wired; see the dispatcher's kind / dtype scope for the supported set",
107 ));
108 }
109
110 let precision_guarantee = PrecisionGuarantee {
116 math_precision: MathPrecision::F32,
117 accumulator: ElementKind::F32,
118 bit_stable_on_same_hardware: true,
119 deterministic: true,
120 };
121 let sku = KernelSku {
122 category: OpCategory::BinaryElementwise,
123 op: desc.kind as u16,
124 element: T::KIND,
125 aux_element: None,
129 layout: None,
130 epilogue: None,
131 arch: ArchSku::Sm80,
132 backend: BackendKind::Bespoke,
133 precision_guarantee,
134 };
135 Ok(Self {
136 desc: *desc,
137 sku,
138 _marker: PhantomData,
139 })
140 }
141
142 pub fn can_implement(&self, args: &BinaryCmpArgs<'_, T, N>) -> Result<()> {
147 if args.y.shape != self.desc.shape {
148 return Err(Error::InvalidProblem(
149 "baracuda-kernels::BinaryCmpPlan: Y shape mismatch with descriptor",
150 ));
151 }
152
153 for d in 0..N {
154 let y_dim = self.desc.shape[d];
155 let a_dim = args.a.shape[d];
156 let b_dim = args.b.shape[d];
157 if a_dim != y_dim && !(a_dim == 1 && args.a.stride[d] == 0) {
158 return Err(Error::InvalidProblem(
159 "baracuda-kernels::BinaryCmpPlan: A axis not broadcast-compatible with output",
160 ));
161 }
162 if b_dim != y_dim && !(b_dim == 1 && args.b.stride[d] == 0) {
163 return Err(Error::InvalidProblem(
164 "baracuda-kernels::BinaryCmpPlan: B axis not broadcast-compatible with output",
165 ));
166 }
167 }
168
169 if N > 8 {
170 return Err(Error::Unsupported(
171 "baracuda-kernels::BinaryCmpPlan: tensor rank > 8 not supported \
172 (kernel param block fixes MAX_RANK = 8)",
173 ));
174 }
175
176 let y_numel = args.y.numel();
177 let a_numel = args.a.numel();
178 let b_numel = args.b.numel();
179 let a_len = args.a.data.len() as i64;
180 let b_len = args.b.data.len() as i64;
181 let y_len = args.y.data.len() as i64;
182 if y_len < y_numel {
183 return Err(Error::BufferTooSmall {
184 needed: y_numel as usize,
185 got: y_len as usize,
186 });
187 }
188 if a_len < a_numel {
189 return Err(Error::BufferTooSmall {
190 needed: a_numel as usize,
191 got: a_len as usize,
192 });
193 }
194 if b_len < b_numel {
195 return Err(Error::BufferTooSmall {
196 needed: b_numel as usize,
197 got: b_len as usize,
198 });
199 }
200 Ok(())
201 }
202
203 #[inline]
205 pub fn workspace_size(&self) -> usize {
206 0
207 }
208
209 #[inline]
211 pub fn sku(&self) -> KernelSku {
212 self.sku
213 }
214
215 #[inline]
217 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
218 self.sku.precision_guarantee
219 }
220
221 pub fn run(
223 &self,
224 stream: &Stream,
225 _workspace: Workspace<'_>,
226 args: BinaryCmpArgs<'_, T, N>,
227 ) -> Result<()> {
228 self.can_implement(&args)?;
229 let numel = args.y.numel();
230 if numel == 0 {
231 return Ok(());
232 }
233 let a_ptr = args.a.data.as_raw().0 as *const c_void;
234 let b_ptr = args.b.data.as_raw().0 as *const c_void;
235 let y_ptr = args.y.data.as_raw().0 as *mut c_void;
236 let stream_ptr = stream.as_raw() as *mut c_void;
237
238 let all_contig_same_shape = args.a.shape == args.y.shape
239 && args.b.shape == args.y.shape
240 && args.a.is_contiguous()
241 && args.b.is_contiguous()
242 && args.y.is_contiguous();
243
244 if !all_contig_same_shape {
245 return self.run_strided(stream_ptr, a_ptr, b_ptr, y_ptr, numel, &args);
246 }
247
248 let status = match (self.desc.kind, T::KIND) {
249 (BinaryCmpKind::Eq, ElementKind::F32) => unsafe {
251 baracuda_kernels_sys::baracuda_kernels_binary_cmp_eq_f32_run(
252 numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
253 )
254 },
255 (BinaryCmpKind::Eq, ElementKind::F16) => unsafe {
256 baracuda_kernels_sys::baracuda_kernels_binary_cmp_eq_f16_run(
257 numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
258 )
259 },
260 (BinaryCmpKind::Eq, ElementKind::Bf16) => unsafe {
261 baracuda_kernels_sys::baracuda_kernels_binary_cmp_eq_bf16_run(
262 numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
263 )
264 },
265 (BinaryCmpKind::Eq, ElementKind::F64) => unsafe {
266 baracuda_kernels_sys::baracuda_kernels_binary_cmp_eq_f64_run(
267 numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
268 )
269 },
270 (BinaryCmpKind::Ne, ElementKind::F32) => unsafe {
272 baracuda_kernels_sys::baracuda_kernels_binary_cmp_ne_f32_run(
273 numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
274 )
275 },
276 (BinaryCmpKind::Ne, ElementKind::F16) => unsafe {
277 baracuda_kernels_sys::baracuda_kernels_binary_cmp_ne_f16_run(
278 numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
279 )
280 },
281 (BinaryCmpKind::Ne, ElementKind::Bf16) => unsafe {
282 baracuda_kernels_sys::baracuda_kernels_binary_cmp_ne_bf16_run(
283 numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
284 )
285 },
286 (BinaryCmpKind::Ne, ElementKind::F64) => unsafe {
287 baracuda_kernels_sys::baracuda_kernels_binary_cmp_ne_f64_run(
288 numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
289 )
290 },
291 (BinaryCmpKind::Gt, ElementKind::F32) => unsafe {
293 baracuda_kernels_sys::baracuda_kernels_binary_cmp_gt_f32_run(
294 numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
295 )
296 },
297 (BinaryCmpKind::Gt, ElementKind::F16) => unsafe {
298 baracuda_kernels_sys::baracuda_kernels_binary_cmp_gt_f16_run(
299 numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
300 )
301 },
302 (BinaryCmpKind::Gt, ElementKind::Bf16) => unsafe {
303 baracuda_kernels_sys::baracuda_kernels_binary_cmp_gt_bf16_run(
304 numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
305 )
306 },
307 (BinaryCmpKind::Gt, ElementKind::F64) => unsafe {
308 baracuda_kernels_sys::baracuda_kernels_binary_cmp_gt_f64_run(
309 numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
310 )
311 },
312 (BinaryCmpKind::Ge, ElementKind::F32) => unsafe {
314 baracuda_kernels_sys::baracuda_kernels_binary_cmp_ge_f32_run(
315 numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
316 )
317 },
318 (BinaryCmpKind::Ge, ElementKind::F16) => unsafe {
319 baracuda_kernels_sys::baracuda_kernels_binary_cmp_ge_f16_run(
320 numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
321 )
322 },
323 (BinaryCmpKind::Ge, ElementKind::Bf16) => unsafe {
324 baracuda_kernels_sys::baracuda_kernels_binary_cmp_ge_bf16_run(
325 numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
326 )
327 },
328 (BinaryCmpKind::Ge, ElementKind::F64) => unsafe {
329 baracuda_kernels_sys::baracuda_kernels_binary_cmp_ge_f64_run(
330 numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
331 )
332 },
333 (BinaryCmpKind::Lt, ElementKind::F32) => unsafe {
335 baracuda_kernels_sys::baracuda_kernels_binary_cmp_lt_f32_run(
336 numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
337 )
338 },
339 (BinaryCmpKind::Lt, ElementKind::F16) => unsafe {
340 baracuda_kernels_sys::baracuda_kernels_binary_cmp_lt_f16_run(
341 numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
342 )
343 },
344 (BinaryCmpKind::Lt, ElementKind::Bf16) => unsafe {
345 baracuda_kernels_sys::baracuda_kernels_binary_cmp_lt_bf16_run(
346 numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
347 )
348 },
349 (BinaryCmpKind::Lt, ElementKind::F64) => unsafe {
350 baracuda_kernels_sys::baracuda_kernels_binary_cmp_lt_f64_run(
351 numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
352 )
353 },
354 (BinaryCmpKind::Le, ElementKind::F32) => unsafe {
356 baracuda_kernels_sys::baracuda_kernels_binary_cmp_le_f32_run(
357 numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
358 )
359 },
360 (BinaryCmpKind::Le, ElementKind::F16) => unsafe {
361 baracuda_kernels_sys::baracuda_kernels_binary_cmp_le_f16_run(
362 numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
363 )
364 },
365 (BinaryCmpKind::Le, ElementKind::Bf16) => unsafe {
366 baracuda_kernels_sys::baracuda_kernels_binary_cmp_le_bf16_run(
367 numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
368 )
369 },
370 (BinaryCmpKind::Le, ElementKind::F64) => unsafe {
371 baracuda_kernels_sys::baracuda_kernels_binary_cmp_le_f64_run(
372 numel, a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
373 )
374 },
375 _ => {
376 return Err(Error::Unsupported(
377 "baracuda-kernels::BinaryCmpPlan::run reached an unimplemented \
378 (kind, dtype) pair — select() should have caught this",
379 ));
380 }
381 };
382 map_status(status)
383 }
384
385 fn run_strided(
387 &self,
388 stream_ptr: *mut c_void,
389 a_ptr: *const c_void,
390 b_ptr: *const c_void,
391 y_ptr: *mut c_void,
392 numel: i64,
393 args: &BinaryCmpArgs<'_, T, N>,
394 ) -> Result<()> {
395 let shape = args.y.shape;
396 let stride_a = args.a.stride;
397 let stride_b = args.b.stride;
398 let stride_y = args.y.stride;
399 let rank = N as i32;
400
401 let status = match (self.desc.kind, T::KIND) {
402 (BinaryCmpKind::Eq, ElementKind::F32) => unsafe {
404 baracuda_kernels_sys::baracuda_kernels_binary_cmp_eq_f32_strided_run(
405 numel, rank, shape.as_ptr(),
406 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
407 a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
408 )
409 },
410 (BinaryCmpKind::Eq, ElementKind::F16) => unsafe {
411 baracuda_kernels_sys::baracuda_kernels_binary_cmp_eq_f16_strided_run(
412 numel, rank, shape.as_ptr(),
413 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
414 a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
415 )
416 },
417 (BinaryCmpKind::Eq, ElementKind::Bf16) => unsafe {
418 baracuda_kernels_sys::baracuda_kernels_binary_cmp_eq_bf16_strided_run(
419 numel, rank, shape.as_ptr(),
420 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
421 a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
422 )
423 },
424 (BinaryCmpKind::Eq, ElementKind::F64) => unsafe {
425 baracuda_kernels_sys::baracuda_kernels_binary_cmp_eq_f64_strided_run(
426 numel, rank, shape.as_ptr(),
427 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
428 a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
429 )
430 },
431 (BinaryCmpKind::Ne, ElementKind::F32) => unsafe {
433 baracuda_kernels_sys::baracuda_kernels_binary_cmp_ne_f32_strided_run(
434 numel, rank, shape.as_ptr(),
435 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
436 a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
437 )
438 },
439 (BinaryCmpKind::Ne, ElementKind::F16) => unsafe {
440 baracuda_kernels_sys::baracuda_kernels_binary_cmp_ne_f16_strided_run(
441 numel, rank, shape.as_ptr(),
442 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
443 a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
444 )
445 },
446 (BinaryCmpKind::Ne, ElementKind::Bf16) => unsafe {
447 baracuda_kernels_sys::baracuda_kernels_binary_cmp_ne_bf16_strided_run(
448 numel, rank, shape.as_ptr(),
449 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
450 a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
451 )
452 },
453 (BinaryCmpKind::Ne, ElementKind::F64) => unsafe {
454 baracuda_kernels_sys::baracuda_kernels_binary_cmp_ne_f64_strided_run(
455 numel, rank, shape.as_ptr(),
456 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
457 a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
458 )
459 },
460 (BinaryCmpKind::Gt, ElementKind::F32) => unsafe {
462 baracuda_kernels_sys::baracuda_kernels_binary_cmp_gt_f32_strided_run(
463 numel, rank, shape.as_ptr(),
464 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
465 a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
466 )
467 },
468 (BinaryCmpKind::Gt, ElementKind::F16) => unsafe {
469 baracuda_kernels_sys::baracuda_kernels_binary_cmp_gt_f16_strided_run(
470 numel, rank, shape.as_ptr(),
471 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
472 a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
473 )
474 },
475 (BinaryCmpKind::Gt, ElementKind::Bf16) => unsafe {
476 baracuda_kernels_sys::baracuda_kernels_binary_cmp_gt_bf16_strided_run(
477 numel, rank, shape.as_ptr(),
478 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
479 a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
480 )
481 },
482 (BinaryCmpKind::Gt, ElementKind::F64) => unsafe {
483 baracuda_kernels_sys::baracuda_kernels_binary_cmp_gt_f64_strided_run(
484 numel, rank, shape.as_ptr(),
485 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
486 a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
487 )
488 },
489 (BinaryCmpKind::Ge, ElementKind::F32) => unsafe {
491 baracuda_kernels_sys::baracuda_kernels_binary_cmp_ge_f32_strided_run(
492 numel, rank, shape.as_ptr(),
493 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
494 a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
495 )
496 },
497 (BinaryCmpKind::Ge, ElementKind::F16) => unsafe {
498 baracuda_kernels_sys::baracuda_kernels_binary_cmp_ge_f16_strided_run(
499 numel, rank, shape.as_ptr(),
500 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
501 a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
502 )
503 },
504 (BinaryCmpKind::Ge, ElementKind::Bf16) => unsafe {
505 baracuda_kernels_sys::baracuda_kernels_binary_cmp_ge_bf16_strided_run(
506 numel, rank, shape.as_ptr(),
507 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
508 a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
509 )
510 },
511 (BinaryCmpKind::Ge, ElementKind::F64) => unsafe {
512 baracuda_kernels_sys::baracuda_kernels_binary_cmp_ge_f64_strided_run(
513 numel, rank, shape.as_ptr(),
514 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
515 a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
516 )
517 },
518 (BinaryCmpKind::Lt, ElementKind::F32) => unsafe {
520 baracuda_kernels_sys::baracuda_kernels_binary_cmp_lt_f32_strided_run(
521 numel, rank, shape.as_ptr(),
522 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
523 a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
524 )
525 },
526 (BinaryCmpKind::Lt, ElementKind::F16) => unsafe {
527 baracuda_kernels_sys::baracuda_kernels_binary_cmp_lt_f16_strided_run(
528 numel, rank, shape.as_ptr(),
529 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
530 a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
531 )
532 },
533 (BinaryCmpKind::Lt, ElementKind::Bf16) => unsafe {
534 baracuda_kernels_sys::baracuda_kernels_binary_cmp_lt_bf16_strided_run(
535 numel, rank, shape.as_ptr(),
536 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
537 a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
538 )
539 },
540 (BinaryCmpKind::Lt, ElementKind::F64) => unsafe {
541 baracuda_kernels_sys::baracuda_kernels_binary_cmp_lt_f64_strided_run(
542 numel, rank, shape.as_ptr(),
543 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
544 a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
545 )
546 },
547 (BinaryCmpKind::Le, ElementKind::F32) => unsafe {
549 baracuda_kernels_sys::baracuda_kernels_binary_cmp_le_f32_strided_run(
550 numel, rank, shape.as_ptr(),
551 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
552 a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
553 )
554 },
555 (BinaryCmpKind::Le, ElementKind::F16) => unsafe {
556 baracuda_kernels_sys::baracuda_kernels_binary_cmp_le_f16_strided_run(
557 numel, rank, shape.as_ptr(),
558 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
559 a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
560 )
561 },
562 (BinaryCmpKind::Le, ElementKind::Bf16) => unsafe {
563 baracuda_kernels_sys::baracuda_kernels_binary_cmp_le_bf16_strided_run(
564 numel, rank, shape.as_ptr(),
565 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
566 a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
567 )
568 },
569 (BinaryCmpKind::Le, ElementKind::F64) => unsafe {
570 baracuda_kernels_sys::baracuda_kernels_binary_cmp_le_f64_strided_run(
571 numel, rank, shape.as_ptr(),
572 stride_a.as_ptr(), stride_b.as_ptr(), stride_y.as_ptr(),
573 a_ptr, b_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
574 )
575 },
576 _ => {
577 return Err(Error::Unsupported(
578 "baracuda-kernels::BinaryCmpPlan::run_strided reached an \
579 unimplemented (kind, dtype) pair — select() should have caught this",
580 ));
581 }
582 };
583 map_status(status)
584 }
585}
586
587fn map_status(code: i32) -> Result<()> {
588 match code {
589 0 => Ok(()),
590 1 => Err(Error::MisalignedOperand),
591 2 => Err(Error::InvalidProblem(
592 "baracuda-kernels-sys reported invalid problem",
593 )),
594 3 => Err(Error::Unsupported(
595 "baracuda-kernels-sys reported unsupported configuration",
596 )),
597 4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
598 n => Err(Error::CutlassInternal(n)),
599 }
600}