Skip to main content

baracuda_kernels/elementwise/
prelu.rs

1//! PReLU FW plan — Milestone 5.3.
2//!
3//! `y[..., c, ...] = x[..., c, ...]` if positive, else
4//! `weight[c] * x[..., c, ...]`.
5//!
6//! `weight` is either per-channel (`shape == [C]`, with `C` the size of the
7//! channel axis) or a single learnable scalar (`shape == [1]`).
8//!
9//! Distinct from [`crate::UnaryParamPlan`] because the parameter is a tensor
10//! operand, not a scalar — needs its own plan shape.
11
12use core::ffi::c_void;
13use core::marker::PhantomData;
14
15use baracuda_cutlass::{Error, Result};
16use baracuda_driver::Stream;
17use baracuda_kernels_types::{
18    ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
19    PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, UnaryKind, Workspace,
20};
21
22/// Descriptor for a PReLU FW op.
23///
24/// `channel_axis` is the axis index where `weight` is indexed by the channel
25/// coordinate. Use `-1` to signal a single scalar weight (`weight.shape ==
26/// [1]`) which is applied to every cell of `x`.
27#[derive(Copy, Clone, Debug)]
28pub struct PReluDescriptor<const N: usize> {
29    /// Input tensor shape.
30    pub input_shape: [i32; N],
31    /// Channel axis (where `weight` indexes); `-1` for scalar weight.
32    pub channel_axis: i8,
33    /// Element type.
34    pub element: ElementKind,
35}
36
37/// Args bundle for a PReLU FW launch.
38pub struct PReluArgs<'a, T: Element, const N: usize> {
39    /// Input tensor.
40    pub x: TensorRef<'a, T, N>,
41    /// Weight tensor — shape `[C]` (per-channel) or `[1]` (scalar).
42    pub weight: TensorRef<'a, T, 1>,
43    /// Output tensor.
44    pub y: TensorMut<'a, T, N>,
45}
46
47/// PReLU forward plan.
48pub struct PReluPlan<T: Element, const N: usize> {
49    desc: PReluDescriptor<N>,
50    sku: KernelSku,
51    channel_stride: i64,
52    channel_extent: i32,
53    scalar_weight: bool,
54    _marker: PhantomData<T>,
55}
56
57fn check_dtype<T: Element>() -> Result<()> {
58    let ok = matches!(
59        T::KIND,
60        ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
61    );
62    if !ok {
63        return Err(Error::Unsupported(
64            "baracuda-kernels::PReluPlan: only {f32, f16, bf16, f64} wired",
65        ));
66    }
67    Ok(())
68}
69
70impl<T: Element, const N: usize> PReluPlan<T, N> {
71    /// Pick a kernel.
72    pub fn select(
73        _stream: &Stream,
74        desc: &PReluDescriptor<N>,
75        _pref: PlanPreference,
76    ) -> Result<Self> {
77        if desc.element != T::KIND {
78            return Err(Error::Unsupported(
79                "baracuda-kernels::PReluPlan: descriptor element != T",
80            ));
81        }
82        check_dtype::<T>()?;
83        let rank = N as i8;
84        let scalar_weight = desc.channel_axis < 0;
85        if !scalar_weight && (desc.channel_axis >= rank) {
86            return Err(Error::InvalidProblem(
87                "baracuda-kernels::PReluPlan: channel_axis out of range",
88            ));
89        }
90        let (channel_stride, channel_extent) = if scalar_weight {
91            (1i64, 1i32)
92        } else {
93            let axis = desc.channel_axis as usize;
94            let extent = desc.input_shape[axis];
95            // channel_stride = product of shape dims AFTER channel axis (row-major).
96            let mut stride: i64 = 1;
97            for d in (axis + 1)..N {
98                stride = stride.saturating_mul(desc.input_shape[d] as i64);
99            }
100            (stride, extent)
101        };
102        let precision_guarantee = PrecisionGuarantee {
103            math_precision: MathPrecision::F32,
104            accumulator: if T::KIND == ElementKind::F64 {
105                ElementKind::F64
106            } else {
107                ElementKind::F32
108            },
109            bit_stable_on_same_hardware: true,
110            deterministic: true,
111        };
112        let sku = KernelSku {
113            category: OpCategory::UnaryElementwise,
114            op: UnaryKind::PReLU as u16,
115            element: T::KIND,
116            aux_element: None,
117            layout: None,
118            epilogue: None,
119            arch: ArchSku::Sm80,
120            backend: BackendKind::Bespoke,
121            precision_guarantee,
122        };
123        Ok(Self {
124            desc: *desc,
125            sku,
126            channel_stride,
127            channel_extent,
128            scalar_weight,
129            _marker: PhantomData,
130        })
131    }
132    /// Workspace size in bytes.
133    #[inline]
134    pub fn workspace_size(&self) -> usize {
135        0
136    }
137    /// Kernel SKU identity.
138    #[inline]
139    pub fn sku(&self) -> KernelSku {
140        self.sku
141    }
142    /// Numerical guarantees.
143    #[inline]
144    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
145        self.sku.precision_guarantee
146    }
147
148    /// Launch.
149    pub fn run(
150        &self,
151        stream: &Stream,
152        _workspace: Workspace<'_>,
153        args: PReluArgs<'_, T, N>,
154    ) -> Result<()> {
155        if args.x.shape != self.desc.input_shape || args.y.shape != self.desc.input_shape {
156            return Err(Error::InvalidProblem(
157                "baracuda-kernels::PReluPlan: x / y shape mismatch",
158            ));
159        }
160        let expected_weight = if self.scalar_weight { 1 } else { self.channel_extent };
161        if args.weight.shape[0] != expected_weight {
162            return Err(Error::InvalidProblem(
163                "baracuda-kernels::PReluPlan: weight shape mismatch",
164            ));
165        }
166        let numel = args.x.numel();
167        if numel == 0 {
168            return Ok(());
169        }
170        let stream_ptr = stream.as_raw() as *mut c_void;
171        let x_ptr = args.x.data.as_raw().0 as *const c_void;
172        let weight_ptr = args.weight.data.as_raw().0 as *const c_void;
173        let y_ptr = args.y.data.as_raw().0 as *mut c_void;
174        let scalar_flag: i32 = if self.scalar_weight { 1 } else { 0 };
175        let status = match T::KIND {
176            ElementKind::F32 => unsafe {
177                baracuda_kernels_sys::baracuda_kernels_prelu_f32_run(
178                    numel,
179                    self.channel_stride,
180                    self.channel_extent,
181                    scalar_flag,
182                    x_ptr,
183                    weight_ptr,
184                    y_ptr,
185                    core::ptr::null_mut(),
186                    0,
187                    stream_ptr,
188                )
189            },
190            ElementKind::F16 => unsafe {
191                baracuda_kernels_sys::baracuda_kernels_prelu_f16_run(
192                    numel,
193                    self.channel_stride,
194                    self.channel_extent,
195                    scalar_flag,
196                    x_ptr,
197                    weight_ptr,
198                    y_ptr,
199                    core::ptr::null_mut(),
200                    0,
201                    stream_ptr,
202                )
203            },
204            ElementKind::Bf16 => unsafe {
205                baracuda_kernels_sys::baracuda_kernels_prelu_bf16_run(
206                    numel,
207                    self.channel_stride,
208                    self.channel_extent,
209                    scalar_flag,
210                    x_ptr,
211                    weight_ptr,
212                    y_ptr,
213                    core::ptr::null_mut(),
214                    0,
215                    stream_ptr,
216                )
217            },
218            ElementKind::F64 => unsafe {
219                baracuda_kernels_sys::baracuda_kernels_prelu_f64_run(
220                    numel,
221                    self.channel_stride,
222                    self.channel_extent,
223                    scalar_flag,
224                    x_ptr,
225                    weight_ptr,
226                    y_ptr,
227                    core::ptr::null_mut(),
228                    0,
229                    stream_ptr,
230                )
231            },
232            _ => {
233                return Err(Error::Unsupported(
234                    "baracuda-kernels::PReluPlan::run unwired dtype",
235                ));
236            }
237        };
238        match status {
239            0 => Ok(()),
240            1 => Err(Error::MisalignedOperand),
241            2 => Err(Error::InvalidProblem(
242                "baracuda-kernels-sys reported invalid problem",
243            )),
244            3 => Err(Error::Unsupported(
245                "baracuda-kernels-sys reported unsupported configuration",
246            )),
247            4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
248            n => Err(Error::CutlassInternal(n)),
249        }
250    }
251}