use core::cell::Cell;
use core::ffi::c_void;
use core::marker::PhantomData;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_sys::{
curandCreateGenerator, curandDestroyGenerator, curandGenerateNormal,
curandGenerateNormalDouble, curandGenerateUniform, curandGenerateUniformDouble,
curandGenerator_t, curandSetPseudoRandomGeneratorSeed, curandSetStream,
};
use baracuda_kernels_types::{
ArchSku, BackendKind, Bool, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
PlanPreference, PrecisionGuarantee, RandomKind, TensorMut, Workspace,
};
#[derive(Copy, Clone, Debug)]
pub struct RandomDescriptor<const N: usize> {
pub kind: RandomKind,
pub shape: [i32; N],
pub element: ElementKind,
pub param1: f32,
pub param2: f32,
pub seed: u64,
}
pub struct RandomArgs<'a, T: Element, const N: usize> {
pub y: TensorMut<'a, T, N>,
}
pub struct RandomBoolArgs<'a, const N: usize> {
pub y: TensorMut<'a, Bool, N>,
}
pub struct RandomPlan<T: Element, const N: usize> {
desc: RandomDescriptor<N>,
sku: KernelSku,
generator: Cell<curandGenerator_t>,
_marker: PhantomData<T>,
}
impl<T: Element, const N: usize> RandomPlan<T, N> {
pub fn select(
_stream: &Stream,
desc: &RandomDescriptor<N>,
_pref: PlanPreference,
) -> Result<Self> {
if desc.element != T::KIND {
return Err(Error::Unsupported(
"baracuda-kernels::RandomPlan: descriptor.element != T::KIND",
));
}
for &d in desc.shape.iter() {
if d < 0 {
return Err(Error::InvalidProblem(
"baracuda-kernels::RandomPlan: shape dims must be non-negative",
));
}
}
if N > 8 {
return Err(Error::Unsupported(
"baracuda-kernels::RandomPlan: tensor rank > 8 not supported",
));
}
let supported = matches!(
(desc.kind, T::KIND),
(RandomKind::Uniform, ElementKind::F32)
| (RandomKind::Uniform, ElementKind::F64)
| (RandomKind::Normal, ElementKind::F32)
| (RandomKind::Normal, ElementKind::F64)
| (RandomKind::Bernoulli, ElementKind::Bool)
);
if !supported {
return Err(Error::Unsupported(
"baracuda-kernels::RandomPlan: wired today: \
`{Uniform, Normal} × {f32, f64}` and `Bernoulli × Bool`",
));
}
if matches!(desc.kind, RandomKind::Bernoulli) {
let p = desc.param1;
if !(p >= 0.0 && p <= 1.0) {
return Err(Error::InvalidProblem(
"baracuda-kernels::RandomPlan(Bernoulli): p must be in [0, 1]",
));
}
}
if matches!(desc.kind, RandomKind::Normal) && !(desc.param2 > 0.0) {
return Err(Error::InvalidProblem(
"baracuda-kernels::RandomPlan(Normal): stddev (param2) must be > 0",
));
}
let backend = match desc.kind {
RandomKind::Uniform | RandomKind::Normal => BackendKind::Curand,
RandomKind::Bernoulli => BackendKind::Bespoke,
_ => BackendKind::Bespoke,
};
let math_precision = match T::KIND {
ElementKind::F64 => MathPrecision::F64,
_ => MathPrecision::F32,
};
let precision_guarantee = PrecisionGuarantee {
math_precision,
accumulator: T::KIND,
bit_stable_on_same_hardware: true,
deterministic: true,
};
let sku = KernelSku {
category: OpCategory::Random,
op: desc.kind as u16,
element: T::KIND,
aux_element: None,
layout: None,
epilogue: None,
arch: ArchSku::Sm80,
backend,
precision_guarantee,
};
Ok(Self {
desc: *desc,
sku,
generator: Cell::new(core::ptr::null_mut()),
_marker: PhantomData,
})
}
#[inline]
pub fn workspace_size(&self) -> usize {
if matches!(self.desc.kind, RandomKind::Bernoulli) {
let numel: i64 = self.desc.shape.iter().map(|&d| d as i64).product();
(numel.max(0) as usize) * core::mem::size_of::<f32>()
} else {
0
}
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
fn ensure_generator(&self) -> Result<curandGenerator_t> {
let g = self.generator.get();
if !g.is_null() {
return Ok(g);
}
let mut handle: curandGenerator_t = core::ptr::null_mut();
let status =
unsafe { curandCreateGenerator(&mut handle as *mut _, 100) };
if status != 0 {
return Err(Error::CutlassInternal(curand_to_status(status)));
}
let status = unsafe { curandSetPseudoRandomGeneratorSeed(handle, self.desc.seed) };
if status != 0 {
unsafe {
let _ = curandDestroyGenerator(handle);
}
return Err(Error::CutlassInternal(curand_to_status(status)));
}
self.generator.set(handle);
Ok(handle)
}
fn bind_stream(&self, gen_handle: curandGenerator_t, stream: &Stream) -> Result<()> {
let stream_ptr = stream.as_raw() as *mut c_void;
let status = unsafe { curandSetStream(gen_handle, stream_ptr) };
if status != 0 {
return Err(Error::CutlassInternal(curand_to_status(status)));
}
Ok(())
}
fn check_shape<U: baracuda_types::DeviceRepr + Copy + 'static>(
&self,
y: &TensorMut<'_, U, N>,
) -> Result<i64> {
if y.shape != self.desc.shape {
return Err(Error::InvalidProblem(
"baracuda-kernels::RandomPlan: y shape != descriptor shape",
));
}
let numel = y.numel();
let len = y.data.len() as i64;
if len < numel {
return Err(Error::BufferTooSmall {
needed: numel as usize,
got: len as usize,
});
}
Ok(numel)
}
}
impl<const N: usize> RandomPlan<f32, N> {
pub fn run(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: RandomArgs<'_, f32, N>,
) -> Result<()> {
let numel = self.check_shape(&args.y)?;
if numel == 0 {
return Ok(());
}
let gen_handle = self.ensure_generator()?;
self.bind_stream(gen_handle, stream)?;
let ptr = args.y.data.as_raw().0 as *mut f32;
let n = numel as usize;
match self.desc.kind {
RandomKind::Uniform => {
let status = unsafe { curandGenerateUniform(gen_handle, ptr, n) };
if status != 0 {
return Err(Error::CutlassInternal(curand_to_status(status)));
}
let low = self.desc.param1;
let high = self.desc.param2;
if (low, high) != (0.0, 1.0) {
affine_transform_f32(stream, ptr, n, high - low, low)?;
}
Ok(())
}
RandomKind::Normal => {
let mean = self.desc.param1;
let stddev = self.desc.param2;
let status = unsafe { curandGenerateNormal(gen_handle, ptr, n, mean, stddev) };
if status != 0 {
return Err(Error::CutlassInternal(curand_to_status(status)));
}
Ok(())
}
RandomKind::Bernoulli => Err(Error::Unsupported(
"baracuda-kernels::RandomPlan<f32>: Bernoulli has Bool output — use RandomPlan<Bool>",
)),
_ => Err(Error::Unsupported(
"baracuda-kernels::RandomPlan<f32>::run reached an unimplemented RandomKind variant",
)),
}
}
}
impl<const N: usize> RandomPlan<f64, N> {
pub fn run(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: RandomArgs<'_, f64, N>,
) -> Result<()> {
let numel = self.check_shape(&args.y)?;
if numel == 0 {
return Ok(());
}
let gen_handle = self.ensure_generator()?;
self.bind_stream(gen_handle, stream)?;
let ptr = args.y.data.as_raw().0 as *mut f64;
let n = numel as usize;
match self.desc.kind {
RandomKind::Uniform => {
let status = unsafe { curandGenerateUniformDouble(gen_handle, ptr, n) };
if status != 0 {
return Err(Error::CutlassInternal(curand_to_status(status)));
}
let low = self.desc.param1 as f64;
let high = self.desc.param2 as f64;
if (low, high) != (0.0, 1.0) {
affine_transform_f64(stream, ptr, n, high - low, low)?;
}
Ok(())
}
RandomKind::Normal => {
let mean = self.desc.param1 as f64;
let stddev = self.desc.param2 as f64;
let status = unsafe { curandGenerateNormalDouble(gen_handle, ptr, n, mean, stddev) };
if status != 0 {
return Err(Error::CutlassInternal(curand_to_status(status)));
}
Ok(())
}
RandomKind::Bernoulli => Err(Error::Unsupported(
"baracuda-kernels::RandomPlan<f64>: Bernoulli has Bool output — use RandomPlan<Bool>",
)),
_ => Err(Error::Unsupported(
"baracuda-kernels::RandomPlan<f64>::run reached an unimplemented RandomKind variant",
)),
}
}
}
impl<const N: usize> RandomPlan<Bool, N> {
pub fn run(
&self,
stream: &Stream,
workspace: Workspace<'_>,
args: RandomBoolArgs<'_, N>,
) -> Result<()> {
if !matches!(self.desc.kind, RandomKind::Bernoulli) {
return Err(Error::Unsupported(
"baracuda-kernels::RandomPlan<Bool>: only Bernoulli is wired \
(Uniform / Normal use the FP variants)",
));
}
let numel = self.check_shape(&args.y)?;
if numel == 0 {
return Ok(());
}
let needed = self.workspace_size();
let (ws_ptr, ws_bytes): (*mut c_void, usize) = match workspace {
Workspace::None => {
return Err(Error::WorkspaceTooSmall {
needed,
got: 0,
})
}
Workspace::Borrowed(slice) => {
if slice.len() < needed {
return Err(Error::WorkspaceTooSmall {
needed,
got: slice.len(),
});
}
(slice.as_raw().0 as *mut c_void, slice.len())
}
};
let gen_handle = self.ensure_generator()?;
self.bind_stream(gen_handle, stream)?;
let rand_ptr = ws_ptr as *mut f32;
let n = numel as usize;
let status = unsafe { curandGenerateUniform(gen_handle, rand_ptr, n) };
if status != 0 {
return Err(Error::CutlassInternal(curand_to_status(status)));
}
let y_ptr = args.y.data.as_raw().0 as *mut c_void;
let stream_ptr = stream.as_raw() as *mut c_void;
let status = unsafe {
baracuda_kernels_sys::baracuda_kernels_bernoulli_run(
numel,
self.desc.param1,
rand_ptr as *const c_void,
y_ptr,
core::ptr::null_mut(),
ws_bytes, stream_ptr,
)
};
map_status(status)
}
}
impl<T: Element, const N: usize> Drop for RandomPlan<T, N> {
fn drop(&mut self) {
let g = self.generator.get();
if !g.is_null() {
unsafe {
let _ = curandDestroyGenerator(g);
}
self.generator.set(core::ptr::null_mut());
}
}
}
fn curand_to_status(curand_code: i32) -> i32 {
if curand_code == 0 {
0
} else {
-curand_code
}
}
fn map_status(code: i32) -> Result<()> {
match code {
0 => Ok(()),
1 => Err(Error::MisalignedOperand),
2 => Err(Error::InvalidProblem(
"baracuda-kernels-sys reported invalid problem",
)),
3 => Err(Error::Unsupported(
"baracuda-kernels-sys reported unsupported configuration",
)),
4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
n => Err(Error::CutlassInternal(n)),
}
}
fn affine_transform_f32(
stream: &Stream,
ptr: *mut f32,
n: usize,
scale: f32,
offset: f32,
) -> Result<()> {
let stream_ptr = stream.as_raw() as *mut c_void;
let status = unsafe {
baracuda_kernels_sys::baracuda_kernels_affine_inplace_f32_run(
n as i64,
scale,
offset,
ptr as *mut c_void,
core::ptr::null_mut(),
0,
stream_ptr,
)
};
map_status(status)
}
fn affine_transform_f64(
stream: &Stream,
ptr: *mut f64,
n: usize,
scale: f64,
offset: f64,
) -> Result<()> {
let stream_ptr = stream.as_raw() as *mut c_void;
let status = unsafe {
baracuda_kernels_sys::baracuda_kernels_affine_inplace_f64_run(
n as i64,
scale,
offset,
ptr as *mut c_void,
core::ptr::null_mut(),
0,
stream_ptr,
)
};
map_status(status)
}