use core::ffi::c_void;
use core::marker::PhantomData;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_types::{
ArchSku, BackendKind, BinaryKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
};
#[derive(Copy, Clone, Debug)]
pub struct BinaryBackwardDescriptor<const N: usize> {
pub kind: BinaryKind,
pub shape: [i32; N],
pub element: ElementKind,
}
pub struct BinaryBackwardArgs<'a, T: Element, const N: usize> {
pub dy: TensorRef<'a, T, N>,
pub a: Option<TensorRef<'a, T, N>>,
pub b: Option<TensorRef<'a, T, N>>,
pub da: TensorMut<'a, T, N>,
pub db: TensorMut<'a, T, N>,
}
pub struct BinaryBackwardPlan<T: Element, const N: usize> {
desc: BinaryBackwardDescriptor<N>,
sku: KernelSku,
_marker: PhantomData<T>,
}
#[inline]
fn op_needs_saves(kind: BinaryKind) -> bool {
matches!(
kind,
BinaryKind::Mul
| BinaryKind::Div
| BinaryKind::Pow
| BinaryKind::Maximum
| BinaryKind::Minimum
| BinaryKind::Atan2
| BinaryKind::Hypot
)
}
impl<T: Element, const N: usize> BinaryBackwardPlan<T, N> {
pub fn select(
_stream: &Stream,
desc: &BinaryBackwardDescriptor<N>,
_pref: PlanPreference,
) -> Result<Self> {
if desc.element != T::KIND {
return Err(Error::Unsupported(
"baracuda-kernels::BinaryBackwardPlan: descriptor element != T",
));
}
for &d in desc.shape.iter() {
if d < 0 {
return Err(Error::InvalidProblem(
"baracuda-kernels::BinaryBackwardPlan: shape dims must be non-negative",
));
}
}
let supported = matches!(
(desc.kind, T::KIND),
(BinaryKind::Add, ElementKind::F32)
| (BinaryKind::Add, ElementKind::F16)
| (BinaryKind::Add, ElementKind::Bf16)
| (BinaryKind::Add, ElementKind::F64)
| (BinaryKind::Sub, ElementKind::F32)
| (BinaryKind::Sub, ElementKind::F16)
| (BinaryKind::Sub, ElementKind::Bf16)
| (BinaryKind::Sub, ElementKind::F64)
| (BinaryKind::Mul, ElementKind::F32)
| (BinaryKind::Mul, ElementKind::F16)
| (BinaryKind::Mul, ElementKind::Bf16)
| (BinaryKind::Mul, ElementKind::F64)
| (BinaryKind::Div, ElementKind::F32)
| (BinaryKind::Div, ElementKind::F16)
| (BinaryKind::Div, ElementKind::Bf16)
| (BinaryKind::Div, ElementKind::F64)
| (BinaryKind::Maximum, ElementKind::F32)
| (BinaryKind::Maximum, ElementKind::F16)
| (BinaryKind::Maximum, ElementKind::Bf16)
| (BinaryKind::Maximum, ElementKind::F64)
| (BinaryKind::Minimum, ElementKind::F32)
| (BinaryKind::Minimum, ElementKind::F16)
| (BinaryKind::Minimum, ElementKind::Bf16)
| (BinaryKind::Minimum, ElementKind::F64)
| (BinaryKind::Pow, ElementKind::F32)
| (BinaryKind::Pow, ElementKind::F16)
| (BinaryKind::Pow, ElementKind::Bf16)
| (BinaryKind::Pow, ElementKind::F64)
| (BinaryKind::Atan2, ElementKind::F32)
| (BinaryKind::Atan2, ElementKind::F16)
| (BinaryKind::Atan2, ElementKind::Bf16)
| (BinaryKind::Atan2, ElementKind::F64)
| (BinaryKind::Hypot, ElementKind::F32)
| (BinaryKind::Hypot, ElementKind::F16)
| (BinaryKind::Hypot, ElementKind::Bf16)
| (BinaryKind::Hypot, ElementKind::F64)
);
if !supported {
return Err(Error::Unsupported(
"baracuda-kernels::BinaryBackwardPlan: only \
`{Add,Sub,Mul,Div,Maximum,Minimum,Pow,Atan2,Hypot}` × \
`{f32, f16, bf16, f64}` are wired today; other (kind, dtype) \
pairs (e.g. integer family, Lerp) land in later fanout. Lerp \
is reserved-but-deferred pending a parameterized-binary plan \
shape.",
));
}
let precision_guarantee = PrecisionGuarantee {
math_precision: MathPrecision::F32,
accumulator: ElementKind::F32,
bit_stable_on_same_hardware: true,
deterministic: true,
};
let sku = KernelSku {
category: OpCategory::BinaryElementwise,
op: desc.kind as u16,
element: T::KIND,
aux_element: None,
layout: None,
epilogue: None,
arch: ArchSku::Sm80,
backend: BackendKind::Bespoke,
precision_guarantee,
};
Ok(Self {
desc: *desc,
sku,
_marker: PhantomData,
})
}
pub fn can_implement(&self, args: &BinaryBackwardArgs<'_, T, N>) -> Result<()> {
if args.dy.shape != self.desc.shape {
return Err(Error::InvalidProblem(
"baracuda-kernels::BinaryBackwardPlan: dy shape mismatch",
));
}
if args.da.shape != self.desc.shape {
return Err(Error::InvalidProblem(
"baracuda-kernels::BinaryBackwardPlan: da shape mismatch",
));
}
if args.db.shape != self.desc.shape {
return Err(Error::InvalidProblem(
"baracuda-kernels::BinaryBackwardPlan: db shape mismatch",
));
}
if !args.dy.is_contiguous() || !args.da.is_contiguous() || !args.db.is_contiguous() {
return Err(Error::Unsupported(
"baracuda-kernels::BinaryBackwardPlan: trailblazer requires contiguous \
dy / da / db; strided fanout lands later",
));
}
if op_needs_saves(self.desc.kind) {
let a = args.a.as_ref().ok_or(Error::InvalidProblem(
"baracuda-kernels::BinaryBackwardPlan: this op requires saved input `a`",
))?;
let b = args.b.as_ref().ok_or(Error::InvalidProblem(
"baracuda-kernels::BinaryBackwardPlan: this op requires saved input `b`",
))?;
if a.shape != self.desc.shape {
return Err(Error::InvalidProblem(
"baracuda-kernels::BinaryBackwardPlan: saved a shape mismatch",
));
}
if b.shape != self.desc.shape {
return Err(Error::InvalidProblem(
"baracuda-kernels::BinaryBackwardPlan: saved b shape mismatch",
));
}
if !a.is_contiguous() || !b.is_contiguous() {
return Err(Error::Unsupported(
"baracuda-kernels::BinaryBackwardPlan: saved a/b must be contiguous \
(strided fanout lands later)",
));
}
let numel = args.dy.numel() as usize;
if a.data.len() < numel {
return Err(Error::BufferTooSmall {
needed: numel,
got: a.data.len(),
});
}
if b.data.len() < numel {
return Err(Error::BufferTooSmall {
needed: numel,
got: b.data.len(),
});
}
}
let numel = args.dy.numel();
let dy_len = args.dy.data.len() as i64;
let da_len = args.da.data.len() as i64;
let db_len = args.db.data.len() as i64;
if dy_len < numel || da_len < numel || db_len < numel {
return Err(Error::BufferTooSmall {
needed: numel as usize,
got: dy_len.min(da_len).min(db_len) as usize,
});
}
Ok(())
}
#[inline]
pub fn workspace_size(&self) -> usize {
0
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
pub fn run(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: BinaryBackwardArgs<'_, T, N>,
) -> Result<()> {
self.can_implement(&args)?;
let numel = args.dy.numel();
if numel == 0 {
return Ok(());
}
let dy_ptr = args.dy.data.as_raw().0 as *const c_void;
let da_ptr = args.da.data.as_raw().0 as *mut c_void;
let db_ptr = args.db.data.as_raw().0 as *mut c_void;
let stream_ptr = stream.as_raw() as *mut c_void;
let status = match (self.desc.kind, T::KIND) {
(BinaryKind::Add, ElementKind::F32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_binary_add_backward_f32_run(
numel, dy_ptr, da_ptr, db_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(BinaryKind::Add, ElementKind::F16) => unsafe {
baracuda_kernels_sys::baracuda_kernels_binary_add_backward_f16_run(
numel, dy_ptr, da_ptr, db_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(BinaryKind::Add, ElementKind::Bf16) => unsafe {
baracuda_kernels_sys::baracuda_kernels_binary_add_backward_bf16_run(
numel, dy_ptr, da_ptr, db_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(BinaryKind::Add, ElementKind::F64) => unsafe {
baracuda_kernels_sys::baracuda_kernels_binary_add_backward_f64_run(
numel, dy_ptr, da_ptr, db_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(BinaryKind::Sub, ElementKind::F32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_binary_sub_backward_f32_run(
numel, dy_ptr, da_ptr, db_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(BinaryKind::Sub, ElementKind::F16) => unsafe {
baracuda_kernels_sys::baracuda_kernels_binary_sub_backward_f16_run(
numel, dy_ptr, da_ptr, db_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(BinaryKind::Sub, ElementKind::Bf16) => unsafe {
baracuda_kernels_sys::baracuda_kernels_binary_sub_backward_bf16_run(
numel, dy_ptr, da_ptr, db_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(BinaryKind::Sub, ElementKind::F64) => unsafe {
baracuda_kernels_sys::baracuda_kernels_binary_sub_backward_f64_run(
numel, dy_ptr, da_ptr, db_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(BinaryKind::Mul, ElementKind::F32) => {
let (a_ptr, b_ptr) = saved_ptrs(&args);
unsafe {
baracuda_kernels_sys::baracuda_kernels_binary_mul_backward_f32_run(
numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
}
}
(BinaryKind::Mul, ElementKind::F16) => {
let (a_ptr, b_ptr) = saved_ptrs(&args);
unsafe {
baracuda_kernels_sys::baracuda_kernels_binary_mul_backward_f16_run(
numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
}
}
(BinaryKind::Mul, ElementKind::Bf16) => {
let (a_ptr, b_ptr) = saved_ptrs(&args);
unsafe {
baracuda_kernels_sys::baracuda_kernels_binary_mul_backward_bf16_run(
numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
}
}
(BinaryKind::Mul, ElementKind::F64) => {
let (a_ptr, b_ptr) = saved_ptrs(&args);
unsafe {
baracuda_kernels_sys::baracuda_kernels_binary_mul_backward_f64_run(
numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
}
}
(BinaryKind::Div, ElementKind::F32) => {
let (a_ptr, b_ptr) = saved_ptrs(&args);
unsafe {
baracuda_kernels_sys::baracuda_kernels_binary_div_backward_f32_run(
numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
}
}
(BinaryKind::Div, ElementKind::F16) => {
let (a_ptr, b_ptr) = saved_ptrs(&args);
unsafe {
baracuda_kernels_sys::baracuda_kernels_binary_div_backward_f16_run(
numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
}
}
(BinaryKind::Div, ElementKind::Bf16) => {
let (a_ptr, b_ptr) = saved_ptrs(&args);
unsafe {
baracuda_kernels_sys::baracuda_kernels_binary_div_backward_bf16_run(
numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
}
}
(BinaryKind::Div, ElementKind::F64) => {
let (a_ptr, b_ptr) = saved_ptrs(&args);
unsafe {
baracuda_kernels_sys::baracuda_kernels_binary_div_backward_f64_run(
numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
}
}
(BinaryKind::Maximum, ElementKind::F32) => {
let (a_ptr, b_ptr) = saved_ptrs(&args);
unsafe {
baracuda_kernels_sys::baracuda_kernels_binary_maximum_backward_f32_run(
numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
}
}
(BinaryKind::Maximum, ElementKind::F16) => {
let (a_ptr, b_ptr) = saved_ptrs(&args);
unsafe {
baracuda_kernels_sys::baracuda_kernels_binary_maximum_backward_f16_run(
numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
}
}
(BinaryKind::Maximum, ElementKind::Bf16) => {
let (a_ptr, b_ptr) = saved_ptrs(&args);
unsafe {
baracuda_kernels_sys::baracuda_kernels_binary_maximum_backward_bf16_run(
numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
}
}
(BinaryKind::Maximum, ElementKind::F64) => {
let (a_ptr, b_ptr) = saved_ptrs(&args);
unsafe {
baracuda_kernels_sys::baracuda_kernels_binary_maximum_backward_f64_run(
numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
}
}
(BinaryKind::Minimum, ElementKind::F32) => {
let (a_ptr, b_ptr) = saved_ptrs(&args);
unsafe {
baracuda_kernels_sys::baracuda_kernels_binary_minimum_backward_f32_run(
numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
}
}
(BinaryKind::Minimum, ElementKind::F16) => {
let (a_ptr, b_ptr) = saved_ptrs(&args);
unsafe {
baracuda_kernels_sys::baracuda_kernels_binary_minimum_backward_f16_run(
numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
}
}
(BinaryKind::Minimum, ElementKind::Bf16) => {
let (a_ptr, b_ptr) = saved_ptrs(&args);
unsafe {
baracuda_kernels_sys::baracuda_kernels_binary_minimum_backward_bf16_run(
numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
}
}
(BinaryKind::Minimum, ElementKind::F64) => {
let (a_ptr, b_ptr) = saved_ptrs(&args);
unsafe {
baracuda_kernels_sys::baracuda_kernels_binary_minimum_backward_f64_run(
numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
}
}
(BinaryKind::Pow, ElementKind::F32) => {
let (a_ptr, b_ptr) = saved_ptrs(&args);
unsafe {
baracuda_kernels_sys::baracuda_kernels_binary_pow_backward_f32_run(
numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
}
}
(BinaryKind::Pow, ElementKind::F16) => {
let (a_ptr, b_ptr) = saved_ptrs(&args);
unsafe {
baracuda_kernels_sys::baracuda_kernels_binary_pow_backward_f16_run(
numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
}
}
(BinaryKind::Pow, ElementKind::Bf16) => {
let (a_ptr, b_ptr) = saved_ptrs(&args);
unsafe {
baracuda_kernels_sys::baracuda_kernels_binary_pow_backward_bf16_run(
numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
}
}
(BinaryKind::Pow, ElementKind::F64) => {
let (a_ptr, b_ptr) = saved_ptrs(&args);
unsafe {
baracuda_kernels_sys::baracuda_kernels_binary_pow_backward_f64_run(
numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
}
}
(BinaryKind::Atan2, ElementKind::F32) => {
let (a_ptr, b_ptr) = saved_ptrs(&args);
unsafe {
baracuda_kernels_sys::baracuda_kernels_binary_atan2_backward_f32_run(
numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
}
}
(BinaryKind::Atan2, ElementKind::F16) => {
let (a_ptr, b_ptr) = saved_ptrs(&args);
unsafe {
baracuda_kernels_sys::baracuda_kernels_binary_atan2_backward_f16_run(
numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
}
}
(BinaryKind::Atan2, ElementKind::Bf16) => {
let (a_ptr, b_ptr) = saved_ptrs(&args);
unsafe {
baracuda_kernels_sys::baracuda_kernels_binary_atan2_backward_bf16_run(
numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
}
}
(BinaryKind::Atan2, ElementKind::F64) => {
let (a_ptr, b_ptr) = saved_ptrs(&args);
unsafe {
baracuda_kernels_sys::baracuda_kernels_binary_atan2_backward_f64_run(
numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
}
}
(BinaryKind::Hypot, ElementKind::F32) => {
let (a_ptr, b_ptr) = saved_ptrs(&args);
unsafe {
baracuda_kernels_sys::baracuda_kernels_binary_hypot_backward_f32_run(
numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
}
}
(BinaryKind::Hypot, ElementKind::F16) => {
let (a_ptr, b_ptr) = saved_ptrs(&args);
unsafe {
baracuda_kernels_sys::baracuda_kernels_binary_hypot_backward_f16_run(
numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
}
}
(BinaryKind::Hypot, ElementKind::Bf16) => {
let (a_ptr, b_ptr) = saved_ptrs(&args);
unsafe {
baracuda_kernels_sys::baracuda_kernels_binary_hypot_backward_bf16_run(
numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
}
}
(BinaryKind::Hypot, ElementKind::F64) => {
let (a_ptr, b_ptr) = saved_ptrs(&args);
unsafe {
baracuda_kernels_sys::baracuda_kernels_binary_hypot_backward_f64_run(
numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
}
}
_ => {
return Err(Error::Unsupported(
"baracuda-kernels::BinaryBackwardPlan::run reached an unimplemented \
(kind, dtype) pair — select() should have caught this",
));
}
};
map_status(status)
}
}
#[inline]
fn saved_ptrs<T: Element, const N: usize>(
args: &BinaryBackwardArgs<'_, T, N>,
) -> (*const c_void, *const c_void) {
let a = args
.a
.as_ref()
.expect("Mul/Div/Pow/Maximum/Minimum/Atan2/Hypot backward require saved a");
let b = args
.b
.as_ref()
.expect("Mul/Div/Pow/Maximum/Minimum/Atan2/Hypot backward require saved b");
(
a.data.as_raw().0 as *const c_void,
b.data.as_raw().0 as *const c_void,
)
}
fn map_status(code: i32) -> Result<()> {
match code {
0 => Ok(()),
1 => Err(Error::MisalignedOperand),
2 => Err(Error::InvalidProblem(
"baracuda-kernels-sys reported invalid problem",
)),
3 => Err(Error::Unsupported(
"baracuda-kernels-sys reported unsupported configuration",
)),
4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
n => Err(Error::CutlassInternal(n)),
}
}