use crate::{FloatExt, XError, XResult, random::PAR_THRESHOLD};
use rand::prelude::*;
use rand_distr::{Exp, Exp1};
use rand_xoshiro::Xoshiro256PlusPlus;
use rayon::prelude::*;
#[derive(Debug, Clone)]
pub struct Exponential<T: FloatExt = f64> {
lambda: T,
}
impl<T: FloatExt> Default for Exponential<T> {
fn default() -> Self {
Self { lambda: T::one() }
}
}
impl<T: FloatExt> Exponential<T> {
pub fn new(lambda: T) -> XResult<Self> {
if lambda <= T::zero() {
return Err(XError::InvalidParameters(format!(
"The rate parameter `lambda` must be greater than 0, got {lambda:?}"
)));
}
Ok(Self { lambda })
}
pub fn get_lambda(&self) -> T {
self.lambda
}
pub fn samples(&self, n: usize) -> XResult<Vec<T>>
where
Exp1: Distribution<T>,
{
if self.lambda == T::one() {
Ok(standard_rands(n))
} else {
rands(self.lambda, n)
}
}
}
pub fn standard_rand<T: FloatExt>() -> T
where
Exp1: Distribution<T>,
{
let mut rng = Xoshiro256PlusPlus::from_rng(&mut rand::rng());
standard_rand_with_rng(&mut rng)
}
#[inline]
pub(crate) fn standard_rand_with_rng<T, R>(rng: &mut R) -> T
where
T: FloatExt,
R: Rng + ?Sized,
Exp1: Distribution<T>,
{
rng.sample(Exp1)
}
pub fn standard_rands<T: FloatExt>(n: usize) -> Vec<T>
where
Exp1: Distribution<T>,
{
let dist = Exp1;
if n <= PAR_THRESHOLD {
let mut rng = Xoshiro256PlusPlus::from_rng(&mut rand::rng());
(0..n).map(|_| rng.sample(dist)).collect()
} else {
(0..n)
.into_par_iter()
.map_init(
|| Xoshiro256PlusPlus::from_rng(&mut rand::rng()),
|r, _| r.sample(dist),
)
.collect()
}
}
pub fn rand<T: FloatExt>(lambda: T) -> XResult<T>
where
Exp1: Distribution<T>,
{
let exp = Exp::new(lambda)?;
let mut rng = Xoshiro256PlusPlus::from_rng(&mut rand::rng());
Ok(rng.sample(exp))
}
pub fn rands<T: FloatExt>(lambda: T, n: usize) -> XResult<Vec<T>>
where
Exp1: Distribution<T>,
{
let exp = Exp::new(lambda)?;
if n <= PAR_THRESHOLD {
let mut rng = Xoshiro256PlusPlus::from_rng(&mut rand::rng());
Ok((0..n).map(|_| rng.sample(exp)).collect())
} else {
Ok((0..n)
.into_par_iter()
.map_init(
|| Xoshiro256PlusPlus::from_rng(&mut rand::rng()),
|r, _| r.sample(exp),
)
.collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::utils::calculate_stats;
use num_traits::Float;
#[test]
fn test_standard_rand() {
let random = standard_rand::<f64>();
assert!(random.is_finite());
}
#[test]
fn test_standard_rands() {
let randoms = standard_rands::<f64>(10);
assert_eq!(randoms.len(), 10);
assert!(randoms.iter().all(|r| r.is_finite()));
}
#[test]
fn test_rand() {
let random = rand(1.0).unwrap();
assert!(random.is_finite());
}
#[test]
fn test_rands() {
let randoms = rands(1.0, 10).unwrap();
assert_eq!(randoms.len(), 10);
assert!(randoms.iter().all(|r| r.is_finite()));
}
#[test]
fn test_standard_exponential_stats() {
let n = 1_000_000;
let samples = standard_rands::<f64>(n);
let (mean, variance) = calculate_stats(&samples);
assert!(
(mean - 1.0).abs() < 0.01,
"The mean of the standard exponential distribution should be close to 1, but got {mean}"
);
assert!(
(variance - 1.0).abs() < 0.05,
"The variance of the standard exponential distribution should be close to 1, but got {variance}"
);
}
#[test]
fn test_exponential_stats() {
let n = 1_000_000;
let lambda = 2.0;
let samples = rands(lambda, n).unwrap();
let (mean, variance) = calculate_stats(&samples);
let expected_mean = 1.0 / lambda;
let expected_variance = 1.0 / (lambda * lambda);
assert!(
(mean - expected_mean).abs() < 0.01,
"The mean of the exponential distribution should be close to {expected_mean}, but got {mean}"
);
assert!(
(variance - expected_variance).abs() < 0.05,
"The variance of the exponential distribution should be close to {expected_variance}, but got {variance}"
);
}
}