use crate::hip::DeviceMemory;
use crate::rocrand::bindings;
use crate::rocrand::error::{Error, Result};
use crate::rocrand::generator::{PseudoRng, QuasiRng};
use std::ptr::NonNull;
pub struct Uniform;
impl Uniform {
pub fn generate(generator: &mut PseudoRng, output: &mut DeviceMemory<f32>) -> Result<()> {
generator.generate_uniform(output)
}
pub fn generate_double(
generator: &mut PseudoRng,
output: &mut DeviceMemory<f64>,
) -> Result<()> {
generator.generate_uniform_double(output)
}
pub fn generate_quasi(generator: &mut QuasiRng, output: &mut DeviceMemory<f32>) -> Result<()> {
generator.generate_uniform(output)
}
pub fn generate_quasi_double(
generator: &mut QuasiRng,
output: &mut DeviceMemory<f64>,
) -> Result<()> {
generator.generate_uniform_double(output)
}
}
pub struct Normal {
mean: f32,
stddev: f32,
}
impl Normal {
pub fn new(mean: f32, stddev: f32) -> Self {
Self { mean, stddev }
}
pub fn generate(
&self,
generator: &mut PseudoRng,
output: &mut DeviceMemory<f32>,
) -> Result<()> {
generator.generate_normal(output, self.mean, self.stddev)
}
}
pub struct NormalDouble {
mean: f64,
stddev: f64,
}
impl NormalDouble {
pub fn new(mean: f64, stddev: f64) -> Self {
Self { mean, stddev }
}
pub fn generate(
&self,
generator: &mut PseudoRng,
output: &mut DeviceMemory<f64>,
) -> Result<()> {
generator.generate_normal_double(output, self.mean, self.stddev)
}
}
pub struct LogNormal {
mean: f32,
stddev: f32,
}
impl LogNormal {
pub fn new(mean: f32, stddev: f32) -> Self {
Self { mean, stddev }
}
pub fn generate(
&self,
generator: &mut PseudoRng,
output: &mut DeviceMemory<f32>,
) -> Result<()> {
generator.generate_log_normal(output, self.mean, self.stddev)
}
}
pub struct LogNormalDouble {
mean: f64,
stddev: f64,
}
impl LogNormalDouble {
pub fn new(mean: f64, stddev: f64) -> Self {
Self { mean, stddev }
}
pub fn generate(
&self,
generator: &mut PseudoRng,
output: &mut DeviceMemory<f64>,
) -> Result<()> {
generator.generate_log_normal_double(output, self.mean, self.stddev)
}
}
pub struct Poisson {
lambda: f64,
}
impl Poisson {
pub fn new(lambda: f64) -> Self {
Self { lambda }
}
pub fn generate(
&self,
generator: &mut PseudoRng,
output: &mut DeviceMemory<u32>,
) -> Result<()> {
generator.generate_poisson(output, self.lambda)
}
}
pub struct Discrete {
distribution: NonNull<bindings::rocrand_discrete_distribution_st>,
}
impl Discrete {
pub fn from_probabilities(probabilities: &[f64], offset: u32) -> Result<Self> {
let mut distribution = std::ptr::null_mut();
unsafe {
Error::from_status(bindings::rocrand_create_discrete_distribution(
probabilities.as_ptr(),
probabilities.len() as u32,
offset,
&mut distribution,
))?;
Ok(Self {
distribution: NonNull::new(distribution).unwrap(),
})
}
}
pub fn poisson(lambda: f64) -> Result<Self> {
let mut distribution = std::ptr::null_mut();
unsafe {
Error::from_status(bindings::rocrand_create_poisson_distribution(
lambda,
&mut distribution,
))?;
Ok(Self {
distribution: NonNull::new(distribution).unwrap(),
})
}
}
pub fn as_ptr(&self) -> bindings::rocrand_discrete_distribution {
self.distribution.as_ptr()
}
}
impl Drop for Discrete {
fn drop(&mut self) {
unsafe {
let _ = bindings::rocrand_destroy_discrete_distribution(self.distribution.as_ptr());
}
}
}