baracuda_kernels/elementwise/
prelu.rs1use core::ffi::c_void;
13use core::marker::PhantomData;
14
15use baracuda_cutlass::{Error, Result};
16use baracuda_driver::Stream;
17use baracuda_kernels_types::{
18 ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
19 PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, UnaryKind, Workspace,
20};
21
22#[derive(Copy, Clone, Debug)]
28pub struct PReluDescriptor<const N: usize> {
29 pub input_shape: [i32; N],
31 pub channel_axis: i8,
33 pub element: ElementKind,
35}
36
37pub struct PReluArgs<'a, T: Element, const N: usize> {
39 pub x: TensorRef<'a, T, N>,
41 pub weight: TensorRef<'a, T, 1>,
43 pub y: TensorMut<'a, T, N>,
45}
46
47pub struct PReluPlan<T: Element, const N: usize> {
49 desc: PReluDescriptor<N>,
50 sku: KernelSku,
51 channel_stride: i64,
52 channel_extent: i32,
53 scalar_weight: bool,
54 _marker: PhantomData<T>,
55}
56
57fn check_dtype<T: Element>() -> Result<()> {
58 let ok = matches!(
59 T::KIND,
60 ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
61 );
62 if !ok {
63 return Err(Error::Unsupported(
64 "baracuda-kernels::PReluPlan: only {f32, f16, bf16, f64} wired",
65 ));
66 }
67 Ok(())
68}
69
70impl<T: Element, const N: usize> PReluPlan<T, N> {
71 pub fn select(
73 _stream: &Stream,
74 desc: &PReluDescriptor<N>,
75 _pref: PlanPreference,
76 ) -> Result<Self> {
77 if desc.element != T::KIND {
78 return Err(Error::Unsupported(
79 "baracuda-kernels::PReluPlan: descriptor element != T",
80 ));
81 }
82 check_dtype::<T>()?;
83 let rank = N as i8;
84 let scalar_weight = desc.channel_axis < 0;
85 if !scalar_weight && (desc.channel_axis >= rank) {
86 return Err(Error::InvalidProblem(
87 "baracuda-kernels::PReluPlan: channel_axis out of range",
88 ));
89 }
90 let (channel_stride, channel_extent) = if scalar_weight {
91 (1i64, 1i32)
92 } else {
93 let axis = desc.channel_axis as usize;
94 let extent = desc.input_shape[axis];
95 let mut stride: i64 = 1;
97 for d in (axis + 1)..N {
98 stride = stride.saturating_mul(desc.input_shape[d] as i64);
99 }
100 (stride, extent)
101 };
102 let precision_guarantee = PrecisionGuarantee {
103 math_precision: MathPrecision::F32,
104 accumulator: if T::KIND == ElementKind::F64 {
105 ElementKind::F64
106 } else {
107 ElementKind::F32
108 },
109 bit_stable_on_same_hardware: true,
110 deterministic: true,
111 };
112 let sku = KernelSku {
113 category: OpCategory::UnaryElementwise,
114 op: UnaryKind::PReLU as u16,
115 element: T::KIND,
116 aux_element: None,
117 layout: None,
118 epilogue: None,
119 arch: ArchSku::Sm80,
120 backend: BackendKind::Bespoke,
121 precision_guarantee,
122 };
123 Ok(Self {
124 desc: *desc,
125 sku,
126 channel_stride,
127 channel_extent,
128 scalar_weight,
129 _marker: PhantomData,
130 })
131 }
132 #[inline]
134 pub fn workspace_size(&self) -> usize {
135 0
136 }
137 #[inline]
139 pub fn sku(&self) -> KernelSku {
140 self.sku
141 }
142 #[inline]
144 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
145 self.sku.precision_guarantee
146 }
147
148 pub fn run(
150 &self,
151 stream: &Stream,
152 _workspace: Workspace<'_>,
153 args: PReluArgs<'_, T, N>,
154 ) -> Result<()> {
155 if args.x.shape != self.desc.input_shape || args.y.shape != self.desc.input_shape {
156 return Err(Error::InvalidProblem(
157 "baracuda-kernels::PReluPlan: x / y shape mismatch",
158 ));
159 }
160 let expected_weight = if self.scalar_weight { 1 } else { self.channel_extent };
161 if args.weight.shape[0] != expected_weight {
162 return Err(Error::InvalidProblem(
163 "baracuda-kernels::PReluPlan: weight shape mismatch",
164 ));
165 }
166 let numel = args.x.numel();
167 if numel == 0 {
168 return Ok(());
169 }
170 let stream_ptr = stream.as_raw() as *mut c_void;
171 let x_ptr = args.x.data.as_raw().0 as *const c_void;
172 let weight_ptr = args.weight.data.as_raw().0 as *const c_void;
173 let y_ptr = args.y.data.as_raw().0 as *mut c_void;
174 let scalar_flag: i32 = if self.scalar_weight { 1 } else { 0 };
175 let status = match T::KIND {
176 ElementKind::F32 => unsafe {
177 baracuda_kernels_sys::baracuda_kernels_prelu_f32_run(
178 numel,
179 self.channel_stride,
180 self.channel_extent,
181 scalar_flag,
182 x_ptr,
183 weight_ptr,
184 y_ptr,
185 core::ptr::null_mut(),
186 0,
187 stream_ptr,
188 )
189 },
190 ElementKind::F16 => unsafe {
191 baracuda_kernels_sys::baracuda_kernels_prelu_f16_run(
192 numel,
193 self.channel_stride,
194 self.channel_extent,
195 scalar_flag,
196 x_ptr,
197 weight_ptr,
198 y_ptr,
199 core::ptr::null_mut(),
200 0,
201 stream_ptr,
202 )
203 },
204 ElementKind::Bf16 => unsafe {
205 baracuda_kernels_sys::baracuda_kernels_prelu_bf16_run(
206 numel,
207 self.channel_stride,
208 self.channel_extent,
209 scalar_flag,
210 x_ptr,
211 weight_ptr,
212 y_ptr,
213 core::ptr::null_mut(),
214 0,
215 stream_ptr,
216 )
217 },
218 ElementKind::F64 => unsafe {
219 baracuda_kernels_sys::baracuda_kernels_prelu_f64_run(
220 numel,
221 self.channel_stride,
222 self.channel_extent,
223 scalar_flag,
224 x_ptr,
225 weight_ptr,
226 y_ptr,
227 core::ptr::null_mut(),
228 0,
229 stream_ptr,
230 )
231 },
232 _ => {
233 return Err(Error::Unsupported(
234 "baracuda-kernels::PReluPlan::run unwired dtype",
235 ));
236 }
237 };
238 match status {
239 0 => Ok(()),
240 1 => Err(Error::MisalignedOperand),
241 2 => Err(Error::InvalidProblem(
242 "baracuda-kernels-sys reported invalid problem",
243 )),
244 3 => Err(Error::Unsupported(
245 "baracuda-kernels-sys reported unsupported configuration",
246 )),
247 4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
248 n => Err(Error::CutlassInternal(n)),
249 }
250 }
251}