use crate::utils::ziggurat;
use crate::{Distribution, ziggurat_tables};
use core::fmt;
use num_traits::Float;
use rand::{Rng, RngExt};
#[derive(Clone, Copy, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Exp1;
impl Distribution<f32> for Exp1 {
#[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f32 {
let x: f64 = self.sample(rng);
x as f32
}
}
impl Distribution<f64> for Exp1 {
#[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
#[inline]
fn pdf(x: f64) -> f64 {
(-x).exp()
}
#[inline]
fn zero_case<R: Rng + ?Sized>(rng: &mut R, _u: f64) -> f64 {
ziggurat_tables::ZIG_EXP_R - rng.random::<f64>().ln()
}
ziggurat(
rng,
false,
&ziggurat_tables::ZIG_EXP_X,
&ziggurat_tables::ZIG_EXP_F,
pdf,
zero_case,
)
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Exp<F>
where
F: Float,
Exp1: Distribution<F>,
{
lambda_inverse: F,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Error {
LambdaTooSmall,
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
Error::LambdaTooSmall => {
"lambda is negative (including -0.0) or NaN in exponential distribution"
}
})
}
}
#[cfg(feature = "std")]
impl std::error::Error for Error {}
impl<F: Float> Exp<F>
where
F: Float,
Exp1: Distribution<F>,
{
#[inline]
pub fn new(lambda: F) -> Result<Exp<F>, Error> {
if lambda.is_sign_negative() || lambda.is_nan() {
return Err(Error::LambdaTooSmall);
}
Ok(Exp {
lambda_inverse: F::one() / lambda,
})
}
}
impl<F> Distribution<F> for Exp<F>
where
F: Float,
Exp1: Distribution<F>,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
rng.sample(Exp1) * self.lambda_inverse
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_exp() {
let exp = Exp::new(10.0).unwrap();
let mut rng = crate::test::rng(221);
for _ in 0..1000 {
assert!(exp.sample(&mut rng) >= 0.0);
}
}
#[test]
fn test_zero() {
let d = Exp::new(0.0).unwrap();
assert_eq!(d.sample(&mut crate::test::rng(21)), f64::infinity());
}
#[test]
#[should_panic]
fn test_exp_invalid_lambda_neg() {
Exp::new(-10.0).unwrap();
}
#[test]
#[should_panic]
fn test_exp_invalid_lambda_nan() {
Exp::new(f64::nan()).unwrap();
}
#[test]
fn exponential_distributions_can_be_compared() {
assert_eq!(Exp::new(1.0), Exp::new(1.0));
}
}