use crate::error::RusTorchError; use num_traits::Float;
use std::f64::consts::PI;
const ERF_A: [f64; 5] = [
0.254829592,
-0.284496736,
1.421413741,
-1.453152027,
1.061405429,
];
const ERF_P: f64 = 0.3275911;
const ERFINV_A: f64 = 0.147;
pub fn erf_scalar<T: Float>(x: T) -> T {
let x_f64 = match x.to_f64() {
Some(val) => val,
None => return T::zero(),
};
if x_f64 == 0.0 {
return T::zero();
}
let sign = if x_f64 < 0.0 { -1.0 } else { 1.0 };
let x_abs = x_f64.abs();
let result = if x_abs < 0.5 {
sign * erf_series_small(x_abs)
} else if x_abs < 3.5 {
sign * erf_approx(x_abs)
} else {
sign
};
T::from(result).unwrap_or(T::zero())
}
fn erf_approx(x: f64) -> f64 {
let t = 1.0 / (1.0 + ERF_P * x);
let t2 = t * t;
let t3 = t2 * t;
let t4 = t2 * t2;
let t5 = t2 * t3;
let poly = ERF_A[0] * t + ERF_A[1] * t2 + ERF_A[2] * t3 + ERF_A[3] * t4 + ERF_A[4] * t5;
1.0 - poly * (-x * x).exp()
}
fn erf_series_small(x: f64) -> f64 {
let x2 = x * x;
let mut sum = x;
let mut term = x;
for n in 1..100 {
term *= -x2 / n as f64;
let new_term = term / (2 * n + 1) as f64;
sum += new_term;
if new_term.abs() < 1e-16 * sum.abs() {
break;
}
}
2.0 / PI.sqrt() * sum
}
pub fn erf_series<T: Float>(x: T) -> T {
let x_f64 = match x.to_f64() {
Some(val) => val,
None => return T::zero(),
};
if x_f64 == 0.0 {
return T::zero();
}
if x_f64.abs() < 1.5 {
let sign = if x_f64 < 0.0 { -1.0 } else { 1.0 };
T::from(sign * erf_series_small(x_f64.abs())).unwrap_or(T::zero())
} else {
erf_scalar(x)
}
}
pub fn erfc_scalar<T: Float>(x: T) -> T {
let x_f64 = match x.to_f64() {
Some(val) => val,
None => return T::one(),
};
if x_f64 == 0.0 {
return T::one();
}
if x_f64 > 3.5 {
T::from(erfc_asymptotic(x_f64)).unwrap_or(T::zero())
} else {
T::one() - erf_scalar(x)
}
}
fn erfc_asymptotic(x: f64) -> f64 {
if x < 3.5 {
return 1.0 - erf_approx(x);
}
let x2 = x * x;
let exp_neg_x2 = (-x2).exp();
let mut sum = 1.0;
let mut term = 1.0;
for n in 1..20 {
term *= -(2.0 * n as f64 - 1.0) / (2.0 * x2);
sum += term;
if term.abs() < 1e-15 {
break;
}
}
exp_neg_x2 / (x * PI.sqrt()) * sum
}
pub fn erfinv_scalar<T: Float>(x: T) -> Result<T, RusTorchError> {
let x_f64 = x.to_f64().ok_or(RusTorchError::DomainError(
"Cannot convert to f64".to_string(),
))?;
if x_f64.abs() >= 1.0 {
if x_f64.abs() == 1.0 {
return T::from(x_f64 * f64::INFINITY)
.ok_or(RusTorchError::OverflowError("Error function overflow"));
}
return Err(RusTorchError::DomainError(format!(
"erfinv undefined for |x| >= 1, got x = {}",
x_f64
)));
}
if x_f64 == 0.0 {
return Ok(T::zero());
}
let result = if x_f64.abs() < 0.7 {
erfinv_series(x_f64)?
} else {
erfinv_rational(x_f64)?
};
let refined = newton_raphson_erfinv(x_f64, result)?;
T::from(refined).ok_or(RusTorchError::OverflowError("Error function overflow"))
}
fn erfinv_series(x: f64) -> Result<f64, RusTorchError> {
let c0 = 1.0;
let c1 = 1.0 / 12.0;
let c2 = 7.0 / 480.0;
let c3 = 127.0 / 40320.0;
let c4 = 4369.0 / 5806080.0;
let c5 = 34807.0 / 182476800.0;
let x2 = x * x;
let x3 = x2 * x;
let x5 = x3 * x2;
let x7 = x5 * x2;
let x9 = x7 * x2;
let x11 = x9 * x2;
let sum = c0 * x + c1 * x3 + c2 * x5 + c3 * x7 + c4 * x9 + c5 * x11;
Ok((PI / 2.0).sqrt() * sum)
}
fn erfinv_rational(x: f64) -> Result<f64, RusTorchError> {
let ln_1_minus_x2 = (1.0 - x * x).ln();
let a = ERFINV_A;
let term1 = 2.0 / (PI * a) + ln_1_minus_x2 / 2.0;
let term2 = ln_1_minus_x2 / a;
let sqrt_arg = term1 * term1 - term2;
if sqrt_arg < 0.0 {
return Err(RusTorchError::DomainError(
"Negative argument in square root".to_string(),
));
}
Ok(x.signum() * (sqrt_arg.sqrt() - term1).sqrt())
}
fn newton_raphson_erfinv(target: f64, initial_guess: f64) -> Result<f64, RusTorchError> {
let mut x = initial_guess;
for _ in 0..10 {
let erf_x = erf_approx(x.abs()) * x.signum();
let diff = erf_x - target;
if diff.abs() < 1e-15 {
break;
}
let deriv = 2.0 / PI.sqrt() * (-x * x).exp();
if deriv.abs() < 1e-15 {
return Err(RusTorchError::ConvergenceError(
"Newton-Raphson iteration failed to converge".to_string(),
));
}
x -= diff / deriv;
}
Ok(x)
}
pub fn erfcinv_scalar<T: Float>(x: T) -> Result<T, RusTorchError> {
let x_f64 = x.to_f64().ok_or(RusTorchError::DomainError(
"Cannot convert to f64".to_string(),
))?;
if x_f64 <= 0.0 || x_f64 >= 2.0 {
if x_f64 == 0.0 {
return T::from(f64::INFINITY)
.ok_or(RusTorchError::OverflowError("Overflow at erfcinv(0)"));
}
if x_f64 == 2.0 {
return T::from(-f64::INFINITY)
.ok_or(RusTorchError::OverflowError("Overflow at erfcinv(2)"));
}
return Err(RusTorchError::DomainError(format!(
"erfcinv undefined for x <= 0 or x >= 2, got x = {}",
x_f64
)));
}
erfinv_scalar(T::one() - x)
}
pub fn erfcx_scalar<T: Float>(x: T) -> T {
let x_f64 = match x.to_f64() {
Some(val) => val,
None => return T::one(),
};
if x_f64 < -26.0 {
T::from(2.0 * (x_f64 * x_f64).exp()).unwrap_or(T::infinity())
} else if x_f64 < 0.0 {
T::from((x_f64 * x_f64).exp() * (1.0 - erf_approx(-x_f64))).unwrap_or(T::one())
} else if x_f64 < 3.5 {
T::from((x_f64 * x_f64).exp() * erfc_scalar(T::from(x_f64).unwrap()).to_f64().unwrap())
.unwrap_or(T::one())
} else {
T::from(erfcx_asymptotic(x_f64)).unwrap_or(T::zero())
}
}
fn erfcx_asymptotic(x: f64) -> f64 {
let x2 = x * x;
let mut sum = 1.0;
let mut term = 1.0;
for n in 1..20 {
term *= -(2.0 * n as f64 - 1.0) / (2.0 * x2);
sum += term;
if term.abs() < 1e-15 {
break;
}
}
sum / (x * PI.sqrt())
}
pub fn erf<T: Float + 'static>(
x: &crate::tensor::Tensor<T>,
) -> Result<crate::tensor::Tensor<T>, RusTorchError> {
let mut result = vec![T::zero(); x.data.len()];
for (i, &val) in x.data.iter().enumerate() {
result[i] = erf_scalar(val);
}
Ok(crate::tensor::Tensor::from_vec(result, x.shape().to_vec()))
}
pub fn erfc<T: Float + 'static>(
x: &crate::tensor::Tensor<T>,
) -> Result<crate::tensor::Tensor<T>, RusTorchError> {
let mut result = vec![T::zero(); x.data.len()];
for (i, &val) in x.data.iter().enumerate() {
result[i] = erfc_scalar(val);
}
Ok(crate::tensor::Tensor::from_vec(result, x.shape().to_vec()))
}
pub fn erfinv<T: Float + 'static>(
x: &crate::tensor::Tensor<T>,
) -> Result<crate::tensor::Tensor<T>, RusTorchError> {
let mut result = vec![T::zero(); x.data.len()];
for (i, &val) in x.data.iter().enumerate() {
result[i] = erfinv_scalar(val)?;
}
Ok(crate::tensor::Tensor::from_vec(result, x.shape().to_vec()))
}
pub fn erfcinv<T: Float + 'static>(
x: &crate::tensor::Tensor<T>,
) -> Result<crate::tensor::Tensor<T>, RusTorchError> {
let mut result = vec![T::zero(); x.data.len()];
for (i, &val) in x.data.iter().enumerate() {
result[i] = erfcinv_scalar(val)?;
}
Ok(crate::tensor::Tensor::from_vec(result, x.shape().to_vec()))
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_erf_basic() {
assert_relative_eq!(erf_scalar(0.0_f64), 0.0, epsilon = 1e-10);
assert_relative_eq!(erf_scalar(-1.0_f64), -erf_scalar(1.0_f64), epsilon = 1e-10);
assert_relative_eq!(erf_scalar(1.0_f64), 0.8427006897475899, epsilon = 1e-8);
assert_relative_eq!(erf_scalar(2.0_f64), 0.9953222650189527, epsilon = 2e-6);
assert_relative_eq!(erf_scalar(10.0_f64), 1.0, epsilon = 1e-10);
}
#[test]
fn test_erfc_basic() {
assert_relative_eq!(erfc_scalar(0.0_f64), 1.0, epsilon = 1e-10);
let x = 1.5;
assert_relative_eq!(erfc_scalar(x), 1.0 - erf_scalar(x), epsilon = 1e-10);
assert_relative_eq!(erfc_scalar(1.0_f64), 0.15729920705028513, epsilon = 2e-5);
assert!(erfc_scalar(5.0_f64) < 1e-10);
}
#[test]
fn test_erfinv_basic() {
assert_relative_eq!(erfinv_scalar(0.0_f64).unwrap(), 0.0, epsilon = 1e-10);
let x = 0.5;
let erf_x = erf_scalar(x);
assert_relative_eq!(erfinv_scalar(erf_x).unwrap(), x, epsilon = 1e-10);
let y = 0.7;
assert_relative_eq!(
erfinv_scalar(-y).unwrap(),
-erfinv_scalar(y).unwrap(),
epsilon = 1e-10
);
assert_relative_eq!(
erfinv_scalar(0.5_f64).unwrap(),
0.4769362762044698,
epsilon = 2e-5
);
}
#[test]
fn test_erfcinv_basic() {
assert_relative_eq!(erfcinv_scalar(1.0_f64).unwrap(), 0.0, epsilon = 1e-10);
let x = 1.5;
let erfc_x = erfc_scalar(x);
assert_relative_eq!(erfcinv_scalar(erfc_x).unwrap(), x, epsilon = 1e-10);
assert_relative_eq!(
erfcinv_scalar(0.5_f64).unwrap(),
0.4769362762044698,
epsilon = 2e-5
);
}
#[test]
fn test_erfcx_basic() {
assert_relative_eq!(erfcx_scalar(0.0_f64), 1.0, epsilon = 1e-10);
let x = 10.0;
let expected = 1.0 / (x * PI.sqrt());
assert_relative_eq!(erfcx_scalar(x), expected, epsilon = 1e-3);
let large_x = 30.0;
let erfcx_val = erfcx_scalar(large_x);
assert!(erfcx_val > 0.0 && erfcx_val.is_finite());
}
#[test]
fn test_erf_series_vs_approx() {
let test_values = [0.1, 0.5, 1.0, 1.5];
for &x in &test_values {
let series = erf_series(x);
let approx = erf_scalar(x);
assert_relative_eq!(series, approx, epsilon = 3e-6);
}
}
#[test]
fn test_domain_errors() {
assert!(erfinv_scalar(1.1_f64).is_err());
assert!(erfinv_scalar(-1.1_f64).is_err());
assert!(erfcinv_scalar(-0.1_f64).is_err());
assert!(erfcinv_scalar(2.1_f64).is_err());
}
}