use crate::error::{StatsError, StatsResult};
use crate::error_messages::{helpers, validation};
use crate::sampling::SampleableDistribution;
use crate::traits::{ContinuousDistribution, Distribution};
use scirs2_core::ndarray::Array1;
use scirs2_core::numeric::{Float, NumCast};
use scirs2_core::random::{Distribution as RandDistribution, Normal as RandNormal};
pub struct Normal<F: Float> {
pub loc: F,
pub scale: F,
rand_distr: RandNormal<f64>,
}
impl<F: Float + NumCast + std::fmt::Display> Normal<F> {
pub fn new(loc: F, scale: F) -> StatsResult<Self> {
validation::ensure_positive(scale, "scale")?;
let loc_f64 = <f64 as NumCast>::from(loc)
.ok_or_else(|| helpers::numerical_error("failed to convert loc to f64"))?;
let scale_f64 = <f64 as NumCast>::from(scale)
.ok_or_else(|| helpers::numerical_error("failed to convert scale to f64"))?;
match RandNormal::new(loc_f64, scale_f64) {
Ok(rand_distr) => Ok(Normal {
loc,
scale,
rand_distr,
}),
Err(_) => Err(helpers::numerical_error("normal distribution creation")),
}
}
pub fn pdf(&self, x: F) -> F {
let pi = F::from(std::f64::consts::PI).unwrap_or_else(|| F::zero());
let two = F::from(2.0).unwrap_or_else(|| F::zero());
let z = (x - self.loc) / self.scale;
let exponent = -z * z / two;
F::from(1.0).unwrap_or_else(|| F::zero()) / (self.scale * (two * pi).sqrt())
* exponent.exp()
}
pub fn cdf(&self, x: F) -> F {
let z = (x - self.loc) / self.scale;
if z == F::zero() {
return F::from(0.5).unwrap_or_else(|| F::zero());
}
let two = F::from(2.0).unwrap_or_else(|| F::zero());
let one = F::one();
let half = F::from(0.5).unwrap_or_else(|| F::zero());
half * (one + erf(z / two.sqrt()))
}
pub fn ppf(&self, p: F) -> StatsResult<F> {
if p < F::zero() || p > F::one() {
return Err(StatsError::DomainError(
"Probability must be between 0 and 1".to_string(),
));
}
if p == F::zero() {
return Ok(F::neg_infinity());
}
if p == F::one() {
return Ok(F::infinity());
}
let half = F::from(0.5).unwrap_or_else(|| F::zero());
let a1 = F::from(-3.969683028665376e+01).unwrap_or_else(|| F::zero());
let a2 = F::from(2.209460984245205e+02).unwrap_or_else(|| F::zero());
let a3 = F::from(-2.759285104469687e+02).unwrap_or_else(|| F::zero());
let a4 = F::from(1.383577518672690e+02).unwrap_or_else(|| F::zero());
let a5 = F::from(-3.066479806614716e+01).unwrap_or_else(|| F::zero());
let a6 = F::from(2.506628277459239e+00).unwrap_or_else(|| F::zero());
let b1 = F::from(-5.447609879822406e+01).unwrap_or_else(|| F::zero());
let b2 = F::from(1.615858368580409e+02).unwrap_or_else(|| F::zero());
let b3 = F::from(-1.556989798598866e+02).unwrap_or_else(|| F::zero());
let b4 = F::from(6.680131188771972e+01).unwrap_or_else(|| F::zero());
let b5 = F::from(-1.328068155288572e+01).unwrap_or_else(|| F::zero());
let c1 = F::from(-7.784894002430293e-03).unwrap_or_else(|| F::zero());
let c2 = F::from(-3.223964580411365e-01).unwrap_or_else(|| F::zero());
let c3 = F::from(-2.400758277161838e+00).unwrap_or_else(|| F::zero());
let c4 = F::from(-2.549732539343734e+00).unwrap_or_else(|| F::zero());
let c5 = F::from(4.374664141464968e+00).unwrap_or_else(|| F::zero());
let c6 = F::from(2.938163982698783e+00).unwrap_or_else(|| F::zero());
let d1c = F::from(7.784695709041462e-03).unwrap_or_else(|| F::zero());
let d2c = F::from(3.224671290700398e-01).unwrap_or_else(|| F::zero());
let d3c = F::from(2.445134137142996e+00).unwrap_or_else(|| F::zero());
let d4c = F::from(3.754408661907416e+00).unwrap_or_else(|| F::zero());
let p_low = F::from(0.02425).unwrap_or_else(|| F::zero());
let p_high = F::one() - p_low;
let z = if p < p_low {
let q = (-F::from(2.0).unwrap_or_else(|| F::zero()) * p.ln()).sqrt();
(((((c1 * q + c2) * q + c3) * q + c4) * q + c5) * q + c6)
/ ((((d1c * q + d2c) * q + d3c) * q + d4c) * q + F::one())
} else if p <= p_high {
let q = p - half;
let r = q * q;
(((((a1 * r + a2) * r + a3) * r + a4) * r + a5) * r + a6) * q
/ (((((b1 * r + b2) * r + b3) * r + b4) * r + b5) * r + F::one())
} else {
let q = (-F::from(2.0).unwrap_or_else(|| F::zero()) * (F::one() - p).ln()).sqrt();
-(((((c1 * q + c2) * q + c3) * q + c4) * q + c5) * q + c6)
/ ((((d1c * q + d2c) * q + d3c) * q + d4c) * q + F::one())
};
Ok(z * self.scale + self.loc)
}
pub fn rvs(&self, size: usize) -> StatsResult<Array1<F>> {
let mut rng = scirs2_core::random::thread_rng();
let mut samples = Vec::with_capacity(size);
for _ in 0..size {
let sample = self.rand_distr.sample(&mut rng);
samples.push(F::from(sample).expect("Failed to convert to float"));
}
Ok(Array1::from(samples))
}
}
#[allow(dead_code)]
fn erf<F: Float>(x: F) -> F {
let zero = F::zero();
let one = F::one();
if x < zero {
return -erf(-x);
}
let a1 = F::from(0.254829592).expect("Failed to convert constant to float");
let a2 = F::from(-0.284496736).expect("Failed to convert constant to float");
let a3 = F::from(1.421413741).expect("Failed to convert constant to float");
let a4 = F::from(-1.453152027).expect("Failed to convert constant to float");
let a5 = F::from(1.061405429).expect("Failed to convert constant to float");
let p = F::from(0.3275911).expect("Failed to convert constant to float");
let t = one / (one + p * x);
one - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp()
}
impl<F: Float + NumCast + std::fmt::Display> Distribution<F> for Normal<F> {
fn mean(&self) -> F {
self.loc
}
fn var(&self) -> F {
self.scale * self.scale
}
fn std(&self) -> F {
self.scale
}
fn rvs(&self, size: usize) -> StatsResult<Array1<F>> {
Normal::rvs(self, size)
}
fn entropy(&self) -> F {
let half = F::from(0.5).expect("Failed to convert constant to float");
let two = F::from(2.0).expect("Failed to convert constant to float");
let pi = F::from(std::f64::consts::PI).expect("Failed to convert to float");
let e = F::from(std::f64::consts::E).expect("Failed to convert to float");
half + half * (two * pi * e * self.scale * self.scale).ln()
}
}
impl<F: Float + NumCast + std::fmt::Display> ContinuousDistribution<F> for Normal<F> {
fn pdf(&self, x: F) -> F {
Normal::pdf(self, x)
}
fn cdf(&self, x: F) -> F {
Normal::cdf(self, x)
}
fn ppf(&self, p: F) -> StatsResult<F> {
Normal::ppf(self, p)
}
}
impl<F: Float + NumCast + std::fmt::Display> SampleableDistribution<F> for Normal<F> {
fn rvs(&self, size: usize) -> StatsResult<Vec<F>> {
let array = Normal::rvs(self, size)?;
Ok(array.to_vec())
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_normal_creation() {
let norm = Normal::new(0.0, 1.0).expect("Operation failed");
assert_eq!(norm.loc, 0.0);
assert_eq!(norm.scale, 1.0);
let custom = Normal::new(5.0, 2.0).expect("Operation failed");
assert_eq!(custom.loc, 5.0);
assert_eq!(custom.scale, 2.0);
assert!(Normal::<f64>::new(0.0, 0.0).is_err());
assert!(Normal::<f64>::new(0.0, -1.0).is_err());
}
#[test]
fn test_normal_pdf() {
let norm = Normal::new(0.0, 1.0).expect("Operation failed");
let pdf_at_zero = norm.pdf(0.0);
assert_relative_eq!(pdf_at_zero, 0.3989423, epsilon = 1e-7);
let pdf_at_one = norm.pdf(1.0);
assert_relative_eq!(pdf_at_one, 0.2419707, epsilon = 1e-7);
let pdf_at_neg_one = norm.pdf(-1.0);
assert_relative_eq!(pdf_at_neg_one, 0.2419707, epsilon = 1e-7);
let custom = Normal::new(5.0, 2.0).expect("Operation failed");
assert_relative_eq!(custom.pdf(5.0), 0.19947114, epsilon = 1e-7);
}
#[test]
fn test_normal_cdf() {
let norm = Normal::new(0.0, 1.0).expect("Operation failed");
let cdf_at_zero = norm.cdf(0.0);
assert_relative_eq!(cdf_at_zero, 0.5, epsilon = 1e-7);
let cdf_at_one = norm.cdf(1.0);
assert_relative_eq!(cdf_at_one, 0.8413447, epsilon = 1e-5);
let cdf_at_neg_one = norm.cdf(-1.0);
assert_relative_eq!(cdf_at_neg_one, 0.1586553, epsilon = 1e-5);
}
#[test]
fn test_normal_ppf() {
let norm = Normal::new(0.0, 1.0).expect("Operation failed");
let median = norm.ppf(0.5).expect("Operation failed");
assert_relative_eq!(median, 0.0, epsilon = 1e-5);
let p975 = norm.ppf(0.975).expect("Operation failed");
assert_relative_eq!(p975, 1.96, epsilon = 1e-2);
let p025 = norm.ppf(0.025).expect("Operation failed");
assert_relative_eq!(p025, -1.96, epsilon = 1e-2);
assert!(norm.ppf(-0.1).is_err());
assert!(norm.ppf(1.1).is_err());
}
#[test]
fn test_normal_rvs() {
let norm = Normal::new(0.0, 1.0).expect("Operation failed");
let samples = norm.rvs(1000).expect("Operation failed");
assert_eq!(samples.len(), 1000);
let sum: f64 = samples.iter().sum();
let mean = sum / 1000.0;
assert!(
mean.abs() < 0.15,
"Sample mean {} is outside expected range",
mean
);
let variance: f64 = samples
.iter()
.map(|&x| (x - mean) * (x - mean))
.sum::<f64>()
/ 1000.0;
let std_dev = variance.sqrt();
assert!(
(std_dev - 1.0).abs() < 0.15,
"Sample std dev {} is outside expected range",
std_dev
);
}
}