Skip to main content

baracuda_kernels/elementwise/
unary_param.rs

1//! Parameterized unary elementwise plan.
2//!
3//! Sibling of [`crate::UnaryPlan`] for ops that carry one or more scalar
4//! parameters alongside the tensor input. The plan family fixes the
5//! parameter slot count at two `f32` values
6//! (`params: [f32; 2]`) so a single descriptor / kernel ABI shape covers
7//! every op in this family — ops that need fewer params (e.g.
8//! `LeakyRelu(α)` would only consume `p0`) simply ignore the unused
9//! slot.
10//!
11//! Today wired:
12//!   * `Threshold` (`y = (x > t) ? x : v`; `t = params[0]`,
13//!     `v = params[1]`) across `{f32, f16, bf16, f64}` — FW + BW
14//!     (BW lives in [`crate::UnaryParamBackwardPlan`]).
15//!   * `PowI` (`y = x^n` integer exponent; `n = params[0] as i32`,
16//!     `params[1]` unused) across `{f32, f16, bf16, f64}` — Phase 12.1.
17//!     Power-by-squaring; well-defined for negative `x` (no NaN) and
18//!     bit-exact for `n = 2` (which collapses to `Square`).
19//!
20//! The existing single-param activation ops (`LeakyRelu(α)`, `ELU(α)`,
21//! `Hardshrink(λ)`, `Softshrink(λ)`) ship through the plain
22//! [`crate::UnaryPlan`] today with hardcoded PyTorch defaults; they
23//! could later re-emit through this parameterized plan to expose the
24//! coefficient as a runtime argument, but doing so isn't required —
25//! the two paths can coexist.
26//!
27//! Layout constraints: `x.shape == y.shape == desc.shape`. The kernel
28//! does not broadcast.
29//!
30//! `Threshold` is contig-only today; `PowI` got a strided sibling in
31//! Phase 14.2 — the run dispatcher checks `is_contiguous()` of both
32//! operands and routes to `*_strided_run` when either is a non-canonical
33//! view (transposed, sliced, etc.). Future params-bearing ops that need
34//! strided support emit through the same
35//! `BARACUDA_KERNELS_UNARY_PARAM_INSTANTIATE_STRIDED` macro.
36
37use core::ffi::c_void;
38use core::marker::PhantomData;
39
40use baracuda_cutlass::{Error, Result};
41use baracuda_driver::Stream;
42use baracuda_kernels_types::{
43    ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
44    PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, UnaryKind, Workspace,
45};
46
47/// Descriptor for a parameterized unary elementwise op.
48///
49/// `shape` is both the input and the output shape. `element` must match
50/// `T::KIND` at `select` time. `params` carries op-specific scalars —
51/// the layout is fixed by the op's `kind`:
52///
53/// | Op          | `params[0]` | `params[1]` |
54/// |-------------|-------------|-------------|
55/// | `Threshold` | `t`         | `v`         |
56/// | `PowI`      | `n as f32`  | unused      |
57///
58/// We chose `[f32; 2]` rather than separate `t: f32, v: f32` fields so
59/// the descriptor shape doesn't shift as more 1- or 2-param ops join
60/// (LeakyRelu / ELU / Hardshrink / Softshrink could all re-emit here
61/// with `params[0] = α` or `params[0] = λ` and `params[1]` ignored).
62#[derive(Copy, Clone, Debug)]
63pub struct UnaryParamDescriptor<const N: usize> {
64    /// Which parameterized unary op to apply.
65    pub kind: UnaryKind,
66    /// Tensor shape — input and output share it.
67    pub shape: [i32; N],
68    /// Primary element type. Must match the type parameter `T` of the
69    /// containing plan.
70    pub element: ElementKind,
71    /// Op-specific scalar parameters. Slot semantics depend on `kind`:
72    /// for `Threshold` it's `[t, v]`. Parameters are always `f32` on the
73    /// FFI; integer / `f64` kernels widen the param losslessly at the
74    /// kernel boundary, and half-precision kernels compare in `f32`.
75    pub params: [f32; 2],
76}
77
78/// Args bundle for a parameterized unary elementwise launch.
79pub struct UnaryParamArgs<'a, T: Element, const N: usize> {
80    /// Input.
81    pub x: TensorRef<'a, T, N>,
82    /// Output.
83    pub y: TensorMut<'a, T, N>,
84}
85
86/// Parameterized unary elementwise plan.
87pub struct UnaryParamPlan<T: Element, const N: usize> {
88    desc: UnaryParamDescriptor<N>,
89    sku: KernelSku,
90    _marker: PhantomData<T>,
91}
92
93impl<T: Element, const N: usize> UnaryParamPlan<T, N> {
94    /// Pick a kernel for `desc`.
95    pub fn select(
96        _stream: &Stream,
97        desc: &UnaryParamDescriptor<N>,
98        _pref: PlanPreference,
99    ) -> Result<Self> {
100        if desc.element != T::KIND {
101            return Err(Error::Unsupported(
102                "baracuda-kernels::UnaryParamPlan: descriptor element != type parameter T",
103            ));
104        }
105        for &d in desc.shape.iter() {
106            if d < 0 {
107                return Err(Error::InvalidProblem(
108                    "baracuda-kernels::UnaryParamPlan: shape dims must be non-negative",
109                ));
110            }
111        }
112
113        // Today's wired matrix: {Threshold, PowI} × {f32, f16, bf16, f64}.
114        // Future params-bearing ops can extend `kind_in_scope` and add
115        // match arms in `run`.
116        let kind_in_scope = matches!(desc.kind, UnaryKind::Threshold | UnaryKind::PowI);
117        let dtype_in_scope = matches!(
118            T::KIND,
119            ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
120        );
121        if !(kind_in_scope && dtype_in_scope) {
122            return Err(Error::Unsupported(
123                "baracuda-kernels::UnaryParamPlan: today only `{Threshold, PowI} × \
124                 {f32, f16, bf16, f64}` is wired; LeakyRelu / ELU / Hardshrink / Softshrink \
125                 ship via UnaryPlan with hardcoded PyTorch defaults today; PReLU needs a \
126                 distinct (channel-vector) plan.",
127            ));
128        }
129
130        let precision_guarantee = PrecisionGuarantee {
131            math_precision: MathPrecision::F32,
132            accumulator: ElementKind::F32,
133            bit_stable_on_same_hardware: true,
134            deterministic: true,
135        };
136        let sku = KernelSku {
137            category: OpCategory::UnaryElementwise,
138            op: desc.kind as u16,
139            element: T::KIND,
140            aux_element: None,
141            layout: None,
142            epilogue: None,
143            arch: ArchSku::Sm80,
144            backend: BackendKind::Bespoke,
145            precision_guarantee,
146        };
147        Ok(Self {
148            desc: *desc,
149            sku,
150            _marker: PhantomData,
151        })
152    }
153
154    /// Validate args.
155    pub fn can_implement(&self, args: &UnaryParamArgs<'_, T, N>) -> Result<()> {
156        if args.x.shape != self.desc.shape {
157            return Err(Error::InvalidProblem(
158                "baracuda-kernels::UnaryParamPlan: X shape mismatch with descriptor",
159            ));
160        }
161        if args.y.shape != self.desc.shape {
162            return Err(Error::InvalidProblem(
163                "baracuda-kernels::UnaryParamPlan: Y shape mismatch with descriptor",
164            ));
165        }
166        // PowI got a strided sibling in Phase 14.2; Threshold remains
167        // contig-only until its strided launcher is wired.
168        let all_contig = args.x.is_contiguous() && args.y.is_contiguous();
169        if !all_contig && !matches!(self.desc.kind, UnaryKind::PowI) {
170            return Err(Error::Unsupported(
171                "baracuda-kernels::UnaryParamPlan: this op is contig-only today; strided \
172                 fanout lands later (PowI is the trailblazer in Phase 14.2)",
173            ));
174        }
175        let numel = args.y.numel();
176        let x_len = args.x.data.len() as i64;
177        let y_len = args.y.data.len() as i64;
178        if x_len < numel || y_len < numel {
179            return Err(Error::BufferTooSmall {
180                needed: numel as usize,
181                got: x_len.min(y_len) as usize,
182            });
183        }
184        Ok(())
185    }
186
187    /// Workspace size in bytes.
188    #[inline]
189    pub fn workspace_size(&self) -> usize {
190        0
191    }
192
193    /// Identity of the kernel this plan picked.
194    #[inline]
195    pub fn sku(&self) -> KernelSku {
196        self.sku
197    }
198
199    /// Numerical guarantees.
200    #[inline]
201    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
202        self.sku.precision_guarantee
203    }
204
205    /// Launch.
206    pub fn run(
207        &self,
208        stream: &Stream,
209        _workspace: Workspace<'_>,
210        args: UnaryParamArgs<'_, T, N>,
211    ) -> Result<()> {
212        self.can_implement(&args)?;
213        let numel = args.y.numel();
214        if numel == 0 {
215            return Ok(());
216        }
217        let x_ptr = args.x.data.as_raw().0 as *const c_void;
218        let y_ptr = args.y.data.as_raw().0 as *mut c_void;
219        let stream_ptr = stream.as_raw() as *mut c_void;
220        let p0 = self.desc.params[0];
221        let p1 = self.desc.params[1];
222
223        // Strided fast-fall: route to `*_strided_run` when either operand
224        // has a non-canonical layout. Today only PowI has the strided
225        // sibling wired (Phase 14.2); `can_implement` would have already
226        // rejected non-contig args for other kinds.
227        let all_contig = args.x.is_contiguous() && args.y.is_contiguous();
228        if !all_contig && matches!(self.desc.kind, UnaryKind::PowI) {
229            return self.run_strided(stream_ptr, x_ptr, y_ptr, numel, &args, p0, p1);
230        }
231
232        let status = match (self.desc.kind, T::KIND) {
233            (UnaryKind::Threshold, ElementKind::F32) => unsafe {
234                baracuda_kernels_sys::baracuda_kernels_unary_threshold_f32_run(
235                    numel, x_ptr, y_ptr, p0, p1,
236                    core::ptr::null_mut(), 0, stream_ptr,
237                )
238            },
239            (UnaryKind::Threshold, ElementKind::F16) => unsafe {
240                baracuda_kernels_sys::baracuda_kernels_unary_threshold_f16_run(
241                    numel, x_ptr, y_ptr, p0, p1,
242                    core::ptr::null_mut(), 0, stream_ptr,
243                )
244            },
245            (UnaryKind::Threshold, ElementKind::Bf16) => unsafe {
246                baracuda_kernels_sys::baracuda_kernels_unary_threshold_bf16_run(
247                    numel, x_ptr, y_ptr, p0, p1,
248                    core::ptr::null_mut(), 0, stream_ptr,
249                )
250            },
251            (UnaryKind::Threshold, ElementKind::F64) => unsafe {
252                baracuda_kernels_sys::baracuda_kernels_unary_threshold_f64_run(
253                    numel, x_ptr, y_ptr, p0, p1,
254                    core::ptr::null_mut(), 0, stream_ptr,
255                )
256            },
257            (UnaryKind::PowI, ElementKind::F32) => unsafe {
258                baracuda_kernels_sys::baracuda_kernels_unary_powi_f32_run(
259                    numel, x_ptr, y_ptr, p0, p1,
260                    core::ptr::null_mut(), 0, stream_ptr,
261                )
262            },
263            (UnaryKind::PowI, ElementKind::F16) => unsafe {
264                baracuda_kernels_sys::baracuda_kernels_unary_powi_f16_run(
265                    numel, x_ptr, y_ptr, p0, p1,
266                    core::ptr::null_mut(), 0, stream_ptr,
267                )
268            },
269            (UnaryKind::PowI, ElementKind::Bf16) => unsafe {
270                baracuda_kernels_sys::baracuda_kernels_unary_powi_bf16_run(
271                    numel, x_ptr, y_ptr, p0, p1,
272                    core::ptr::null_mut(), 0, stream_ptr,
273                )
274            },
275            (UnaryKind::PowI, ElementKind::F64) => unsafe {
276                baracuda_kernels_sys::baracuda_kernels_unary_powi_f64_run(
277                    numel, x_ptr, y_ptr, p0, p1,
278                    core::ptr::null_mut(), 0, stream_ptr,
279                )
280            },
281            _ => {
282                return Err(Error::Unsupported(
283                    "baracuda-kernels::UnaryParamPlan: dispatcher reached an unimplemented \
284                     (kind, dtype) pair — select() should have caught this",
285                ));
286            }
287        };
288        map_status(status)
289    }
290}
291
292impl<T: Element, const N: usize> UnaryParamPlan<T, N> {
293    /// Strided dispatcher — called by [`Self::run`] when at least one
294    /// operand isn't contiguous. Today only `PowI` reaches this path;
295    /// other kinds are rejected at `can_implement` time.
296    fn run_strided(
297        &self,
298        stream_ptr: *mut c_void,
299        x_ptr: *const c_void,
300        y_ptr: *mut c_void,
301        numel: i64,
302        args: &UnaryParamArgs<'_, T, N>,
303        p0: f32,
304        p1: f32,
305    ) -> Result<()> {
306        let shape = args.y.shape;
307        let stride_x = args.x.stride;
308        let stride_y = args.y.stride;
309        let rank = N as i32;
310
311        let status = match (self.desc.kind, T::KIND) {
312            (UnaryKind::PowI, ElementKind::F32) => unsafe {
313                baracuda_kernels_sys::baracuda_kernels_unary_powi_f32_strided_run(
314                    numel, rank, shape.as_ptr(),
315                    stride_x.as_ptr(), stride_y.as_ptr(),
316                    x_ptr, y_ptr, p0, p1,
317                    core::ptr::null_mut(), 0, stream_ptr,
318                )
319            },
320            (UnaryKind::PowI, ElementKind::F16) => unsafe {
321                baracuda_kernels_sys::baracuda_kernels_unary_powi_f16_strided_run(
322                    numel, rank, shape.as_ptr(),
323                    stride_x.as_ptr(), stride_y.as_ptr(),
324                    x_ptr, y_ptr, p0, p1,
325                    core::ptr::null_mut(), 0, stream_ptr,
326                )
327            },
328            (UnaryKind::PowI, ElementKind::Bf16) => unsafe {
329                baracuda_kernels_sys::baracuda_kernels_unary_powi_bf16_strided_run(
330                    numel, rank, shape.as_ptr(),
331                    stride_x.as_ptr(), stride_y.as_ptr(),
332                    x_ptr, y_ptr, p0, p1,
333                    core::ptr::null_mut(), 0, stream_ptr,
334                )
335            },
336            (UnaryKind::PowI, ElementKind::F64) => unsafe {
337                baracuda_kernels_sys::baracuda_kernels_unary_powi_f64_strided_run(
338                    numel, rank, shape.as_ptr(),
339                    stride_x.as_ptr(), stride_y.as_ptr(),
340                    x_ptr, y_ptr, p0, p1,
341                    core::ptr::null_mut(), 0, stream_ptr,
342                )
343            },
344            _ => {
345                return Err(Error::Unsupported(
346                    "baracuda-kernels::UnaryParamPlan::run_strided: only PowI is wired \
347                     for the strided path today (Phase 14.2 trailblazer)",
348                ));
349            }
350        };
351        map_status(status)
352    }
353}
354
355fn map_status(code: i32) -> Result<()> {
356    match code {
357        0 => Ok(()),
358        1 => Err(Error::MisalignedOperand),
359        2 => Err(Error::InvalidProblem(
360            "baracuda-kernels-sys reported invalid problem",
361        )),
362        3 => Err(Error::Unsupported(
363            "baracuda-kernels-sys reported unsupported configuration",
364        )),
365        4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
366        n => Err(Error::CutlassInternal(n)),
367    }
368}