use core::ffi::c_void;
use core::marker::PhantomData;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_types::{
ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, TernaryKind, Workspace,
};
#[derive(Copy, Clone, Debug)]
pub struct TernaryBackwardDescriptor<const N: usize> {
pub kind: TernaryKind,
pub shape: [i32; N],
pub element: ElementKind,
pub scale: f32,
}
pub struct TernaryBackwardArgs<'a, T: Element, const N: usize> {
pub dy: TensorRef<'a, T, N>,
pub a: TensorRef<'a, T, N>,
pub b: TensorRef<'a, T, N>,
pub c: TensorRef<'a, T, N>,
pub da: TensorMut<'a, T, N>,
pub db: TensorMut<'a, T, N>,
pub dc: TensorMut<'a, T, N>,
}
pub struct TernaryBackwardPlan<T: Element, const N: usize> {
desc: TernaryBackwardDescriptor<N>,
sku: KernelSku,
_marker: PhantomData<T>,
}
impl<T: Element, const N: usize> TernaryBackwardPlan<T, N> {
pub fn select(
_stream: &Stream,
desc: &TernaryBackwardDescriptor<N>,
_pref: PlanPreference,
) -> Result<Self> {
if desc.element != T::KIND {
return Err(Error::Unsupported(
"baracuda-kernels::TernaryBackwardPlan: descriptor element != type parameter T",
));
}
for &d in desc.shape.iter() {
if d < 0 {
return Err(Error::InvalidProblem(
"baracuda-kernels::TernaryBackwardPlan: shape dims must be non-negative",
));
}
}
if matches!(desc.kind, TernaryKind::Where) {
return Err(Error::Unsupported(
"baracuda-kernels::TernaryBackwardPlan: `Where` backward needs a \
heterogeneous-dtype plan shape (cond is u8, value tensors are T) — \
it will land as a separate `WhereBackwardPlan` in a future session; \
this plan only handles the homogeneous-dtype ternary family \
(Fma, Clamp, Addcmul, Addcdiv).",
));
}
let kind_in_scope = matches!(
desc.kind,
TernaryKind::Fma | TernaryKind::Clamp | TernaryKind::Addcmul | TernaryKind::Addcdiv
);
let dtype_in_scope = matches!(
T::KIND,
ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
);
if !(kind_in_scope && dtype_in_scope) {
return Err(Error::Unsupported(
"baracuda-kernels::TernaryBackwardPlan: this (kind, dtype) cell is not \
wired today — the trailblazer covers {Fma, Clamp, Addcmul, Addcdiv} × \
{f32, f16, bf16, f64}. Integer / other dtype cells land in later fanout.",
));
}
let precision_guarantee = PrecisionGuarantee {
math_precision: MathPrecision::F32,
accumulator: ElementKind::F32,
bit_stable_on_same_hardware: true,
deterministic: true,
};
let sku = KernelSku {
category: OpCategory::TernaryElementwise,
op: desc.kind as u16,
element: T::KIND,
aux_element: None,
layout: None,
epilogue: None,
arch: ArchSku::Sm80,
backend: BackendKind::Bespoke,
precision_guarantee,
};
Ok(Self {
desc: *desc,
sku,
_marker: PhantomData,
})
}
pub fn can_implement(&self, args: &TernaryBackwardArgs<'_, T, N>) -> Result<()> {
if args.dy.shape != self.desc.shape {
return Err(Error::InvalidProblem(
"baracuda-kernels::TernaryBackwardPlan: dy shape mismatch",
));
}
if args.a.shape != self.desc.shape
|| args.b.shape != self.desc.shape
|| args.c.shape != self.desc.shape
{
return Err(Error::InvalidProblem(
"baracuda-kernels::TernaryBackwardPlan: saved a/b/c shape mismatch",
));
}
if args.da.shape != self.desc.shape
|| args.db.shape != self.desc.shape
|| args.dc.shape != self.desc.shape
{
return Err(Error::InvalidProblem(
"baracuda-kernels::TernaryBackwardPlan: da/db/dc shape mismatch",
));
}
if !args.dy.is_contiguous()
|| !args.a.is_contiguous()
|| !args.b.is_contiguous()
|| !args.c.is_contiguous()
|| !args.da.is_contiguous()
|| !args.db.is_contiguous()
|| !args.dc.is_contiguous()
{
return Err(Error::Unsupported(
"baracuda-kernels::TernaryBackwardPlan: trailblazer requires contiguous \
dy / a / b / c / da / db / dc; strided fanout lands later",
));
}
let numel = args.dy.numel();
let needed = numel as usize;
let lens = [
args.dy.data.len(),
args.a.data.len(),
args.b.data.len(),
args.c.data.len(),
args.da.data.len(),
args.db.data.len(),
args.dc.data.len(),
];
if let Some(&min_len) = lens.iter().min() {
if min_len < needed {
return Err(Error::BufferTooSmall {
needed,
got: min_len,
});
}
}
Ok(())
}
#[inline]
pub fn workspace_size(&self) -> usize {
0
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
pub fn run(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: TernaryBackwardArgs<'_, T, N>,
) -> Result<()> {
self.can_implement(&args)?;
let numel = args.dy.numel();
if numel == 0 {
return Ok(());
}
let dy_ptr = args.dy.data.as_raw().0 as *const c_void;
let a_ptr = args.a.data.as_raw().0 as *const c_void;
let b_ptr = args.b.data.as_raw().0 as *const c_void;
let c_ptr = args.c.data.as_raw().0 as *const c_void;
let da_ptr = args.da.data.as_raw().0 as *mut c_void;
let db_ptr = args.db.data.as_raw().0 as *mut c_void;
let dc_ptr = args.dc.data.as_raw().0 as *mut c_void;
let stream_ptr = stream.as_raw() as *mut c_void;
let scale = self.desc.scale;
let status = match (self.desc.kind, T::KIND) {
(TernaryKind::Fma, ElementKind::F32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_ternary_fma_backward_f32_run(
numel, dy_ptr, a_ptr, b_ptr, c_ptr, da_ptr, db_ptr, dc_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(TernaryKind::Fma, ElementKind::F16) => unsafe {
baracuda_kernels_sys::baracuda_kernels_ternary_fma_backward_f16_run(
numel, dy_ptr, a_ptr, b_ptr, c_ptr, da_ptr, db_ptr, dc_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(TernaryKind::Fma, ElementKind::Bf16) => unsafe {
baracuda_kernels_sys::baracuda_kernels_ternary_fma_backward_bf16_run(
numel, dy_ptr, a_ptr, b_ptr, c_ptr, da_ptr, db_ptr, dc_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(TernaryKind::Fma, ElementKind::F64) => unsafe {
baracuda_kernels_sys::baracuda_kernels_ternary_fma_backward_f64_run(
numel, dy_ptr, a_ptr, b_ptr, c_ptr, da_ptr, db_ptr, dc_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(TernaryKind::Clamp, ElementKind::F32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_ternary_clamp_backward_f32_run(
numel, dy_ptr, a_ptr, b_ptr, c_ptr, da_ptr, db_ptr, dc_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(TernaryKind::Clamp, ElementKind::F16) => unsafe {
baracuda_kernels_sys::baracuda_kernels_ternary_clamp_backward_f16_run(
numel, dy_ptr, a_ptr, b_ptr, c_ptr, da_ptr, db_ptr, dc_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(TernaryKind::Clamp, ElementKind::Bf16) => unsafe {
baracuda_kernels_sys::baracuda_kernels_ternary_clamp_backward_bf16_run(
numel, dy_ptr, a_ptr, b_ptr, c_ptr, da_ptr, db_ptr, dc_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(TernaryKind::Clamp, ElementKind::F64) => unsafe {
baracuda_kernels_sys::baracuda_kernels_ternary_clamp_backward_f64_run(
numel, dy_ptr, a_ptr, b_ptr, c_ptr, da_ptr, db_ptr, dc_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(TernaryKind::Addcmul, ElementKind::F32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_ternary_addcmul_backward_f32_run(
numel, dy_ptr, a_ptr, b_ptr, c_ptr, da_ptr, db_ptr, dc_ptr, scale,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(TernaryKind::Addcmul, ElementKind::F16) => unsafe {
baracuda_kernels_sys::baracuda_kernels_ternary_addcmul_backward_f16_run(
numel, dy_ptr, a_ptr, b_ptr, c_ptr, da_ptr, db_ptr, dc_ptr, scale,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(TernaryKind::Addcmul, ElementKind::Bf16) => unsafe {
baracuda_kernels_sys::baracuda_kernels_ternary_addcmul_backward_bf16_run(
numel, dy_ptr, a_ptr, b_ptr, c_ptr, da_ptr, db_ptr, dc_ptr, scale,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(TernaryKind::Addcmul, ElementKind::F64) => unsafe {
baracuda_kernels_sys::baracuda_kernels_ternary_addcmul_backward_f64_run(
numel, dy_ptr, a_ptr, b_ptr, c_ptr, da_ptr, db_ptr, dc_ptr, scale,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(TernaryKind::Addcdiv, ElementKind::F32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_ternary_addcdiv_backward_f32_run(
numel, dy_ptr, a_ptr, b_ptr, c_ptr, da_ptr, db_ptr, dc_ptr, scale,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(TernaryKind::Addcdiv, ElementKind::F16) => unsafe {
baracuda_kernels_sys::baracuda_kernels_ternary_addcdiv_backward_f16_run(
numel, dy_ptr, a_ptr, b_ptr, c_ptr, da_ptr, db_ptr, dc_ptr, scale,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(TernaryKind::Addcdiv, ElementKind::Bf16) => unsafe {
baracuda_kernels_sys::baracuda_kernels_ternary_addcdiv_backward_bf16_run(
numel, dy_ptr, a_ptr, b_ptr, c_ptr, da_ptr, db_ptr, dc_ptr, scale,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(TernaryKind::Addcdiv, ElementKind::F64) => unsafe {
baracuda_kernels_sys::baracuda_kernels_ternary_addcdiv_backward_f64_run(
numel, dy_ptr, a_ptr, b_ptr, c_ptr, da_ptr, db_ptr, dc_ptr, scale,
core::ptr::null_mut(), 0, stream_ptr,
)
},
_ => {
return Err(Error::Unsupported(
"baracuda-kernels::TernaryBackwardPlan::run reached an \
unimplemented (kind, dtype) pair — select() should have caught this",
));
}
};
map_status(status)
}
}
fn map_status(code: i32) -> Result<()> {
match code {
0 => Ok(()),
1 => Err(Error::MisalignedOperand),
2 => Err(Error::InvalidProblem(
"baracuda-kernels-sys reported invalid problem",
)),
3 => Err(Error::Unsupported(
"baracuda-kernels-sys reported unsupported configuration",
)),
4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
n => Err(Error::CutlassInternal(n)),
}
}