use std::ptr::{self, NonNull};
use crate::hip::DeviceMemory;
use crate::rocrand::bindings;
use crate::rocrand::error::{Error, Result};
pub trait Generator {
fn as_ptr(&self) -> bindings::rocrand_generator;
unsafe fn set_stream(&mut self, stream: bindings::hipStream_t) -> Result<()> {
unsafe { Error::from_status(bindings::rocrand_set_stream(self.as_ptr(), stream)) }
}
fn set_ordering(&mut self, ordering: bindings::rocrand_ordering) -> Result<()> {
unsafe { Error::from_status(bindings::rocrand_set_ordering(self.as_ptr(), ordering)) }
}
fn initialize(&mut self) -> Result<()> {
unsafe { Error::from_status(bindings::rocrand_initialize_generator(self.as_ptr())) }
}
fn get_version() -> Result<i32> {
let mut version = 0;
unsafe {
Error::from_status(bindings::rocrand_get_version(&mut version))?;
Ok(version)
}
}
}
pub struct PseudoRng {
generator: NonNull<bindings::rocrand_generator_base_type>,
}
impl PseudoRng {
pub fn new(rng_type: bindings::rocrand_rng_type) -> Result<Self> {
let mut generator = ptr::null_mut();
unsafe {
Error::from_status(bindings::rocrand_create_generator(&mut generator, rng_type))?;
Ok(Self {
generator: NonNull::new(generator).unwrap(),
})
}
}
pub fn new_host(rng_type: bindings::rocrand_rng_type) -> Result<Self> {
let mut generator = ptr::null_mut();
unsafe {
Error::from_status(bindings::rocrand_create_generator_host(
&mut generator,
rng_type,
))?;
Ok(Self {
generator: NonNull::new(generator).unwrap(),
})
}
}
pub fn set_seed(&mut self, seed: u64) -> Result<()> {
unsafe { Error::from_status(bindings::rocrand_set_seed(self.generator.as_ptr(), seed)) }
}
pub fn set_seed_array(&mut self, seed: u128) -> Result<()> {
unsafe {
Error::from_status(bindings::rocrand_set_seed_uint4(
self.generator.as_ptr(),
seed,
))
}
}
pub fn set_offset(&mut self, offset: u64) -> Result<()> {
unsafe {
Error::from_status(bindings::rocrand_set_offset(
self.generator.as_ptr(),
offset,
))
}
}
pub fn generate_u32(&mut self, output: &mut DeviceMemory<u32>) -> Result<()> {
unsafe {
Error::from_status(bindings::rocrand_generate(
self.generator.as_ptr(),
output.as_ptr().cast(),
output.count(),
))
}
}
pub fn generate_u64(&mut self, output: &mut DeviceMemory<u64>) -> Result<()> {
unsafe {
Error::from_status(bindings::rocrand_generate_long_long(
self.generator.as_ptr(),
output.as_ptr().cast(),
output.count(),
))
}
}
pub fn generate_u8(&mut self, output: &mut DeviceMemory<u8>) -> Result<()> {
unsafe {
Error::from_status(bindings::rocrand_generate_char(
self.generator.as_ptr(),
output.as_ptr().cast(),
output.count(),
))
}
}
pub fn generate_u16(&mut self, output: &mut DeviceMemory<u16>) -> Result<()> {
unsafe {
Error::from_status(bindings::rocrand_generate_short(
self.generator.as_ptr(),
output.as_ptr().cast(),
output.count(),
))
}
}
pub fn generate_uniform(&mut self, output: &mut DeviceMemory<f32>) -> Result<()> {
unsafe {
Error::from_status(bindings::rocrand_generate_uniform(
self.generator.as_ptr(),
output.as_ptr().cast(),
output.count(),
))
}
}
pub fn generate_uniform_double(&mut self, output: &mut DeviceMemory<f64>) -> Result<()> {
unsafe {
Error::from_status(bindings::rocrand_generate_uniform_double(
self.generator.as_ptr(),
output.as_ptr().cast(),
output.count(),
))
}
}
pub fn generate_normal(
&mut self,
output: &mut DeviceMemory<f32>,
mean: f32,
stddev: f32,
) -> Result<()> {
unsafe {
Error::from_status(bindings::rocrand_generate_normal(
self.generator.as_ptr(),
output.as_ptr().cast(),
output.count(),
mean,
stddev,
))
}
}
pub fn generate_normal_double(
&mut self,
output: &mut DeviceMemory<f64>,
mean: f64,
stddev: f64,
) -> Result<()> {
unsafe {
Error::from_status(bindings::rocrand_generate_normal_double(
self.generator.as_ptr(),
output.as_ptr().cast(),
output.count(),
mean,
stddev,
))
}
}
pub fn generate_log_normal(
&mut self,
output: &mut DeviceMemory<f32>,
mean: f32,
stddev: f32,
) -> Result<()> {
unsafe {
Error::from_status(bindings::rocrand_generate_log_normal(
self.generator.as_ptr(),
output.as_ptr().cast(),
output.count(),
mean,
stddev,
))
}
}
pub fn generate_log_normal_double(
&mut self,
output: &mut DeviceMemory<f64>,
mean: f64,
stddev: f64,
) -> Result<()> {
unsafe {
Error::from_status(bindings::rocrand_generate_log_normal_double(
self.generator.as_ptr(),
output.as_ptr().cast(),
output.count(),
mean,
stddev,
))
}
}
pub fn generate_poisson(&mut self, output: &mut DeviceMemory<u32>, lambda: f64) -> Result<()> {
unsafe {
Error::from_status(bindings::rocrand_generate_poisson(
self.generator.as_ptr(),
output.as_ptr().cast(),
output.count(),
lambda,
))
}
}
}
impl Generator for PseudoRng {
fn as_ptr(&self) -> bindings::rocrand_generator {
self.generator.as_ptr()
}
}
impl Drop for PseudoRng {
fn drop(&mut self) {
unsafe {
let _ = bindings::rocrand_destroy_generator(self.generator.as_ptr());
}
}
}
pub struct QuasiRng {
generator: NonNull<bindings::rocrand_generator_base_type>,
}
impl QuasiRng {
pub fn new(rng_type: bindings::rocrand_rng_type) -> Result<Self> {
let mut generator = ptr::null_mut();
unsafe {
Error::from_status(bindings::rocrand_create_generator(&mut generator, rng_type))?;
Ok(Self {
generator: NonNull::new(generator).unwrap(),
})
}
}
pub fn set_dimensions(&mut self, dimensions: u32) -> Result<()> {
unsafe {
Error::from_status(bindings::rocrand_set_quasi_random_generator_dimensions(
self.generator.as_ptr(),
dimensions,
))
}
}
pub fn set_offset(&mut self, offset: u64) -> Result<()> {
unsafe {
Error::from_status(bindings::rocrand_set_offset(
self.generator.as_ptr(),
offset,
))
}
}
pub fn generate_uniform(&mut self, output: &mut DeviceMemory<f32>) -> Result<()> {
unsafe {
Error::from_status(bindings::rocrand_generate_uniform(
self.generator.as_ptr(),
output.as_ptr().cast(),
output.count(),
))
}
}
pub fn generate_uniform_double(&mut self, output: &mut DeviceMemory<f64>) -> Result<()> {
unsafe {
Error::from_status(bindings::rocrand_generate_uniform_double(
self.generator.as_ptr(),
output.as_ptr().cast(),
output.count(),
))
}
}
}
impl Generator for QuasiRng {
fn as_ptr(&self) -> bindings::rocrand_generator {
self.generator.as_ptr()
}
}
impl Drop for QuasiRng {
fn drop(&mut self) {
unsafe {
let _ = bindings::rocrand_destroy_generator(self.generator.as_ptr());
}
}
}