1use core::ffi::c_void;
21use core::marker::PhantomData;
22
23use baracuda_cutlass::{Error, Result};
24use baracuda_driver::Stream;
25use baracuda_kernels_types::{
26 ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
27 PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, TernaryKind, Workspace,
28};
29
30#[derive(Copy, Clone, Debug)]
37pub struct TernaryDescriptor<const N: usize> {
38 pub kind: TernaryKind,
40 pub shape: [i32; N],
42 pub element: ElementKind,
44 pub scale: f32,
47}
48
49pub struct TernaryArgs<'a, T: Element, const N: usize> {
56 pub a: TensorRef<'a, T, N>,
58 pub b: TensorRef<'a, T, N>,
60 pub c: TensorRef<'a, T, N>,
62 pub y: TensorMut<'a, T, N>,
64}
65
66pub struct TernaryPlan<T: Element, const N: usize> {
71 desc: TernaryDescriptor<N>,
72 sku: KernelSku,
73 _marker: PhantomData<T>,
74}
75
76impl<T: Element, const N: usize> TernaryPlan<T, N> {
77 pub fn select(
80 _stream: &Stream,
81 desc: &TernaryDescriptor<N>,
82 _pref: PlanPreference,
83 ) -> Result<Self> {
84 if desc.element != T::KIND {
85 return Err(Error::Unsupported(
86 "baracuda-kernels::TernaryPlan: descriptor element != type parameter T",
87 ));
88 }
89 for &d in desc.shape.iter() {
90 if d < 0 {
91 return Err(Error::InvalidProblem(
92 "baracuda-kernels::TernaryPlan: shape dims must be non-negative",
93 ));
94 }
95 }
96
97 let kind_in_scope = matches!(
102 desc.kind,
103 TernaryKind::Clamp | TernaryKind::Fma | TernaryKind::Addcmul | TernaryKind::Addcdiv
104 );
105 let dtype_in_scope = matches!(
106 T::KIND,
107 ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
108 );
109 let supported = kind_in_scope && dtype_in_scope;
110 if !supported {
111 return Err(Error::Unsupported(
112 "baracuda-kernels::TernaryPlan: this (kind, dtype) cell is not yet \
113 wired; see the dispatcher's kind / dtype scope for the supported set. \
114 Note: `Where` requires a separate heterogeneous-dtype plan \
115 (`crate::WherePlan`).",
116 ));
117 }
118
119 let precision_guarantee = PrecisionGuarantee {
120 math_precision: MathPrecision::F32,
121 accumulator: ElementKind::F32,
122 bit_stable_on_same_hardware: true,
123 deterministic: true,
124 };
125 let sku = KernelSku {
126 category: OpCategory::TernaryElementwise,
127 op: desc.kind as u16,
128 element: T::KIND,
129 aux_element: None,
130 layout: None,
131 epilogue: None,
132 arch: ArchSku::Sm80,
133 backend: BackendKind::Bespoke,
134 precision_guarantee,
135 };
136 Ok(Self {
137 desc: *desc,
138 sku,
139 _marker: PhantomData,
140 })
141 }
142
143 pub fn can_implement(&self, args: &TernaryArgs<'_, T, N>) -> Result<()> {
148 if args.y.shape != self.desc.shape {
149 return Err(Error::InvalidProblem(
150 "baracuda-kernels::TernaryPlan: Y shape mismatch with descriptor",
151 ));
152 }
153
154 for d in 0..N {
155 let y_dim = self.desc.shape[d];
156 for (name, (op_dim, op_stride)) in [
157 ("A", (args.a.shape[d], args.a.stride[d])),
158 ("B", (args.b.shape[d], args.b.stride[d])),
159 ("C", (args.c.shape[d], args.c.stride[d])),
160 ] {
161 if op_dim != y_dim && !(op_dim == 1 && op_stride == 0) {
162 let _ = name; return Err(Error::InvalidProblem(
165 "baracuda-kernels::TernaryPlan: input axis not broadcast-compatible \
166 with output (require shape[d] == y.shape[d], OR \
167 shape[d] == 1 AND stride[d] == 0)",
168 ));
169 }
170 }
171 }
172
173 if N > 8 {
174 return Err(Error::Unsupported(
175 "baracuda-kernels::TernaryPlan: tensor rank > 8 not supported",
176 ));
177 }
178
179 let y_numel = args.y.numel();
180 let a_numel = args.a.numel();
181 let b_numel = args.b.numel();
182 let c_numel = args.c.numel();
183 let a_len = args.a.data.len() as i64;
184 let b_len = args.b.data.len() as i64;
185 let c_len = args.c.data.len() as i64;
186 let y_len = args.y.data.len() as i64;
187 if y_len < y_numel {
188 return Err(Error::BufferTooSmall {
189 needed: y_numel as usize,
190 got: y_len as usize,
191 });
192 }
193 if a_len < a_numel {
194 return Err(Error::BufferTooSmall {
195 needed: a_numel as usize,
196 got: a_len as usize,
197 });
198 }
199 if b_len < b_numel {
200 return Err(Error::BufferTooSmall {
201 needed: b_numel as usize,
202 got: b_len as usize,
203 });
204 }
205 if c_len < c_numel {
206 return Err(Error::BufferTooSmall {
207 needed: c_numel as usize,
208 got: c_len as usize,
209 });
210 }
211 Ok(())
212 }
213
214 #[inline]
216 pub fn workspace_size(&self) -> usize {
217 0
218 }
219
220 #[inline]
222 pub fn sku(&self) -> KernelSku {
223 self.sku
224 }
225
226 #[inline]
228 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
229 self.sku.precision_guarantee
230 }
231
232 pub fn run(
234 &self,
235 stream: &Stream,
236 _workspace: Workspace<'_>,
237 args: TernaryArgs<'_, T, N>,
238 ) -> Result<()> {
239 self.can_implement(&args)?;
240 let numel = args.y.numel();
241 if numel == 0 {
242 return Ok(());
243 }
244 let a_ptr = args.a.data.as_raw().0 as *const c_void;
245 let b_ptr = args.b.data.as_raw().0 as *const c_void;
246 let c_ptr = args.c.data.as_raw().0 as *const c_void;
247 let y_ptr = args.y.data.as_raw().0 as *mut c_void;
248 let stream_ptr = stream.as_raw() as *mut c_void;
249
250 let all_contig_same_shape = args.a.shape == args.y.shape
251 && args.b.shape == args.y.shape
252 && args.c.shape == args.y.shape
253 && args.a.is_contiguous()
254 && args.b.is_contiguous()
255 && args.c.is_contiguous()
256 && args.y.is_contiguous();
257
258 if !all_contig_same_shape {
259 return self.run_strided(stream_ptr, a_ptr, b_ptr, c_ptr, y_ptr, numel, &args);
260 }
261
262 let status = match (self.desc.kind, T::KIND) {
263 (TernaryKind::Clamp, ElementKind::F32) => unsafe {
265 baracuda_kernels_sys::baracuda_kernels_ternary_clamp_f32_run(
266 numel, a_ptr, b_ptr, c_ptr, y_ptr,
267 core::ptr::null_mut(), 0, stream_ptr,
268 )
269 },
270 (TernaryKind::Clamp, ElementKind::F16) => unsafe {
271 baracuda_kernels_sys::baracuda_kernels_ternary_clamp_f16_run(
272 numel, a_ptr, b_ptr, c_ptr, y_ptr,
273 core::ptr::null_mut(), 0, stream_ptr,
274 )
275 },
276 (TernaryKind::Clamp, ElementKind::Bf16) => unsafe {
277 baracuda_kernels_sys::baracuda_kernels_ternary_clamp_bf16_run(
278 numel, a_ptr, b_ptr, c_ptr, y_ptr,
279 core::ptr::null_mut(), 0, stream_ptr,
280 )
281 },
282 (TernaryKind::Clamp, ElementKind::F64) => unsafe {
283 baracuda_kernels_sys::baracuda_kernels_ternary_clamp_f64_run(
284 numel, a_ptr, b_ptr, c_ptr, y_ptr,
285 core::ptr::null_mut(), 0, stream_ptr,
286 )
287 },
288 (TernaryKind::Fma, ElementKind::F32) => unsafe {
290 baracuda_kernels_sys::baracuda_kernels_ternary_fma_f32_run(
291 numel, a_ptr, b_ptr, c_ptr, y_ptr,
292 core::ptr::null_mut(), 0, stream_ptr,
293 )
294 },
295 (TernaryKind::Fma, ElementKind::F16) => unsafe {
296 baracuda_kernels_sys::baracuda_kernels_ternary_fma_f16_run(
297 numel, a_ptr, b_ptr, c_ptr, y_ptr,
298 core::ptr::null_mut(), 0, stream_ptr,
299 )
300 },
301 (TernaryKind::Fma, ElementKind::Bf16) => unsafe {
302 baracuda_kernels_sys::baracuda_kernels_ternary_fma_bf16_run(
303 numel, a_ptr, b_ptr, c_ptr, y_ptr,
304 core::ptr::null_mut(), 0, stream_ptr,
305 )
306 },
307 (TernaryKind::Fma, ElementKind::F64) => unsafe {
308 baracuda_kernels_sys::baracuda_kernels_ternary_fma_f64_run(
309 numel, a_ptr, b_ptr, c_ptr, y_ptr,
310 core::ptr::null_mut(), 0, stream_ptr,
311 )
312 },
313 (TernaryKind::Addcmul, ElementKind::F32) => unsafe {
315 baracuda_kernels_sys::baracuda_kernels_ternary_addcmul_f32_run(
316 numel, a_ptr, b_ptr, c_ptr, y_ptr, self.desc.scale,
317 core::ptr::null_mut(), 0, stream_ptr,
318 )
319 },
320 (TernaryKind::Addcmul, ElementKind::F16) => unsafe {
321 baracuda_kernels_sys::baracuda_kernels_ternary_addcmul_f16_run(
322 numel, a_ptr, b_ptr, c_ptr, y_ptr, self.desc.scale,
323 core::ptr::null_mut(), 0, stream_ptr,
324 )
325 },
326 (TernaryKind::Addcmul, ElementKind::Bf16) => unsafe {
327 baracuda_kernels_sys::baracuda_kernels_ternary_addcmul_bf16_run(
328 numel, a_ptr, b_ptr, c_ptr, y_ptr, self.desc.scale,
329 core::ptr::null_mut(), 0, stream_ptr,
330 )
331 },
332 (TernaryKind::Addcmul, ElementKind::F64) => unsafe {
333 baracuda_kernels_sys::baracuda_kernels_ternary_addcmul_f64_run(
334 numel, a_ptr, b_ptr, c_ptr, y_ptr, self.desc.scale,
335 core::ptr::null_mut(), 0, stream_ptr,
336 )
337 },
338 (TernaryKind::Addcdiv, ElementKind::F32) => unsafe {
340 baracuda_kernels_sys::baracuda_kernels_ternary_addcdiv_f32_run(
341 numel, a_ptr, b_ptr, c_ptr, y_ptr, self.desc.scale,
342 core::ptr::null_mut(), 0, stream_ptr,
343 )
344 },
345 (TernaryKind::Addcdiv, ElementKind::F16) => unsafe {
346 baracuda_kernels_sys::baracuda_kernels_ternary_addcdiv_f16_run(
347 numel, a_ptr, b_ptr, c_ptr, y_ptr, self.desc.scale,
348 core::ptr::null_mut(), 0, stream_ptr,
349 )
350 },
351 (TernaryKind::Addcdiv, ElementKind::Bf16) => unsafe {
352 baracuda_kernels_sys::baracuda_kernels_ternary_addcdiv_bf16_run(
353 numel, a_ptr, b_ptr, c_ptr, y_ptr, self.desc.scale,
354 core::ptr::null_mut(), 0, stream_ptr,
355 )
356 },
357 (TernaryKind::Addcdiv, ElementKind::F64) => unsafe {
358 baracuda_kernels_sys::baracuda_kernels_ternary_addcdiv_f64_run(
359 numel, a_ptr, b_ptr, c_ptr, y_ptr, self.desc.scale,
360 core::ptr::null_mut(), 0, stream_ptr,
361 )
362 },
363 _ => {
364 return Err(Error::Unsupported(
365 "baracuda-kernels::TernaryPlan::run reached an unimplemented \
366 (kind, dtype) — select() should have caught this",
367 ));
368 }
369 };
370 map_status(status)
371 }
372
373 fn run_strided(
375 &self,
376 stream_ptr: *mut c_void,
377 a_ptr: *const c_void,
378 b_ptr: *const c_void,
379 c_ptr: *const c_void,
380 y_ptr: *mut c_void,
381 numel: i64,
382 args: &TernaryArgs<'_, T, N>,
383 ) -> Result<()> {
384 let shape = args.y.shape;
385 let stride_a = args.a.stride;
386 let stride_b = args.b.stride;
387 let stride_c = args.c.stride;
388 let stride_y = args.y.stride;
389 let rank = N as i32;
390
391 let status = match (self.desc.kind, T::KIND) {
392 (TernaryKind::Clamp, ElementKind::F32) => unsafe {
394 baracuda_kernels_sys::baracuda_kernels_ternary_clamp_f32_strided_run(
395 numel, rank, shape.as_ptr(),
396 stride_a.as_ptr(), stride_b.as_ptr(),
397 stride_c.as_ptr(), stride_y.as_ptr(),
398 a_ptr, b_ptr, c_ptr, y_ptr,
399 core::ptr::null_mut(), 0, stream_ptr,
400 )
401 },
402 (TernaryKind::Clamp, ElementKind::F16) => unsafe {
403 baracuda_kernels_sys::baracuda_kernels_ternary_clamp_f16_strided_run(
404 numel, rank, shape.as_ptr(),
405 stride_a.as_ptr(), stride_b.as_ptr(),
406 stride_c.as_ptr(), stride_y.as_ptr(),
407 a_ptr, b_ptr, c_ptr, y_ptr,
408 core::ptr::null_mut(), 0, stream_ptr,
409 )
410 },
411 (TernaryKind::Clamp, ElementKind::Bf16) => unsafe {
412 baracuda_kernels_sys::baracuda_kernels_ternary_clamp_bf16_strided_run(
413 numel, rank, shape.as_ptr(),
414 stride_a.as_ptr(), stride_b.as_ptr(),
415 stride_c.as_ptr(), stride_y.as_ptr(),
416 a_ptr, b_ptr, c_ptr, y_ptr,
417 core::ptr::null_mut(), 0, stream_ptr,
418 )
419 },
420 (TernaryKind::Clamp, ElementKind::F64) => unsafe {
421 baracuda_kernels_sys::baracuda_kernels_ternary_clamp_f64_strided_run(
422 numel, rank, shape.as_ptr(),
423 stride_a.as_ptr(), stride_b.as_ptr(),
424 stride_c.as_ptr(), stride_y.as_ptr(),
425 a_ptr, b_ptr, c_ptr, y_ptr,
426 core::ptr::null_mut(), 0, stream_ptr,
427 )
428 },
429 (TernaryKind::Fma, ElementKind::F32) => unsafe {
431 baracuda_kernels_sys::baracuda_kernels_ternary_fma_f32_strided_run(
432 numel, rank, shape.as_ptr(),
433 stride_a.as_ptr(), stride_b.as_ptr(),
434 stride_c.as_ptr(), stride_y.as_ptr(),
435 a_ptr, b_ptr, c_ptr, y_ptr,
436 core::ptr::null_mut(), 0, stream_ptr,
437 )
438 },
439 (TernaryKind::Fma, ElementKind::F16) => unsafe {
440 baracuda_kernels_sys::baracuda_kernels_ternary_fma_f16_strided_run(
441 numel, rank, shape.as_ptr(),
442 stride_a.as_ptr(), stride_b.as_ptr(),
443 stride_c.as_ptr(), stride_y.as_ptr(),
444 a_ptr, b_ptr, c_ptr, y_ptr,
445 core::ptr::null_mut(), 0, stream_ptr,
446 )
447 },
448 (TernaryKind::Fma, ElementKind::Bf16) => unsafe {
449 baracuda_kernels_sys::baracuda_kernels_ternary_fma_bf16_strided_run(
450 numel, rank, shape.as_ptr(),
451 stride_a.as_ptr(), stride_b.as_ptr(),
452 stride_c.as_ptr(), stride_y.as_ptr(),
453 a_ptr, b_ptr, c_ptr, y_ptr,
454 core::ptr::null_mut(), 0, stream_ptr,
455 )
456 },
457 (TernaryKind::Fma, ElementKind::F64) => unsafe {
458 baracuda_kernels_sys::baracuda_kernels_ternary_fma_f64_strided_run(
459 numel, rank, shape.as_ptr(),
460 stride_a.as_ptr(), stride_b.as_ptr(),
461 stride_c.as_ptr(), stride_y.as_ptr(),
462 a_ptr, b_ptr, c_ptr, y_ptr,
463 core::ptr::null_mut(), 0, stream_ptr,
464 )
465 },
466 (TernaryKind::Addcmul, ElementKind::F32) => unsafe {
468 baracuda_kernels_sys::baracuda_kernels_ternary_addcmul_f32_strided_run(
469 numel, rank, shape.as_ptr(),
470 stride_a.as_ptr(), stride_b.as_ptr(),
471 stride_c.as_ptr(), stride_y.as_ptr(),
472 a_ptr, b_ptr, c_ptr, y_ptr, self.desc.scale,
473 core::ptr::null_mut(), 0, stream_ptr,
474 )
475 },
476 (TernaryKind::Addcmul, ElementKind::F16) => unsafe {
477 baracuda_kernels_sys::baracuda_kernels_ternary_addcmul_f16_strided_run(
478 numel, rank, shape.as_ptr(),
479 stride_a.as_ptr(), stride_b.as_ptr(),
480 stride_c.as_ptr(), stride_y.as_ptr(),
481 a_ptr, b_ptr, c_ptr, y_ptr, self.desc.scale,
482 core::ptr::null_mut(), 0, stream_ptr,
483 )
484 },
485 (TernaryKind::Addcmul, ElementKind::Bf16) => unsafe {
486 baracuda_kernels_sys::baracuda_kernels_ternary_addcmul_bf16_strided_run(
487 numel, rank, shape.as_ptr(),
488 stride_a.as_ptr(), stride_b.as_ptr(),
489 stride_c.as_ptr(), stride_y.as_ptr(),
490 a_ptr, b_ptr, c_ptr, y_ptr, self.desc.scale,
491 core::ptr::null_mut(), 0, stream_ptr,
492 )
493 },
494 (TernaryKind::Addcmul, ElementKind::F64) => unsafe {
495 baracuda_kernels_sys::baracuda_kernels_ternary_addcmul_f64_strided_run(
496 numel, rank, shape.as_ptr(),
497 stride_a.as_ptr(), stride_b.as_ptr(),
498 stride_c.as_ptr(), stride_y.as_ptr(),
499 a_ptr, b_ptr, c_ptr, y_ptr, self.desc.scale,
500 core::ptr::null_mut(), 0, stream_ptr,
501 )
502 },
503 (TernaryKind::Addcdiv, ElementKind::F32) => unsafe {
505 baracuda_kernels_sys::baracuda_kernels_ternary_addcdiv_f32_strided_run(
506 numel, rank, shape.as_ptr(),
507 stride_a.as_ptr(), stride_b.as_ptr(),
508 stride_c.as_ptr(), stride_y.as_ptr(),
509 a_ptr, b_ptr, c_ptr, y_ptr, self.desc.scale,
510 core::ptr::null_mut(), 0, stream_ptr,
511 )
512 },
513 (TernaryKind::Addcdiv, ElementKind::F16) => unsafe {
514 baracuda_kernels_sys::baracuda_kernels_ternary_addcdiv_f16_strided_run(
515 numel, rank, shape.as_ptr(),
516 stride_a.as_ptr(), stride_b.as_ptr(),
517 stride_c.as_ptr(), stride_y.as_ptr(),
518 a_ptr, b_ptr, c_ptr, y_ptr, self.desc.scale,
519 core::ptr::null_mut(), 0, stream_ptr,
520 )
521 },
522 (TernaryKind::Addcdiv, ElementKind::Bf16) => unsafe {
523 baracuda_kernels_sys::baracuda_kernels_ternary_addcdiv_bf16_strided_run(
524 numel, rank, shape.as_ptr(),
525 stride_a.as_ptr(), stride_b.as_ptr(),
526 stride_c.as_ptr(), stride_y.as_ptr(),
527 a_ptr, b_ptr, c_ptr, y_ptr, self.desc.scale,
528 core::ptr::null_mut(), 0, stream_ptr,
529 )
530 },
531 (TernaryKind::Addcdiv, ElementKind::F64) => unsafe {
532 baracuda_kernels_sys::baracuda_kernels_ternary_addcdiv_f64_strided_run(
533 numel, rank, shape.as_ptr(),
534 stride_a.as_ptr(), stride_b.as_ptr(),
535 stride_c.as_ptr(), stride_y.as_ptr(),
536 a_ptr, b_ptr, c_ptr, y_ptr, self.desc.scale,
537 core::ptr::null_mut(), 0, stream_ptr,
538 )
539 },
540 _ => {
541 return Err(Error::Unsupported(
542 "baracuda-kernels::TernaryPlan::run_strided reached an \
543 unimplemented (kind, dtype) pair — select() should have caught this",
544 ));
545 }
546 };
547 map_status(status)
548 }
549}
550
551fn map_status(code: i32) -> Result<()> {
552 match code {
553 0 => Ok(()),
554 1 => Err(Error::MisalignedOperand),
555 2 => Err(Error::InvalidProblem(
556 "baracuda-kernels-sys reported invalid problem",
557 )),
558 3 => Err(Error::Unsupported(
559 "baracuda-kernels-sys reported unsupported configuration",
560 )),
561 4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
562 n => Err(Error::CutlassInternal(n)),
563 }
564}