#[inline(always)]
#[allow(
clippy::excessive_precision,
clippy::approx_constant,
clippy::manual_clamp
)]
pub fn exp_fast_f64(x: f64) -> f64 {
let x_orig = x;
let x_c = if x > 709.7827128933840 {
709.7827128933840
} else if x < -745.1332191019411 {
-745.1332191019411
} else {
x
};
let n = (x_c * 1.44269504088896338700e+00).round();
let ni = n as i64;
let r = (x_c - n * 6.93147180369123816490e-01) - n * 1.90821492927058770002e-10;
let r2 = r * r;
let even = r2
* (5.000000000000000000e-01
+ r2 * (4.166666666666667823e-02
+ r2 * (1.388888888887908173e-03
+ r2 * (2.480158733642404552e-05
+ r2 * (2.755726329888269414e-07 + r2 * 2.091813031817600864e-09)))));
let odd = r
* (1.000000000000000000e+00
+ r2 * (1.666666666666667962e-01
+ r2 * (8.333333333319599412e-03
+ r2 * (1.984126989005147428e-04
+ r2 * (2.755724091443086719e-06 + r2 * 2.511003898736913440e-08)))));
let expm1 = even + odd;
let mut result = if ni <= 1023 {
let scale = f64::from_bits(((ni + 1023) as u64) << 52);
scale + scale * expm1
} else {
let scale_half = f64::from_bits(((1023 + 1023) as u64) << 52); let r = scale_half + scale_half * expm1;
r * 2.0
};
#[allow(clippy::eq_op)]
if x_orig != x_orig {
result = f64::NAN;
}
if x_orig >= 709.7827128933840 {
result = f64::INFINITY;
}
if x_orig < -745.1332191019411 {
result = 0.0;
}
result
}
#[inline(never)]
pub fn exp_fast_batch_f64(input: &[f64], output: &mut [f64]) {
for i in 0..input.len() {
output[i] = exp_fast_f64(input[i]);
}
}
#[inline(always)]
pub fn exp_fast_f32(x: f32) -> f32 {
exp_fast_f64(x as f64) as f32
}
#[inline(never)]
pub fn exp_fast_batch_f32(input: &[f32], output: &mut [f32]) {
for i in 0..input.len() {
output[i] = exp_fast_f32(input[i]);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn basic_values() {
assert!((exp_fast_f64(0.0) - 1.0).abs() < 1e-15);
assert!((exp_fast_f64(1.0) - std::f64::consts::E).abs() < 1e-14);
assert!((exp_fast_f64(-1.0) - 1.0 / std::f64::consts::E).abs() < 1e-15);
}
#[test]
fn edge_cases() {
assert!(exp_fast_f64(f64::NAN).is_nan());
assert_eq!(exp_fast_f64(f64::INFINITY), f64::INFINITY);
assert_eq!(exp_fast_f64(f64::NEG_INFINITY), 0.0);
assert_eq!(exp_fast_f64(710.0), f64::INFINITY);
assert_eq!(exp_fast_f64(-750.0), 0.0);
assert_eq!(exp_fast_f64(0.0), 1.0);
assert_eq!(exp_fast_f64(-0.0), 1.0);
}
#[test]
fn accuracy_vs_libm() {
let test_values: Vec<f64> = (-7000..=7097).map(|i| i as f64 * 0.1).collect();
for &x in &test_values {
let fast = exp_fast_f64(x);
let reference = x.exp();
if reference == 0.0 || reference.is_infinite() {
continue;
}
let ulp = (fast - reference).abs() / (reference.abs() * f64::EPSILON);
assert!(
ulp <= 1.5,
"exp_fast_f64({x}) = {fast}, libm = {reference}, ulp = {ulp}"
);
}
}
#[test]
fn batch_matches_scalar() {
let input: Vec<f64> = (-100..=100).map(|i| i as f64 * 0.1).collect();
let mut output = vec![0.0f64; input.len()];
exp_fast_batch_f64(&input, &mut output);
for (i, &x) in input.iter().enumerate() {
assert_eq!(output[i].to_bits(), exp_fast_f64(x).to_bits());
}
}
#[test]
fn f32_basic() {
assert!((exp_fast_f32(0.0f32) - 1.0).abs() < 1e-7);
assert!((exp_fast_f32(1.0f32) - std::f32::consts::E).abs() < 1e-6);
assert!(exp_fast_f32(f32::NAN).is_nan());
assert_eq!(exp_fast_f32(f32::INFINITY), f32::INFINITY);
assert_eq!(exp_fast_f32(f32::NEG_INFINITY), 0.0);
}
}