use crate::float_trait::Float;
use conv::*;
impl Default for ErrorFunction {
fn default() -> Self {
Self::Direct
}
}
impl ErrorFunction {
pub fn erf<T>(&self, x: T) -> T
where
T: Float + LibMFloat + ErfEps1Over1e3Float,
{
match self {
Self::Direct => x.libm_erf(),
Self::Eps1Over1e3 => x.erf_eps_1over1e3(),
}
}
pub fn normal_cdf<T>(&self, x: T, mean: T, w: T) -> T
where
T: Float + LibMFloat + ErfEps1Over1e3Float,
{
let inv_sigma = T::sqrt(w);
T::half() * (T::one() + self.erf((x - mean) * inv_sigma * T::FRAC_1_SQRT_2()))
}
}
pub trait LibMFloat {
fn libm_erf(self) -> Self;
}
impl LibMFloat for f32 {
fn libm_erf(self) -> Self {
libm::erff(self)
}
}
impl LibMFloat for f64 {
fn libm_erf(self) -> Self {
libm::erf(self)
}
}
pub enum ErrorFunction {
Direct,
Eps1Over1e3,
}
pub trait ErfEps1Over1e3Float: ApproxInto<usize, RoundToZero> + num_traits::Float {
const X_FOR_ERF_EPS_1OVER1E3: [Self; 64];
const INVERSED_DX_FOR_ERF_EPS_1OVER1E3: Self;
const Y_FOR_ERF_EPS_1OVER1E3: [Self; 64];
fn erf_eps_1over1e3(self) -> Self {
match self {
_ if self < Self::X_FOR_ERF_EPS_1OVER1E3[0] => -Self::one(),
_ if self >= Self::X_FOR_ERF_EPS_1OVER1E3[63] => Self::one(),
x => {
let idx =
(x - Self::X_FOR_ERF_EPS_1OVER1E3[0]) * Self::INVERSED_DX_FOR_ERF_EPS_1OVER1E3;
let alpha = idx.fract();
let idx: usize = idx.approx_by::<RoundToZero>().unwrap();
Self::Y_FOR_ERF_EPS_1OVER1E3[idx] * (Self::one() - alpha)
+ Self::Y_FOR_ERF_EPS_1OVER1E3[idx + 1] * alpha
}
}
}
}
#[allow(clippy::excessive_precision)]
impl ErfEps1Over1e3Float for f32 {
const X_FOR_ERF_EPS_1OVER1E3: [Self; 64] = [
-2.39693895,
-2.32084565,
-2.24475235,
-2.16865905,
-2.09256575,
-2.01647245,
-1.94037915,
-1.86428585,
-1.78819255,
-1.71209925,
-1.63600595,
-1.55991265,
-1.48381935,
-1.40772605,
-1.33163275,
-1.25553945,
-1.17944615,
-1.10335285,
-1.02725955,
-0.95116625,
-0.87507295,
-0.79897965,
-0.72288635,
-0.64679305,
-0.57069975,
-0.49460645,
-0.41851315,
-0.34241985,
-0.26632655,
-0.19023325,
-0.11413995,
-0.03804665,
0.03804665,
0.11413995,
0.19023325,
0.26632655,
0.34241985,
0.41851315,
0.49460645,
0.57069975,
0.64679305,
0.72288635,
0.79897965,
0.87507295,
0.95116625,
1.02725955,
1.10335285,
1.17944615,
1.25553945,
1.33163275,
1.40772605,
1.48381935,
1.55991265,
1.63600595,
1.71209925,
1.78819255,
1.86428585,
1.94037915,
2.01647245,
2.09256575,
2.16865905,
2.24475235,
2.32084565,
2.39693895,
];
const INVERSED_DX_FOR_ERF_EPS_1OVER1E3: Self = 13.141761468984605;
const Y_FOR_ERF_EPS_1OVER1E3: [Self; 64] = [
-0.99930052,
-0.99896989,
-0.99849936,
-0.99783743,
-0.99691696,
-0.9956517,
-0.99393249,
-0.99162334,
-0.98855749,
-0.98453378,
-0.97931372,
-0.97261948,
-0.96413348,
-0.9534999,
-0.94032851,
-0.92420128,
-0.90468204,
-0.88132908,
-0.85371082,
-0.82142392,
-0.78411334,
-0.74149338,
-0.69336849,
-0.6396527,
-0.58038613,
-0.51574736,
-0.44606033,
-0.37179495,
-0.29356079,
-0.21209374,
-0.12823602,
-0.04291034,
0.04291034,
0.12823602,
0.21209374,
0.29356079,
0.37179495,
0.44606033,
0.51574736,
0.58038613,
0.6396527,
0.69336849,
0.74149338,
0.78411334,
0.82142392,
0.85371082,
0.88132908,
0.90468204,
0.92420128,
0.94032851,
0.9534999,
0.96413348,
0.97261948,
0.97931372,
0.98453378,
0.98855749,
0.99162334,
0.99393249,
0.9956517,
0.99691696,
0.99783743,
0.99849936,
0.99896989,
0.99930052,
];
}
impl ErfEps1Over1e3Float for f64 {
const X_FOR_ERF_EPS_1OVER1E3: [Self; 64] = [
-2.39693895,
-2.32084565,
-2.24475235,
-2.16865905,
-2.09256575,
-2.01647245,
-1.94037915,
-1.86428585,
-1.78819255,
-1.71209925,
-1.63600595,
-1.55991265,
-1.48381935,
-1.40772605,
-1.33163275,
-1.25553945,
-1.17944615,
-1.10335285,
-1.02725955,
-0.95116625,
-0.87507295,
-0.79897965,
-0.72288635,
-0.64679305,
-0.57069975,
-0.49460645,
-0.41851315,
-0.34241985,
-0.26632655,
-0.19023325,
-0.11413995,
-0.03804665,
0.03804665,
0.11413995,
0.19023325,
0.26632655,
0.34241985,
0.41851315,
0.49460645,
0.57069975,
0.64679305,
0.72288635,
0.79897965,
0.87507295,
0.95116625,
1.02725955,
1.10335285,
1.17944615,
1.25553945,
1.33163275,
1.40772605,
1.48381935,
1.55991265,
1.63600595,
1.71209925,
1.78819255,
1.86428585,
1.94037915,
2.01647245,
2.09256575,
2.16865905,
2.24475235,
2.32084565,
2.39693895,
];
const INVERSED_DX_FOR_ERF_EPS_1OVER1E3: Self = 13.141761468984605;
const Y_FOR_ERF_EPS_1OVER1E3: [Self; 64] = [
-0.99930052,
-0.99896989,
-0.99849936,
-0.99783743,
-0.99691696,
-0.9956517,
-0.99393249,
-0.99162334,
-0.98855749,
-0.98453378,
-0.97931372,
-0.97261948,
-0.96413348,
-0.9534999,
-0.94032851,
-0.92420128,
-0.90468204,
-0.88132908,
-0.85371082,
-0.82142392,
-0.78411334,
-0.74149338,
-0.69336849,
-0.6396527,
-0.58038613,
-0.51574736,
-0.44606033,
-0.37179495,
-0.29356079,
-0.21209374,
-0.12823602,
-0.04291034,
0.04291034,
0.12823602,
0.21209374,
0.29356079,
0.37179495,
0.44606033,
0.51574736,
0.58038613,
0.6396527,
0.69336849,
0.74149338,
0.78411334,
0.82142392,
0.85371082,
0.88132908,
0.90468204,
0.92420128,
0.94032851,
0.9534999,
0.96413348,
0.97261948,
0.97931372,
0.98453378,
0.98855749,
0.99162334,
0.99393249,
0.9956517,
0.99691696,
0.99783743,
0.99849936,
0.99896989,
0.99930052,
];
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use ndarray::Array1;
#[test]
fn erf_eps_1over1e3() {
let x = Array1::linspace(-5.0, 5.0, 1 << 20);
let desired = x.mapv(f32::libm_erf);
let actual = x.mapv(f32::erf_eps_1over1e3);
assert_abs_diff_eq!(
actual.as_slice().unwrap(),
desired.as_slice().unwrap(),
epsilon = 7e-4,
);
}
}