use crate::error::{StatsError, StatsResult};
use num_traits::ToPrimitive;
#[inline]
pub fn erf<T>(x: T) -> StatsResult<f64>
where
T: ToPrimitive,
{
let x = x.to_f64().ok_or_else(|| StatsError::ConversionError {
message: "prob::erf: Failed to convert x to f64".to_string(),
})?;
if x == 0.0 {
return Ok(0.0);
}
let sign = if x < 0.0 { -1.0 } else { 1.0 };
let x = x.abs();
let a1 = 0.254829592;
let a2 = -0.284496736;
let a3 = 1.421413741;
let a4 = -1.453152027;
let a5 = 1.061405429;
let p = 0.3275911;
let t = 1.0 / (1.0 + p * x);
let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
Ok(sign * y)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_erf_special_cases() {
assert!((erf(f64::INFINITY).unwrap() - 1.0).abs() < 1e-10);
assert!((erf(f64::NEG_INFINITY).unwrap() + 1.0).abs() < 1e-10);
assert!(erf(f64::NAN).unwrap().is_nan());
}
#[test]
fn test_erf_against_known_values() {
let test_cases = vec![
(-3.0, -0.999977909503),
(-2.0, -0.995322265019),
(-1.0, -0.842700792950),
(0.0, 0.0),
(0.5, 0.520499877813),
(1.0, 0.842700792950),
(2.0, 0.995322265019),
(3.0, 0.999977909503),
];
for (x, expected) in test_cases {
let actual = erf(x).unwrap();
assert!(
(actual - expected).abs() < 1e-6,
"For x = {}, expected {}, but got {}",
x,
expected,
actual
);
}
}
#[test]
fn test_erf_symmetry() {
let x = 0.7;
let actual = erf(x).unwrap() + erf(-x).unwrap();
assert!(
actual.abs() < 1e-10,
"erf(x) + erf(-x) should be 0.0, but got {}",
actual
);
}
#[test]
fn test_erf_limits() {
assert!((erf(10.0).unwrap() - 1.0).abs() < 1e-15); assert!((erf(-10.0).unwrap() + 1.0).abs() < 1e-15); }
#[test]
fn test_erf_large_negative() {
let x = -8.0;
let actual = erf(x).unwrap();
assert!(
(actual + 1.0).abs() < 1e-10,
"For large negative x, erf(x) should be close to -1.0, but got {}",
actual
);
}
}