use crate::prelude::*;
use statrs::distribution::Normal as StatrsNormal;
#[derive(Debug, Clone)]
pub struct Normal {
inner: StatrsNormal,
min: f64,
likely: f64,
max: f64,
}
impl Normal {
pub fn new(min: f64, max: f64) -> Result<Self> {
if min >= max {
return Err(Error::InvalidRange { min, max });
}
let likely = (min + max) / 2.0;
let mean = likely;
let std_dev = (max - min) / 6.0;
Ok(Normal {
inner: StatrsNormal::new(mean, std_dev)?,
min,
likely,
max,
})
}
fn clamp(&self, x: f64) -> f64 {
x.max(self.min).min(self.max)
}
}
impl EstimationDistribution for Normal {}
impl Distribution<f64> for Normal {
fn mean(&self) -> Option<f64> {
self.inner.mean()
}
fn variance(&self) -> Option<f64> {
self.inner.variance()
}
fn skewness(&self) -> Option<f64> {
Some(0.0) }
fn entropy(&self) -> Option<f64> {
self.inner.entropy()
}
}
impl Median<f64> for Normal {
fn median(&self) -> f64 {
self.likely
}
}
impl Mode<f64> for Normal {
fn mode(&self) -> f64 {
self.likely
}
}
impl Continuous<f64, f64> for Normal {
fn pdf(&self, x: f64) -> f64 {
if x < self.min || x > self.max {
0.0
} else {
self.inner.pdf(x)
}
}
fn ln_pdf(&self, x: f64) -> f64 {
if x < self.min || x > self.max {
f64::NEG_INFINITY
} else {
self.inner.ln_pdf(x)
}
}
}
impl ContinuousCDF<f64, f64> for Normal {
fn cdf(&self, x: f64) -> f64 {
if x <= self.min {
0.0
} else if x >= self.max {
1.0
} else {
self.inner.cdf(x)
}
}
fn inverse_cdf(&self, p: f64) -> f64 {
self.clamp(self.inner.inverse_cdf(p))
}
}
impl Min<f64> for Normal {
fn min(&self) -> f64 {
self.min
}
}
impl Max<f64> for Normal {
fn max(&self) -> f64 {
self.max
}
}
impl RandDistribution<f64> for Normal {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
self.clamp(self.inner.sample(rng))
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use rand::distributions::Distribution as RandDistribution;
use rand::rngs::StdRng;
use rand::SeedableRng;
use statrs::statistics::{Data, Distribution, OrderStatistics};
#[test]
fn test_new() {
let normal = Normal::new(1.0, 3.0).unwrap();
assert_eq!(normal.min(), 1.0);
assert_eq!(normal.mode(), 2.0);
assert_eq!(normal.max(), 3.0);
assert_eq!(normal.mean().unwrap(), 2.0);
assert_relative_eq!(normal.std_dev().unwrap(), 1.0 / 3.0, epsilon = 1e-6);
}
#[test]
fn test_invalid_parameters() {
assert!(Normal::new(3.0, 1.0).is_err()); }
#[test]
fn test_pdf() {
let normal = Normal::new(1.0, 3.0).unwrap();
assert_relative_eq!(normal.pdf(2.0), normal.inner.pdf(2.0), epsilon = 1e-6);
assert_eq!(normal.pdf(0.5), 0.0); assert_eq!(normal.pdf(3.5), 0.0); }
#[test]
fn test_cdf() {
let normal = Normal::new(1.0, 3.0).unwrap();
assert_relative_eq!(normal.cdf(2.0), 0.5, epsilon = 1e-6);
assert_eq!(normal.cdf(0.5), 0.0); assert_eq!(normal.cdf(3.5), 1.0); }
#[test]
fn test_inverse_cdf() {
let normal = Normal::new(1.0, 3.0).unwrap();
assert_relative_eq!(normal.inverse_cdf(0.5), 2.0, epsilon = 1e-6);
assert_eq!(normal.inverse_cdf(0.0), 1.0); assert_eq!(normal.inverse_cdf(1.0), 3.0); }
#[test]
fn test_sampling() {
let normal = Normal::new(1.0, 3.0).unwrap();
let mut rng = StdRng::seed_from_u64(42);
let samples: Vec<f64> = (0..1000).map(|_| normal.sample(&mut rng)).collect();
assert!(samples.iter().all(|&x| (1.0..=3.0).contains(&x)));
let mean = samples.iter().sum::<f64>() / samples.len() as f64;
assert_relative_eq!(mean, 2.0, epsilon = 0.1);
}
#[test]
fn test_distribution_statistics() {
let test_cases = vec![
(1.0, 2.0, 3.0), (0.0, 5.0, 10.0), (-5.0, 0.0, 5.0), (100.0, 200.0, 300.0), ];
for (min, likely, max) in test_cases {
let normal = Normal::new(min, max).unwrap();
let n = 100_000; let mut rng = StdRng::seed_from_u64(42);
let samples: Vec<f64> = (0..n).map(|_| normal.sample(&mut rng)).collect();
let mut data = Data::new(samples.clone());
let sample_mean = data.mean().unwrap();
let theoretical_mean = normal.mean().unwrap();
assert_relative_eq!(
sample_mean,
theoretical_mean,
epsilon = 0.01,
max_relative = 0.01
);
let sample_variance = data.variance().unwrap();
let theoretical_variance = normal.variance().unwrap();
assert_relative_eq!(
sample_variance,
theoretical_variance,
epsilon = 0.05,
max_relative = 0.05
);
let sample_std_dev = data.std_dev().unwrap();
let theoretical_std_dev = normal.std_dev().unwrap();
assert_relative_eq!(
sample_std_dev,
theoretical_std_dev,
epsilon = 0.05,
max_relative = 0.05
);
let sample_skewness = 0.0; let theoretical_skewness = normal.skewness().unwrap();
assert_relative_eq!(
sample_skewness,
theoretical_skewness,
epsilon = 0.1,
max_relative = 0.1
);
let sample_median = data.median();
let theoretical_median = normal.median();
assert_relative_eq!(
sample_median,
theoretical_median,
epsilon = 0.01,
max_relative = 0.01
);
assert!(normal.mode() >= min && normal.mode() <= max);
let sample_min = data.min();
let sample_max = data.max();
assert!(sample_min >= normal.min());
assert!(sample_max <= normal.max());
let percentiles = [0.1, 0.25, 0.5, 0.75, 0.9];
for &p in percentiles.iter() {
let sample_percentile = data.percentile((p * 100.0) as usize);
let theoretical_percentile = normal.inverse_cdf(p);
assert_relative_eq!(
sample_percentile,
theoretical_percentile,
epsilon = 0.05,
max_relative = 0.05
);
let sample_cdf =
samples.iter().filter(|&x| x <= &sample_percentile).count() as f64 / n as f64;
let theoretical_cdf = normal.cdf(sample_percentile);
assert_relative_eq!(
sample_cdf,
theoretical_cdf,
epsilon = 0.01,
max_relative = 0.01
);
}
let pdf_points = [min, (min + likely) / 2.0, likely, (likely + max) / 2.0, max];
for &x in pdf_points.iter() {
let sample_pdf = samples.iter().filter(|&s| (s - x).abs() < 0.1).count() as f64
/ (n as f64 * 0.2);
let theoretical_pdf = normal.pdf(x);
assert_relative_eq!(
sample_pdf,
theoretical_pdf,
epsilon = 0.1,
max_relative = 0.1
);
}
println!("Test passed for Normal({}, {}, {})", min, likely, max);
}
}
}