#![allow(non_camel_case_types)]
#![allow(clippy::too_many_arguments)]
use core::ffi::c_void;
use core::ptr;
use super::{
baracuda_kernels_affine_inplace_f32_run, baracuda_kernels_affine_inplace_f64_run,
curandCreateGenerator, curandDestroyGenerator, curandGenerateNormal,
curandGenerateNormalDouble, curandGenerateUniform, curandGenerateUniformDouble,
curandGenerator_t, curandSetPseudoRandomGeneratorSeed, curandSetStream,
CURAND_RNG_PSEUDO_DEFAULT,
};
const OK: i32 = 0;
const INVALID: i32 = 2;
const INTERNAL: i32 = 5;
#[inline]
fn map_curand(status: i32) -> i32 {
if status == 0 { OK } else { INTERNAL }
}
struct Generator {
g: curandGenerator_t,
}
impl Generator {
#[inline]
fn new() -> Self {
Self { g: ptr::null_mut() }
}
}
impl Drop for Generator {
fn drop(&mut self) {
if !self.g.is_null() {
unsafe {
let _ = curandDestroyGenerator(self.g);
}
}
}
}
#[inline]
unsafe fn setup_generator(g: &mut Generator, seed: u64, stream: *mut c_void) -> i32 {
let s = unsafe { curandCreateGenerator(&mut g.g as *mut _, CURAND_RNG_PSEUDO_DEFAULT) };
if s != 0 {
return INTERNAL;
}
let s = unsafe { curandSetPseudoRandomGeneratorSeed(g.g, seed) };
if s != 0 {
return INTERNAL;
}
let s = unsafe { curandSetStream(g.g, stream) };
if s != 0 {
return INTERNAL;
}
OK
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn baracuda_kernels_curand_uniform_f32_workspace_size(
_numel: i64,
out_bytes: *mut usize,
) -> i32 {
if out_bytes.is_null() {
return INVALID;
}
unsafe { *out_bytes = 0 };
OK
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn baracuda_kernels_curand_uniform_f64_workspace_size(
_numel: i64,
out_bytes: *mut usize,
) -> i32 {
if out_bytes.is_null() {
return INVALID;
}
unsafe { *out_bytes = 0 };
OK
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn baracuda_kernels_curand_uniform_f32_run(
numel: i64,
low: f32,
high: f32,
seed: u64,
y: *mut c_void,
_workspace: *mut c_void,
_workspace_bytes: usize,
stream: *mut c_void,
) -> i32 {
if numel < 0 || y.is_null() {
return INVALID;
}
if !(high > low) && numel > 0 {
return INVALID;
}
if numel == 0 {
return OK;
}
let mut g = Generator::new();
let s = unsafe { setup_generator(&mut g, seed, stream) };
if s != OK {
return s;
}
let st = unsafe { curandGenerateUniform(g.g, y as *mut f32, numel as usize) };
if st != 0 {
return INTERNAL;
}
if low != 0.0 || high != 1.0 {
let scale = high - low;
let s = unsafe {
baracuda_kernels_affine_inplace_f32_run(
numel,
scale,
low,
y,
ptr::null_mut(),
0,
stream,
)
};
if s != OK {
return s;
}
}
OK
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn baracuda_kernels_curand_uniform_f64_run(
numel: i64,
low: f64,
high: f64,
seed: u64,
y: *mut c_void,
_workspace: *mut c_void,
_workspace_bytes: usize,
stream: *mut c_void,
) -> i32 {
if numel < 0 || y.is_null() {
return INVALID;
}
if !(high > low) && numel > 0 {
return INVALID;
}
if numel == 0 {
return OK;
}
let mut g = Generator::new();
let s = unsafe { setup_generator(&mut g, seed, stream) };
if s != OK {
return s;
}
let st = unsafe { curandGenerateUniformDouble(g.g, y as *mut f64, numel as usize) };
if st != 0 {
return INTERNAL;
}
if low != 0.0 || high != 1.0 {
let scale = high - low;
let s = unsafe {
baracuda_kernels_affine_inplace_f64_run(
numel,
scale,
low,
y,
ptr::null_mut(),
0,
stream,
)
};
if s != OK {
return s;
}
}
OK
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn baracuda_kernels_curand_normal_f32_workspace_size(
_numel: i64,
out_bytes: *mut usize,
) -> i32 {
if out_bytes.is_null() {
return INVALID;
}
unsafe { *out_bytes = 0 };
OK
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn baracuda_kernels_curand_normal_f64_workspace_size(
_numel: i64,
out_bytes: *mut usize,
) -> i32 {
if out_bytes.is_null() {
return INVALID;
}
unsafe { *out_bytes = 0 };
OK
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn baracuda_kernels_curand_normal_f32_run(
numel: i64,
mean: f32,
stddev: f32,
seed: u64,
y: *mut c_void,
_workspace: *mut c_void,
_workspace_bytes: usize,
stream: *mut c_void,
) -> i32 {
if numel < 0 || y.is_null() || !(stddev > 0.0) {
return INVALID;
}
if numel == 0 {
return OK;
}
let mut g = Generator::new();
let s = unsafe { setup_generator(&mut g, seed, stream) };
if s != OK {
return s;
}
let st = unsafe {
curandGenerateNormal(g.g, y as *mut f32, numel as usize, mean, stddev)
};
map_curand(st)
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn baracuda_kernels_curand_normal_f64_run(
numel: i64,
mean: f64,
stddev: f64,
seed: u64,
y: *mut c_void,
_workspace: *mut c_void,
_workspace_bytes: usize,
stream: *mut c_void,
) -> i32 {
if numel < 0 || y.is_null() || !(stddev > 0.0) {
return INVALID;
}
if numel == 0 {
return OK;
}
let mut g = Generator::new();
let s = unsafe { setup_generator(&mut g, seed, stream) };
if s != OK {
return s;
}
let st = unsafe {
curandGenerateNormalDouble(g.g, y as *mut f64, numel as usize, mean, stddev)
};
map_curand(st)
}