diffusionx 0.12.0

A multi-threaded crate for random number generation and stochastic process simulation, with optional GPU acceleration.
//! Normal random number generation
//! For other stable distributions, see [crate::random::stable].

use crate::{FloatExt, XError, XResult, random::PAR_THRESHOLD};
use rand::prelude::*;
use rand_distr::StandardNormal;
use rand_xoshiro::Xoshiro256PlusPlus;
use rayon::prelude::*;

/// Normal distribution
#[derive(Debug, Clone)]
pub struct Normal<T: FloatExt = f64> {
    /// mean
    mu: T,
    /// standard deviation
    sigma: T,
}

impl<T: FloatExt> Default for Normal<T> {
    fn default() -> Self {
        Self {
            mu: T::zero(),
            sigma: T::one(),
        }
    }
}

impl<T: FloatExt> Normal<T> {
    /// Create a new normal distribution with a given mean and standard deviation
    ///
    /// # Arguments
    ///
    /// * `mu` - The mean of the normal distribution.
    /// * `sigma` - The standard deviation of the normal distribution, must be greater than 0.
    ///
    /// # Example
    ///
    /// ```rust
    /// use diffusionx::random::normal::Normal;
    ///
    /// let mu = 1.0;
    /// let sigma = 2.0;
    /// let normal = Normal::new(mu, sigma).unwrap();
    /// ```
    pub fn new(mu: T, sigma: T) -> XResult<Self> {
        if sigma <= T::zero() {
            return Err(XError::InvalidParameters(format!(
                "The standard deviation `sigma` must be greater than 0, got {sigma:?}"
            )));
        }
        Ok(Self { mu, sigma })
    }

    /// Get the mean
    pub fn get_mu(&self) -> T {
        self.mu
    }

    /// Get the standard deviation
    pub fn get_sigma(&self) -> T {
        self.sigma
    }

    /// Generate a vector of normal random numbers
    ///
    /// # Arguments
    ///
    /// * `n` - The number of random numbers to generate, must be greater than 0.
    ///
    /// # Example
    ///
    /// ```rust
    /// use diffusionx::random::normal::Normal;
    ///
    /// let normal = Normal::default();
    /// let randoms = normal.samples(10).unwrap();
    /// ```
    pub fn samples(&self, n: usize) -> XResult<Vec<T>>
    where
        StandardNormal: Distribution<T>,
    {
        if self.sigma == T::one() && self.mu == T::zero() {
            Ok(standard_rands(n))
        } else {
            rands(self.mu, self.sigma, n)
        }
    }
}

/// Generate a standard normal random number
///
/// # Example
///
/// ```rust
/// use diffusionx::random::normal::standard_rand;
///
/// let random = standard_rand::<f64>();
/// ```
pub fn standard_rand<T: FloatExt>() -> T
where
    StandardNormal: Distribution<T>,
{
    let mut rng = Xoshiro256PlusPlus::from_rng(&mut rand::rng());
    standard_rand_with_rng(&mut rng)
}

#[inline]
pub(crate) fn standard_rand_with_rng<T, R>(rng: &mut R) -> T
where
    T: FloatExt,
    R: Rng + ?Sized,
    StandardNormal: Distribution<T>,
{
    rng.sample(rand_distr::StandardNormal)
}

/// Generate a vector of standard normal random numbers
///
/// # Example
///
/// ```rust
/// use diffusionx::random::normal::standard_rands;
///
/// let randoms = standard_rands::<f64>(10);
/// ```
pub fn standard_rands<T: FloatExt>(n: usize) -> Vec<T>
where
    StandardNormal: Distribution<T>,
{
    let dist = rand_distr::StandardNormal;
    if n <= PAR_THRESHOLD {
        let mut rng = Xoshiro256PlusPlus::from_rng(&mut rand::rng());
        (0..n).map(|_| rng.sample(dist)).collect()
    } else {
        (0..n)
            .into_par_iter()
            .map_init(
                || Xoshiro256PlusPlus::from_rng(&mut rand::rng()),
                |r, _| r.sample(dist),
            )
            .collect()
    }
}

/// Generate a normal random number
///
/// # Arguments
///
/// * `mean` - The mean of the normal distribution.
/// * `std_dev` - The standard deviation of the normal distribution, must be greater than 0.
///
/// # Example
///
/// ```rust
/// use diffusionx::random::normal::rand;
///
/// let random = rand(0.0, 1.0).unwrap();
/// ```
pub fn rand<T: FloatExt>(mean: T, std_dev: T) -> XResult<T>
where
    StandardNormal: Distribution<T>,
{
    let normal = rand_distr::Normal::new(mean, std_dev)?;
    let mut rng = Xoshiro256PlusPlus::from_rng(&mut rand::rng());
    Ok(rng.sample(normal))
}

/// Generate a vector of normal random numbers
///
/// # Arguments
///
/// * `mean` - The mean of the normal distribution.
/// * `std_dev` - The standard deviation of the normal distribution, must be greater than 0.
/// * `n` - The number of random numbers to generate, must be greater than 0.
///
/// # Example
///
/// ```rust
/// use diffusionx::random::normal::rands;
///
/// let randoms = rands(0.0, 1.0, 10).unwrap();
/// ```
pub fn rands<T: FloatExt>(mean: T, std_dev: T, n: usize) -> XResult<Vec<T>>
where
    StandardNormal: Distribution<T>,
{
    let normal = rand_distr::Normal::new(mean, std_dev)?;
    if n <= PAR_THRESHOLD {
        let mut rng = Xoshiro256PlusPlus::from_rng(&mut rand::rng());
        Ok((0..n).map(|_| rng.sample(normal)).collect())
    } else {
        Ok((0..n)
            .into_par_iter()
            .map_init(
                || Xoshiro256PlusPlus::from_rng(&mut rand::rng()),
                |r, _| r.sample(normal),
            )
            .collect())
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::utils::calculate_stats;
    use num_traits::Float;

    #[test]
    fn test_standard_rand() {
        let random = standard_rand::<f64>();
        assert!(random.is_finite());
    }

    #[test]
    fn test_standard_rands() {
        let randoms = standard_rands::<f64>(10);
        assert_eq!(randoms.len(), 10);
        assert!(randoms.iter().all(|r| r.is_finite()));
    }

    #[test]
    fn test_rand() {
        let random = rand(0.0, 1.0).unwrap();
        assert!(random.is_finite());
    }

    #[test]
    fn test_rands() {
        let randoms = rands(0.0, 1.0, 10).unwrap();
        assert_eq!(randoms.len(), 10);
        assert!(randoms.iter().all(|r| r.is_finite()));
    }

    #[test]
    fn test_standard_normal_stats() {
        let n = 1_000_000;
        let samples = standard_rands(n);
        let (mean, variance) = calculate_stats(&samples);
        let std_dev = variance.sqrt();

        assert!(
            mean.abs() < 0.01,
            "The mean of the standard normal distribution should be close to 0, got {mean}"
        );
        assert!(
            (std_dev - 1.0).abs() < 0.01,
            "The standard deviation of the standard normal distribution should be close to 1, got {std_dev}"
        );
    }

    #[test]
    fn test_normal_stats() {
        let n = 1_000_000;
        let mu = 2.0;
        let sigma = 3.0;
        let samples = rands(mu, sigma, n).unwrap();
        let (mean, variance) = calculate_stats(&samples);
        let std_dev = variance.sqrt();

        assert!(
            (mean - mu).abs() < 0.05,
            "The mean of the normal distribution should be close to {mu}, got {mean}"
        );
        assert!(
            (std_dev - sigma).abs() < 0.05,
            "The standard deviation of the normal distribution should be close to {sigma}, got {std_dev}"
        );
    }
}