use rand_distr::Distribution as Distribution2;
use rand_distr::Normal as Normal2;
use crate::distributions::Distribution;
use rand::Rng;
pub struct Normal {
mean: f64,
std_dev: f64,
}
impl Normal {
pub fn new(mean: f64, std_dev: f64) -> Result<Normal, String> {
if std_dev <= 0f64 {
Err(format! {"Normal: illegal std_dev `{}` should be greater than 0", std_dev})
} else {
Ok(Normal { mean, std_dev })
}
}
}
impl<R: Rng + ?Sized> Distribution<R> for Normal {
type Domain = f64;
fn sample(&self, rng: &mut R) -> f64 {
Normal2::new(self.mean, self.std_dev).unwrap().sample(rng)
}
fn log_prob(&self, x: &f64) -> f64 {
let z = (x - self.mean) / self.std_dev;
-0.5 * z * z - self.std_dev.ln() - 0.5 * (2.0 * std::f64::consts::PI).ln()
}
fn is_discrete(&self) -> bool {
false
}
}
impl std::fmt::Display for Normal {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Normal {{ mean = {}, std_dev = {} }}",
self.mean, self.std_dev
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::rngs::ThreadRng;
use rand::thread_rng;
#[test]
fn normal_sample() {
let mut rng = thread_rng();
let mean = 3.0f64;
let std_dev = 2.0f64;
let dist = Normal::new(mean, std_dev).unwrap();
println!("dist = {}", dist);
let mut total = 0f64;
let trials = 10000;
for _ in 0..trials {
total += dist.sample(&mut rng);
}
let empirical_mean = total / (trials as f64);
let err = 5.0 * std_dev / (trials as f64).sqrt();
println!(
"empirical mean is {} 5sigma error is {}",
empirical_mean, err
);
assert!((empirical_mean - mean).abs() < err);
}
#[test]
fn normal_log_prob() {
let dist = Normal::new(0.0, 1.0).unwrap();
let lp = <Normal as Distribution<ThreadRng>>::log_prob(&dist, &0.0);
assert!((lp - (-0.5 * (2.0 * std::f64::consts::PI).ln())).abs() < 1e-10);
assert!(!<Normal as Distribution<ThreadRng>>::is_discrete(&dist));
}
#[test]
#[should_panic]
fn normal_zero_std() {
let _dist = Normal::new(0.0, 0.0).unwrap();
}
#[test]
#[should_panic]
fn normal_negative_std() {
let _dist = Normal::new(0.0, -1.0).unwrap();
}
}