use crate::utils::ziggurat;
use num_traits::Float;
use crate::{ziggurat_tables, Distribution, Open01};
use rand::Rng;
use core::fmt;
#[derive(Clone, Copy, Debug)]
pub struct StandardNormal;
impl Distribution<f32> for StandardNormal {
#[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f32 {
let x: f64 = self.sample(rng);
x as f32
}
}
impl Distribution<f64> for StandardNormal {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
#[inline]
fn pdf(x: f64) -> f64 {
(-x * x / 2.0).exp()
}
#[inline]
fn zero_case<R: Rng + ?Sized>(rng: &mut R, u: f64) -> f64 {
let mut x = 1.0f64;
let mut y = 0.0f64;
while -2.0 * y < x * x {
let x_: f64 = rng.sample(Open01);
let y_: f64 = rng.sample(Open01);
x = x_.ln() / ziggurat_tables::ZIG_NORM_R;
y = y_.ln();
}
if u < 0.0 {
x - ziggurat_tables::ZIG_NORM_R
} else {
ziggurat_tables::ZIG_NORM_R - x
}
}
ziggurat(
rng,
true, &ziggurat_tables::ZIG_NORM_X,
&ziggurat_tables::ZIG_NORM_F,
pdf,
zero_case,
)
}
}
#[derive(Clone, Copy, Debug)]
pub struct Normal<F>
where F: Float, StandardNormal: Distribution<F>
{
mean: F,
std_dev: F,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Error {
MeanTooSmall,
BadVariance,
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
Error::MeanTooSmall => "mean < 0 or NaN in log-normal distribution",
Error::BadVariance => "variation parameter is non-finite in (log)normal distribution",
})
}
}
#[cfg(feature = "std")]
#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
impl std::error::Error for Error {}
impl<F> Normal<F>
where F: Float, StandardNormal: Distribution<F>
{
#[inline]
pub fn new(mean: F, std_dev: F) -> Result<Normal<F>, Error> {
if !std_dev.is_finite() {
return Err(Error::BadVariance);
}
Ok(Normal { mean, std_dev })
}
#[inline]
pub fn from_mean_cv(mean: F, cv: F) -> Result<Normal<F>, Error> {
if !cv.is_finite() || cv < F::zero() {
return Err(Error::BadVariance);
}
let std_dev = cv * mean;
Ok(Normal { mean, std_dev })
}
#[inline]
pub fn from_zscore(&self, zscore: F) -> F {
self.mean + self.std_dev * zscore
}
pub fn mean(&self) -> F {
self.mean
}
pub fn std_dev(&self) -> F {
self.std_dev
}
}
impl<F> Distribution<F> for Normal<F>
where F: Float, StandardNormal: Distribution<F>
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
self.from_zscore(rng.sample(StandardNormal))
}
}
#[derive(Clone, Copy, Debug)]
pub struct LogNormal<F>
where F: Float, StandardNormal: Distribution<F>
{
norm: Normal<F>,
}
impl<F> LogNormal<F>
where F: Float, StandardNormal: Distribution<F>
{
#[inline]
pub fn new(mu: F, sigma: F) -> Result<LogNormal<F>, Error> {
let norm = Normal::new(mu, sigma)?;
Ok(LogNormal { norm })
}
#[inline]
pub fn from_mean_cv(mean: F, cv: F) -> Result<LogNormal<F>, Error> {
if cv == F::zero() {
let mu = mean.ln();
let norm = Normal::new(mu, F::zero()).unwrap();
return Ok(LogNormal { norm });
}
if !(mean > F::zero()) {
return Err(Error::MeanTooSmall);
}
if !(cv >= F::zero()) {
return Err(Error::BadVariance);
}
let a = F::one() + cv * cv; let mu = F::from(0.5).unwrap() * (mean * mean / a).ln();
let sigma = a.ln().sqrt();
let norm = Normal::new(mu, sigma)?;
Ok(LogNormal { norm })
}
#[inline]
pub fn from_zscore(&self, zscore: F) -> F {
self.norm.from_zscore(zscore).exp()
}
}
impl<F> Distribution<F> for LogNormal<F>
where F: Float, StandardNormal: Distribution<F>
{
#[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
self.norm.sample(rng).exp()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_normal() {
let norm = Normal::new(10.0, 10.0).unwrap();
let mut rng = crate::test::rng(210);
for _ in 0..1000 {
norm.sample(&mut rng);
}
}
#[test]
fn test_normal_cv() {
let norm = Normal::from_mean_cv(1024.0, 1.0 / 256.0).unwrap();
assert_eq!((norm.mean, norm.std_dev), (1024.0, 4.0));
}
#[test]
fn test_normal_invalid_sd() {
assert!(Normal::from_mean_cv(10.0, -1.0).is_err());
}
#[test]
fn test_log_normal() {
let lnorm = LogNormal::new(10.0, 10.0).unwrap();
let mut rng = crate::test::rng(211);
for _ in 0..1000 {
lnorm.sample(&mut rng);
}
}
#[test]
fn test_log_normal_cv() {
let lnorm = LogNormal::from_mean_cv(0.0, 0.0).unwrap();
assert_eq!((lnorm.norm.mean, lnorm.norm.std_dev), (-core::f64::INFINITY, 0.0));
let lnorm = LogNormal::from_mean_cv(1.0, 0.0).unwrap();
assert_eq!((lnorm.norm.mean, lnorm.norm.std_dev), (0.0, 0.0));
let e = core::f64::consts::E;
let lnorm = LogNormal::from_mean_cv(e.sqrt(), (e - 1.0).sqrt()).unwrap();
assert_almost_eq!(lnorm.norm.mean, 0.0, 2e-16);
assert_almost_eq!(lnorm.norm.std_dev, 1.0, 2e-16);
let lnorm = LogNormal::from_mean_cv(e.powf(1.5), (e - 1.0).sqrt()).unwrap();
assert_almost_eq!(lnorm.norm.mean, 1.0, 1e-15);
assert_eq!(lnorm.norm.std_dev, 1.0);
}
#[test]
fn test_log_normal_invalid_sd() {
assert!(LogNormal::from_mean_cv(-1.0, 1.0).is_err());
assert!(LogNormal::from_mean_cv(0.0, 1.0).is_err());
assert!(LogNormal::from_mean_cv(1.0, -1.0).is_err());
}
}