fn fast_log2f(x: f32) -> f32 {
debug_assert!(x > 0.0);
let bits = x.to_bits();
let out = bits as f32;
out * 1.192_092_9e-7 - 126.942_695 }
pub(crate) fn sqrt_fast_approximation(f: f32) -> f32 {
f.sqrt()
}
pub(crate) fn log_approximation(x: f32) -> f32 {
use core::f32::consts::LN_2;
fast_log2f(x) * LN_2
}
pub(crate) fn log_approximation_batch(x: &[f32], y: &mut [f32]) {
for (xi, yi) in x.iter().zip(y.iter_mut()) {
*yi = log_approximation(*xi);
}
}
pub(crate) fn pow2_approximation(p: f32) -> f32 {
p.exp2()
}
pub(crate) fn pow_approximation(x: f32, p: f32) -> f32 {
pow2_approximation(p * fast_log2f(x))
}
pub(crate) fn exp_approximation(x: f32) -> f32 {
use core::f32::consts::LOG10_E;
pow_approximation(10.0, x * LOG10_E)
}
pub(crate) fn exp_approximation_batch(x: &[f32], y: &mut [f32]) {
for (xi, yi) in x.iter().zip(y.iter_mut()) {
*yi = exp_approximation(*xi);
}
}
pub(crate) fn exp_approximation_sign_flip(x: &[f32], y: &mut [f32]) {
for (xi, yi) in x.iter().zip(y.iter_mut()) {
*yi = exp_approximation(-*xi);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn fast_log2f_powers_of_two() {
assert!((fast_log2f(1.0) - 0.0).abs() < 0.1);
assert!((fast_log2f(2.0) - 1.0).abs() < 0.1);
assert!((fast_log2f(4.0) - 2.0).abs() < 0.1);
assert!((fast_log2f(8.0) - 3.0).abs() < 0.1);
assert!((fast_log2f(0.5) - (-1.0)).abs() < 0.1);
}
#[test]
fn log_approximation_positive_values() {
use core::f32::consts::E;
assert!(log_approximation(1.0).abs() < 0.1);
assert!((log_approximation(E) - 1.0).abs() < 0.1);
}
#[test]
fn exp_approximation_known_values() {
use core::f32::consts::E;
assert!((exp_approximation(0.0) - 1.0).abs() < 0.1);
assert!((exp_approximation(1.0) - E).abs() < 0.5);
}
#[test]
fn pow_approximation_squares() {
assert!((pow_approximation(2.0, 2.0) - 4.0).abs() < 0.5);
assert!((pow_approximation(3.0, 2.0) - 9.0).abs() < 1.0);
}
#[test]
fn sqrt_matches_std() {
use core::f32::consts::SQRT_2;
assert_eq!(sqrt_fast_approximation(4.0), 2.0);
assert_eq!(sqrt_fast_approximation(0.0), 0.0);
assert!((sqrt_fast_approximation(2.0) - SQRT_2).abs() < 1e-6);
}
#[test]
fn batch_operations() {
let x = [1.0_f32, 2.0, 4.0, 8.0];
let mut y = [0.0_f32; 4];
log_approximation_batch(&x, &mut y);
for (i, &yi) in y.iter().enumerate() {
let expected = x[i].ln();
assert!(
(yi - expected).abs() < 0.2,
"log mismatch at {i}: got {yi}, expected {expected}"
);
}
let x2 = [0.0_f32, 0.5, 1.0, 2.0];
exp_approximation_batch(&x2, &mut y);
for (i, &yi) in y.iter().enumerate() {
let expected = x2[i].exp();
assert!(
(yi - expected).abs() / expected.max(1.0) < 0.15,
"exp mismatch at {i}: got {yi}, expected {expected}"
);
}
}
#[test]
fn exp_sign_flip() {
let x = [0.0_f32, 1.0, 2.0];
let mut y = [0.0_f32; 3];
exp_approximation_sign_flip(&x, &mut y);
for (i, &yi) in y.iter().enumerate() {
let expected = (-x[i]).exp();
assert!(
(yi - expected).abs() / expected.max(1.0) < 0.15,
"exp_sign_flip mismatch at {i}: got {yi}, expected {expected}"
);
}
}
#[test]
fn fast_log2f_matches_cpp_bit_trick() {
let test_values = [0.001_f32, 0.1, 0.5, 1.0, 2.0, 10.0, 100.0, 1000.0];
for &v in &test_values {
let bits = v.to_bits();
let cpp_result = bits as f32 * 1.192_092_9e-7 - 126.942_695;
let rust_result = fast_log2f(v);
assert_eq!(
rust_result, cpp_result,
"fast_log2f({v}) mismatch: rust={rust_result}, cpp={cpp_result}"
);
}
}
}