use ferray_core::{Array, FerrayError, IxDyn};
use crate::bitgen::BitGenerator;
use crate::distributions::ziggurat::{standard_normal_ziggurat, standard_normal_ziggurat_f32};
use crate::generator::{
Generator, generate_vec, generate_vec_f32, shape_size, vec_to_array_f32, vec_to_array_f64,
};
use crate::shape::IntoShape;
pub(crate) fn standard_normal_single<B: BitGenerator>(bg: &mut B) -> f64 {
standard_normal_ziggurat(bg)
}
pub(crate) fn standard_normal_single_f32<B: BitGenerator>(bg: &mut B) -> f32 {
standard_normal_ziggurat_f32(bg)
}
impl<B: BitGenerator> Generator<B> {
pub fn standard_normal(
&mut self,
size: impl IntoShape,
) -> Result<Array<f64, IxDyn>, FerrayError> {
let shape = size.into_shape()?;
let n = shape_size(&shape);
let data = generate_vec(self, n, standard_normal_single);
vec_to_array_f64(data, &shape)
}
pub fn normal(
&mut self,
loc: f64,
scale: f64,
size: impl IntoShape,
) -> Result<Array<f64, IxDyn>, FerrayError> {
if scale <= 0.0 {
return Err(FerrayError::invalid_value(format!(
"scale must be positive, got {scale}"
)));
}
let shape = size.into_shape()?;
let n = shape_size(&shape);
let data = generate_vec(self, n, |bg| scale.mul_add(standard_normal_single(bg), loc));
vec_to_array_f64(data, &shape)
}
pub fn standard_normal_f32(
&mut self,
size: impl IntoShape,
) -> Result<Array<f32, IxDyn>, FerrayError> {
let shape = size.into_shape()?;
let n = shape_size(&shape);
let data = generate_vec_f32(self, n, standard_normal_single_f32);
vec_to_array_f32(data, &shape)
}
pub fn normal_f32(
&mut self,
loc: f32,
scale: f32,
size: impl IntoShape,
) -> Result<Array<f32, IxDyn>, FerrayError> {
if scale <= 0.0 {
return Err(FerrayError::invalid_value(format!(
"scale must be positive, got {scale}"
)));
}
let shape = size.into_shape()?;
let n = shape_size(&shape);
let data = generate_vec_f32(self, n, |bg| {
scale.mul_add(standard_normal_single_f32(bg), loc)
});
vec_to_array_f32(data, &shape)
}
pub fn lognormal_f32(
&mut self,
mean: f32,
sigma: f32,
size: impl IntoShape,
) -> Result<Array<f32, IxDyn>, FerrayError> {
if sigma <= 0.0 {
return Err(FerrayError::invalid_value(format!(
"sigma must be positive, got {sigma}"
)));
}
let shape = size.into_shape()?;
let n = shape_size(&shape);
let data = generate_vec_f32(self, n, |bg| {
sigma.mul_add(standard_normal_single_f32(bg), mean).exp()
});
vec_to_array_f32(data, &shape)
}
pub fn lognormal(
&mut self,
mean: f64,
sigma: f64,
size: impl IntoShape,
) -> Result<Array<f64, IxDyn>, FerrayError> {
if sigma <= 0.0 {
return Err(FerrayError::invalid_value(format!(
"sigma must be positive, got {sigma}"
)));
}
let shape = size.into_shape()?;
let n = shape_size(&shape);
let data = generate_vec(self, n, |bg| {
sigma.mul_add(standard_normal_single(bg), mean).exp()
});
vec_to_array_f64(data, &shape)
}
}
#[cfg(test)]
mod tests {
use crate::default_rng_seeded;
#[test]
fn standard_normal_deterministic() {
let mut rng1 = default_rng_seeded(42);
let mut rng2 = default_rng_seeded(42);
let a = rng1.standard_normal(1000).unwrap();
let b = rng2.standard_normal(1000).unwrap();
assert_eq!(a.as_slice().unwrap(), b.as_slice().unwrap());
}
#[test]
fn standard_normal_mean_variance() {
let mut rng = default_rng_seeded(42);
let n = 100_000;
let arr = rng.standard_normal(n).unwrap();
let slice = arr.as_slice().unwrap();
let mean: f64 = slice.iter().sum::<f64>() / n as f64;
let var: f64 = slice.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
let se = (1.0 / n as f64).sqrt();
assert!(mean.abs() < 3.0 * se, "mean {mean} too far from 0");
assert!((var - 1.0).abs() < 0.05, "variance {var} too far from 1");
}
#[test]
fn normal_mean_variance() {
let mut rng = default_rng_seeded(42);
let n = 100_000;
let loc = 5.0;
let scale = 2.0;
let arr = rng.normal(loc, scale, n).unwrap();
let slice = arr.as_slice().unwrap();
let mean: f64 = slice.iter().sum::<f64>() / n as f64;
let var: f64 = slice.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
let se = (scale * scale / n as f64).sqrt();
assert!(
(mean - loc).abs() < 3.0 * se,
"mean {mean} too far from {loc}"
);
assert!(
(var - scale * scale).abs() < 0.2,
"variance {var} too far from {}",
scale * scale
);
}
#[test]
fn normal_bad_scale() {
let mut rng = default_rng_seeded(42);
assert!(rng.normal(0.0, 0.0, 100).is_err());
assert!(rng.normal(0.0, -1.0, 100).is_err());
}
#[test]
fn lognormal_positive() {
let mut rng = default_rng_seeded(42);
let arr = rng.lognormal(0.0, 1.0, 10_000).unwrap();
let slice = arr.as_slice().unwrap();
for &v in slice {
assert!(v > 0.0, "lognormal produced non-positive value: {v}");
}
}
#[test]
fn lognormal_mean() {
let mut rng = default_rng_seeded(42);
let n = 100_000;
let mu = 0.0;
let sigma = 0.5;
let arr = rng.lognormal(mu, sigma, n).unwrap();
let slice = arr.as_slice().unwrap();
let mean: f64 = slice.iter().sum::<f64>() / n as f64;
let expected_mean = (mu + sigma * sigma / 2.0).exp();
let expected_var = (sigma * sigma).exp_m1() * 2.0f64.mul_add(mu, sigma * sigma).exp();
let se = (expected_var / n as f64).sqrt();
assert!(
(mean - expected_mean).abs() < 3.0 * se,
"lognormal mean {mean} too far from {expected_mean}"
);
}
#[test]
fn standard_normal_variance() {
let mut rng = default_rng_seeded(42);
let n = 100_000;
let arr = rng.standard_normal(n).unwrap();
let s = arr.as_slice().unwrap();
let mean: f64 = s.iter().sum::<f64>() / n as f64;
let var: f64 = s.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
assert!(
(var - 1.0).abs() < 0.05,
"standard_normal variance {var} too far from 1.0"
);
}
#[test]
fn normal_mean_and_variance() {
let mut rng = default_rng_seeded(42);
let n = 100_000;
let loc = 5.0;
let scale = 2.0;
let arr = rng.normal(loc, scale, n).unwrap();
let s: Vec<f64> = arr.iter().copied().collect();
let mean: f64 = s.iter().sum::<f64>() / n as f64;
let var: f64 = s.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
assert!(
(mean - loc).abs() < 0.05,
"normal mean {mean} too far from {loc}"
);
assert!(
(var - scale * scale).abs() < 0.2,
"normal variance {var} too far from {}",
scale * scale
);
}
#[test]
fn standard_normal_nd_shape() {
let mut rng = crate::default_rng_seeded(42);
let arr = rng.standard_normal([3, 4]).unwrap();
assert_eq!(arr.shape(), &[3, 4]);
}
#[test]
fn normal_nd_shape() {
let mut rng = crate::default_rng_seeded(42);
let arr = rng.normal(10.0, 2.0, [2, 3, 4]).unwrap();
assert_eq!(arr.shape(), &[2, 3, 4]);
}
#[test]
fn lognormal_nd_shape() {
let mut rng = crate::default_rng_seeded(42);
let arr = rng.lognormal(0.0, 1.0, [5, 5]).unwrap();
assert_eq!(arr.shape(), &[5, 5]);
for &v in arr.iter() {
assert!(v > 0.0);
}
}
#[test]
fn standard_normal_f32_deterministic() {
let mut rng1 = default_rng_seeded(42);
let mut rng2 = default_rng_seeded(42);
let a = rng1.standard_normal_f32(1000).unwrap();
let b = rng2.standard_normal_f32(1000).unwrap();
assert_eq!(a.as_slice().unwrap(), b.as_slice().unwrap());
}
#[test]
fn standard_normal_f32_mean_variance() {
let mut rng = default_rng_seeded(42);
let n = 100_000usize;
let arr = rng.standard_normal_f32(n).unwrap();
let slice = arr.as_slice().unwrap();
let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
let var: f64 = slice
.iter()
.map(|&x| {
let d = x as f64 - mean;
d * d
})
.sum::<f64>()
/ n as f64;
let se = (1.0 / n as f64).sqrt();
assert!(mean.abs() < 5.0 * se, "f32 mean {mean} too far from 0");
assert!(
(var - 1.0).abs() < 0.05,
"f32 variance {var} too far from 1"
);
}
#[test]
fn standard_normal_f32_nd_shape() {
let mut rng = default_rng_seeded(42);
let arr = rng.standard_normal_f32([3, 4]).unwrap();
assert_eq!(arr.shape(), &[3, 4]);
}
#[test]
fn normal_f32_mean() {
let mut rng = default_rng_seeded(42);
let n = 100_000usize;
let loc = 5.0f32;
let scale = 2.0f32;
let arr = rng.normal_f32(loc, scale, n).unwrap();
let slice = arr.as_slice().unwrap();
let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
assert!(
(mean - loc as f64).abs() < 0.05,
"f32 normal mean {mean} too far from {loc}"
);
}
#[test]
fn normal_f32_bad_scale() {
let mut rng = default_rng_seeded(42);
assert!(rng.normal_f32(0.0, 0.0, 100).is_err());
assert!(rng.normal_f32(0.0, -1.0, 100).is_err());
}
#[test]
fn lognormal_f32_positive() {
let mut rng = default_rng_seeded(42);
let arr = rng.lognormal_f32(0.0, 1.0, 10_000).unwrap();
for &v in arr.as_slice().unwrap() {
assert!(v > 0.0, "lognormal_f32 produced non-positive value: {v}");
}
}
#[test]
fn lognormal_f32_bad_sigma() {
let mut rng = default_rng_seeded(42);
assert!(rng.lognormal_f32(0.0, 0.0, 100).is_err());
assert!(rng.lognormal_f32(0.0, -0.5, 100).is_err());
}
#[test]
fn normal_nan_loc_produces_nan_output() {
let mut rng = default_rng_seeded(42);
let arr = rng.normal(f64::NAN, 1.0, 5).unwrap();
for &v in arr.as_slice().unwrap() {
assert!(v.is_nan(), "expected NaN, got {v}");
}
}
#[test]
fn normal_inf_scale_produces_inf_output() {
let mut rng = default_rng_seeded(42);
let arr = rng.normal(0.0, f64::INFINITY, 5).unwrap();
for &v in arr.as_slice().unwrap() {
assert!(v.is_infinite() || v.is_nan(), "expected Inf/NaN, got {v}");
}
}
#[test]
fn normal_nan_scale_rejected() {
let mut rng = default_rng_seeded(42);
let _ = rng.normal(0.0, f64::NAN, 5);
}
}