use crate::error::{StatsError, StatsResult};
use crate::utils::constants::INV_SQRT_2PI;
use num_traits::ToPrimitive;
#[inline]
pub fn probability_density<T>(x: T, avg: f64, stddev: f64) -> StatsResult<f64>
where
T: ToPrimitive,
{
let x_64 = x.to_f64().ok_or_else(|| StatsError::ConversionError {
message: "prob::probability_density: Failed to convert x to f64".to_string(),
})?;
if stddev == 0.0 {
return Err(StatsError::InvalidInput {
message: "prob::probability_density: Standard deviation must be non-zero".to_string(),
});
}
let z = (x_64 - avg) / stddev;
let exponent = -0.5 * z * z;
Ok(exponent.exp() * INV_SQRT_2PI / stddev)
}
#[inline]
pub fn normal_probability_density(z: f64) -> StatsResult<f64> {
let exponent = -0.5 * z * z;
Ok(exponent.exp() * INV_SQRT_2PI)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_probability_density_basic() {
let avg = 0.0;
let stddev = 1.0;
let test_cases = vec![
(0.0, 0.3989422804014327), (1.0, 0.24197072451914337), (-1.0, 0.24197072451914337), (2.0, 0.05399096651318806), (3.0, 0.00443184841193801), ];
for (x, expected) in test_cases {
let actual = probability_density(x, avg, stddev).unwrap();
assert!(
(actual - expected).abs() < 1e-10,
"For x = {}, expected {}, but got {}",
x,
expected,
actual
);
}
}
#[test]
fn test_probability_density_different_mean() {
let avg = 5.0;
let stddev = 2.0;
let x = 7.0; let expected = 0.12098536225957168;
let actual = probability_density(x, avg, stddev).unwrap();
assert!(
(actual - expected).abs() < 1e-10,
"For x = {}, expected {}, but got {}",
x,
expected,
actual
);
}
#[test]
fn test_probability_density_different_stddev() {
let avg = 0.0;
let stddev = 0.5;
let x = 0.0;
let expected = 0.7978845608028654;
let actual = probability_density(x, avg, stddev).unwrap();
assert!(
(actual - expected).abs() < 1e-10,
"For x = {}, expected {}, but got {}",
x,
expected,
actual
);
}
#[test]
fn test_normal_probability_density_basic() {
let test_cases = vec![
(0.0, 0.3989422804014327), (1.0, 0.24197072451914337), (-1.0, 0.24197072451914337), (2.0, 0.05399096651318806), (3.0, 0.00443184841193801), ];
for (z, expected) in test_cases {
let actual = normal_probability_density(z).unwrap();
assert!(
(actual - expected).abs() < 1e-10,
"For z = {}, expected {}, but got {}",
z,
expected,
actual
);
}
}
#[test]
fn test_normal_probability_density_symmetry() {
let z = 0.7;
let actual =
normal_probability_density(z).unwrap() - normal_probability_density(-z).unwrap();
assert!(
actual.abs() < 1e-10,
"normal_probability_density(z) should equal normal_probability_density(-z), but got {}",
actual
);
}
#[test]
fn test_normal_probability_density_limits() {
assert!(normal_probability_density(10.0).unwrap() < 1e-20); assert!(normal_probability_density(-10.0).unwrap() < 1e-20);
}
#[test]
fn test_probability_density_stddev_zero() {
let result = probability_density(0.0, 0.0, 0.0);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
}