pub fn logistic(x: f64) -> f64 {
1. / (1. + (-x).exp())
}
pub fn logit(p: f64) -> f64 {
if !(0. ..=1.).contains(&p) {
panic!("p must be in [0, 1]");
}
(p / (1. - p)).ln()
}
pub fn boxcox(x: f64, lambda: f64) -> f64 {
assert!(x > 0., "x must be positive");
if lambda == 0. {
x.ln()
} else {
(x.powf(lambda) - 1.) / lambda
}
}
pub fn boxcox_shifted(x: f64, lambda: f64, alpha: f64) -> f64 {
assert!(x > alpha, "x must larger than alpha");
if lambda == 0. {
(x + alpha).ln()
} else {
((x + alpha).powf(lambda) - 1.) / lambda
}
}
pub fn softmax(x: &[f64]) -> Vec<f64> {
let sum_exp: f64 = x.iter().map(|i| i.exp()).sum();
x.iter().map(|i| i.exp() / sum_exp).collect()
}
const ERF_P: f64 = 0.3275911;
const ERF_A1: f64 = 0.254829592;
const ERF_A2: f64 = -0.284496736;
const ERF_A3: f64 = 1.421413741;
const ERF_A4: f64 = -1.453152027;
const ERF_A5: f64 = 1.061405429;
pub fn erf(x: f64) -> f64 {
if x >= 0. {
let t = 1. / (1. + ERF_P * x);
1. - (((((ERF_A5 * t + ERF_A4) * t) + ERF_A3) * t + ERF_A2) * t + ERF_A1)
* t
* (-x * x).exp()
} else {
-erf(-x)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::distributions::{Distribution1D, Exponential, Uniform};
use approx_eq::assert_approx_eq;
#[test]
fn test_logistic() {
let d = Exponential::new(5.).sample_n(100 as usize);
d.iter().for_each(|x| {
assert_approx_eq!(logistic(*x) + logistic(-*x), 1.);
});
for i in 0..d.len() {
for j in i..d.len() {
if d[i] >= d[j] {
assert!(logistic(d[i]) >= logistic(d[j]));
}
}
}
assert_eq!(logistic(f64::NEG_INFINITY), 0.);
assert_eq!(logistic(0.), 0.5);
assert_eq!(logistic(f64::INFINITY), 1.);
}
#[test]
fn test_logit() {
let d = Uniform::new(0., 1.).sample_n(100 as usize);
d.iter().for_each(|x| {
assert_approx_eq!(*x, logistic(logit(*x)));
assert_approx_eq!(*x, logit(logistic(*x)));
});
for i in 0..d.len() {
for j in (i + 1)..d.len() {
assert_approx_eq!(
logit(d[i]) - logit(d[j]),
((d[i] / (1. - d[i])) / (d[j] / (1. - d[j]))).ln()
);
}
}
assert_eq!(logit(0.), f64::NEG_INFINITY);
assert_eq!(logit(0.5), 0.);
assert_eq!(logit(1.), f64::INFINITY);
}
#[test]
fn test_softmax() {
let orig = vec![1., 2., 3., 4., 1., 2., 3.];
let tfm = vec![
0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813,
];
let smv = softmax(&orig);
for i in 0..smv.len() {
assert_approx_eq!(smv[i], tfm[i]);
}
assert_approx_eq!(smv.iter().sum(), 1.);
}
#[test]
fn test_erf() {
assert_approx_eq!(erf(0.), 0., 1e-5);
assert_approx_eq!(erf(0.02), 0.022564575, 1e-5);
assert_approx_eq!(erf(0.04), 0.045111106, 1e-5);
assert_approx_eq!(erf(0.06), 0.067621594, 1e-5);
assert_approx_eq!(erf(0.08), 0.090078126, 1e-5);
assert_approx_eq!(erf(0.1), 0.112462916, 1e-5);
assert_approx_eq!(erf(0.2), 0.222702589, 1e-5);
assert_approx_eq!(erf(0.3), 0.328626759, 1e-5);
assert_approx_eq!(erf(0.4), 0.428392355, 1e-5);
assert_approx_eq!(erf(0.5), 0.520499878, 1e-5);
assert_approx_eq!(erf(0.6), 0.603856091, 1e-5);
assert_approx_eq!(erf(0.7), 0.677801194, 1e-5);
assert_approx_eq!(erf(0.8), 0.742100965, 1e-5);
assert_approx_eq!(erf(0.9), 0.796908212, 1e-5);
assert_approx_eq!(erf(1.), 0.842700793, 1e-5);
assert_approx_eq!(erf(1.1), 0.88020507, 1e-5);
assert_approx_eq!(erf(1.2), 0.910313978, 1e-5);
assert_approx_eq!(erf(1.3), 0.934007945, 1e-5);
assert_approx_eq!(erf(1.4), 0.95228512, 1e-5);
assert_approx_eq!(erf(1.5), 0.966105146, 1e-5);
assert_approx_eq!(erf(1.6), 0.976348383, 1e-5);
assert_approx_eq!(erf(1.7), 0.983790459, 1e-5);
assert_approx_eq!(erf(1.8), 0.989090502, 1e-5);
assert_approx_eq!(erf(1.9), 0.992790429, 1e-5);
assert_approx_eq!(erf(2.), 0.995322265, 1e-5);
assert_approx_eq!(erf(2.1), 0.997020533, 1e-5);
assert_approx_eq!(erf(2.2), 0.998137154, 1e-5);
assert_approx_eq!(erf(2.3), 0.998856823, 1e-5);
assert_approx_eq!(erf(2.4), 0.999311486, 1e-5);
assert_approx_eq!(erf(2.5), 0.999593048, 1e-5);
assert_approx_eq!(erf(3.), 0.99997791, 1e-5);
assert_approx_eq!(erf(3.5), 0.999999257, 1e-5);
}
}