use crate::error::Result;
use crate::hip::DeviceMemory;
use crate::rocrand::{
Generator, LogNormal, Normal, Poisson, PseudoRng, QuasiRng, Uniform, rng_type,
};
macro_rules! generate_uniform_rand_func {
($fn_name: ident, $data_type:ty, $generato_fn:ident, $rng_type:ident) => {
paste::paste! {
#[doc = "Generate random " $data_type " values on device"]
pub fn $fn_name(
count: usize,
seed: Option<u64>,
) -> Result<DeviceMemory<$data_type>> {
let mut generator = PseudoRng::new(rng_type::$rng_type)?;
if let Some(seed_value) = seed {
generator.set_seed(seed_value)?;
}
generator.initialize()?;
let mut device_output = DeviceMemory::<$data_type>::new(count)?;
generator.$generato_fn(&mut device_output)?;
Ok(device_output)
}
}
};
}
generate_uniform_rand_func!(generate_uniform_f32, f32, generate_uniform, XORWOW);
generate_uniform_rand_func!(generate_uniform_f64, f64, generate_uniform_double, XORWOW);
generate_uniform_rand_func!(generate_u32, u32, generate_u32, XORWOW);
macro_rules! generate_normal_rand_func {
($fn_name: ident, $data_type:ty, $rng_type:ident, $dist:expr) => {
paste::paste! {
#[doc = "Generate normally distributed random " $data_type " values with specified mean and standard deviation"]
pub fn $fn_name(
count: usize,
mean: f32,
stddev: f32,
seed: Option<u64>,
) -> Result<DeviceMemory<$data_type>> {
let mut generator = PseudoRng::new(rng_type::$rng_type)?;
if let Some(seed_value) = seed {
generator.set_seed(seed_value)?;
}
generator.initialize()?;
let dist = $dist(mean, stddev);
let mut device_output = DeviceMemory::<f32>::new(count)?;
dist.generate(&mut generator, &mut device_output)?;
Ok(device_output)
}
}
};
}
generate_normal_rand_func!(generate_normal_f32, f32, PHILOX4_32_10, Normal::new);
generate_normal_rand_func!(generate_log_normal_f32, f32, PHILOX4_32_10, LogNormal::new);
pub fn generate_poisson(count: usize, lambda: f64, seed: Option<u64>) -> Result<DeviceMemory<u32>> {
let mut generator = PseudoRng::new(rng_type::MTGP32)?;
if let Some(seed_value) = seed {
generator.set_seed(seed_value)?;
}
generator.initialize()?;
let poisson_dist = Poisson::new(lambda);
let mut device_output = DeviceMemory::<u32>::new(count)?;
poisson_dist.generate(&mut generator, &mut device_output)?;
Ok(device_output)
}
pub fn generate_quasi_f32(count: usize, dimensions: u32) -> Result<DeviceMemory<f32>> {
let mut generator = QuasiRng::new(rng_type::SOBOL32)?;
generator.set_dimensions(dimensions)?;
generator.initialize()?;
let mut device_output = DeviceMemory::<f32>::new(count)?;
Uniform::generate_quasi(&mut generator, &mut device_output)?;
Ok(device_output)
}