use core::cell::Cell;
use core::ffi::c_void;
use core::marker::PhantomData;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_sys::{
curandCreateGenerator, curandDestroyGenerator, curandGenerateUniform, curandGenerator_t,
curandSetPseudoRandomGeneratorSeed, curandSetStream,
};
use baracuda_kernels_types::{
ArchSku, BackendKind, Bool, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
};
#[derive(Copy, Clone, Debug)]
pub struct DropoutDescriptor<const N: usize> {
pub shape: [i32; N],
pub element: ElementKind,
pub p: f32,
pub seed: u64,
}
pub struct DropoutArgs<'a, T: Element, const N: usize> {
pub x: TensorRef<'a, T, N>,
pub y: TensorMut<'a, T, N>,
pub mask: TensorMut<'a, Bool, N>,
}
pub struct DropoutPlan<T: Element, const N: usize> {
desc: DropoutDescriptor<N>,
sku: KernelSku,
generator: Cell<curandGenerator_t>,
_marker: PhantomData<T>,
}
impl<T: Element, const N: usize> DropoutPlan<T, N> {
pub fn select(
_stream: &Stream,
desc: &DropoutDescriptor<N>,
_pref: PlanPreference,
) -> Result<Self> {
if desc.element != T::KIND {
return Err(Error::Unsupported(
"baracuda-kernels::DropoutPlan: descriptor.element != T::KIND",
));
}
if !matches!(T::KIND, ElementKind::F32 | ElementKind::F64) {
return Err(Error::Unsupported(
"baracuda-kernels::DropoutPlan: wired today: f32 + f64",
));
}
for &d in desc.shape.iter() {
if d < 0 {
return Err(Error::InvalidProblem(
"baracuda-kernels::DropoutPlan: shape dims must be non-negative",
));
}
}
if N > 8 {
return Err(Error::Unsupported(
"baracuda-kernels::DropoutPlan: tensor rank > 8 not supported",
));
}
if !(desc.p >= 0.0 && desc.p <= 1.0) {
return Err(Error::InvalidProblem(
"baracuda-kernels::DropoutPlan: p must be in [0, 1]",
));
}
let math_precision = match T::KIND {
ElementKind::F64 => MathPrecision::F64,
_ => MathPrecision::F32,
};
let precision_guarantee = PrecisionGuarantee {
math_precision,
accumulator: T::KIND,
bit_stable_on_same_hardware: true,
deterministic: true,
};
let sku = KernelSku {
category: OpCategory::Random,
op: 100, element: T::KIND,
aux_element: Some(ElementKind::Bool),
layout: None,
epilogue: None,
arch: ArchSku::Sm80,
backend: BackendKind::Bespoke,
precision_guarantee,
};
Ok(Self {
desc: *desc,
sku,
generator: Cell::new(core::ptr::null_mut()),
_marker: PhantomData,
})
}
#[inline]
pub fn workspace_size(&self) -> usize {
let numel: i64 = self.desc.shape.iter().map(|&d| d as i64).product();
(numel.max(0) as usize) * core::mem::size_of::<f32>()
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
fn ensure_generator(&self) -> Result<curandGenerator_t> {
let g = self.generator.get();
if !g.is_null() {
return Ok(g);
}
let mut handle: curandGenerator_t = core::ptr::null_mut();
let status =
unsafe { curandCreateGenerator(&mut handle as *mut _, 100) };
if status != 0 {
return Err(Error::CutlassInternal(-status));
}
let status = unsafe { curandSetPseudoRandomGeneratorSeed(handle, self.desc.seed) };
if status != 0 {
unsafe {
let _ = curandDestroyGenerator(handle);
}
return Err(Error::CutlassInternal(-status));
}
self.generator.set(handle);
Ok(handle)
}
fn check_args(&self, args: &DropoutArgs<'_, T, N>) -> Result<i64> {
if args.x.shape != self.desc.shape
|| args.y.shape != self.desc.shape
|| args.mask.shape != self.desc.shape
{
return Err(Error::InvalidProblem(
"baracuda-kernels::DropoutPlan: shape mismatch (x / y / mask)",
));
}
let numel = args.y.numel();
let xlen = args.x.data.len() as i64;
let ylen = args.y.data.len() as i64;
let mlen = args.mask.data.len() as i64;
if xlen < numel || ylen < numel || mlen < numel {
return Err(Error::BufferTooSmall {
needed: numel as usize,
got: xlen.min(ylen).min(mlen) as usize,
});
}
Ok(numel)
}
}
impl<const N: usize> DropoutPlan<f32, N> {
pub fn run(
&self,
stream: &Stream,
workspace: Workspace<'_>,
args: DropoutArgs<'_, f32, N>,
) -> Result<()> {
let numel = self.check_args(&args)?;
if numel == 0 {
return Ok(());
}
let needed = self.workspace_size();
let (ws_ptr, ws_bytes): (*mut c_void, usize) = match workspace {
Workspace::None => {
return Err(Error::WorkspaceTooSmall {
needed,
got: 0,
})
}
Workspace::Borrowed(slice) => {
if slice.len() < needed {
return Err(Error::WorkspaceTooSmall {
needed,
got: slice.len(),
});
}
(slice.as_raw().0 as *mut c_void, slice.len())
}
};
let stream_ptr = stream.as_raw() as *mut c_void;
let x_ptr = args.x.data.as_raw().0 as *const c_void;
let y_ptr = args.y.data.as_raw().0 as *mut c_void;
let mask_ptr = args.mask.data.as_raw().0 as *mut c_void;
let rand_ptr = ws_ptr as *mut f32;
let gen_handle = self.ensure_generator()?;
let status = unsafe { curandSetStream(gen_handle, stream_ptr) };
if status != 0 {
return Err(Error::CutlassInternal(-status));
}
let status = unsafe { curandGenerateUniform(gen_handle, rand_ptr, numel as usize) };
if status != 0 {
return Err(Error::CutlassInternal(-status));
}
let p = self.desc.p;
let scale = if p < 1.0 { 1.0_f32 / (1.0 - p) } else { 0.0_f32 };
let status = unsafe {
baracuda_kernels_sys::baracuda_kernels_dropout_f32_run(
numel,
p,
scale,
x_ptr,
rand_ptr as *const c_void,
y_ptr,
mask_ptr,
core::ptr::null_mut(),
ws_bytes,
stream_ptr,
)
};
map_status(status)
}
}
impl<const N: usize> DropoutPlan<f64, N> {
pub fn run(
&self,
stream: &Stream,
workspace: Workspace<'_>,
args: DropoutArgs<'_, f64, N>,
) -> Result<()> {
let numel = self.check_args(&args)?;
if numel == 0 {
return Ok(());
}
let needed = self.workspace_size();
let (ws_ptr, ws_bytes): (*mut c_void, usize) = match workspace {
Workspace::None => {
return Err(Error::WorkspaceTooSmall {
needed,
got: 0,
})
}
Workspace::Borrowed(slice) => {
if slice.len() < needed {
return Err(Error::WorkspaceTooSmall {
needed,
got: slice.len(),
});
}
(slice.as_raw().0 as *mut c_void, slice.len())
}
};
let stream_ptr = stream.as_raw() as *mut c_void;
let x_ptr = args.x.data.as_raw().0 as *const c_void;
let y_ptr = args.y.data.as_raw().0 as *mut c_void;
let mask_ptr = args.mask.data.as_raw().0 as *mut c_void;
let rand_ptr = ws_ptr as *mut f32;
let gen_handle = self.ensure_generator()?;
let status = unsafe { curandSetStream(gen_handle, stream_ptr) };
if status != 0 {
return Err(Error::CutlassInternal(-status));
}
let status = unsafe { curandGenerateUniform(gen_handle, rand_ptr, numel as usize) };
if status != 0 {
return Err(Error::CutlassInternal(-status));
}
let p = self.desc.p;
let scale = if p < 1.0 { 1.0_f64 / (1.0 - p as f64) } else { 0.0_f64 };
let status = unsafe {
baracuda_kernels_sys::baracuda_kernels_dropout_f64_run(
numel,
p,
scale,
x_ptr,
rand_ptr as *const c_void,
y_ptr,
mask_ptr,
core::ptr::null_mut(),
ws_bytes,
stream_ptr,
)
};
map_status(status)
}
}
impl<T: Element, const N: usize> Drop for DropoutPlan<T, N> {
fn drop(&mut self) {
let g = self.generator.get();
if !g.is_null() {
unsafe {
let _ = curandDestroyGenerator(g);
}
self.generator.set(core::ptr::null_mut());
}
}
}
#[derive(Copy, Clone, Debug)]
pub struct DropoutBackwardDescriptor<const N: usize> {
pub shape: [i32; N],
pub element: ElementKind,
pub p: f32,
}
pub struct DropoutBackwardArgs<'a, T: Element, const N: usize> {
pub dy: TensorRef<'a, T, N>,
pub mask: TensorRef<'a, Bool, N>,
pub dx: TensorMut<'a, T, N>,
}
pub struct DropoutBackwardPlan<T: Element, const N: usize> {
desc: DropoutBackwardDescriptor<N>,
sku: KernelSku,
_marker: PhantomData<T>,
}
impl<T: Element, const N: usize> DropoutBackwardPlan<T, N> {
pub fn select(
_stream: &Stream,
desc: &DropoutBackwardDescriptor<N>,
_pref: PlanPreference,
) -> Result<Self> {
if desc.element != T::KIND {
return Err(Error::Unsupported(
"baracuda-kernels::DropoutBackwardPlan: descriptor.element != T::KIND",
));
}
if !matches!(T::KIND, ElementKind::F32 | ElementKind::F64) {
return Err(Error::Unsupported(
"baracuda-kernels::DropoutBackwardPlan: wired today: f32 + f64",
));
}
for &d in desc.shape.iter() {
if d < 0 {
return Err(Error::InvalidProblem(
"baracuda-kernels::DropoutBackwardPlan: shape dims must be non-negative",
));
}
}
if N > 8 {
return Err(Error::Unsupported(
"baracuda-kernels::DropoutBackwardPlan: tensor rank > 8 not supported",
));
}
if !(desc.p >= 0.0 && desc.p <= 1.0) {
return Err(Error::InvalidProblem(
"baracuda-kernels::DropoutBackwardPlan: p must be in [0, 1]",
));
}
let math_precision = match T::KIND {
ElementKind::F64 => MathPrecision::F64,
_ => MathPrecision::F32,
};
let precision_guarantee = PrecisionGuarantee {
math_precision,
accumulator: T::KIND,
bit_stable_on_same_hardware: true,
deterministic: true,
};
let sku = KernelSku {
category: OpCategory::Random,
op: 101, element: T::KIND,
aux_element: Some(ElementKind::Bool),
layout: None,
epilogue: None,
arch: ArchSku::Sm80,
backend: BackendKind::Bespoke,
precision_guarantee,
};
Ok(Self {
desc: *desc,
sku,
_marker: PhantomData,
})
}
#[inline]
pub fn workspace_size(&self) -> usize {
0
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
fn check_args(&self, args: &DropoutBackwardArgs<'_, T, N>) -> Result<i64> {
if args.dy.shape != self.desc.shape
|| args.mask.shape != self.desc.shape
|| args.dx.shape != self.desc.shape
{
return Err(Error::InvalidProblem(
"baracuda-kernels::DropoutBackwardPlan: shape mismatch",
));
}
let numel = args.dy.numel();
let dylen = args.dy.data.len() as i64;
let mlen = args.mask.data.len() as i64;
let dxlen = args.dx.data.len() as i64;
if dylen < numel || mlen < numel || dxlen < numel {
return Err(Error::BufferTooSmall {
needed: numel as usize,
got: dylen.min(mlen).min(dxlen) as usize,
});
}
Ok(numel)
}
}
impl<const N: usize> DropoutBackwardPlan<f32, N> {
pub fn run(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: DropoutBackwardArgs<'_, f32, N>,
) -> Result<()> {
let numel = self.check_args(&args)?;
if numel == 0 {
return Ok(());
}
let stream_ptr = stream.as_raw() as *mut c_void;
let dy_ptr = args.dy.data.as_raw().0 as *const c_void;
let mask_ptr = args.mask.data.as_raw().0 as *const c_void;
let dx_ptr = args.dx.data.as_raw().0 as *mut c_void;
let p = self.desc.p;
let scale = if p < 1.0 { 1.0_f32 / (1.0 - p) } else { 0.0_f32 };
let status = unsafe {
baracuda_kernels_sys::baracuda_kernels_dropout_backward_f32_run(
numel,
scale,
dy_ptr,
mask_ptr,
dx_ptr,
core::ptr::null_mut(),
0,
stream_ptr,
)
};
map_status(status)
}
}
impl<const N: usize> DropoutBackwardPlan<f64, N> {
pub fn run(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: DropoutBackwardArgs<'_, f64, N>,
) -> Result<()> {
let numel = self.check_args(&args)?;
if numel == 0 {
return Ok(());
}
let stream_ptr = stream.as_raw() as *mut c_void;
let dy_ptr = args.dy.data.as_raw().0 as *const c_void;
let mask_ptr = args.mask.data.as_raw().0 as *const c_void;
let dx_ptr = args.dx.data.as_raw().0 as *mut c_void;
let p = self.desc.p;
let scale = if p < 1.0 { 1.0_f64 / (1.0 - p as f64) } else { 0.0_f64 };
let status = unsafe {
baracuda_kernels_sys::baracuda_kernels_dropout_backward_f64_run(
numel,
scale,
dy_ptr,
mask_ptr,
dx_ptr,
core::ptr::null_mut(),
0,
stream_ptr,
)
};
map_status(status)
}
}
fn map_status(code: i32) -> Result<()> {
match code {
0 => Ok(()),
1 => Err(Error::MisalignedOperand),
2 => Err(Error::InvalidProblem(
"baracuda-kernels-sys reported invalid problem",
)),
3 => Err(Error::Unsupported(
"baracuda-kernels-sys reported unsupported configuration",
)),
4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
n => Err(Error::CutlassInternal(n)),
}
}