use super::*;
#[derive(Copy, Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Exp1;
impl Distribution<f32> for Exp1 {
#[inline]
fn sample<R: Rng + ?Sized>(&self, rand: &mut Random<R>) -> f32 {
let x: f64 = self.sample(rand);
x as f32
}
}
impl Distribution<f64> for Exp1 {
#[inline]
fn sample<R: Rng + ?Sized>(&self, rand: &mut Random<R>) -> f64 {
#[inline]
fn pdf(x: f64) -> f64 {
(-x).exp()
}
#[inline]
fn zero_case<R: Rng + ?Sized>(rand: &mut Random<R>, _u: f64) -> f64 {
ziggurat::ZIG_EXP_R - rand.float01().ln()
}
ziggurat::ziggurat(
rand,
false,
&ziggurat::ZIG_EXP_X,
&ziggurat::ZIG_EXP_F,
pdf,
zero_case,
)
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum ExpError {
LambdaTooSmall,
}
impl fmt::Display for ExpError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
ExpError::LambdaTooSmall => "lambda is negative or NaN in exponential distribution",
})
}
}
#[cfg(feature = "std")]
impl std::error::Error for ExpError {}
pub trait ExpImpl<Float>: Sized {
fn try_new(lambda: Float) -> Result<Self, ExpError>;
}
#[derive(Copy, Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Exp<Float> {
lambda_inverse: Float,
}
impl<Float> Exp<Float> where Self: ExpImpl<Float> {
#[inline]
pub fn try_new(lambda: Float) -> Result<Exp<Float>, ExpError> {
ExpImpl::try_new(lambda)
}
#[track_caller]
#[inline]
pub fn new(lambda: Float) -> Exp<Float> {
ExpImpl::try_new(lambda).unwrap()
}
}
macro_rules! impl_exp {
($f:ty) => {
impl ExpImpl<$f> for Exp<$f> {
#[inline]
fn try_new(lambda: $f) -> Result<Self, ExpError> {
if !(lambda >= 0.0) {
return Err(ExpError::LambdaTooSmall);
}
Ok(Exp {
lambda_inverse: 1.0 / lambda,
})
}
}
impl Distribution<$f> for Exp<$f> {
#[inline]
fn sample<R: Rng + ?Sized>(&self, rand: &mut Random<R>) -> $f {
let x: $f = Exp1.sample(rand);
x * self.lambda_inverse
}
}
};
}
impl_exp!(f32);
impl_exp!(f64);
#[cfg(test)]
mod tests;