Skip to main content

baracuda_kernels/elementwise/
where_op.rs

1//! Heterogeneous-dtype ternary plan: `where(cond, a, b)`.
2//!
3//! Distinct from [`crate::TernaryPlan`] because the cond input has a
4//! different dtype (`u8` — PyTorch / NumPy bool storage convention)
5//! from the value inputs and output. `y = cond ? a : b` elementwise,
6//! with full broadcast support on every operand including the cond.
7//!
8//! All 4 FP value dtypes wired: {f32, f16, bf16, f64} × {contig,
9//! strided}. The op does no arithmetic — pure element selection —
10//! so output is bit-exact against host reference regardless of dtype.
11//!
12//! Module name: `where_op` (rather than `where`) because `where` is a
13//! Rust keyword.
14
15use core::ffi::c_void;
16use core::marker::PhantomData;
17
18use baracuda_cutlass::{Error, Result};
19use baracuda_driver::Stream;
20use baracuda_kernels_types::{
21    ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
22    PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
23};
24
25/// Descriptor for a `where` op.
26///
27/// `shape` is the output tensor shape. `element` is the **value** dtype
28/// — cond is always `u8`. `element` must match the type parameter `T`
29/// of the containing plan at `select` time.
30///
31/// No `kind` field because `where` is the only heterogeneous-dtype
32/// ternary op in this Plan today. If future ops join (e.g., a
33/// `masked_fill` variant), they get their own plan or a kind enum gets
34/// introduced — choice deferred until that lands.
35#[derive(Copy, Clone, Debug)]
36pub struct WhereDescriptor<const N: usize> {
37    /// Output tensor shape.
38    pub shape: [i32; N],
39    /// Value element type (a / b / y dtype; cond is always `u8`).
40    pub element: ElementKind,
41}
42
43/// Args bundle for a `where` launch.
44///
45/// `cond` is `u8` (0 = false, non-zero = true). `a`, `b`, `y` share
46/// dtype `T`. All four operands can broadcast independently to
47/// `y.shape` via stride-0 axes.
48pub struct WhereArgs<'a, T: Element, const N: usize> {
49    /// Boolean mask. `0u8` selects `b`, any other value selects `a`.
50    pub cond: TensorRef<'a, u8, N>,
51    /// Value selected where `cond != 0`.
52    pub a: TensorRef<'a, T, N>,
53    /// Value selected where `cond == 0`.
54    pub b: TensorRef<'a, T, N>,
55    /// Output.
56    pub y: TensorMut<'a, T, N>,
57}
58
59/// `where(cond, a, b)` plan with heterogeneous-dtype inputs.
60///
61/// `T: Element` is the value dtype (a / b / y). The cond is always `u8`.
62/// `const N: usize` is the tensor rank.
63pub struct WherePlan<T: Element, const N: usize> {
64    desc: WhereDescriptor<N>,
65    sku: KernelSku,
66    _marker: PhantomData<T>,
67}
68
69impl<T: Element, const N: usize> WherePlan<T, N> {
70    /// Pick a kernel for `desc`. Returns [`Error::Unsupported`] if the
71    /// value dtype isn't wired today.
72    pub fn select(
73        _stream: &Stream,
74        desc: &WhereDescriptor<N>,
75        _pref: PlanPreference,
76    ) -> Result<Self> {
77        if desc.element != T::KIND {
78            return Err(Error::Unsupported(
79                "baracuda-kernels::WherePlan: descriptor element != type parameter T",
80            ));
81        }
82        for &d in desc.shape.iter() {
83            if d < 0 {
84                return Err(Error::InvalidProblem(
85                    "baracuda-kernels::WherePlan: shape dims must be non-negative",
86                ));
87            }
88        }
89
90        // All 4 FP value dtypes wired.
91        let supported = matches!(
92            T::KIND,
93            ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
94        );
95        if !supported {
96            return Err(Error::Unsupported(
97                "baracuda-kernels::WherePlan: value dtype must be one of \
98                 {F32, F16, Bf16, F64}",
99            ));
100        }
101
102        // `where` is a pure select — no math, fully deterministic and
103        // bit-stable on the same hardware. The MathPrecision tag
104        // mirrors the value dtype by convention even though no
105        // arithmetic happens.
106        let (math_precision, accumulator) = match T::KIND {
107            ElementKind::F16 => (MathPrecision::F16, ElementKind::F16),
108            ElementKind::Bf16 => (MathPrecision::Bf16, ElementKind::Bf16),
109            ElementKind::F64 => (MathPrecision::F64, ElementKind::F64),
110            _ => (MathPrecision::F32, ElementKind::F32),
111        };
112        let precision_guarantee = PrecisionGuarantee {
113            math_precision,
114            accumulator,
115            bit_stable_on_same_hardware: true,
116            deterministic: true,
117        };
118        let sku = KernelSku {
119            category: OpCategory::TernaryElementwise,
120            // `op` discriminant: `TernaryKind::Where` lives in the
121            // shared op enum but is intentionally not routed via
122            // `TernaryPlan` — we tag the SKU with its discriminant
123            // value (4) so telemetry / autotuner-cache keys
124            // distinguish this from same-dtype ternary ops.
125            op: 4,
126            element: T::KIND,
127            // `aux_element` captures cond's dtype — but ElementKind
128            // doesn't have a `U8` variant today, so leave None and
129            // rely on the `Where`-specific op discriminant.
130            aux_element: None,
131            layout: None,
132            epilogue: None,
133            arch: ArchSku::Sm80,
134            backend: BackendKind::Bespoke,
135            precision_guarantee,
136        };
137        Ok(Self {
138            desc: *desc,
139            sku,
140            _marker: PhantomData,
141        })
142    }
143
144    /// Validate that this plan can launch with `args`.
145    pub fn can_implement(&self, args: &WhereArgs<'_, T, N>) -> Result<()> {
146        if args.y.shape != self.desc.shape {
147            return Err(Error::InvalidProblem(
148                "baracuda-kernels::WherePlan: Y shape mismatch with descriptor",
149            ));
150        }
151
152        // Per-axis broadcast compatibility check for all four operands.
153        for d in 0..N {
154            let y_dim = self.desc.shape[d];
155            let checks = [
156                (args.cond.shape[d], args.cond.stride[d]),
157                (args.a.shape[d], args.a.stride[d]),
158                (args.b.shape[d], args.b.stride[d]),
159            ];
160            for (op_dim, op_stride) in checks {
161                if op_dim != y_dim && !(op_dim == 1 && op_stride == 0) {
162                    return Err(Error::InvalidProblem(
163                        "baracuda-kernels::WherePlan: input axis not broadcast-compatible \
164                         with output (require shape[d] == y.shape[d], OR \
165                         shape[d] == 1 AND stride[d] == 0)",
166                    ));
167                }
168            }
169        }
170
171        if N > 8 {
172            return Err(Error::Unsupported(
173                "baracuda-kernels::WherePlan: tensor rank > 8 not supported",
174            ));
175        }
176
177        let y_numel = args.y.numel();
178        let cond_numel = args.cond.numel();
179        let a_numel = args.a.numel();
180        let b_numel = args.b.numel();
181        let cond_len = args.cond.data.len() as i64;
182        let a_len = args.a.data.len() as i64;
183        let b_len = args.b.data.len() as i64;
184        let y_len = args.y.data.len() as i64;
185        if y_len < y_numel {
186            return Err(Error::BufferTooSmall {
187                needed: y_numel as usize,
188                got: y_len as usize,
189            });
190        }
191        if cond_len < cond_numel {
192            return Err(Error::BufferTooSmall {
193                needed: cond_numel as usize,
194                got: cond_len as usize,
195            });
196        }
197        if a_len < a_numel {
198            return Err(Error::BufferTooSmall {
199                needed: a_numel as usize,
200                got: a_len as usize,
201            });
202        }
203        if b_len < b_numel {
204            return Err(Error::BufferTooSmall {
205                needed: b_numel as usize,
206                got: b_len as usize,
207            });
208        }
209        Ok(())
210    }
211
212    /// Workspace size in bytes. Always `0` for the trailblazer.
213    #[inline]
214    pub fn workspace_size(&self) -> usize {
215        0
216    }
217
218    /// Identity of the kernel this plan picked.
219    #[inline]
220    pub fn sku(&self) -> KernelSku {
221        self.sku
222    }
223
224    /// Numerical guarantees for this plan's kernel.
225    #[inline]
226    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
227        self.sku.precision_guarantee
228    }
229
230    /// Launch.
231    pub fn run(
232        &self,
233        stream: &Stream,
234        _workspace: Workspace<'_>,
235        args: WhereArgs<'_, T, N>,
236    ) -> Result<()> {
237        self.can_implement(&args)?;
238        let numel = args.y.numel();
239        if numel == 0 {
240            return Ok(());
241        }
242        let cond_ptr = args.cond.data.as_raw().0 as *const c_void;
243        let a_ptr = args.a.data.as_raw().0 as *const c_void;
244        let b_ptr = args.b.data.as_raw().0 as *const c_void;
245        let y_ptr = args.y.data.as_raw().0 as *mut c_void;
246        let stream_ptr = stream.as_raw() as *mut c_void;
247
248        let all_contig_same_shape = args.cond.shape == args.y.shape
249            && args.a.shape == args.y.shape
250            && args.b.shape == args.y.shape
251            && args.cond.is_contiguous()
252            && args.a.is_contiguous()
253            && args.b.is_contiguous()
254            && args.y.is_contiguous();
255
256        if !all_contig_same_shape {
257            return self.run_strided(
258                stream_ptr, cond_ptr, a_ptr, b_ptr, y_ptr, numel, &args,
259            );
260        }
261
262        let status = match T::KIND {
263            ElementKind::F32 => unsafe {
264                baracuda_kernels_sys::baracuda_kernels_where_f32_run(
265                    numel,
266                    cond_ptr,
267                    a_ptr,
268                    b_ptr,
269                    y_ptr,
270                    core::ptr::null_mut(),
271                    0,
272                    stream_ptr,
273                )
274            },
275            ElementKind::F16 => unsafe {
276                baracuda_kernels_sys::baracuda_kernels_where_f16_run(
277                    numel,
278                    cond_ptr,
279                    a_ptr,
280                    b_ptr,
281                    y_ptr,
282                    core::ptr::null_mut(),
283                    0,
284                    stream_ptr,
285                )
286            },
287            ElementKind::Bf16 => unsafe {
288                baracuda_kernels_sys::baracuda_kernels_where_bf16_run(
289                    numel,
290                    cond_ptr,
291                    a_ptr,
292                    b_ptr,
293                    y_ptr,
294                    core::ptr::null_mut(),
295                    0,
296                    stream_ptr,
297                )
298            },
299            ElementKind::F64 => unsafe {
300                baracuda_kernels_sys::baracuda_kernels_where_f64_run(
301                    numel,
302                    cond_ptr,
303                    a_ptr,
304                    b_ptr,
305                    y_ptr,
306                    core::ptr::null_mut(),
307                    0,
308                    stream_ptr,
309                )
310            },
311            _ => {
312                return Err(Error::Unsupported(
313                    "baracuda-kernels::WherePlan::run reached an unimplemented dtype \
314                     — select() should have caught this",
315                ));
316            }
317        };
318        map_status(status)
319    }
320
321    /// Strided / broadcast kernel path.
322    fn run_strided(
323        &self,
324        stream_ptr: *mut c_void,
325        cond_ptr: *const c_void,
326        a_ptr: *const c_void,
327        b_ptr: *const c_void,
328        y_ptr: *mut c_void,
329        numel: i64,
330        args: &WhereArgs<'_, T, N>,
331    ) -> Result<()> {
332        let shape = args.y.shape;
333        let stride_cond = args.cond.stride;
334        let stride_a = args.a.stride;
335        let stride_b = args.b.stride;
336        let stride_y = args.y.stride;
337        let rank = N as i32;
338
339        let status = match T::KIND {
340            ElementKind::F32 => unsafe {
341                baracuda_kernels_sys::baracuda_kernels_where_f32_strided_run(
342                    numel,
343                    rank,
344                    shape.as_ptr(),
345                    stride_cond.as_ptr(),
346                    stride_a.as_ptr(),
347                    stride_b.as_ptr(),
348                    stride_y.as_ptr(),
349                    cond_ptr,
350                    a_ptr,
351                    b_ptr,
352                    y_ptr,
353                    core::ptr::null_mut(),
354                    0,
355                    stream_ptr,
356                )
357            },
358            ElementKind::F16 => unsafe {
359                baracuda_kernels_sys::baracuda_kernels_where_f16_strided_run(
360                    numel,
361                    rank,
362                    shape.as_ptr(),
363                    stride_cond.as_ptr(),
364                    stride_a.as_ptr(),
365                    stride_b.as_ptr(),
366                    stride_y.as_ptr(),
367                    cond_ptr,
368                    a_ptr,
369                    b_ptr,
370                    y_ptr,
371                    core::ptr::null_mut(),
372                    0,
373                    stream_ptr,
374                )
375            },
376            ElementKind::Bf16 => unsafe {
377                baracuda_kernels_sys::baracuda_kernels_where_bf16_strided_run(
378                    numel,
379                    rank,
380                    shape.as_ptr(),
381                    stride_cond.as_ptr(),
382                    stride_a.as_ptr(),
383                    stride_b.as_ptr(),
384                    stride_y.as_ptr(),
385                    cond_ptr,
386                    a_ptr,
387                    b_ptr,
388                    y_ptr,
389                    core::ptr::null_mut(),
390                    0,
391                    stream_ptr,
392                )
393            },
394            ElementKind::F64 => unsafe {
395                baracuda_kernels_sys::baracuda_kernels_where_f64_strided_run(
396                    numel,
397                    rank,
398                    shape.as_ptr(),
399                    stride_cond.as_ptr(),
400                    stride_a.as_ptr(),
401                    stride_b.as_ptr(),
402                    stride_y.as_ptr(),
403                    cond_ptr,
404                    a_ptr,
405                    b_ptr,
406                    y_ptr,
407                    core::ptr::null_mut(),
408                    0,
409                    stream_ptr,
410                )
411            },
412            _ => {
413                return Err(Error::Unsupported(
414                    "baracuda-kernels::WherePlan: strided path reached unimplemented dtype \
415                     — select() should have caught this",
416                ));
417            }
418        };
419        map_status(status)
420    }
421}
422
423fn map_status(code: i32) -> Result<()> {
424    match code {
425        0 => Ok(()),
426        1 => Err(Error::MisalignedOperand),
427        2 => Err(Error::InvalidProblem(
428            "baracuda-kernels-sys reported invalid problem",
429        )),
430        3 => Err(Error::Unsupported(
431            "baracuda-kernels-sys reported unsupported configuration",
432        )),
433        4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
434        n => Err(Error::CutlassInternal(n)),
435    }
436}