1use core::ffi::c_void;
19use core::marker::PhantomData;
20
21use baracuda_cutlass::{Error, Result};
22use baracuda_driver::Stream;
23use baracuda_kernels_types::{
24 ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
25 PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, UnaryKind, Workspace,
26};
27use half::{bf16, f16};
28
29#[derive(Copy, Clone, Debug)]
35pub struct AffineDescriptor<T: Element> {
36 pub numel: i32,
38 pub a: T,
40 pub b: T,
42 pub element: ElementKind,
44}
45
46pub struct AffineArgs<'a, T: Element> {
48 pub input: TensorRef<'a, T, 1>,
50 pub output: TensorMut<'a, T, 1>,
52}
53
54pub struct AffinePlan<T: Element> {
56 desc: AffineDescriptor<T>,
57 sku: KernelSku,
58 _marker: PhantomData<T>,
59}
60
61impl<T: Element> AffinePlan<T> {
62 pub fn select(
64 _stream: &Stream,
65 desc: &AffineDescriptor<T>,
66 _pref: PlanPreference,
67 ) -> Result<Self> {
68 if desc.element != T::KIND {
69 return Err(Error::Unsupported(
70 "baracuda-kernels::AffinePlan: descriptor element != type parameter T",
71 ));
72 }
73 if desc.numel < 0 {
74 return Err(Error::InvalidProblem(
75 "baracuda-kernels::AffinePlan: numel must be non-negative",
76 ));
77 }
78 if !dtype_in_scope(T::KIND) {
79 return Err(Error::Unsupported(
80 "baracuda-kernels::AffinePlan: dtype not wired today; supported set is \
81 {f32, f64, f16, bf16, i32, i64}",
82 ));
83 }
84
85 let precision_guarantee = PrecisionGuarantee {
86 math_precision: MathPrecision::F32,
87 accumulator: ElementKind::F32,
88 bit_stable_on_same_hardware: true,
89 deterministic: true,
90 };
91 let sku = KernelSku {
92 category: OpCategory::UnaryElementwise,
93 op: UnaryKind::Affine as u16,
94 element: T::KIND,
95 aux_element: None,
96 layout: None,
97 epilogue: None,
98 arch: ArchSku::Sm80,
99 backend: BackendKind::Bespoke,
100 precision_guarantee,
101 };
102 Ok(Self {
103 desc: *desc,
104 sku,
105 _marker: PhantomData,
106 })
107 }
108
109 pub fn can_implement(&self, args: &AffineArgs<'_, T>) -> Result<()> {
111 let expected = self.desc.numel as i64;
112 if args.input.numel() != expected {
113 return Err(Error::InvalidProblem(
114 "baracuda-kernels::AffinePlan: input numel mismatch with descriptor",
115 ));
116 }
117 if args.output.numel() != expected {
118 return Err(Error::InvalidProblem(
119 "baracuda-kernels::AffinePlan: output numel mismatch with descriptor",
120 ));
121 }
122 if (args.input.data.len() as i64) < expected {
123 return Err(Error::BufferTooSmall {
124 needed: expected as usize,
125 got: args.input.data.len(),
126 });
127 }
128 if (args.output.data.len() as i64) < expected {
129 return Err(Error::BufferTooSmall {
130 needed: expected as usize,
131 got: args.output.data.len(),
132 });
133 }
134 Ok(())
135 }
136
137 #[inline]
139 pub fn workspace_size(&self) -> usize {
140 0
141 }
142
143 #[inline]
145 pub fn sku(&self) -> KernelSku {
146 self.sku
147 }
148
149 #[inline]
151 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
152 self.sku.precision_guarantee
153 }
154
155 pub fn run(
163 &self,
164 stream: &Stream,
165 _workspace: Workspace<'_>,
166 args: AffineArgs<'_, T>,
167 ) -> Result<()> {
168 self.can_implement(&args)?;
169 let numel = self.desc.numel as i64;
170 if numel == 0 {
171 return Ok(());
172 }
173 let x_ptr = args.input.data.as_raw().0 as *const c_void;
174 let y_ptr = args.output.data.as_raw().0 as *mut c_void;
175 let stream_ptr = stream.as_raw() as *mut c_void;
176
177 let contig =
179 is_canonical_contig(&args.input.shape, &args.input.stride)
180 && is_canonical_contig(&args.output.shape, &args.output.stride);
181
182 let status = unsafe {
188 if contig {
189 match T::KIND {
190 ElementKind::F32 => {
191 let a: f32 = core::mem::transmute_copy(&self.desc.a);
192 let b: f32 = core::mem::transmute_copy(&self.desc.b);
193 baracuda_kernels_sys::baracuda_kernels_affine_f32_run(
194 numel, x_ptr, y_ptr, a, b, core::ptr::null_mut(), 0, stream_ptr,
195 )
196 }
197 ElementKind::F64 => {
198 let a: f64 = core::mem::transmute_copy(&self.desc.a);
199 let b: f64 = core::mem::transmute_copy(&self.desc.b);
200 baracuda_kernels_sys::baracuda_kernels_affine_f64_run(
201 numel, x_ptr, y_ptr, a, b, core::ptr::null_mut(), 0, stream_ptr,
202 )
203 }
204 ElementKind::I32 => {
205 let a: i32 = core::mem::transmute_copy(&self.desc.a);
206 let b: i32 = core::mem::transmute_copy(&self.desc.b);
207 baracuda_kernels_sys::baracuda_kernels_affine_i32_run(
208 numel, x_ptr, y_ptr, a, b, core::ptr::null_mut(), 0, stream_ptr,
209 )
210 }
211 ElementKind::I64 => {
212 let a: i64 = core::mem::transmute_copy(&self.desc.a);
213 let b: i64 = core::mem::transmute_copy(&self.desc.b);
214 baracuda_kernels_sys::baracuda_kernels_affine_i64_run(
215 numel, x_ptr, y_ptr, a, b, core::ptr::null_mut(), 0, stream_ptr,
216 )
217 }
218 ElementKind::F16 => {
219 let a: f16 = core::mem::transmute_copy(&self.desc.a);
220 let b: f16 = core::mem::transmute_copy(&self.desc.b);
221 baracuda_kernels_sys::baracuda_kernels_affine_f16_run(
222 numel, x_ptr, y_ptr, a.to_f32(), b.to_f32(),
223 core::ptr::null_mut(), 0, stream_ptr,
224 )
225 }
226 ElementKind::Bf16 => {
227 let a: bf16 = core::mem::transmute_copy(&self.desc.a);
228 let b: bf16 = core::mem::transmute_copy(&self.desc.b);
229 baracuda_kernels_sys::baracuda_kernels_affine_bf16_run(
230 numel, x_ptr, y_ptr, a.to_f32(), b.to_f32(),
231 core::ptr::null_mut(), 0, stream_ptr,
232 )
233 }
234 _ => {
235 return Err(Error::Unsupported(
236 "baracuda-kernels::AffinePlan::run reached an unimplemented dtype \
237 — select() should have caught this",
238 ));
239 }
240 }
241 } else {
242 let shape_ptr = args.input.shape.as_ptr();
245 let stride_x_ptr = args.input.stride.as_ptr();
246 let stride_y_ptr = args.output.stride.as_ptr();
247 let rank: i32 = 1;
248 match T::KIND {
249 ElementKind::F32 => {
250 let a: f32 = core::mem::transmute_copy(&self.desc.a);
251 let b: f32 = core::mem::transmute_copy(&self.desc.b);
252 baracuda_kernels_sys::baracuda_kernels_affine_f32_strided_run(
253 numel, rank, shape_ptr, stride_x_ptr, stride_y_ptr,
254 x_ptr, y_ptr, a, b, core::ptr::null_mut(), 0, stream_ptr,
255 )
256 }
257 ElementKind::F64 => {
258 let a: f64 = core::mem::transmute_copy(&self.desc.a);
259 let b: f64 = core::mem::transmute_copy(&self.desc.b);
260 baracuda_kernels_sys::baracuda_kernels_affine_f64_strided_run(
261 numel, rank, shape_ptr, stride_x_ptr, stride_y_ptr,
262 x_ptr, y_ptr, a, b, core::ptr::null_mut(), 0, stream_ptr,
263 )
264 }
265 ElementKind::I32 => {
266 let a: i32 = core::mem::transmute_copy(&self.desc.a);
267 let b: i32 = core::mem::transmute_copy(&self.desc.b);
268 baracuda_kernels_sys::baracuda_kernels_affine_i32_strided_run(
269 numel, rank, shape_ptr, stride_x_ptr, stride_y_ptr,
270 x_ptr, y_ptr, a, b, core::ptr::null_mut(), 0, stream_ptr,
271 )
272 }
273 ElementKind::I64 => {
274 let a: i64 = core::mem::transmute_copy(&self.desc.a);
275 let b: i64 = core::mem::transmute_copy(&self.desc.b);
276 baracuda_kernels_sys::baracuda_kernels_affine_i64_strided_run(
277 numel, rank, shape_ptr, stride_x_ptr, stride_y_ptr,
278 x_ptr, y_ptr, a, b, core::ptr::null_mut(), 0, stream_ptr,
279 )
280 }
281 ElementKind::F16 => {
282 let a: f16 = core::mem::transmute_copy(&self.desc.a);
283 let b: f16 = core::mem::transmute_copy(&self.desc.b);
284 baracuda_kernels_sys::baracuda_kernels_affine_f16_strided_run(
285 numel, rank, shape_ptr, stride_x_ptr, stride_y_ptr,
286 x_ptr, y_ptr, a.to_f32(), b.to_f32(),
287 core::ptr::null_mut(), 0, stream_ptr,
288 )
289 }
290 ElementKind::Bf16 => {
291 let a: bf16 = core::mem::transmute_copy(&self.desc.a);
292 let b: bf16 = core::mem::transmute_copy(&self.desc.b);
293 baracuda_kernels_sys::baracuda_kernels_affine_bf16_strided_run(
294 numel, rank, shape_ptr, stride_x_ptr, stride_y_ptr,
295 x_ptr, y_ptr, a.to_f32(), b.to_f32(),
296 core::ptr::null_mut(), 0, stream_ptr,
297 )
298 }
299 _ => {
300 return Err(Error::Unsupported(
301 "baracuda-kernels::AffinePlan::run reached an unimplemented dtype \
302 — select() should have caught this",
303 ));
304 }
305 }
306 }
307 };
308 map_status(status)
309 }
310}
311
312#[inline]
319fn is_canonical_contig<const N: usize>(shape: &[i32; N], stride: &[i64; N]) -> bool {
320 if N == 0 {
321 return true;
322 }
323 let mut expected: i64 = 1;
324 let mut i = N;
325 while i > 0 {
326 i -= 1;
327 if stride[i] != expected {
328 return false;
329 }
330 expected = expected.saturating_mul(shape[i] as i64);
331 }
332 true
333}
334
335fn dtype_in_scope(k: ElementKind) -> bool {
336 matches!(
337 k,
338 ElementKind::F32
339 | ElementKind::F64
340 | ElementKind::F16
341 | ElementKind::Bf16
342 | ElementKind::I32
343 | ElementKind::I64
344 )
345}
346
347fn map_status(code: i32) -> Result<()> {
348 match code {
349 0 => Ok(()),
350 1 => Err(Error::MisalignedOperand),
351 2 => Err(Error::InvalidProblem(
352 "baracuda-kernels-sys reported invalid problem",
353 )),
354 3 => Err(Error::Unsupported(
355 "baracuda-kernels-sys reported unsupported configuration",
356 )),
357 4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
358 n => Err(Error::CutlassInternal(n)),
359 }
360}