use crate::{FloatExt, XError, XResult, random::PAR_THRESHOLD};
use rand::prelude::*;
use rand_distr::StandardNormal;
use rand_xoshiro::Xoshiro256PlusPlus;
use rayon::prelude::*;
#[derive(Debug, Clone)]
pub struct Normal<T: FloatExt = f64> {
mu: T,
sigma: T,
}
impl<T: FloatExt> Default for Normal<T> {
fn default() -> Self {
Self {
mu: T::zero(),
sigma: T::one(),
}
}
}
impl<T: FloatExt> Normal<T> {
pub fn new(mu: T, sigma: T) -> XResult<Self> {
if sigma <= T::zero() {
return Err(XError::InvalidParameters(format!(
"The standard deviation `sigma` must be greater than 0, got {sigma:?}"
)));
}
Ok(Self { mu, sigma })
}
pub fn get_mu(&self) -> T {
self.mu
}
pub fn get_sigma(&self) -> T {
self.sigma
}
pub fn samples(&self, n: usize) -> XResult<Vec<T>>
where
StandardNormal: Distribution<T>,
{
if self.sigma == T::one() && self.mu == T::zero() {
Ok(standard_rands(n))
} else {
rands(self.mu, self.sigma, n)
}
}
}
pub fn standard_rand<T: FloatExt>() -> T
where
StandardNormal: Distribution<T>,
{
let mut rng = Xoshiro256PlusPlus::from_rng(&mut rand::rng());
standard_rand_with_rng(&mut rng)
}
#[inline]
pub(crate) fn standard_rand_with_rng<T, R>(rng: &mut R) -> T
where
T: FloatExt,
R: Rng + ?Sized,
StandardNormal: Distribution<T>,
{
rng.sample(rand_distr::StandardNormal)
}
pub fn standard_rands<T: FloatExt>(n: usize) -> Vec<T>
where
StandardNormal: Distribution<T>,
{
let dist = rand_distr::StandardNormal;
if n <= PAR_THRESHOLD {
let mut rng = Xoshiro256PlusPlus::from_rng(&mut rand::rng());
(0..n).map(|_| rng.sample(dist)).collect()
} else {
(0..n)
.into_par_iter()
.map_init(
|| Xoshiro256PlusPlus::from_rng(&mut rand::rng()),
|r, _| r.sample(dist),
)
.collect()
}
}
pub fn rand<T: FloatExt>(mean: T, std_dev: T) -> XResult<T>
where
StandardNormal: Distribution<T>,
{
let normal = rand_distr::Normal::new(mean, std_dev)?;
let mut rng = Xoshiro256PlusPlus::from_rng(&mut rand::rng());
Ok(rng.sample(normal))
}
pub fn rands<T: FloatExt>(mean: T, std_dev: T, n: usize) -> XResult<Vec<T>>
where
StandardNormal: Distribution<T>,
{
let normal = rand_distr::Normal::new(mean, std_dev)?;
if n <= PAR_THRESHOLD {
let mut rng = Xoshiro256PlusPlus::from_rng(&mut rand::rng());
Ok((0..n).map(|_| rng.sample(normal)).collect())
} else {
Ok((0..n)
.into_par_iter()
.map_init(
|| Xoshiro256PlusPlus::from_rng(&mut rand::rng()),
|r, _| r.sample(normal),
)
.collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::utils::calculate_stats;
use num_traits::Float;
#[test]
fn test_standard_rand() {
let random = standard_rand::<f64>();
assert!(random.is_finite());
}
#[test]
fn test_standard_rands() {
let randoms = standard_rands::<f64>(10);
assert_eq!(randoms.len(), 10);
assert!(randoms.iter().all(|r| r.is_finite()));
}
#[test]
fn test_rand() {
let random = rand(0.0, 1.0).unwrap();
assert!(random.is_finite());
}
#[test]
fn test_rands() {
let randoms = rands(0.0, 1.0, 10).unwrap();
assert_eq!(randoms.len(), 10);
assert!(randoms.iter().all(|r| r.is_finite()));
}
#[test]
fn test_standard_normal_stats() {
let n = 1_000_000;
let samples = standard_rands(n);
let (mean, variance) = calculate_stats(&samples);
let std_dev = variance.sqrt();
assert!(
mean.abs() < 0.01,
"The mean of the standard normal distribution should be close to 0, got {mean}"
);
assert!(
(std_dev - 1.0).abs() < 0.01,
"The standard deviation of the standard normal distribution should be close to 1, got {std_dev}"
);
}
#[test]
fn test_normal_stats() {
let n = 1_000_000;
let mu = 2.0;
let sigma = 3.0;
let samples = rands(mu, sigma, n).unwrap();
let (mean, variance) = calculate_stats(&samples);
let std_dev = variance.sqrt();
assert!(
(mean - mu).abs() < 0.05,
"The mean of the normal distribution should be close to {mu}, got {mean}"
);
assert!(
(std_dev - sigma).abs() < 0.05,
"The standard deviation of the normal distribution should be close to {sigma}, got {std_dev}"
);
}
}