Skip to main content

baracuda_kernels/elementwise/
ternary.rs

1//! Ternary elementwise plan.
2//!
3//! 3→1 sibling of [`crate::BinaryPlan`]. Same-dtype-input, same-dtype-
4//! output ops with three inputs (a, b, c) and one output (y).
5//!
6//! Wired matrix: {[`TernaryKind::Clamp`], [`TernaryKind::Fma`],
7//! [`TernaryKind::Addcmul`], [`TernaryKind::Addcdiv`]} × {f32, f16,
8//! bf16, f64} = 16 (kind, dtype) cells, each with both the contig fast
9//! path and the strided / broadcast path (32 launchers total).
10//!
11//! Addcmul / Addcdiv read [`TernaryDescriptor::scale`] (PyTorch's
12//! `value` parameter); Clamp / Fma ignore it. The dispatcher routes
13//! parameterized ops through a separate FFI family that threads the
14//! `scale` parameter through to the kernel.
15//!
16//! Reserved-but-deferred: [`TernaryKind::Where`] needs a
17//! heterogeneous-dtype plan shape (its bool cond input is dtype `u8`,
18//! not `T`) — see [`crate::WherePlan`].
19
20use core::ffi::c_void;
21use core::marker::PhantomData;
22
23use baracuda_cutlass::{Error, Result};
24use baracuda_driver::Stream;
25use baracuda_kernels_types::{
26    ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
27    PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, TernaryKind, Workspace,
28};
29
30/// Descriptor for a ternary elementwise op.
31///
32/// `scale` is used by parameterized ops (`Addcmul`, `Addcdiv`) — set to
33/// the `value` multiplier from PyTorch's `torch.addcmul(c, a, b,
34/// value=k)` convention. Ignored by unparameterized ops (`Clamp`,
35/// `Fma`). Default `1.0`.
36#[derive(Copy, Clone, Debug)]
37pub struct TernaryDescriptor<const N: usize> {
38    /// Which ternary op to apply.
39    pub kind: TernaryKind,
40    /// Output tensor shape.
41    pub shape: [i32; N],
42    /// Element type (shared across a, b, c, y).
43    pub element: ElementKind,
44    /// Scalar multiplier for parameterized ops (`Addcmul`, `Addcdiv`).
45    /// Unused by `Clamp` / `Fma` — pass `1.0` for those.
46    pub scale: f32,
47}
48
49/// Args bundle for a ternary elementwise launch.
50///
51/// All four operands share dtype `T`. Each input may be broadcast to
52/// `y.shape` via stride-0 axes (typical use case for `clamp(x, lo, hi)`
53/// where `lo` and `hi` are scalars: pass them as rank-N tensors with
54/// `shape[d] = 1` and `stride[d] = 0` on every axis).
55pub struct TernaryArgs<'a, T: Element, const N: usize> {
56    /// First input. For `clamp`, this is `x` (the value to clamp).
57    pub a: TensorRef<'a, T, N>,
58    /// Second input. For `clamp`, this is `lo` (the lower bound).
59    pub b: TensorRef<'a, T, N>,
60    /// Third input. For `clamp`, this is `hi` (the upper bound).
61    pub c: TensorRef<'a, T, N>,
62    /// Output.
63    pub y: TensorMut<'a, T, N>,
64}
65
66/// Ternary elementwise plan.
67///
68/// `T: Element` is the kernel's element type (today: must be `f32`).
69/// `const N: usize` is the tensor rank.
70pub struct TernaryPlan<T: Element, const N: usize> {
71    desc: TernaryDescriptor<N>,
72    sku: KernelSku,
73    _marker: PhantomData<T>,
74}
75
76impl<T: Element, const N: usize> TernaryPlan<T, N> {
77    /// Pick a kernel for `desc`. Returns [`Error::Unsupported`] if the
78    /// `(kind, T::KIND)` pair isn't wired today.
79    pub fn select(
80        _stream: &Stream,
81        desc: &TernaryDescriptor<N>,
82        _pref: PlanPreference,
83    ) -> Result<Self> {
84        if desc.element != T::KIND {
85            return Err(Error::Unsupported(
86                "baracuda-kernels::TernaryPlan: descriptor element != type parameter T",
87            ));
88        }
89        for &d in desc.shape.iter() {
90            if d < 0 {
91                return Err(Error::InvalidProblem(
92                    "baracuda-kernels::TernaryPlan: shape dims must be non-negative",
93                ));
94            }
95        }
96
97        // Wired matrix: {Clamp, Fma, Addcmul, Addcdiv} × {f32, f16,
98        // bf16, f64}. Addcmul / Addcdiv read `desc.scale`; Clamp / Fma
99        // ignore it. Where stays reserved-but-deferred — it requires a
100        // separate heterogeneous-dtype plan ([`crate::WherePlan`]).
101        let kind_in_scope = matches!(
102            desc.kind,
103            TernaryKind::Clamp | TernaryKind::Fma | TernaryKind::Addcmul | TernaryKind::Addcdiv
104        );
105        let dtype_in_scope = matches!(
106            T::KIND,
107            ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
108        );
109        let supported = kind_in_scope && dtype_in_scope;
110        if !supported {
111            return Err(Error::Unsupported(
112                "baracuda-kernels::TernaryPlan: this (kind, dtype) cell is not yet \
113                 wired; see the dispatcher's kind / dtype scope for the supported set. \
114                 Note: `Where` requires a separate heterogeneous-dtype plan \
115                 (`crate::WherePlan`).",
116            ));
117        }
118
119        let precision_guarantee = PrecisionGuarantee {
120            math_precision: MathPrecision::F32,
121            accumulator: ElementKind::F32,
122            bit_stable_on_same_hardware: true,
123            deterministic: true,
124        };
125        let sku = KernelSku {
126            category: OpCategory::TernaryElementwise,
127            op: desc.kind as u16,
128            element: T::KIND,
129            aux_element: None,
130            layout: None,
131            epilogue: None,
132            arch: ArchSku::Sm80,
133            backend: BackendKind::Bespoke,
134            precision_guarantee,
135        };
136        Ok(Self {
137            desc: *desc,
138            sku,
139            _marker: PhantomData,
140        })
141    }
142
143    /// Validate that this plan can launch with `args`.
144    ///
145    /// Per-axis broadcast compatibility: each input's `shape[d]` must
146    /// match `y.shape[d]` or be 1 with `stride[d] == 0`.
147    pub fn can_implement(&self, args: &TernaryArgs<'_, T, N>) -> Result<()> {
148        if args.y.shape != self.desc.shape {
149            return Err(Error::InvalidProblem(
150                "baracuda-kernels::TernaryPlan: Y shape mismatch with descriptor",
151            ));
152        }
153
154        for d in 0..N {
155            let y_dim = self.desc.shape[d];
156            for (name, (op_dim, op_stride)) in [
157                ("A", (args.a.shape[d], args.a.stride[d])),
158                ("B", (args.b.shape[d], args.b.stride[d])),
159                ("C", (args.c.shape[d], args.c.stride[d])),
160            ] {
161                if op_dim != y_dim && !(op_dim == 1 && op_stride == 0) {
162                    let _ = name; // Error variant takes a static string;
163                    // log the bad operand via a single shared message.
164                    return Err(Error::InvalidProblem(
165                        "baracuda-kernels::TernaryPlan: input axis not broadcast-compatible \
166                         with output (require shape[d] == y.shape[d], OR \
167                         shape[d] == 1 AND stride[d] == 0)",
168                    ));
169                }
170            }
171        }
172
173        if N > 8 {
174            return Err(Error::Unsupported(
175                "baracuda-kernels::TernaryPlan: tensor rank > 8 not supported",
176            ));
177        }
178
179        let y_numel = args.y.numel();
180        let a_numel = args.a.numel();
181        let b_numel = args.b.numel();
182        let c_numel = args.c.numel();
183        let a_len = args.a.data.len() as i64;
184        let b_len = args.b.data.len() as i64;
185        let c_len = args.c.data.len() as i64;
186        let y_len = args.y.data.len() as i64;
187        if y_len < y_numel {
188            return Err(Error::BufferTooSmall {
189                needed: y_numel as usize,
190                got: y_len as usize,
191            });
192        }
193        if a_len < a_numel {
194            return Err(Error::BufferTooSmall {
195                needed: a_numel as usize,
196                got: a_len as usize,
197            });
198        }
199        if b_len < b_numel {
200            return Err(Error::BufferTooSmall {
201                needed: b_numel as usize,
202                got: b_len as usize,
203            });
204        }
205        if c_len < c_numel {
206            return Err(Error::BufferTooSmall {
207                needed: c_numel as usize,
208                got: c_len as usize,
209            });
210        }
211        Ok(())
212    }
213
214    /// Workspace size in bytes. Always `0` for the trailblazer.
215    #[inline]
216    pub fn workspace_size(&self) -> usize {
217        0
218    }
219
220    /// Identity of the kernel this plan picked.
221    #[inline]
222    pub fn sku(&self) -> KernelSku {
223        self.sku
224    }
225
226    /// Numerical guarantees for this plan's kernel.
227    #[inline]
228    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
229        self.sku.precision_guarantee
230    }
231
232    /// Launch.
233    pub fn run(
234        &self,
235        stream: &Stream,
236        _workspace: Workspace<'_>,
237        args: TernaryArgs<'_, T, N>,
238    ) -> Result<()> {
239        self.can_implement(&args)?;
240        let numel = args.y.numel();
241        if numel == 0 {
242            return Ok(());
243        }
244        let a_ptr = args.a.data.as_raw().0 as *const c_void;
245        let b_ptr = args.b.data.as_raw().0 as *const c_void;
246        let c_ptr = args.c.data.as_raw().0 as *const c_void;
247        let y_ptr = args.y.data.as_raw().0 as *mut c_void;
248        let stream_ptr = stream.as_raw() as *mut c_void;
249
250        let all_contig_same_shape = args.a.shape == args.y.shape
251            && args.b.shape == args.y.shape
252            && args.c.shape == args.y.shape
253            && args.a.is_contiguous()
254            && args.b.is_contiguous()
255            && args.c.is_contiguous()
256            && args.y.is_contiguous();
257
258        if !all_contig_same_shape {
259            return self.run_strided(stream_ptr, a_ptr, b_ptr, c_ptr, y_ptr, numel, &args);
260        }
261
262        let status = match (self.desc.kind, T::KIND) {
263            // --- Clamp --------------------------------------------------
264            (TernaryKind::Clamp, ElementKind::F32) => unsafe {
265                baracuda_kernels_sys::baracuda_kernels_ternary_clamp_f32_run(
266                    numel, a_ptr, b_ptr, c_ptr, y_ptr,
267                    core::ptr::null_mut(), 0, stream_ptr,
268                )
269            },
270            (TernaryKind::Clamp, ElementKind::F16) => unsafe {
271                baracuda_kernels_sys::baracuda_kernels_ternary_clamp_f16_run(
272                    numel, a_ptr, b_ptr, c_ptr, y_ptr,
273                    core::ptr::null_mut(), 0, stream_ptr,
274                )
275            },
276            (TernaryKind::Clamp, ElementKind::Bf16) => unsafe {
277                baracuda_kernels_sys::baracuda_kernels_ternary_clamp_bf16_run(
278                    numel, a_ptr, b_ptr, c_ptr, y_ptr,
279                    core::ptr::null_mut(), 0, stream_ptr,
280                )
281            },
282            (TernaryKind::Clamp, ElementKind::F64) => unsafe {
283                baracuda_kernels_sys::baracuda_kernels_ternary_clamp_f64_run(
284                    numel, a_ptr, b_ptr, c_ptr, y_ptr,
285                    core::ptr::null_mut(), 0, stream_ptr,
286                )
287            },
288            // --- Fma ----------------------------------------------------
289            (TernaryKind::Fma, ElementKind::F32) => unsafe {
290                baracuda_kernels_sys::baracuda_kernels_ternary_fma_f32_run(
291                    numel, a_ptr, b_ptr, c_ptr, y_ptr,
292                    core::ptr::null_mut(), 0, stream_ptr,
293                )
294            },
295            (TernaryKind::Fma, ElementKind::F16) => unsafe {
296                baracuda_kernels_sys::baracuda_kernels_ternary_fma_f16_run(
297                    numel, a_ptr, b_ptr, c_ptr, y_ptr,
298                    core::ptr::null_mut(), 0, stream_ptr,
299                )
300            },
301            (TernaryKind::Fma, ElementKind::Bf16) => unsafe {
302                baracuda_kernels_sys::baracuda_kernels_ternary_fma_bf16_run(
303                    numel, a_ptr, b_ptr, c_ptr, y_ptr,
304                    core::ptr::null_mut(), 0, stream_ptr,
305                )
306            },
307            (TernaryKind::Fma, ElementKind::F64) => unsafe {
308                baracuda_kernels_sys::baracuda_kernels_ternary_fma_f64_run(
309                    numel, a_ptr, b_ptr, c_ptr, y_ptr,
310                    core::ptr::null_mut(), 0, stream_ptr,
311                )
312            },
313            // --- Addcmul (reads desc.scale) ------------------------------
314            (TernaryKind::Addcmul, ElementKind::F32) => unsafe {
315                baracuda_kernels_sys::baracuda_kernels_ternary_addcmul_f32_run(
316                    numel, a_ptr, b_ptr, c_ptr, y_ptr, self.desc.scale,
317                    core::ptr::null_mut(), 0, stream_ptr,
318                )
319            },
320            (TernaryKind::Addcmul, ElementKind::F16) => unsafe {
321                baracuda_kernels_sys::baracuda_kernels_ternary_addcmul_f16_run(
322                    numel, a_ptr, b_ptr, c_ptr, y_ptr, self.desc.scale,
323                    core::ptr::null_mut(), 0, stream_ptr,
324                )
325            },
326            (TernaryKind::Addcmul, ElementKind::Bf16) => unsafe {
327                baracuda_kernels_sys::baracuda_kernels_ternary_addcmul_bf16_run(
328                    numel, a_ptr, b_ptr, c_ptr, y_ptr, self.desc.scale,
329                    core::ptr::null_mut(), 0, stream_ptr,
330                )
331            },
332            (TernaryKind::Addcmul, ElementKind::F64) => unsafe {
333                baracuda_kernels_sys::baracuda_kernels_ternary_addcmul_f64_run(
334                    numel, a_ptr, b_ptr, c_ptr, y_ptr, self.desc.scale,
335                    core::ptr::null_mut(), 0, stream_ptr,
336                )
337            },
338            // --- Addcdiv (reads desc.scale) ------------------------------
339            (TernaryKind::Addcdiv, ElementKind::F32) => unsafe {
340                baracuda_kernels_sys::baracuda_kernels_ternary_addcdiv_f32_run(
341                    numel, a_ptr, b_ptr, c_ptr, y_ptr, self.desc.scale,
342                    core::ptr::null_mut(), 0, stream_ptr,
343                )
344            },
345            (TernaryKind::Addcdiv, ElementKind::F16) => unsafe {
346                baracuda_kernels_sys::baracuda_kernels_ternary_addcdiv_f16_run(
347                    numel, a_ptr, b_ptr, c_ptr, y_ptr, self.desc.scale,
348                    core::ptr::null_mut(), 0, stream_ptr,
349                )
350            },
351            (TernaryKind::Addcdiv, ElementKind::Bf16) => unsafe {
352                baracuda_kernels_sys::baracuda_kernels_ternary_addcdiv_bf16_run(
353                    numel, a_ptr, b_ptr, c_ptr, y_ptr, self.desc.scale,
354                    core::ptr::null_mut(), 0, stream_ptr,
355                )
356            },
357            (TernaryKind::Addcdiv, ElementKind::F64) => unsafe {
358                baracuda_kernels_sys::baracuda_kernels_ternary_addcdiv_f64_run(
359                    numel, a_ptr, b_ptr, c_ptr, y_ptr, self.desc.scale,
360                    core::ptr::null_mut(), 0, stream_ptr,
361                )
362            },
363            _ => {
364                return Err(Error::Unsupported(
365                    "baracuda-kernels::TernaryPlan::run reached an unimplemented \
366                     (kind, dtype) — select() should have caught this",
367                ));
368            }
369        };
370        map_status(status)
371    }
372
373    /// Strided / broadcast kernel path.
374    fn run_strided(
375        &self,
376        stream_ptr: *mut c_void,
377        a_ptr: *const c_void,
378        b_ptr: *const c_void,
379        c_ptr: *const c_void,
380        y_ptr: *mut c_void,
381        numel: i64,
382        args: &TernaryArgs<'_, T, N>,
383    ) -> Result<()> {
384        let shape = args.y.shape;
385        let stride_a = args.a.stride;
386        let stride_b = args.b.stride;
387        let stride_c = args.c.stride;
388        let stride_y = args.y.stride;
389        let rank = N as i32;
390
391        let status = match (self.desc.kind, T::KIND) {
392            // --- Clamp --------------------------------------------------
393            (TernaryKind::Clamp, ElementKind::F32) => unsafe {
394                baracuda_kernels_sys::baracuda_kernels_ternary_clamp_f32_strided_run(
395                    numel, rank, shape.as_ptr(),
396                    stride_a.as_ptr(), stride_b.as_ptr(),
397                    stride_c.as_ptr(), stride_y.as_ptr(),
398                    a_ptr, b_ptr, c_ptr, y_ptr,
399                    core::ptr::null_mut(), 0, stream_ptr,
400                )
401            },
402            (TernaryKind::Clamp, ElementKind::F16) => unsafe {
403                baracuda_kernels_sys::baracuda_kernels_ternary_clamp_f16_strided_run(
404                    numel, rank, shape.as_ptr(),
405                    stride_a.as_ptr(), stride_b.as_ptr(),
406                    stride_c.as_ptr(), stride_y.as_ptr(),
407                    a_ptr, b_ptr, c_ptr, y_ptr,
408                    core::ptr::null_mut(), 0, stream_ptr,
409                )
410            },
411            (TernaryKind::Clamp, ElementKind::Bf16) => unsafe {
412                baracuda_kernels_sys::baracuda_kernels_ternary_clamp_bf16_strided_run(
413                    numel, rank, shape.as_ptr(),
414                    stride_a.as_ptr(), stride_b.as_ptr(),
415                    stride_c.as_ptr(), stride_y.as_ptr(),
416                    a_ptr, b_ptr, c_ptr, y_ptr,
417                    core::ptr::null_mut(), 0, stream_ptr,
418                )
419            },
420            (TernaryKind::Clamp, ElementKind::F64) => unsafe {
421                baracuda_kernels_sys::baracuda_kernels_ternary_clamp_f64_strided_run(
422                    numel, rank, shape.as_ptr(),
423                    stride_a.as_ptr(), stride_b.as_ptr(),
424                    stride_c.as_ptr(), stride_y.as_ptr(),
425                    a_ptr, b_ptr, c_ptr, y_ptr,
426                    core::ptr::null_mut(), 0, stream_ptr,
427                )
428            },
429            // --- Fma ----------------------------------------------------
430            (TernaryKind::Fma, ElementKind::F32) => unsafe {
431                baracuda_kernels_sys::baracuda_kernels_ternary_fma_f32_strided_run(
432                    numel, rank, shape.as_ptr(),
433                    stride_a.as_ptr(), stride_b.as_ptr(),
434                    stride_c.as_ptr(), stride_y.as_ptr(),
435                    a_ptr, b_ptr, c_ptr, y_ptr,
436                    core::ptr::null_mut(), 0, stream_ptr,
437                )
438            },
439            (TernaryKind::Fma, ElementKind::F16) => unsafe {
440                baracuda_kernels_sys::baracuda_kernels_ternary_fma_f16_strided_run(
441                    numel, rank, shape.as_ptr(),
442                    stride_a.as_ptr(), stride_b.as_ptr(),
443                    stride_c.as_ptr(), stride_y.as_ptr(),
444                    a_ptr, b_ptr, c_ptr, y_ptr,
445                    core::ptr::null_mut(), 0, stream_ptr,
446                )
447            },
448            (TernaryKind::Fma, ElementKind::Bf16) => unsafe {
449                baracuda_kernels_sys::baracuda_kernels_ternary_fma_bf16_strided_run(
450                    numel, rank, shape.as_ptr(),
451                    stride_a.as_ptr(), stride_b.as_ptr(),
452                    stride_c.as_ptr(), stride_y.as_ptr(),
453                    a_ptr, b_ptr, c_ptr, y_ptr,
454                    core::ptr::null_mut(), 0, stream_ptr,
455                )
456            },
457            (TernaryKind::Fma, ElementKind::F64) => unsafe {
458                baracuda_kernels_sys::baracuda_kernels_ternary_fma_f64_strided_run(
459                    numel, rank, shape.as_ptr(),
460                    stride_a.as_ptr(), stride_b.as_ptr(),
461                    stride_c.as_ptr(), stride_y.as_ptr(),
462                    a_ptr, b_ptr, c_ptr, y_ptr,
463                    core::ptr::null_mut(), 0, stream_ptr,
464                )
465            },
466            // --- Addcmul (reads desc.scale) ------------------------------
467            (TernaryKind::Addcmul, ElementKind::F32) => unsafe {
468                baracuda_kernels_sys::baracuda_kernels_ternary_addcmul_f32_strided_run(
469                    numel, rank, shape.as_ptr(),
470                    stride_a.as_ptr(), stride_b.as_ptr(),
471                    stride_c.as_ptr(), stride_y.as_ptr(),
472                    a_ptr, b_ptr, c_ptr, y_ptr, self.desc.scale,
473                    core::ptr::null_mut(), 0, stream_ptr,
474                )
475            },
476            (TernaryKind::Addcmul, ElementKind::F16) => unsafe {
477                baracuda_kernels_sys::baracuda_kernels_ternary_addcmul_f16_strided_run(
478                    numel, rank, shape.as_ptr(),
479                    stride_a.as_ptr(), stride_b.as_ptr(),
480                    stride_c.as_ptr(), stride_y.as_ptr(),
481                    a_ptr, b_ptr, c_ptr, y_ptr, self.desc.scale,
482                    core::ptr::null_mut(), 0, stream_ptr,
483                )
484            },
485            (TernaryKind::Addcmul, ElementKind::Bf16) => unsafe {
486                baracuda_kernels_sys::baracuda_kernels_ternary_addcmul_bf16_strided_run(
487                    numel, rank, shape.as_ptr(),
488                    stride_a.as_ptr(), stride_b.as_ptr(),
489                    stride_c.as_ptr(), stride_y.as_ptr(),
490                    a_ptr, b_ptr, c_ptr, y_ptr, self.desc.scale,
491                    core::ptr::null_mut(), 0, stream_ptr,
492                )
493            },
494            (TernaryKind::Addcmul, ElementKind::F64) => unsafe {
495                baracuda_kernels_sys::baracuda_kernels_ternary_addcmul_f64_strided_run(
496                    numel, rank, shape.as_ptr(),
497                    stride_a.as_ptr(), stride_b.as_ptr(),
498                    stride_c.as_ptr(), stride_y.as_ptr(),
499                    a_ptr, b_ptr, c_ptr, y_ptr, self.desc.scale,
500                    core::ptr::null_mut(), 0, stream_ptr,
501                )
502            },
503            // --- Addcdiv (reads desc.scale) ------------------------------
504            (TernaryKind::Addcdiv, ElementKind::F32) => unsafe {
505                baracuda_kernels_sys::baracuda_kernels_ternary_addcdiv_f32_strided_run(
506                    numel, rank, shape.as_ptr(),
507                    stride_a.as_ptr(), stride_b.as_ptr(),
508                    stride_c.as_ptr(), stride_y.as_ptr(),
509                    a_ptr, b_ptr, c_ptr, y_ptr, self.desc.scale,
510                    core::ptr::null_mut(), 0, stream_ptr,
511                )
512            },
513            (TernaryKind::Addcdiv, ElementKind::F16) => unsafe {
514                baracuda_kernels_sys::baracuda_kernels_ternary_addcdiv_f16_strided_run(
515                    numel, rank, shape.as_ptr(),
516                    stride_a.as_ptr(), stride_b.as_ptr(),
517                    stride_c.as_ptr(), stride_y.as_ptr(),
518                    a_ptr, b_ptr, c_ptr, y_ptr, self.desc.scale,
519                    core::ptr::null_mut(), 0, stream_ptr,
520                )
521            },
522            (TernaryKind::Addcdiv, ElementKind::Bf16) => unsafe {
523                baracuda_kernels_sys::baracuda_kernels_ternary_addcdiv_bf16_strided_run(
524                    numel, rank, shape.as_ptr(),
525                    stride_a.as_ptr(), stride_b.as_ptr(),
526                    stride_c.as_ptr(), stride_y.as_ptr(),
527                    a_ptr, b_ptr, c_ptr, y_ptr, self.desc.scale,
528                    core::ptr::null_mut(), 0, stream_ptr,
529                )
530            },
531            (TernaryKind::Addcdiv, ElementKind::F64) => unsafe {
532                baracuda_kernels_sys::baracuda_kernels_ternary_addcdiv_f64_strided_run(
533                    numel, rank, shape.as_ptr(),
534                    stride_a.as_ptr(), stride_b.as_ptr(),
535                    stride_c.as_ptr(), stride_y.as_ptr(),
536                    a_ptr, b_ptr, c_ptr, y_ptr, self.desc.scale,
537                    core::ptr::null_mut(), 0, stream_ptr,
538                )
539            },
540            _ => {
541                return Err(Error::Unsupported(
542                    "baracuda-kernels::TernaryPlan::run_strided reached an \
543                     unimplemented (kind, dtype) pair — select() should have caught this",
544                ));
545            }
546        };
547        map_status(status)
548    }
549}
550
551fn map_status(code: i32) -> Result<()> {
552    match code {
553        0 => Ok(()),
554        1 => Err(Error::MisalignedOperand),
555        2 => Err(Error::InvalidProblem(
556            "baracuda-kernels-sys reported invalid problem",
557        )),
558        3 => Err(Error::Unsupported(
559            "baracuda-kernels-sys reported unsupported configuration",
560        )),
561        4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
562        n => Err(Error::CutlassInternal(n)),
563    }
564}