use std::sync::Arc;
use oxicuda_driver::context::Context;
use oxicuda_driver::module::Module;
use oxicuda_driver::stream::Stream;
use oxicuda_launch::grid::grid_size_for;
use oxicuda_launch::kernel::Kernel;
use oxicuda_launch::params::LaunchParams;
use oxicuda_memory::DeviceBuffer;
use oxicuda_ptx::arch::SmVersion;
use oxicuda_ptx::ir::PtxType;
use crate::engines::{mrg32k3a, philox, philox_optimized, xorwow};
use crate::error::{RandError, RandResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum RngEngine {
Philox,
Xorwow,
Mrg32k3a,
}
impl std::fmt::Display for RngEngine {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Philox => write!(f, "Philox-4x32-10"),
Self::Xorwow => write!(f, "XORWOW"),
Self::Mrg32k3a => write!(f, "MRG32k3a"),
}
}
}
pub struct RngGenerator {
engine: RngEngine,
seed: u64,
offset: u64,
#[allow(dead_code)]
context: Arc<Context>,
stream: Stream,
sm_version: SmVersion,
}
impl RngGenerator {
pub fn new(engine: RngEngine, seed: u64, ctx: &Arc<Context>) -> RandResult<Self> {
let stream = Stream::new(ctx).map_err(RandError::Cuda)?;
Ok(Self {
engine,
seed,
offset: 0,
context: Arc::clone(ctx),
stream,
sm_version: SmVersion::Sm80,
})
}
pub fn set_seed(&mut self, seed: u64) {
self.seed = seed;
}
pub fn set_offset(&mut self, offset: u64) {
self.offset = offset;
}
pub fn skip(&mut self, n: u64) {
self.offset = self.offset.wrapping_add(n);
}
pub fn generate_uniform_f32(&mut self, output: &mut DeviceBuffer<f32>) -> RandResult<()> {
let n = output.len();
let ptx_source = self.get_uniform_ptx(PtxType::F32)?;
self.compile_and_launch_uniform(&ptx_source, PtxType::F32, output.as_device_ptr(), n)?;
self.offset += n as u64;
Ok(())
}
pub fn generate_uniform_f64(&mut self, output: &mut DeviceBuffer<f64>) -> RandResult<()> {
let n = output.len();
let ptx_source = self.get_uniform_ptx(PtxType::F64)?;
self.compile_and_launch_uniform(&ptx_source, PtxType::F64, output.as_device_ptr(), n)?;
self.offset += n as u64;
Ok(())
}
pub fn generate_uniform_f32_optimized(
&mut self,
output: &mut DeviceBuffer<f32>,
) -> RandResult<()> {
let n = output.len();
if self.engine != RngEngine::Philox || n < philox_optimized::OPTIMIZED_THRESHOLD {
return self.generate_uniform_f32(output);
}
let ptx_source =
philox_optimized::generate_philox_optimized_uniform_f32_ptx(self.sm_version)?;
self.compile_and_launch_uniform(&ptx_source, PtxType::F32, output.as_device_ptr(), n)?;
self.offset += n.div_ceil(4) as u64;
Ok(())
}
pub fn generate_normal_f32_optimized(
&mut self,
output: &mut DeviceBuffer<f32>,
mean: f32,
stddev: f32,
) -> RandResult<()> {
let n = output.len();
if self.engine != RngEngine::Philox || n < philox_optimized::OPTIMIZED_THRESHOLD {
return self.generate_normal_f32(output, mean, stddev);
}
let ptx_source =
philox_optimized::generate_philox_optimized_normal_f32_ptx(self.sm_version)?;
self.compile_and_launch_normal_f32(&ptx_source, output.as_device_ptr(), n, mean, stddev)?;
self.offset += n.div_ceil(4) as u64;
Ok(())
}
pub fn generate_normal_f32(
&mut self,
output: &mut DeviceBuffer<f32>,
mean: f32,
stddev: f32,
) -> RandResult<()> {
let n = output.len();
let ptx_source = self.get_normal_ptx(PtxType::F32)?;
self.compile_and_launch_normal_f32(&ptx_source, output.as_device_ptr(), n, mean, stddev)?;
self.offset += n as u64;
Ok(())
}
pub fn generate_normal_f64(
&mut self,
output: &mut DeviceBuffer<f64>,
mean: f64,
stddev: f64,
) -> RandResult<()> {
let n = output.len();
let ptx_source = self.get_normal_ptx(PtxType::F64)?;
self.compile_and_launch_normal_f64(&ptx_source, output.as_device_ptr(), n, mean, stddev)?;
self.offset += n as u64;
Ok(())
}
pub fn generate_log_normal_f32(
&mut self,
output: &mut DeviceBuffer<f32>,
mean: f32,
stddev: f32,
) -> RandResult<()> {
self.generate_normal_f32(output, mean, stddev)
}
pub fn generate_log_normal_f64(
&mut self,
output: &mut DeviceBuffer<f64>,
mean: f64,
stddev: f64,
) -> RandResult<()> {
self.generate_normal_f64(output, mean, stddev)
}
pub fn generate_poisson_f32(
&mut self,
output: &mut DeviceBuffer<f32>,
lambda: f64,
) -> RandResult<()> {
let _lambda_f32 = lambda as f32;
let _n = output.len();
self.generate_uniform_f32(output)
}
pub fn generate_u32(&mut self, output: &mut DeviceBuffer<u32>) -> RandResult<()> {
let n = output.len();
let ptx_source = self.get_u32_ptx()?;
let kernel_name = self.u32_kernel_name();
self.compile_and_launch_u32(&ptx_source, &kernel_name, output.as_device_ptr(), n)?;
self.offset += n as u64;
Ok(())
}
fn get_uniform_ptx(&self, precision: PtxType) -> RandResult<String> {
let ptx = match self.engine {
RngEngine::Philox => philox::generate_philox_uniform_ptx(precision, self.sm_version)?,
RngEngine::Xorwow => xorwow::generate_xorwow_uniform_ptx(precision, self.sm_version)?,
RngEngine::Mrg32k3a => {
mrg32k3a::generate_mrg32k3a_uniform_ptx(precision, self.sm_version)?
}
};
Ok(ptx)
}
fn get_normal_ptx(&self, precision: PtxType) -> RandResult<String> {
let ptx = match self.engine {
RngEngine::Philox => philox::generate_philox_normal_ptx(precision, self.sm_version)?,
RngEngine::Xorwow => xorwow::generate_xorwow_normal_ptx(precision, self.sm_version)?,
RngEngine::Mrg32k3a => {
mrg32k3a::generate_mrg32k3a_normal_ptx(precision, self.sm_version)?
}
};
Ok(ptx)
}
fn get_u32_ptx(&self) -> RandResult<String> {
let ptx = match self.engine {
RngEngine::Philox => philox::generate_philox_u32_ptx(self.sm_version)?,
RngEngine::Mrg32k3a => mrg32k3a::generate_mrg32k3a_u32_ptx(self.sm_version)?,
RngEngine::Xorwow => {
return Err(RandError::UnsupportedDistribution(
"u32 output is not supported for XORWOW engine".to_string(),
));
}
};
Ok(ptx)
}
fn uniform_kernel_name(&self, precision: PtxType) -> String {
let prec_str = match precision {
PtxType::F32 => "f32",
PtxType::F64 => "f64",
_ => "f32",
};
match self.engine {
RngEngine::Philox => format!("philox_uniform_{prec_str}"),
RngEngine::Xorwow => format!("xorwow_uniform_{prec_str}"),
RngEngine::Mrg32k3a => format!("mrg32k3a_uniform_{prec_str}"),
}
}
fn normal_kernel_name(&self, precision: PtxType) -> String {
let prec_str = match precision {
PtxType::F32 => "f32",
PtxType::F64 => "f64",
_ => "f32",
};
match self.engine {
RngEngine::Philox => format!("philox_normal_{prec_str}"),
RngEngine::Xorwow => format!("xorwow_normal_{prec_str}"),
RngEngine::Mrg32k3a => format!("mrg32k3a_normal_{prec_str}"),
}
}
fn u32_kernel_name(&self) -> String {
match self.engine {
RngEngine::Philox => "philox_u32".to_string(),
RngEngine::Mrg32k3a => "mrg32k3a_u32".to_string(),
RngEngine::Xorwow => "xorwow_u32".to_string(), }
}
fn compile_and_launch_uniform(
&self,
ptx_source: &str,
precision: PtxType,
out_ptr: u64,
n: usize,
) -> RandResult<()> {
let module = Arc::new(Module::from_ptx(ptx_source).map_err(RandError::Cuda)?);
let kernel_name = self.uniform_kernel_name(precision);
let kernel = Kernel::from_module(module, &kernel_name).map_err(RandError::Cuda)?;
let n_u32 = u32::try_from(n)
.map_err(|_| RandError::InvalidSize(format!("output size {n} exceeds u32::MAX")))?;
let grid = grid_size_for(n_u32, 256);
let params = LaunchParams::new(grid, 256u32);
let seed_lo = self.seed as u32;
let seed_hi = (self.seed >> 32) as u32;
let offset_lo = self.offset as u32;
let offset_hi = (self.offset >> 32) as u32;
match self.engine {
RngEngine::Philox => {
let args = (out_ptr, n_u32, seed_lo, seed_hi, offset_lo, offset_hi);
kernel
.launch(¶ms, &self.stream, &args)
.map_err(RandError::Cuda)?;
}
RngEngine::Xorwow | RngEngine::Mrg32k3a => {
let args = (out_ptr, n_u32, seed_lo, offset_lo, offset_hi);
kernel
.launch(¶ms, &self.stream, &args)
.map_err(RandError::Cuda)?;
}
}
self.stream.synchronize().map_err(RandError::Cuda)?;
Ok(())
}
fn compile_and_launch_normal_f32(
&self,
ptx_source: &str,
out_ptr: u64,
n: usize,
mean: f32,
stddev: f32,
) -> RandResult<()> {
let module = Arc::new(Module::from_ptx(ptx_source).map_err(RandError::Cuda)?);
let kernel_name = self.normal_kernel_name(PtxType::F32);
let kernel = Kernel::from_module(module, &kernel_name).map_err(RandError::Cuda)?;
let n_u32 = u32::try_from(n)
.map_err(|_| RandError::InvalidSize(format!("output size {n} exceeds u32::MAX")))?;
let grid = grid_size_for(n_u32, 256);
let params = LaunchParams::new(grid, 256u32);
let seed_lo = self.seed as u32;
let seed_hi = (self.seed >> 32) as u32;
let offset_lo = self.offset as u32;
let offset_hi = (self.offset >> 32) as u32;
match self.engine {
RngEngine::Philox => {
let args = (
out_ptr, n_u32, seed_lo, seed_hi, offset_lo, offset_hi, mean, stddev,
);
kernel
.launch(¶ms, &self.stream, &args)
.map_err(RandError::Cuda)?;
}
RngEngine::Xorwow | RngEngine::Mrg32k3a => {
let args = (out_ptr, n_u32, seed_lo, offset_lo, offset_hi, mean, stddev);
kernel
.launch(¶ms, &self.stream, &args)
.map_err(RandError::Cuda)?;
}
}
self.stream.synchronize().map_err(RandError::Cuda)?;
Ok(())
}
fn compile_and_launch_normal_f64(
&self,
ptx_source: &str,
out_ptr: u64,
n: usize,
mean: f64,
stddev: f64,
) -> RandResult<()> {
let module = Arc::new(Module::from_ptx(ptx_source).map_err(RandError::Cuda)?);
let kernel_name = self.normal_kernel_name(PtxType::F64);
let kernel = Kernel::from_module(module, &kernel_name).map_err(RandError::Cuda)?;
let n_u32 = u32::try_from(n)
.map_err(|_| RandError::InvalidSize(format!("output size {n} exceeds u32::MAX")))?;
let grid = grid_size_for(n_u32, 256);
let params = LaunchParams::new(grid, 256u32);
let seed_lo = self.seed as u32;
let seed_hi = (self.seed >> 32) as u32;
let offset_lo = self.offset as u32;
let offset_hi = (self.offset >> 32) as u32;
match self.engine {
RngEngine::Philox => {
let args = (
out_ptr, n_u32, seed_lo, seed_hi, offset_lo, offset_hi, mean, stddev,
);
kernel
.launch(¶ms, &self.stream, &args)
.map_err(RandError::Cuda)?;
}
RngEngine::Xorwow | RngEngine::Mrg32k3a => {
let args = (out_ptr, n_u32, seed_lo, offset_lo, offset_hi, mean, stddev);
kernel
.launch(¶ms, &self.stream, &args)
.map_err(RandError::Cuda)?;
}
}
self.stream.synchronize().map_err(RandError::Cuda)?;
Ok(())
}
fn compile_and_launch_u32(
&self,
ptx_source: &str,
kernel_name: &str,
out_ptr: u64,
n: usize,
) -> RandResult<()> {
let module = Arc::new(Module::from_ptx(ptx_source).map_err(RandError::Cuda)?);
let kernel = Kernel::from_module(module, kernel_name).map_err(RandError::Cuda)?;
let n_u32 = u32::try_from(n)
.map_err(|_| RandError::InvalidSize(format!("output size {n} exceeds u32::MAX")))?;
let grid = grid_size_for(n_u32, 256);
let params = LaunchParams::new(grid, 256u32);
let seed_lo = self.seed as u32;
let seed_hi = (self.seed >> 32) as u32;
let offset_lo = self.offset as u32;
let offset_hi = (self.offset >> 32) as u32;
match self.engine {
RngEngine::Philox => {
let args = (out_ptr, n_u32, seed_lo, seed_hi, offset_lo, offset_hi);
kernel
.launch(¶ms, &self.stream, &args)
.map_err(RandError::Cuda)?;
}
RngEngine::Mrg32k3a => {
let args = (out_ptr, n_u32, seed_lo, offset_lo, offset_hi);
kernel
.launch(¶ms, &self.stream, &args)
.map_err(RandError::Cuda)?;
}
RngEngine::Xorwow => {
return Err(RandError::UnsupportedDistribution(
"u32 not supported for XORWOW".to_string(),
));
}
}
self.stream.synchronize().map_err(RandError::Cuda)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn engine_display() {
assert_eq!(format!("{}", RngEngine::Philox), "Philox-4x32-10");
assert_eq!(format!("{}", RngEngine::Xorwow), "XORWOW");
assert_eq!(format!("{}", RngEngine::Mrg32k3a), "MRG32k3a");
}
#[test]
fn uniform_kernel_names() {
let expected_philox_f32 = "philox_uniform_f32";
let expected_xorwow_f64 = "xorwow_uniform_f64";
let expected_mrg_f32 = "mrg32k3a_uniform_f32";
assert_eq!(expected_philox_f32, "philox_uniform_f32");
assert_eq!(expected_xorwow_f64, "xorwow_uniform_f64");
assert_eq!(expected_mrg_f32, "mrg32k3a_uniform_f32");
}
#[test]
fn ptx_generation_philox_uniform() {
let ptx = philox::generate_philox_uniform_ptx(PtxType::F32, SmVersion::Sm80);
assert!(ptx.is_ok());
}
#[test]
fn ptx_generation_xorwow_uniform() {
let ptx = xorwow::generate_xorwow_uniform_ptx(PtxType::F32, SmVersion::Sm80);
assert!(ptx.is_ok());
}
#[test]
fn ptx_generation_mrg32k3a_uniform() {
let ptx = mrg32k3a::generate_mrg32k3a_uniform_ptx(PtxType::F32, SmVersion::Sm80);
assert!(ptx.is_ok());
}
}