use num_traits::ToPrimitive;
use crate::distributions::traits::Distribution;
use crate::error::{StatsError, StatsResult};
use crate::prob::erf;
use crate::utils::constants::{INV_SQRT_2PI, SQRT_2};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct NormalConfig<T>
where
T: ToPrimitive,
{
pub mean: T,
pub std_dev: T,
}
impl<T> NormalConfig<T>
where
T: ToPrimitive,
{
pub fn new(mean: T, std_dev: T) -> StatsResult<Self> {
let std_dev_64 = std_dev
.to_f64()
.ok_or_else(|| StatsError::ConversionError {
message: "NormalConfig::new: Failed to convert std_dev to f64".to_string(),
})?;
let mean_64 = mean.to_f64().ok_or_else(|| StatsError::ConversionError {
message: "NormalConfig::new: Failed to convert mean to f64".to_string(),
})?;
if std_dev_64 > 0.0 && !mean_64.is_nan() && !std_dev_64.is_nan() {
Ok(Self { mean, std_dev })
} else {
Err(StatsError::InvalidInput {
message: "NormalConfig::new: std_dev must be positive".to_string(),
})
}
}
}
#[inline]
pub fn normal_pdf<T>(x: T, mean: f64, std_dev: f64) -> StatsResult<f64>
where
T: ToPrimitive,
{
if std_dev <= 0.0 {
return Err(StatsError::InvalidInput {
message: "normal_pdf: Standard deviation must be positive".to_string(),
});
}
let x_64 = x.to_f64().ok_or_else(|| StatsError::ConversionError {
message: "normal_pdf: Failed to convert x to f64".to_string(),
})?;
let z = (x_64 - mean) / std_dev;
let exponent = -0.5 * z * z;
Ok(exponent.exp() * INV_SQRT_2PI / std_dev)
}
#[inline]
pub fn normal_cdf<T>(x: T, mean: f64, std_dev: f64) -> StatsResult<f64>
where
T: ToPrimitive,
{
if std_dev <= 0.0 {
return Err(StatsError::InvalidInput {
message: "normal_cdf: Standard deviation must be positive".to_string(),
});
}
let x_64 = x.to_f64().ok_or_else(|| StatsError::ConversionError {
message: "normal_cdf: Failed to convert x to f64".to_string(),
})?;
if x_64 == mean {
return Ok(0.5);
}
let z = (x_64 - mean) / (std_dev * SQRT_2);
Ok(0.5 * (1.0 + erf(z)?))
}
#[inline]
pub fn normal_inverse_cdf<T>(p: T, mean: f64, std_dev: f64) -> StatsResult<f64>
where
T: ToPrimitive,
{
let p_64 = p.to_f64().ok_or_else(|| StatsError::ConversionError {
message: "normal_inverse_cdf: Failed to convert p to f64".to_string(),
})?;
if !(0.0..=1.0).contains(&p_64) {
return Err(StatsError::InvalidInput {
message: "normal_inverse_cdf: Probability must be between 0 and 1".to_string(),
});
}
if p_64 == 0.0 {
return Ok(f64::NEG_INFINITY);
}
if p_64 == 1.0 {
return Ok(f64::INFINITY);
}
let q = if p_64 <= 0.5 { p_64 } else { 1.0 - p_64 };
if q <= 0.0 {
return if p_64 <= 0.5 {
Ok(f64::NEG_INFINITY)
} else {
Ok(f64::INFINITY)
};
}
let a = [
-3.969_683_028_665_376e1,
2.209_460_984_245_205e2,
-2.759_285_104_469_687e2,
1.383_577_518_672_69e2,
-3.066_479_806_614_716e1,
2.506_628_277_459_239,
];
let b = [
-5.447_609_879_822_406e1,
1.615_858_368_580_409e2,
-1.556_989_798_598_866e2,
6.680_131_188_771_972e1,
-1.328_068_155_288_572e1,
1.0,
];
let r = q - 0.5;
let z = if q > 0.02425 && q < 0.97575 {
let r2 = r * r;
let num = ((((a[0] * r2 + a[1]) * r2 + a[2]) * r2 + a[3]) * r2 + a[4]) * r2 + a[5];
let den = ((((b[0] * r2 + b[1]) * r2 + b[2]) * r2 + b[3]) * r2 + b[4]) * r2 + b[5];
r * num / den
} else {
let s = if r < 0.0 { q } else { 1.0 - q };
let t = (-2.0 * s.ln()).sqrt();
let c = [
-7.784_894_002_430_293e-3,
-3.223_964_580_411_365e-1,
-2.400_758_277_161_838,
-2.549_732_539_343_734,
4.374_664_141_464_968,
2.938_163_982_698_783,
];
let d = [
7.784_695_709_041_462e-3,
3.224_671_290_700_398e-1,
2.445_134_137_142_996,
3.754_408_661_907_416,
1.0,
];
let num = ((((c[0] * t + c[1]) * t + c[2]) * t + c[3]) * t + c[4]) * t + c[5];
let den = (((d[0] * t + d[1]) * t + d[2]) * t + d[3]) * t + d[4];
if r < 0.0 {
-t - num / den
} else {
t - num / den
}
};
let final_z = if p_64 > 0.5 { -z } else { z };
let result = mean + std_dev * final_z;
Ok(result)
}
#[derive(Debug, Clone, Copy)]
pub struct Normal {
pub mean: f64,
pub std_dev: f64,
}
impl Normal {
pub fn new(mean: f64, std_dev: f64) -> StatsResult<Self> {
if std_dev <= 0.0 || std_dev.is_nan() || mean.is_nan() {
return Err(StatsError::InvalidInput {
message: "Normal::new: std_dev must be positive and parameters must be finite"
.to_string(),
});
}
Ok(Self { mean, std_dev })
}
pub fn fit(data: &[f64]) -> StatsResult<Self> {
if data.is_empty() {
return Err(StatsError::InvalidInput {
message: "Normal::fit: data must not be empty".to_string(),
});
}
let n = data.len() as f64;
let mean = data.iter().sum::<f64>() / n;
let variance = data.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n;
Self::new(mean, variance.sqrt())
}
}
impl Distribution for Normal {
fn name(&self) -> &str {
"Normal"
}
fn num_params(&self) -> usize {
2
}
fn pdf(&self, x: f64) -> StatsResult<f64> {
normal_pdf(x, self.mean, self.std_dev)
}
fn logpdf(&self, x: f64) -> StatsResult<f64> {
let z = (x - self.mean) / self.std_dev;
Ok(-0.5 * z * z - self.std_dev.ln() - 0.5 * (2.0 * std::f64::consts::PI).ln())
}
fn cdf(&self, x: f64) -> StatsResult<f64> {
normal_cdf(x, self.mean, self.std_dev)
}
fn inverse_cdf(&self, p: f64) -> StatsResult<f64> {
normal_inverse_cdf(p, self.mean, self.std_dev)
}
fn mean(&self) -> f64 {
self.mean
}
fn variance(&self) -> f64 {
self.std_dev * self.std_dev
}
}
#[cfg(test)]
mod tests {
use super::*;
const EPSILON: f64 = 1e-7;
#[test]
fn test_normal_pdf_standard() {
let mean = 0.0;
let sigma = 1.0;
let result = normal_pdf(mean, mean, sigma).unwrap();
assert!((result - 0.3989422804014327).abs() < 1e-10);
let result = normal_pdf(mean + sigma, mean, sigma).unwrap();
assert!((result - 0.24197072451914337).abs() < 1e-10);
}
#[test]
fn test_normal_pdf_non_standard() {
let mean = 5.0;
let sigma = 2.0;
let result = normal_pdf(mean, mean, sigma).unwrap();
assert!((result - 0.19947114020071635).abs() < 1e-10);
let result = normal_pdf(mean + sigma, mean, sigma).unwrap();
assert!((result - 0.12098536225957168).abs() < 1e-10);
}
#[test]
fn test_normal_pdf_symmetry() {
let mean = 0.0;
let sigma = 1.0;
let x = 1.5;
let pdf_plus = normal_pdf(mean + x, mean, sigma).unwrap();
let pdf_minus = normal_pdf(mean - x, mean, sigma).unwrap();
assert!((pdf_plus - pdf_minus).abs() < 1e-10);
}
#[test]
fn test_normal_cdf_standard() {
let mean = 0.0;
let sigma = 1.0;
let result = normal_cdf(mean, mean, sigma).unwrap();
assert!((result - 0.5).abs() < 1e-10);
let result = normal_cdf(mean + sigma, mean, sigma).unwrap();
assert!((result - 0.8413447460685429).abs() < EPSILON);
let result = normal_cdf(mean - sigma, mean, sigma).unwrap();
assert!((result - 0.15865525393145707).abs() < EPSILON);
}
#[test]
fn test_normal_cdf_non_standard() {
let mean = 100.0;
let sigma = 15.0;
let result = normal_cdf(mean, mean, sigma).unwrap();
assert!((result - 0.5).abs() < 1e-10);
let result = normal_cdf(mean + sigma, mean, sigma).unwrap();
assert!((result - 0.8413447460685429).abs() < EPSILON);
}
#[test]
fn test_normal_inverse_cdf() {
let mean = 0.0;
let sigma = 1.0;
let result = normal_inverse_cdf(0.5, mean, sigma).unwrap();
assert!((result - mean).abs() < EPSILON);
let result = normal_inverse_cdf(0.8413447460685429, mean, sigma).unwrap();
assert!((result - sigma).abs() < EPSILON);
let result = normal_inverse_cdf(0.15865525393145707, mean, sigma).unwrap();
assert!((result - (-sigma)).abs() < EPSILON);
}
#[test]
fn test_normal_inverse_cdf_non_standard() {
let mean = 50.0;
let sigma = 5.0;
let result = normal_inverse_cdf(0.5, mean, sigma).unwrap();
assert!((result - mean).abs() < EPSILON);
let result = normal_inverse_cdf(0.8413447460685429, mean, sigma).unwrap();
assert!((result - (mean + sigma)).abs() < EPSILON);
}
#[test]
fn test_normal_pdf_standard_normal() {
let pdf = (normal_pdf(0.0, 0.0, 1.0).unwrap() * 1e7).round() / 1e7;
assert!((pdf - 0.3989423).abs() < EPSILON);
let pdf_plus1 = normal_pdf(1.0, 0.0, 1.0).unwrap();
let pdf_minus1 = normal_pdf(-1.0, 0.0, 1.0).unwrap();
assert!((pdf_plus1 - pdf_minus1).abs() < EPSILON);
assert!((normal_pdf(1.0, 0.0, 1.0).unwrap() - 0.2419707).abs() < EPSILON);
assert!((normal_pdf(2.0, 0.0, 1.0).unwrap() - 0.0539909).abs() < EPSILON);
}
#[test]
fn test_normal_pdf_invalid_sigma() {
let result = normal_pdf(0.0, 0.0, -1.0);
assert!(
result.is_err(),
"Should return error for negative standard deviation"
);
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_normal_cdf_standard_normal() {
let cdf = (normal_cdf(0.0, 0.0, 1.0).unwrap() * 1e1).round() / 1e1;
assert!((cdf - 0.5).abs() < EPSILON);
let cdf = (normal_cdf(1.0, 0.0, 1.0).unwrap() * 1e7).round() / 1e7;
assert!((cdf - 0.8413447).abs() < EPSILON);
let cdf = (normal_cdf(-1.0, 0.0, 1.0).unwrap() * 1e7).round() / 1e7;
assert!((cdf - 0.1586553).abs() < EPSILON);
let cdf = (normal_cdf(2.0, 0.0, 1.0).unwrap() * 1e7).round() / 1e7;
assert!((cdf - 0.9772499).abs() < EPSILON);
}
#[test]
fn test_normal_cdf_invalid_sigma() {
let result = normal_cdf(0.0, 0.0, -1.0);
assert!(
result.is_err(),
"Should return error for negative standard deviation"
);
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_normal_inverse_cdf_standard_normal() {
let x = (normal_inverse_cdf(0.5, 0.0, 1.0).unwrap() * 1e7).round() / 1e7;
assert!(x.abs() < EPSILON);
assert!((normal_inverse_cdf(0.8413447, 0.0, 1.0).unwrap() - 1.0).abs() < 0.01);
assert!((normal_inverse_cdf(0.1586553, 0.0, 1.0).unwrap() + 1.0).abs() < 0.01);
}
#[test]
fn test_normal_config_new_nan_mean() {
let result = NormalConfig::new(f64::NAN, 1.0);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_normal_config_new_nan_std_dev() {
let result = NormalConfig::new(0.0, f64::NAN);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_normal_config_new_std_dev_zero() {
let result = NormalConfig::new(0.0, 0.0);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_normal_config_new_std_dev_negative() {
let result = NormalConfig::new(0.0, -1.0);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_normal_inverse_cdf_p_negative() {
let result = normal_inverse_cdf(-0.1, 0.0, 1.0);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_normal_inverse_cdf_p_greater_than_one() {
let result = normal_inverse_cdf(1.5, 0.0, 1.0);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_normal_inverse_cdf_p_zero() {
let result = normal_inverse_cdf(0.0, 0.0, 1.0).unwrap();
assert_eq!(result, f64::NEG_INFINITY);
}
#[test]
fn test_normal_inverse_cdf_p_one() {
let result = normal_inverse_cdf(1.0, 0.0, 1.0).unwrap();
assert_eq!(result, f64::INFINITY);
}
#[test]
fn test_normal_pdf_std_dev_zero() {
let result = normal_pdf(0.0, 0.0, 0.0);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_normal_cdf_std_dev_zero() {
let result = normal_cdf(0.0, 0.0, 0.0);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_normal_inverse_cdf_std_dev_zero() {
let result = normal_inverse_cdf(0.5, 5.0, 0.0).unwrap();
assert_eq!(result, 5.0);
}
#[test]
fn test_normal_inverse_cdf_std_dev_negative() {
let result = normal_inverse_cdf(0.5, 0.0, -1.0).unwrap();
assert_eq!(result, 0.0);
}
#[test]
fn test_normal_config_new_valid() {
let config = NormalConfig::new(0.0, 1.0);
assert!(config.is_ok());
let config = config.unwrap();
assert_eq!(config.mean, 0.0);
}
}