use super::*;
#[derive(Copy, Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct StandardNormal;
impl Distribution<f32> for StandardNormal {
#[inline]
fn sample<R: Rng + ?Sized>(&self, rand: &mut Random<R>) -> f32 {
let x: f64 = self.sample(rand);
x as f32
}
}
impl Distribution<f64> for StandardNormal {
fn sample<R: Rng + ?Sized>(&self, rand: &mut Random<R>) -> f64 {
#[inline]
fn pdf(x: f64) -> f64 {
(-x * x / 2.0).exp()
}
#[inline]
fn zero_case<R: Rng + ?Sized>(rand: &mut Random<R>, u: f64) -> f64 {
let mut x = 1.0f64;
let mut y = 0.0f64;
while -2.0 * y < x * x {
let x_: f64 = rand.float01();
let y_: f64 = rand.float01();
x = x_.ln() / ziggurat::ZIG_NORM_R;
y = y_.ln();
}
if u < 0.0 {
x - ziggurat::ZIG_NORM_R
}
else {
ziggurat::ZIG_NORM_R - x
}
}
ziggurat::ziggurat(
rand,
true, &ziggurat::ZIG_NORM_X,
&ziggurat::ZIG_NORM_F,
pdf,
zero_case,
)
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum NormalError {
MeanTooSmall,
BadVariance,
}
impl fmt::Display for NormalError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
NormalError::MeanTooSmall => "mean < 0 or NaN in log-normal distribution",
NormalError::BadVariance => "variation parameter is non-finite in (log)normal distribution",
})
}
}
#[cfg(feature = "std")]
impl std::error::Error for NormalError {}
pub trait NormalImpl<Float>: Sized {
fn try_new(_1: Float, _2: Float) -> Result<Self, NormalError>;
fn try_from_mean_cv(_1: Float, _2: Float) -> Result<Self, NormalError>;
fn from_zscore(&self, zscore: Float) -> Float;
}
#[derive(Copy, Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Normal<Float> {
mean: Float,
std_dev: Float,
}
impl<Float: Copy> Normal<Float> where Self: NormalImpl<Float> {
#[inline]
pub fn try_new(mean: Float, std_dev: Float) -> Result<Normal<Float>, NormalError> {
NormalImpl::try_new(mean, std_dev)
}
#[track_caller]
#[inline]
pub fn new(mean: Float, std_dev: Float) -> Normal<Float> {
NormalImpl::try_new(mean, std_dev).unwrap()
}
#[inline]
pub fn try_from_mean_cv(mean: Float, cv: Float) -> Result<Normal<Float>, NormalError> {
NormalImpl::try_from_mean_cv(mean, cv)
}
#[track_caller]
#[inline]
pub fn from_mean_cv(mean: Float, cv: Float) -> Normal<Float> {
NormalImpl::try_from_mean_cv(mean, cv).unwrap()
}
#[inline]
pub fn mean(&self) -> Float {
self.mean
}
#[inline]
pub fn std_dev(&self) -> Float {
self.std_dev
}
#[inline]
pub fn from_zscore(&self, zscore: Float) -> Float {
NormalImpl::from_zscore(self, zscore)
}
}
macro_rules! impl_normal {
($ty:ty) => {
impl NormalImpl<$ty> for Normal<$ty> {
#[inline]
fn try_new(mean: $ty, std_dev: $ty) -> Result<Normal<$ty>, NormalError> {
if !std_dev.is_finite() {
return Err(NormalError::BadVariance);
}
Ok(Normal { mean, std_dev })
}
#[inline]
fn try_from_mean_cv(mean: $ty, cv: $ty) -> Result<Normal<$ty>, NormalError> {
if !cv.is_finite() || cv < 0.0 {
return Err(NormalError::BadVariance);
}
let std_dev = cv * mean;
Ok(Normal { mean, std_dev })
}
#[inline]
fn from_zscore(&self, zscore: $ty) -> $ty {
self.std_dev.mul_add(zscore, self.mean)
}
}
impl Distribution<$ty> for Normal<$ty> {
fn sample<R: Rng + ?Sized>(&self, rand: &mut Random<R>) -> $ty {
self.from_zscore(StandardNormal.sample(rand))
}
}
}
}
impl_normal!(f32);
impl_normal!(f64);
#[derive(Copy, Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct LogNormal<Float> {
norm: Normal<Float>,
}
impl<Float: Copy> LogNormal<Float> where Self: NormalImpl<Float> {
#[inline]
pub fn try_new(mu: Float, sigma: Float) -> Result<LogNormal<Float>, NormalError> {
NormalImpl::try_new(mu, sigma)
}
#[track_caller]
#[inline]
pub fn new(mu: Float, sigma: Float) -> LogNormal<Float> {
NormalImpl::try_new(mu, sigma).unwrap()
}
#[inline]
pub fn try_from_mean_cv(mean: Float, cv: Float) -> Result<LogNormal<Float>, NormalError> {
NormalImpl::try_from_mean_cv(mean, cv)
}
#[track_caller]
#[inline]
pub fn from_mean_cv(mean: Float, cv: Float) -> LogNormal<Float> {
NormalImpl::try_from_mean_cv(mean, cv).unwrap()
}
#[inline]
pub fn from_zscore(&self, zscore: Float) -> Float {
NormalImpl::from_zscore(self, zscore)
}
}
macro_rules! impl_log_normal {
($ty:ty) => {
impl NormalImpl<$ty> for LogNormal<$ty> {
#[inline]
fn try_new(mu: $ty, sigma: $ty) -> Result<LogNormal<$ty>, NormalError> {
let norm = Normal::try_new(mu, sigma)?;
Ok(LogNormal { norm })
}
#[inline]
fn try_from_mean_cv(mean: $ty, cv: $ty) -> Result<LogNormal<$ty>, NormalError> {
if cv == 0.0 {
let mu = mean.ln();
let norm = Normal::try_new(mu, 0.0)?;
return Ok(LogNormal { norm });
}
if !(mean > 0.0) {
return Err(NormalError::MeanTooSmall);
}
if !(cv >= 0.0) {
return Err(NormalError::BadVariance);
}
let a = 1.0 + cv * cv; let mu = 0.5 * (mean * mean / a).ln();
let sigma = a.ln().sqrt();
let norm = Normal::try_new(mu, sigma)?;
Ok(LogNormal { norm })
}
#[inline]
fn from_zscore(&self, zscore: $ty) -> $ty {
self.norm.from_zscore(zscore).exp()
}
}
impl Distribution<$ty> for LogNormal<$ty> {
fn sample<R: Rng + ?Sized>(&self, rand: &mut Random<R>) -> $ty {
self.norm.sample(rand).exp()
}
}
}
}
impl_log_normal!(f32);
impl_log_normal!(f64);
#[cfg(test)]
mod tests;