use ferray_core::{Array, FerrayError, IxDyn};
use crate::bitgen::BitGenerator;
use crate::generator::{
Generator, generate_vec, generate_vec_f32, generate_vec_i64, shape_size, vec_to_array_f32,
vec_to_array_f64, vec_to_array_i64,
};
use crate::shape::IntoShape;
impl<B: BitGenerator> Generator<B> {
pub fn random(&mut self, size: impl IntoShape) -> Result<Array<f64, IxDyn>, FerrayError> {
let shape = size.into_shape()?;
let n = shape_size(&shape);
let data = generate_vec(self, n, super::super::bitgen::BitGenerator::next_f64);
vec_to_array_f64(data, &shape)
}
pub fn uniform(
&mut self,
low: f64,
high: f64,
size: impl IntoShape,
) -> Result<Array<f64, IxDyn>, FerrayError> {
if low >= high {
return Err(FerrayError::invalid_value(format!(
"low ({low}) must be less than high ({high})"
)));
}
let shape = size.into_shape()?;
let n = shape_size(&shape);
let range = high - low;
let data = generate_vec(self, n, |bg| bg.next_f64().mul_add(range, low));
vec_to_array_f64(data, &shape)
}
pub fn random_f32(&mut self, size: impl IntoShape) -> Result<Array<f32, IxDyn>, FerrayError> {
let shape = size.into_shape()?;
let n = shape_size(&shape);
let data = generate_vec_f32(self, n, super::super::bitgen::BitGenerator::next_f32);
vec_to_array_f32(data, &shape)
}
pub fn uniform_f32(
&mut self,
low: f32,
high: f32,
size: impl IntoShape,
) -> Result<Array<f32, IxDyn>, FerrayError> {
if low >= high {
return Err(FerrayError::invalid_value(format!(
"low ({low}) must be less than high ({high})"
)));
}
let shape = size.into_shape()?;
let n = shape_size(&shape);
let range = high - low;
let data = generate_vec_f32(self, n, |bg| bg.next_f32().mul_add(range, low));
vec_to_array_f32(data, &shape)
}
pub fn integers(
&mut self,
low: i64,
high: i64,
size: impl IntoShape,
) -> Result<Array<i64, IxDyn>, FerrayError> {
if low >= high {
return Err(FerrayError::invalid_value(format!(
"low ({low}) must be less than high ({high})"
)));
}
let shape = size.into_shape()?;
let n = shape_size(&shape);
let range = (high - low) as u64;
let data = generate_vec_i64(self, n, |bg| low + bg.next_u64_bounded(range) as i64);
vec_to_array_i64(data, &shape)
}
}
#[cfg(test)]
mod tests {
use crate::default_rng_seeded;
#[test]
fn random_in_range() {
let mut rng = default_rng_seeded(42);
let arr = rng.random(10_000).unwrap();
let slice = arr.as_slice().unwrap();
for &v in slice {
assert!((0.0..1.0).contains(&v));
}
}
#[test]
fn random_deterministic() {
let mut rng1 = default_rng_seeded(42);
let mut rng2 = default_rng_seeded(42);
let a = rng1.random(100).unwrap();
let b = rng2.random(100).unwrap();
assert_eq!(a.as_slice().unwrap(), b.as_slice().unwrap());
}
#[test]
fn uniform_in_range() {
let mut rng = default_rng_seeded(42);
let arr = rng.uniform(5.0, 10.0, 10_000).unwrap();
let slice = arr.as_slice().unwrap();
for &v in slice {
assert!((5.0..10.0).contains(&v), "value {v} out of range");
}
}
#[test]
fn uniform_bad_range() {
let mut rng = default_rng_seeded(42);
assert!(rng.uniform(10.0, 5.0, 100).is_err());
assert!(rng.uniform(5.0, 5.0, 100).is_err());
}
#[test]
fn integers_in_range() {
let mut rng = default_rng_seeded(42);
let arr = rng.integers(0, 10, 10_000).unwrap();
let slice = arr.as_slice().unwrap();
for &v in slice {
assert!((0..10).contains(&v), "value {v} out of range");
}
}
#[test]
fn integers_negative_range() {
let mut rng = default_rng_seeded(42);
let arr = rng.integers(-5, 5, 1000).unwrap();
let slice = arr.as_slice().unwrap();
for &v in slice {
assert!((-5..5).contains(&v), "value {v} out of range");
}
}
#[test]
fn integers_bad_range() {
let mut rng = default_rng_seeded(42);
assert!(rng.integers(10, 5, 100).is_err());
}
#[test]
fn uniform_mean_variance() {
let mut rng = default_rng_seeded(42);
let n = 100_000;
let arr = rng.uniform(2.0, 8.0, n).unwrap();
let slice = arr.as_slice().unwrap();
let mean: f64 = slice.iter().sum::<f64>() / n as f64;
let var: f64 = slice.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
let expected_mean = 5.0;
let expected_var = 3.0;
let se_mean = (expected_var / n as f64).sqrt();
assert!(
(mean - expected_mean).abs() < 3.0 * se_mean,
"mean {mean} too far from {expected_mean}"
);
assert!(
(var - expected_var).abs() < 0.1,
"variance {var} too far from {expected_var}"
);
}
#[test]
fn reproducibility_golden_values() {
let mut rng = default_rng_seeded(42);
let arr = rng.random(5).unwrap();
let vals = arr.as_slice().unwrap();
let golden = [vals[0], vals[1], vals[2], vals[3], vals[4]];
let mut rng2 = default_rng_seeded(42);
let arr2 = rng2.random(5).unwrap();
let vals2 = arr2.as_slice().unwrap();
for i in 0..5 {
assert_eq!(
vals2[i].to_bits(),
golden[i].to_bits(),
"golden value mismatch at index {i}"
);
}
}
#[test]
fn different_seeds_different_values() {
let mut rng1 = default_rng_seeded(42);
let mut rng2 = default_rng_seeded(123);
let a = rng1.random(100).unwrap();
let b = rng2.random(100).unwrap();
let diffs = a
.as_slice()
.unwrap()
.iter()
.zip(b.as_slice().unwrap().iter())
.filter(|(x, y)| x != y)
.count();
assert!(diffs > 50, "seeds 42 and 123 produced too-similar output");
}
#[test]
fn random_nd_shape_from_array() {
let mut rng = default_rng_seeded(42);
let arr = rng.random([3, 4]).unwrap();
assert_eq!(arr.shape(), &[3, 4]);
assert_eq!(arr.size(), 12);
}
#[test]
fn random_nd_shape_from_slice() {
let mut rng = default_rng_seeded(42);
let shape: &[usize] = &[2, 3, 4];
let arr = rng.random(shape).unwrap();
assert_eq!(arr.shape(), &[2, 3, 4]);
assert_eq!(arr.size(), 24);
}
#[test]
fn random_nd_shape_from_vec() {
let mut rng = default_rng_seeded(42);
let shape = vec![5, 5];
let arr = rng.random(shape).unwrap();
assert_eq!(arr.shape(), &[5, 5]);
}
#[test]
fn random_nd_zero_axis_returns_empty() {
let mut rng = default_rng_seeded(42);
let a = rng.random([3, 0]).unwrap();
assert_eq!(a.shape(), &[3, 0]);
assert_eq!(a.size(), 0);
let b = rng.random(0usize).unwrap();
assert_eq!(b.shape(), &[0]);
assert_eq!(b.size(), 0);
}
#[test]
fn random_nd_equivalent_to_reshape() {
let mut rng1 = default_rng_seeded(42);
let mut rng2 = default_rng_seeded(42);
let a = rng1.random(12).unwrap();
let b = rng2.random([3, 4]).unwrap();
assert_eq!(a.size(), b.size());
let a_data: Vec<f64> = a.iter().copied().collect();
let b_data: Vec<f64> = b.iter().copied().collect();
assert_eq!(a_data, b_data);
}
#[test]
fn uniform_nd_shape() {
let mut rng = default_rng_seeded(42);
let arr = rng.uniform(0.0, 10.0, [2, 5]).unwrap();
assert_eq!(arr.shape(), &[2, 5]);
for &v in arr.iter() {
assert!((0.0..10.0).contains(&v));
}
}
#[test]
fn integers_nd_shape() {
let mut rng = default_rng_seeded(42);
let arr = rng.integers(0, 100, [4, 3]).unwrap();
assert_eq!(arr.shape(), &[4, 3]);
for &v in arr.iter() {
assert!((0..100).contains(&v));
}
}
#[test]
fn random_f32_in_range() {
let mut rng = default_rng_seeded(42);
let arr = rng.random_f32(10_000).unwrap();
for &v in arr.as_slice().unwrap() {
assert!((0.0..1.0).contains(&v), "f32 value out of range: {v}");
}
}
#[test]
fn random_f32_deterministic() {
let mut rng1 = default_rng_seeded(42);
let mut rng2 = default_rng_seeded(42);
let a = rng1.random_f32(100).unwrap();
let b = rng2.random_f32(100).unwrap();
assert_eq!(a.as_slice().unwrap(), b.as_slice().unwrap());
}
#[test]
fn random_f32_nd_shape() {
let mut rng = default_rng_seeded(42);
let arr = rng.random_f32([3, 4]).unwrap();
assert_eq!(arr.shape(), &[3, 4]);
}
#[test]
fn random_f32_mean() {
let mut rng = default_rng_seeded(42);
let n = 100_000usize;
let arr = rng.random_f32(n).unwrap();
let sum: f64 = arr.as_slice().unwrap().iter().map(|&v| v as f64).sum();
let mean = sum / n as f64;
assert!(
(mean - 0.5).abs() < 0.01,
"f32 random mean {mean} too far from 0.5"
);
}
#[test]
fn uniform_f32_in_range() {
let mut rng = default_rng_seeded(42);
let arr = rng.uniform_f32(5.0, 10.0, 10_000).unwrap();
for &v in arr.as_slice().unwrap() {
assert!(
(5.0..10.0).contains(&v),
"f32 uniform value out of range: {v}"
);
}
}
#[test]
fn uniform_f32_bad_range() {
let mut rng = default_rng_seeded(42);
assert!(rng.uniform_f32(10.0, 5.0, 100).is_err());
assert!(rng.uniform_f32(5.0, 5.0, 100).is_err());
}
#[test]
fn uniform_f32_nd_shape() {
let mut rng = default_rng_seeded(42);
let arr = rng.uniform_f32(-1.0, 1.0, [2, 5]).unwrap();
assert_eq!(arr.shape(), &[2, 5]);
for &v in arr.iter() {
assert!((-1.0..1.0).contains(&v));
}
}
#[test]
fn random_f32_zero_axis_returns_empty() {
let mut rng = default_rng_seeded(42);
let a = rng.random_f32([3, 0]).unwrap();
assert_eq!(a.shape(), &[3, 0]);
assert_eq!(a.size(), 0);
}
}