use ferray_core::{Array, FerrayError, IxDyn};
use crate::bitgen::{BitGenerator, Xoshiro256StarStar};
pub struct Generator<B: BitGenerator = Xoshiro256StarStar> {
pub(crate) bg: B,
pub(crate) seed: u64,
}
impl<B: BitGenerator> Generator<B> {
pub const fn new(bg: B) -> Self {
Self { bg, seed: 0 }
}
pub(crate) const fn new_with_seed(bg: B, seed: u64) -> Self {
Self { bg, seed }
}
#[inline]
pub const fn bit_generator(&mut self) -> &mut B {
&mut self.bg
}
#[inline]
pub fn next_u64(&mut self) -> u64 {
self.bg.next_u64()
}
#[inline]
pub fn next_f64(&mut self) -> f64 {
self.bg.next_f64()
}
#[inline]
pub fn next_f32(&mut self) -> f32 {
self.bg.next_f32()
}
#[inline]
pub fn next_u64_bounded(&mut self, bound: u64) -> u64 {
self.bg.next_u64_bounded(bound)
}
pub fn bytes(&mut self, n: usize) -> Vec<u8> {
let mut out = Vec::with_capacity(n);
let full_words = n / 8;
for _ in 0..full_words {
out.extend_from_slice(&self.bg.next_u64().to_le_bytes());
}
let remainder = n % 8;
if remainder > 0 {
let bytes = self.bg.next_u64().to_le_bytes();
out.extend_from_slice(&bytes[..remainder]);
}
out
}
}
#[must_use]
pub fn default_rng() -> Generator<Xoshiro256StarStar> {
let seed = {
let mut buf = [0u8; 8];
if getrandom::fill(&mut buf).is_ok() {
u64::from_ne_bytes(buf)
} else {
use std::time::SystemTime;
let dur = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default();
let nanos = dur.as_nanos();
let mut s = nanos as u64;
s ^= (nanos >> 64) as u64;
let stack_var: u8 = 0;
s ^= &raw const stack_var as u64;
s
}
};
default_rng_seeded(seed)
}
#[must_use]
pub fn default_rng_seeded(seed: u64) -> Generator<Xoshiro256StarStar> {
let bg = Xoshiro256StarStar::seed_from_u64(seed);
Generator::new_with_seed(bg, seed)
}
pub fn spawn_generators<B: BitGenerator + Clone>(
parent: &mut Generator<B>,
n: usize,
) -> Result<Vec<Generator<B>>, FerrayError> {
if n == 0 {
return Err(FerrayError::invalid_value("spawn count must be > 0"));
}
let mut children = Vec::with_capacity(n);
let mut test_bg = parent.bg.clone();
if test_bg.jump().is_some() {
let mut current = parent.bg.clone();
for _ in 0..n {
children.push(Generator::new(current.clone()));
current.jump();
}
parent.bg = current;
return Ok(children);
}
if let Some(first) = B::stream(parent.seed, 0) {
drop(first);
for i in 0..n {
if let Some(bg) = B::stream(parent.seed, i as u64) {
children.push(Generator::new(bg));
}
}
if children.len() == n {
return Ok(children);
}
children.clear();
}
for _ in 0..n {
let child_seed = parent.bg.next_u64();
let bg = B::seed_from_u64(child_seed);
children.push(Generator::new(bg));
}
Ok(children)
}
pub(crate) fn generate_vec<B: BitGenerator>(
rng: &mut Generator<B>,
size: usize,
mut f: impl FnMut(&mut B) -> f64,
) -> Vec<f64> {
let mut data = Vec::with_capacity(size);
for _ in 0..size {
data.push(f(&mut rng.bg));
}
data
}
pub(crate) fn generate_vec_f32<B: BitGenerator>(
rng: &mut Generator<B>,
size: usize,
mut f: impl FnMut(&mut B) -> f32,
) -> Vec<f32> {
let mut data = Vec::with_capacity(size);
for _ in 0..size {
data.push(f(&mut rng.bg));
}
data
}
pub(crate) fn generate_vec_i64<B: BitGenerator>(
rng: &mut Generator<B>,
size: usize,
mut f: impl FnMut(&mut B) -> i64,
) -> Vec<i64> {
let mut data = Vec::with_capacity(size);
for _ in 0..size {
data.push(f(&mut rng.bg));
}
data
}
#[inline]
pub(crate) fn shape_size(shape: &[usize]) -> usize {
if shape.is_empty() {
0
} else {
shape.iter().product()
}
}
pub(crate) fn vec_to_array_f64(
data: Vec<f64>,
shape: &[usize],
) -> Result<Array<f64, IxDyn>, FerrayError> {
Array::<f64, IxDyn>::from_vec(IxDyn::new(shape), data)
}
pub(crate) fn vec_to_array_f32(
data: Vec<f32>,
shape: &[usize],
) -> Result<Array<f32, IxDyn>, FerrayError> {
Array::<f32, IxDyn>::from_vec(IxDyn::new(shape), data)
}
pub(crate) fn vec_to_array_i64(
data: Vec<i64>,
shape: &[usize],
) -> Result<Array<i64, IxDyn>, FerrayError> {
Array::<i64, IxDyn>::from_vec(IxDyn::new(shape), data)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_rng_seeded_deterministic() {
let mut rng1 = default_rng_seeded(42);
let mut rng2 = default_rng_seeded(42);
for _ in 0..100 {
assert_eq!(rng1.next_u64(), rng2.next_u64());
}
}
#[test]
fn default_rng_works() {
let mut rng = default_rng();
let v = rng.next_f64();
assert!((0.0..1.0).contains(&v));
}
#[test]
fn spawn_xoshiro() {
let mut parent = default_rng_seeded(42);
let children = spawn_generators(&mut parent, 4).unwrap();
assert_eq!(children.len(), 4);
}
#[test]
fn spawn_zero_is_error() {
let mut parent = default_rng_seeded(42);
assert!(spawn_generators(&mut parent, 0).is_err());
}
#[test]
fn bytes_length_zero() {
let mut rng = default_rng_seeded(42);
assert!(rng.bytes(0).is_empty());
}
#[test]
fn bytes_length_full_word() {
let mut rng = default_rng_seeded(42);
let b = rng.bytes(8);
assert_eq!(b.len(), 8);
}
#[test]
fn bytes_length_partial_word() {
let mut rng = default_rng_seeded(42);
let b = rng.bytes(13);
assert_eq!(b.len(), 13);
}
#[test]
fn bytes_deterministic_for_same_seed() {
let mut rng1 = default_rng_seeded(42);
let mut rng2 = default_rng_seeded(42);
assert_eq!(rng1.bytes(64), rng2.bytes(64));
}
}