use crate::error::RngError;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum RngEngineKind {
Philox,
Xorwow,
Mrg32k3a,
}
impl RngEngineKind {
pub fn as_str(self) -> &'static str {
match self {
Self::Philox => "philox",
Self::Xorwow => "xorwow",
Self::Mrg32k3a => "mrg32k3a",
}
}
}
impl std::fmt::Display for RngEngineKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
#[cfg(feature = "gpu")]
fn gpu_available() -> bool {
false }
#[cfg(feature = "cpu")]
struct CpuRngState {
state: u64,
inc: u64,
}
#[cfg(feature = "cpu")]
impl CpuRngState {
const PCG_MULT: u64 = 6_364_136_223_846_793_005_u64;
fn new(seed: u64) -> Self {
let inc = seed.wrapping_shl(1) | 1; let mut s = Self { state: 0, inc };
let _ = s.next_u32(); s.state = s.state.wrapping_add(seed);
let _ = s.next_u32(); s
}
#[inline]
fn next_u32(&mut self) -> u32 {
let old = self.state;
self.state = old.wrapping_mul(Self::PCG_MULT).wrapping_add(self.inc);
let xorshifted = (((old >> 18) ^ old) >> 27) as u32;
let rot = (old >> 59) as u32;
xorshifted.rotate_right(rot)
}
#[inline]
fn next_f32(&mut self) -> f32 {
let bits = (self.next_u32() >> 9) | 0x3f80_0000_u32;
f32::from_bits(bits) - 1.0_f32
}
#[inline]
fn next_normal_pair(&mut self) -> (f32, f32) {
let u1 = {
let raw = self.next_f32();
if raw < f32::EPSILON {
f32::EPSILON
} else {
raw
}
};
let u2 = self.next_f32();
let r = (-2.0_f32 * u1.ln()).sqrt();
let theta = std::f32::consts::TAU * u2; (r * theta.cos(), r * theta.sin())
}
#[inline]
fn next_u64(&mut self) -> u64 {
let hi = self.next_u32() as u64;
let lo = self.next_u32() as u64;
(hi << 32) | lo
}
#[inline]
fn next_f64(&mut self) -> f64 {
let bits = self.next_u64();
f64::from_bits(0x3FF0_0000_0000_0000_u64 | (bits >> 12)) - 1.0_f64
}
#[inline]
fn next_normal_pair_f64(&mut self) -> (f64, f64) {
let u1 = self.next_f64();
let u2 = self.next_f64();
let safe_u1 = if u1 >= 1.0 { f64::EPSILON } else { 1.0 - u1 };
let r = (-2.0_f64 * safe_u1.ln()).sqrt();
let theta = std::f64::consts::TAU * u2; (r * theta.cos(), r * theta.sin())
}
}
#[cfg(feature = "gpu")]
use oxicuda_rand::generator::{RngEngine as OxiRngEngine, RngGenerator};
#[cfg(feature = "gpu")]
use std::sync::Arc;
enum RngEngineInner {
#[cfg(feature = "cpu")]
Cpu(CpuRngState),
#[cfg(feature = "gpu")]
Gpu(GpuRngState),
}
#[cfg(feature = "gpu")]
struct GpuRngState {
generator: RngGenerator,
}
pub struct RngEngine {
kind: RngEngineKind,
inner: RngEngineInner,
#[cfg(feature = "gpu")]
_not_sync: std::marker::PhantomData<*const ()>,
}
unsafe impl Send for RngEngine {}
impl RngEngine {
pub fn new(kind: RngEngineKind, seed: u64) -> Result<Self, RngError> {
#[cfg(feature = "gpu")]
if gpu_available() {
return Self::new_gpu(kind, seed);
}
#[cfg(feature = "cpu")]
{
Ok(Self {
kind,
inner: RngEngineInner::Cpu(CpuRngState::new(seed)),
#[cfg(feature = "gpu")]
_not_sync: std::marker::PhantomData,
})
}
#[cfg(not(any(feature = "cpu", feature = "gpu")))]
Err(RngError::GpuError(
"no backend compiled: enable the `cpu` or `gpu` feature".to_string(),
))
}
#[cfg(feature = "gpu")]
fn new_gpu(kind: RngEngineKind, seed: u64) -> Result<Self, RngError> {
use oxicuda_driver::{context::Context, Device};
oxicuda_driver::init().map_err(|e| RngError::GpuError(e.to_string()))?;
let device = Device::get(0).map_err(|e| RngError::GpuError(e.to_string()))?;
let ctx = Arc::new(Context::new(&device).map_err(|e| RngError::GpuError(e.to_string()))?);
let oxi_kind = match kind {
RngEngineKind::Philox => OxiRngEngine::Philox,
RngEngineKind::Xorwow => OxiRngEngine::Xorwow,
RngEngineKind::Mrg32k3a => OxiRngEngine::Mrg32k3a,
};
let generator = RngGenerator::new(oxi_kind, seed, &ctx)
.map_err(|e| RngError::GpuError(e.to_string()))?;
Ok(Self {
kind,
inner: RngEngineInner::Gpu(GpuRngState { generator }),
_not_sync: std::marker::PhantomData::<*const ()>,
})
}
#[inline]
pub fn kind(&self) -> RngEngineKind {
self.kind
}
pub fn is_gpu(&self) -> bool {
match &self.inner {
#[cfg(feature = "cpu")]
RngEngineInner::Cpu(_) => false,
#[cfg(feature = "gpu")]
RngEngineInner::Gpu(_) => true,
}
}
pub fn uniform_f32(&mut self, out: &mut [f32]) -> Result<(), RngError> {
if out.is_empty() {
return Err(RngError::EmptyBuffer);
}
match &mut self.inner {
#[cfg(feature = "cpu")]
RngEngineInner::Cpu(state) => {
for slot in out.iter_mut() {
*slot = state.next_f32();
}
Ok(())
}
#[cfg(feature = "gpu")]
RngEngineInner::Gpu(gs) => {
use oxicuda_memory::DeviceBuffer;
let n = out.len();
let mut dev_buf =
DeviceBuffer::<f32>::alloc(n).map_err(|e| RngError::GpuError(e.to_string()))?;
gs.generator
.generate_uniform_f32(&mut dev_buf)
.map_err(|e| RngError::GpuError(e.to_string()))?;
dev_buf
.copy_to_host(out)
.map_err(|e| RngError::GpuError(e.to_string()))?;
Ok(())
}
}
}
pub fn normal_f32(&mut self, out: &mut [f32], mean: f32, std_dev: f32) -> Result<(), RngError> {
if out.is_empty() {
return Err(RngError::EmptyBuffer);
}
if !std_dev.is_finite() || std_dev < 0.0 {
return Err(RngError::InvalidParam(format!(
"std_dev must be finite and >= 0, got {std_dev}"
)));
}
if !mean.is_finite() {
return Err(RngError::InvalidParam(format!(
"mean must be finite, got {mean}"
)));
}
match &mut self.inner {
#[cfg(feature = "cpu")]
RngEngineInner::Cpu(state) => {
let n = out.len();
let mut i = 0usize;
while i + 1 < n {
let (z0, z1) = state.next_normal_pair();
out[i] = mean + std_dev * z0;
out[i + 1] = mean + std_dev * z1;
i += 2;
}
if i < n {
let (z0, _) = state.next_normal_pair();
out[i] = mean + std_dev * z0;
}
Ok(())
}
#[cfg(feature = "gpu")]
RngEngineInner::Gpu(gs) => {
use oxicuda_memory::DeviceBuffer;
let n = out.len();
let mut dev_buf =
DeviceBuffer::<f32>::alloc(n).map_err(|e| RngError::GpuError(e.to_string()))?;
gs.generator
.generate_normal_f32(&mut dev_buf, mean, std_dev)
.map_err(|e| RngError::GpuError(e.to_string()))?;
dev_buf
.copy_to_host(out)
.map_err(|e| RngError::GpuError(e.to_string()))?;
Ok(())
}
}
}
pub fn bernoulli(&mut self, out: &mut [u8], p: f32) -> Result<(), RngError> {
if out.is_empty() {
return Err(RngError::EmptyBuffer);
}
if !p.is_finite() || !(0.0..=1.0).contains(&p) {
return Err(RngError::InvalidParam(format!(
"p must be in [0.0, 1.0], got {p}"
)));
}
match &mut self.inner {
#[cfg(feature = "cpu")]
RngEngineInner::Cpu(state) => {
for slot in out.iter_mut() {
*slot = u8::from(state.next_f32() < p);
}
Ok(())
}
#[cfg(feature = "gpu")]
RngEngineInner::Gpu(gs) => {
use oxicuda_memory::DeviceBuffer;
let n = out.len();
let mut dev_buf =
DeviceBuffer::<f32>::alloc(n).map_err(|e| RngError::GpuError(e.to_string()))?;
gs.generator
.generate_uniform_f32(&mut dev_buf)
.map_err(|e| RngError::GpuError(e.to_string()))?;
let mut host_buf = vec![0f32; n];
dev_buf
.copy_to_host(&mut host_buf)
.map_err(|e| RngError::GpuError(e.to_string()))?;
for (slot, &u) in out.iter_mut().zip(host_buf.iter()) {
*slot = u8::from(u < p);
}
Ok(())
}
}
}
pub fn uniform_f64(&mut self, out: &mut [f64]) -> Result<(), RngError> {
if out.is_empty() {
return Err(RngError::EmptyBuffer);
}
match &mut self.inner {
#[cfg(feature = "cpu")]
RngEngineInner::Cpu(state) => {
for slot in out.iter_mut() {
*slot = state.next_f64();
}
Ok(())
}
#[cfg(feature = "gpu")]
RngEngineInner::Gpu(_gs) => {
Err(RngError::GpuError(
"uniform_f64 on GPU path not yet implemented".to_string(),
))
}
}
}
pub fn normal_f64(&mut self, out: &mut [f64], mean: f64, std_dev: f64) -> Result<(), RngError> {
if out.is_empty() {
return Err(RngError::EmptyBuffer);
}
if !std_dev.is_finite() || std_dev < 0.0 {
return Err(RngError::InvalidParam(format!(
"std_dev must be finite and >= 0, got {std_dev}"
)));
}
if !mean.is_finite() {
return Err(RngError::InvalidParam(format!(
"mean must be finite, got {mean}"
)));
}
match &mut self.inner {
#[cfg(feature = "cpu")]
RngEngineInner::Cpu(state) => {
let n = out.len();
let mut i = 0usize;
while i + 1 < n {
let (z0, z1) = state.next_normal_pair_f64();
out[i] = mean + std_dev * z0;
out[i + 1] = mean + std_dev * z1;
i += 2;
}
if i < n {
let (z0, _) = state.next_normal_pair_f64();
out[i] = mean + std_dev * z0;
}
Ok(())
}
#[cfg(feature = "gpu")]
RngEngineInner::Gpu(_gs) => Err(RngError::GpuError(
"normal_f64 on GPU path not yet implemented".to_string(),
)),
}
}
pub fn fill_uniform_chunked<F: FnMut(&[f32])>(
&mut self,
total: usize,
chunk_size: usize,
consumer: &mut F,
) -> Result<(), RngError> {
if total == 0 || chunk_size == 0 {
return Err(RngError::EmptyBuffer);
}
let mut buf = vec![0f32; chunk_size];
let mut remaining = total;
while remaining > 0 {
let n = remaining.min(chunk_size);
self.uniform_f32(&mut buf[..n])?;
consumer(&buf[..n]);
remaining -= n;
}
Ok(())
}
pub fn fill_uniform_chunked_f64<F: FnMut(&[f64])>(
&mut self,
total: usize,
chunk_size: usize,
consumer: &mut F,
) -> Result<(), RngError> {
if total == 0 || chunk_size == 0 {
return Err(RngError::EmptyBuffer);
}
let mut buf = vec![0f64; chunk_size];
let mut remaining = total;
while remaining > 0 {
let n = remaining.min(chunk_size);
self.uniform_f64(&mut buf[..n])?;
consumer(&buf[..n]);
remaining -= n;
}
Ok(())
}
pub fn fill_normal_chunked<F: FnMut(&[f32])>(
&mut self,
total: usize,
chunk_size: usize,
mean: f32,
std_dev: f32,
consumer: &mut F,
) -> Result<(), RngError> {
if total == 0 || chunk_size == 0 {
return Err(RngError::EmptyBuffer);
}
if !std_dev.is_finite() || std_dev < 0.0 {
return Err(RngError::InvalidParam(format!(
"std_dev must be finite and >= 0, got {std_dev}"
)));
}
if !mean.is_finite() {
return Err(RngError::InvalidParam(format!(
"mean must be finite, got {mean}"
)));
}
let mut buf = vec![0f32; chunk_size];
let mut remaining = total;
while remaining > 0 {
let n = remaining.min(chunk_size);
self.normal_f32(&mut buf[..n], mean, std_dev)?;
consumer(&buf[..n]);
remaining -= n;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[cfg(feature = "cpu")]
fn pcg_inc_is_odd() {
for seed in [0u64, 1, 42, u64::MAX, u64::MAX / 2] {
let state = CpuRngState::new(seed);
assert_eq!(state.inc & 1, 1, "inc must be odd for seed={seed}");
}
}
#[test]
#[cfg(feature = "cpu")]
fn pcg_uniform_in_range() {
let mut state = CpuRngState::new(12345);
for _ in 0..10_000 {
let v = state.next_f32();
assert!(
(0.0..1.0).contains(&v),
"uniform sample {v} not in [0.0, 1.0)"
);
}
}
#[test]
#[cfg(feature = "cpu")]
fn pcg_deterministic_replay() {
let mut a = CpuRngState::new(777);
let mut b = CpuRngState::new(777);
for _ in 0..1000 {
assert_eq!(a.next_u32(), b.next_u32());
}
}
#[test]
#[cfg(feature = "cpu")]
fn pcg_different_seeds_differ() {
let mut a = CpuRngState::new(0);
let mut b = CpuRngState::new(1);
let outputs_a: Vec<u32> = (0..100).map(|_| a.next_u32()).collect();
let outputs_b: Vec<u32> = (0..100).map(|_| b.next_u32()).collect();
assert_ne!(
outputs_a, outputs_b,
"different seeds should produce different sequences"
);
}
#[test]
#[cfg(feature = "cpu")]
fn box_muller_pair_is_finite() {
let mut state = CpuRngState::new(42);
for _ in 0..10_000 {
let (z0, z1) = state.next_normal_pair();
assert!(z0.is_finite(), "z0 is not finite: {z0}");
assert!(z1.is_finite(), "z1 is not finite: {z1}");
}
}
#[test]
fn engine_kind_as_str() {
assert_eq!(RngEngineKind::Philox.as_str(), "philox");
assert_eq!(RngEngineKind::Xorwow.as_str(), "xorwow");
assert_eq!(RngEngineKind::Mrg32k3a.as_str(), "mrg32k3a");
}
#[test]
fn engine_kind_display() {
assert_eq!(format!("{}", RngEngineKind::Philox), "philox");
assert_eq!(format!("{}", RngEngineKind::Xorwow), "xorwow");
assert_eq!(format!("{}", RngEngineKind::Mrg32k3a), "mrg32k3a");
}
#[test]
fn engine_new_returns_ok() {
for kind in [
RngEngineKind::Philox,
RngEngineKind::Xorwow,
RngEngineKind::Mrg32k3a,
] {
assert!(
RngEngine::new(kind, 0).is_ok(),
"construction failed for {kind}"
);
}
}
#[test]
fn engine_kind_accessor() {
let eng = RngEngine::new(RngEngineKind::Mrg32k3a, 1).unwrap();
assert_eq!(eng.kind(), RngEngineKind::Mrg32k3a);
}
#[test]
fn engine_is_not_gpu_in_ci() {
let eng = RngEngine::new(RngEngineKind::Philox, 42).unwrap();
assert!(!eng.is_gpu());
}
#[test]
fn uniform_empty_buffer_error() {
let mut eng = RngEngine::new(RngEngineKind::Philox, 0).unwrap();
let mut out: Vec<f32> = vec![];
assert!(matches!(
eng.uniform_f32(&mut out),
Err(RngError::EmptyBuffer)
));
}
#[test]
fn normal_empty_buffer_error() {
let mut eng = RngEngine::new(RngEngineKind::Philox, 0).unwrap();
let mut out: Vec<f32> = vec![];
assert!(matches!(
eng.normal_f32(&mut out, 0.0, 1.0),
Err(RngError::EmptyBuffer)
));
}
#[test]
fn bernoulli_empty_buffer_error() {
let mut eng = RngEngine::new(RngEngineKind::Philox, 0).unwrap();
let mut out: Vec<u8> = vec![];
assert!(matches!(
eng.bernoulli(&mut out, 0.5),
Err(RngError::EmptyBuffer)
));
}
#[test]
fn normal_negative_stddev_error() {
let mut eng = RngEngine::new(RngEngineKind::Philox, 0).unwrap();
let mut out = vec![0f32; 10];
assert!(matches!(
eng.normal_f32(&mut out, 0.0, -1.0),
Err(RngError::InvalidParam(_))
));
}
#[test]
fn normal_nan_mean_error() {
let mut eng = RngEngine::new(RngEngineKind::Philox, 0).unwrap();
let mut out = vec![0f32; 10];
assert!(matches!(
eng.normal_f32(&mut out, f32::NAN, 1.0),
Err(RngError::InvalidParam(_))
));
}
#[test]
fn bernoulli_invalid_p_error() {
let mut eng = RngEngine::new(RngEngineKind::Philox, 0).unwrap();
let mut out = vec![0u8; 10];
assert!(matches!(
eng.bernoulli(&mut out, -0.1),
Err(RngError::InvalidParam(_))
));
assert!(matches!(
eng.bernoulli(&mut out, 1.1),
Err(RngError::InvalidParam(_))
));
assert!(matches!(
eng.bernoulli(&mut out, f32::NAN),
Err(RngError::InvalidParam(_))
));
}
#[test]
fn uniform_in_range() {
let mut eng = RngEngine::new(RngEngineKind::Philox, 42).unwrap();
let mut out = vec![0f32; 1_000];
eng.uniform_f32(&mut out).unwrap();
for &v in &out {
assert!((0.0..1.0).contains(&v), "uniform sample {v} out of [0,1)");
}
}
#[test]
fn normal_odd_length_fills_all_elements() {
let mut eng = RngEngine::new(RngEngineKind::Xorwow, 99).unwrap();
let mut out = vec![f32::NAN; 7]; eng.normal_f32(&mut out, 0.0, 1.0).unwrap();
for (i, &v) in out.iter().enumerate() {
assert!(v.is_finite(), "element {i} is not finite: {v}");
}
}
#[test]
fn bernoulli_outputs_only_zero_or_one() {
let mut eng = RngEngine::new(RngEngineKind::Mrg32k3a, 555).unwrap();
let mut out = vec![255u8; 1_000];
eng.bernoulli(&mut out, 0.5).unwrap();
for &b in &out {
assert!(b == 0 || b == 1, "bernoulli output {b} is not 0 or 1");
}
}
#[test]
fn bernoulli_p_zero_produces_all_zeros() {
let mut eng = RngEngine::new(RngEngineKind::Philox, 1).unwrap();
let mut out = vec![1u8; 500];
eng.bernoulli(&mut out, 0.0).unwrap();
assert!(out.iter().all(|&b| b == 0));
}
#[test]
fn bernoulli_p_one_produces_all_ones() {
let mut eng = RngEngine::new(RngEngineKind::Philox, 2).unwrap();
let mut out = vec![0u8; 500];
eng.bernoulli(&mut out, 1.0).unwrap();
assert!(out.iter().all(|&b| b == 1));
}
}
#[cfg(not(feature = "gpu"))]
mod send_sync_assertions {
use super::RngEngine;
fn _assert_send<T: Send>() {}
fn _assert_sync<T: Sync>() {}
fn _check_rng_engine_send_sync() {
_assert_send::<RngEngine>();
_assert_sync::<RngEngine>();
}
}