#![cfg(feature = "curand-host")]
use cudarc::curand::result::CurandError;
use cudarc::curand::sys;
use crate::sys::curand as csys;
use crate::sys::curand::RngGeneratorKind;
pub struct HostRng {
gen: sys::curandGenerator_t,
kind: RngGeneratorKind,
}
unsafe impl Send for HostRng {}
impl HostRng {
pub fn new(kind: RngGeneratorKind, seed: u64) -> Result<Self, CurandError> {
let gen = unsafe { csys::create_generator_host(kind)? };
if !kind.is_quasi() {
unsafe { csys::set_seed(gen, seed)? };
}
Ok(Self { gen, kind })
}
pub fn kind(&self) -> RngGeneratorKind {
self.kind
}
pub fn set_seed(&mut self, seed: u64) -> Result<(), CurandError> {
if self.kind.is_quasi() {
return Ok(());
}
unsafe { csys::set_seed(self.gen, seed) }
}
pub fn fill_u32(&mut self, out: &mut [u32]) -> Result<(), CurandError> {
let n = out.len();
unsafe { csys::generate_u32(self.gen, out.as_mut_ptr(), n) }
}
pub fn fill_u64(&mut self, out: &mut [u64]) -> Result<(), CurandError> {
let n = out.len();
unsafe { csys::generate_u64(self.gen, out.as_mut_ptr(), n) }
}
pub fn fill_uniform_f32(&mut self, out: &mut [f32]) -> Result<(), CurandError> {
let n = out.len();
unsafe { sys::curandGenerateUniform(self.gen, out.as_mut_ptr(), n).result() }
}
pub fn fill_uniform_f64(&mut self, out: &mut [f64]) -> Result<(), CurandError> {
let n = out.len();
unsafe { sys::curandGenerateUniformDouble(self.gen, out.as_mut_ptr(), n).result() }
}
pub fn fill_normal_f32(
&mut self,
out: &mut [f32],
mean: f32,
std: f32,
) -> Result<(), CurandError> {
let n = out.len();
unsafe { sys::curandGenerateNormal(self.gen, out.as_mut_ptr(), n, mean, std).result() }
}
pub fn fill_normal_f64(
&mut self,
out: &mut [f64],
mean: f64,
std: f64,
) -> Result<(), CurandError> {
let n = out.len();
unsafe {
sys::curandGenerateNormalDouble(self.gen, out.as_mut_ptr(), n, mean, std).result()
}
}
pub fn fill_poisson_u32(&mut self, out: &mut [u32], lambda: f64) -> Result<(), CurandError> {
let n = out.len();
unsafe { csys::generate_poisson_u32(self.gen, out.as_mut_ptr(), n, lambda) }
}
}
impl Drop for HostRng {
fn drop(&mut self) {
if !self.gen.is_null() {
let _ = unsafe { csys::destroy_generator(self.gen) };
self.gen = std::ptr::null_mut();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn host_api_generator_constructs() {
let f: fn(RngGeneratorKind, u64) -> Result<HostRng, CurandError> = HostRng::new;
let _ = f;
}
}