1use core::ffi::c_void;
38use core::marker::PhantomData;
39
40use baracuda_cutlass::{Error, Result};
41use baracuda_driver::Stream;
42use baracuda_kernels_types::{
43 ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
44 PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, UnaryKind, Workspace,
45};
46
47#[derive(Copy, Clone, Debug)]
63pub struct UnaryParamDescriptor<const N: usize> {
64 pub kind: UnaryKind,
66 pub shape: [i32; N],
68 pub element: ElementKind,
71 pub params: [f32; 2],
76}
77
78pub struct UnaryParamArgs<'a, T: Element, const N: usize> {
80 pub x: TensorRef<'a, T, N>,
82 pub y: TensorMut<'a, T, N>,
84}
85
86pub struct UnaryParamPlan<T: Element, const N: usize> {
88 desc: UnaryParamDescriptor<N>,
89 sku: KernelSku,
90 _marker: PhantomData<T>,
91}
92
93impl<T: Element, const N: usize> UnaryParamPlan<T, N> {
94 pub fn select(
96 _stream: &Stream,
97 desc: &UnaryParamDescriptor<N>,
98 _pref: PlanPreference,
99 ) -> Result<Self> {
100 if desc.element != T::KIND {
101 return Err(Error::Unsupported(
102 "baracuda-kernels::UnaryParamPlan: descriptor element != type parameter T",
103 ));
104 }
105 for &d in desc.shape.iter() {
106 if d < 0 {
107 return Err(Error::InvalidProblem(
108 "baracuda-kernels::UnaryParamPlan: shape dims must be non-negative",
109 ));
110 }
111 }
112
113 let kind_in_scope = matches!(desc.kind, UnaryKind::Threshold | UnaryKind::PowI);
117 let dtype_in_scope = matches!(
118 T::KIND,
119 ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
120 );
121 if !(kind_in_scope && dtype_in_scope) {
122 return Err(Error::Unsupported(
123 "baracuda-kernels::UnaryParamPlan: today only `{Threshold, PowI} × \
124 {f32, f16, bf16, f64}` is wired; LeakyRelu / ELU / Hardshrink / Softshrink \
125 ship via UnaryPlan with hardcoded PyTorch defaults today; PReLU needs a \
126 distinct (channel-vector) plan.",
127 ));
128 }
129
130 let precision_guarantee = PrecisionGuarantee {
131 math_precision: MathPrecision::F32,
132 accumulator: ElementKind::F32,
133 bit_stable_on_same_hardware: true,
134 deterministic: true,
135 };
136 let sku = KernelSku {
137 category: OpCategory::UnaryElementwise,
138 op: desc.kind as u16,
139 element: T::KIND,
140 aux_element: None,
141 layout: None,
142 epilogue: None,
143 arch: ArchSku::Sm80,
144 backend: BackendKind::Bespoke,
145 precision_guarantee,
146 };
147 Ok(Self {
148 desc: *desc,
149 sku,
150 _marker: PhantomData,
151 })
152 }
153
154 pub fn can_implement(&self, args: &UnaryParamArgs<'_, T, N>) -> Result<()> {
156 if args.x.shape != self.desc.shape {
157 return Err(Error::InvalidProblem(
158 "baracuda-kernels::UnaryParamPlan: X shape mismatch with descriptor",
159 ));
160 }
161 if args.y.shape != self.desc.shape {
162 return Err(Error::InvalidProblem(
163 "baracuda-kernels::UnaryParamPlan: Y shape mismatch with descriptor",
164 ));
165 }
166 let all_contig = args.x.is_contiguous() && args.y.is_contiguous();
169 if !all_contig && !matches!(self.desc.kind, UnaryKind::PowI) {
170 return Err(Error::Unsupported(
171 "baracuda-kernels::UnaryParamPlan: this op is contig-only today; strided \
172 fanout lands later (PowI is the trailblazer in Phase 14.2)",
173 ));
174 }
175 let numel = args.y.numel();
176 let x_len = args.x.data.len() as i64;
177 let y_len = args.y.data.len() as i64;
178 if x_len < numel || y_len < numel {
179 return Err(Error::BufferTooSmall {
180 needed: numel as usize,
181 got: x_len.min(y_len) as usize,
182 });
183 }
184 Ok(())
185 }
186
187 #[inline]
189 pub fn workspace_size(&self) -> usize {
190 0
191 }
192
193 #[inline]
195 pub fn sku(&self) -> KernelSku {
196 self.sku
197 }
198
199 #[inline]
201 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
202 self.sku.precision_guarantee
203 }
204
205 pub fn run(
207 &self,
208 stream: &Stream,
209 _workspace: Workspace<'_>,
210 args: UnaryParamArgs<'_, T, N>,
211 ) -> Result<()> {
212 self.can_implement(&args)?;
213 let numel = args.y.numel();
214 if numel == 0 {
215 return Ok(());
216 }
217 let x_ptr = args.x.data.as_raw().0 as *const c_void;
218 let y_ptr = args.y.data.as_raw().0 as *mut c_void;
219 let stream_ptr = stream.as_raw() as *mut c_void;
220 let p0 = self.desc.params[0];
221 let p1 = self.desc.params[1];
222
223 let all_contig = args.x.is_contiguous() && args.y.is_contiguous();
228 if !all_contig && matches!(self.desc.kind, UnaryKind::PowI) {
229 return self.run_strided(stream_ptr, x_ptr, y_ptr, numel, &args, p0, p1);
230 }
231
232 let status = match (self.desc.kind, T::KIND) {
233 (UnaryKind::Threshold, ElementKind::F32) => unsafe {
234 baracuda_kernels_sys::baracuda_kernels_unary_threshold_f32_run(
235 numel, x_ptr, y_ptr, p0, p1,
236 core::ptr::null_mut(), 0, stream_ptr,
237 )
238 },
239 (UnaryKind::Threshold, ElementKind::F16) => unsafe {
240 baracuda_kernels_sys::baracuda_kernels_unary_threshold_f16_run(
241 numel, x_ptr, y_ptr, p0, p1,
242 core::ptr::null_mut(), 0, stream_ptr,
243 )
244 },
245 (UnaryKind::Threshold, ElementKind::Bf16) => unsafe {
246 baracuda_kernels_sys::baracuda_kernels_unary_threshold_bf16_run(
247 numel, x_ptr, y_ptr, p0, p1,
248 core::ptr::null_mut(), 0, stream_ptr,
249 )
250 },
251 (UnaryKind::Threshold, ElementKind::F64) => unsafe {
252 baracuda_kernels_sys::baracuda_kernels_unary_threshold_f64_run(
253 numel, x_ptr, y_ptr, p0, p1,
254 core::ptr::null_mut(), 0, stream_ptr,
255 )
256 },
257 (UnaryKind::PowI, ElementKind::F32) => unsafe {
258 baracuda_kernels_sys::baracuda_kernels_unary_powi_f32_run(
259 numel, x_ptr, y_ptr, p0, p1,
260 core::ptr::null_mut(), 0, stream_ptr,
261 )
262 },
263 (UnaryKind::PowI, ElementKind::F16) => unsafe {
264 baracuda_kernels_sys::baracuda_kernels_unary_powi_f16_run(
265 numel, x_ptr, y_ptr, p0, p1,
266 core::ptr::null_mut(), 0, stream_ptr,
267 )
268 },
269 (UnaryKind::PowI, ElementKind::Bf16) => unsafe {
270 baracuda_kernels_sys::baracuda_kernels_unary_powi_bf16_run(
271 numel, x_ptr, y_ptr, p0, p1,
272 core::ptr::null_mut(), 0, stream_ptr,
273 )
274 },
275 (UnaryKind::PowI, ElementKind::F64) => unsafe {
276 baracuda_kernels_sys::baracuda_kernels_unary_powi_f64_run(
277 numel, x_ptr, y_ptr, p0, p1,
278 core::ptr::null_mut(), 0, stream_ptr,
279 )
280 },
281 _ => {
282 return Err(Error::Unsupported(
283 "baracuda-kernels::UnaryParamPlan: dispatcher reached an unimplemented \
284 (kind, dtype) pair — select() should have caught this",
285 ));
286 }
287 };
288 map_status(status)
289 }
290}
291
292impl<T: Element, const N: usize> UnaryParamPlan<T, N> {
293 fn run_strided(
297 &self,
298 stream_ptr: *mut c_void,
299 x_ptr: *const c_void,
300 y_ptr: *mut c_void,
301 numel: i64,
302 args: &UnaryParamArgs<'_, T, N>,
303 p0: f32,
304 p1: f32,
305 ) -> Result<()> {
306 let shape = args.y.shape;
307 let stride_x = args.x.stride;
308 let stride_y = args.y.stride;
309 let rank = N as i32;
310
311 let status = match (self.desc.kind, T::KIND) {
312 (UnaryKind::PowI, ElementKind::F32) => unsafe {
313 baracuda_kernels_sys::baracuda_kernels_unary_powi_f32_strided_run(
314 numel, rank, shape.as_ptr(),
315 stride_x.as_ptr(), stride_y.as_ptr(),
316 x_ptr, y_ptr, p0, p1,
317 core::ptr::null_mut(), 0, stream_ptr,
318 )
319 },
320 (UnaryKind::PowI, ElementKind::F16) => unsafe {
321 baracuda_kernels_sys::baracuda_kernels_unary_powi_f16_strided_run(
322 numel, rank, shape.as_ptr(),
323 stride_x.as_ptr(), stride_y.as_ptr(),
324 x_ptr, y_ptr, p0, p1,
325 core::ptr::null_mut(), 0, stream_ptr,
326 )
327 },
328 (UnaryKind::PowI, ElementKind::Bf16) => unsafe {
329 baracuda_kernels_sys::baracuda_kernels_unary_powi_bf16_strided_run(
330 numel, rank, shape.as_ptr(),
331 stride_x.as_ptr(), stride_y.as_ptr(),
332 x_ptr, y_ptr, p0, p1,
333 core::ptr::null_mut(), 0, stream_ptr,
334 )
335 },
336 (UnaryKind::PowI, ElementKind::F64) => unsafe {
337 baracuda_kernels_sys::baracuda_kernels_unary_powi_f64_strided_run(
338 numel, rank, shape.as_ptr(),
339 stride_x.as_ptr(), stride_y.as_ptr(),
340 x_ptr, y_ptr, p0, p1,
341 core::ptr::null_mut(), 0, stream_ptr,
342 )
343 },
344 _ => {
345 return Err(Error::Unsupported(
346 "baracuda-kernels::UnaryParamPlan::run_strided: only PowI is wired \
347 for the strided path today (Phase 14.2 trailblazer)",
348 ));
349 }
350 };
351 map_status(status)
352 }
353}
354
355fn map_status(code: i32) -> Result<()> {
356 match code {
357 0 => Ok(()),
358 1 => Err(Error::MisalignedOperand),
359 2 => Err(Error::InvalidProblem(
360 "baracuda-kernels-sys reported invalid problem",
361 )),
362 3 => Err(Error::Unsupported(
363 "baracuda-kernels-sys reported unsupported configuration",
364 )),
365 4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
366 n => Err(Error::CutlassInternal(n)),
367 }
368}