use std::mem::MaybeUninit;
use cudarc::curand::result::CurandError;
use cudarc::curand::sys;
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Default)]
pub enum RngGeneratorKind {
#[default]
PseudoDefault,
Philox4_32_10,
XorWow,
Mrg32K3A,
Mtgp32,
Sobol32,
ScrambledSobol32,
Sobol64,
ScrambledSobol64,
}
impl RngGeneratorKind {
pub fn is_quasi(self) -> bool {
matches!(
self,
Self::Sobol32 | Self::ScrambledSobol32 | Self::Sobol64 | Self::ScrambledSobol64
)
}
pub fn is_quasi_64(self) -> bool {
matches!(self, Self::Sobol64 | Self::ScrambledSobol64)
}
pub fn to_sys(self) -> sys::curandRngType_t {
match self {
Self::PseudoDefault => sys::curandRngType_t::CURAND_RNG_PSEUDO_DEFAULT,
Self::Philox4_32_10 => sys::curandRngType_t::CURAND_RNG_PSEUDO_PHILOX4_32_10,
Self::XorWow => sys::curandRngType_t::CURAND_RNG_PSEUDO_XORWOW,
Self::Mrg32K3A => sys::curandRngType_t::CURAND_RNG_PSEUDO_MRG32K3A,
Self::Mtgp32 => sys::curandRngType_t::CURAND_RNG_PSEUDO_MTGP32,
Self::Sobol32 => sys::curandRngType_t::CURAND_RNG_QUASI_SOBOL32,
Self::ScrambledSobol32 => sys::curandRngType_t::CURAND_RNG_QUASI_SCRAMBLED_SOBOL32,
Self::Sobol64 => sys::curandRngType_t::CURAND_RNG_QUASI_SOBOL64,
Self::ScrambledSobol64 => sys::curandRngType_t::CURAND_RNG_QUASI_SCRAMBLED_SOBOL64,
}
}
}
pub unsafe fn create_generator(
kind: RngGeneratorKind,
) -> Result<sys::curandGenerator_t, CurandError> {
let mut g = MaybeUninit::uninit();
sys::curandCreateGenerator(g.as_mut_ptr(), kind.to_sys()).result()?;
Ok(g.assume_init())
}
#[cfg(feature = "curand-host")]
pub unsafe fn create_generator_host(
kind: RngGeneratorKind,
) -> Result<sys::curandGenerator_t, CurandError> {
let mut g = MaybeUninit::uninit();
sys::curandCreateGeneratorHost(g.as_mut_ptr(), kind.to_sys()).result()?;
Ok(g.assume_init())
}
pub unsafe fn set_stream(
gen: sys::curandGenerator_t,
stream: sys::cudaStream_t,
) -> Result<(), CurandError> {
sys::curandSetStream(gen, stream).result()
}
pub unsafe fn set_seed(gen: sys::curandGenerator_t, seed: u64) -> Result<(), CurandError> {
sys::curandSetPseudoRandomGeneratorSeed(gen, seed).result()
}
pub unsafe fn set_offset(gen: sys::curandGenerator_t, offset: u64) -> Result<(), CurandError> {
sys::curandSetGeneratorOffset(gen, offset).result()
}
#[cfg(feature = "curand-quasirandom")]
pub unsafe fn set_quasi_random_dimensions(
gen: sys::curandGenerator_t,
dimensions: u32,
) -> Result<(), CurandError> {
sys::curandSetQuasiRandomGeneratorDimensions(gen, dimensions).result()
}
pub unsafe fn destroy_generator(gen: sys::curandGenerator_t) -> Result<(), CurandError> {
sys::curandDestroyGenerator(gen).result()
}
pub unsafe fn generate_poisson_u32(
gen: sys::curandGenerator_t,
out: *mut u32,
n: usize,
lambda: f64,
) -> Result<(), CurandError> {
sys::curandGeneratePoisson(gen, out, n, lambda).result()
}
pub unsafe fn generate_u32(
gen: sys::curandGenerator_t,
out: *mut u32,
n: usize,
) -> Result<(), CurandError> {
sys::curandGenerate(gen, out, n).result()
}
pub unsafe fn generate_u64(
gen: sys::curandGenerator_t,
out: *mut u64,
n: usize,
) -> Result<(), CurandError> {
sys::curandGenerateLongLong(gen, out as *mut std::os::raw::c_ulonglong, n).result()
}