use multiversion::multiversion;
#[multiversion(targets(
"x86_64+avx2+bmi1+bmi2+popcnt+lzcnt",
"x86_64+avx512f+avx512bw+avx512dq+avx512vl",
"aarch64+neon"
))]
#[must_use]
pub(crate) fn rcp_nr(w: f32) -> f32 {
#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
unsafe {
use std::arch::x86_64::*;
let w_vec = _mm_set_ss(w);
let rcp = _mm_rcp_ss(w_vec);
let two = _mm_set_ss(2.0);
let prod = _mm_mul_ss(w_vec, rcp);
let diff = _mm_sub_ss(two, prod);
let res = _mm_mul_ss(rcp, diff);
return _mm_cvtss_f32(res);
}
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
#[allow(unsafe_code)]
unsafe {
use std::arch::aarch64::*;
let v_w = vdupq_n_f32(w);
let res_vec = vrecpeq_f32(v_w);
let res_vec = vmulq_f32(res_vec, vrecpsq_f32(v_w, res_vec));
return vgetq_lane_f32(res_vec, 0);
}
#[cfg(not(any(
all(target_arch = "x86_64", target_feature = "avx2"),
all(target_arch = "aarch64", target_feature = "neon")
)))]
{
1.0 / w
}
}
#[must_use]
#[allow(clippy::cast_sign_loss, dead_code)]
pub(crate) fn bilinear_interpolate_fixed(x: f32, y: f32, p00: u8, p10: u8, p01: u8, p11: u8) -> u8 {
let fx = ((x.fract() * 65536.0) as u32) & 0xFFFF;
let fy = ((y.fract() * 65536.0) as u32) & 0xFFFF;
let inv_x = 0x10000 - fx;
let inv_y = 0x10000 - fy;
let w00 = (u64::from(inv_x) * u64::from(inv_y)) >> 16;
let w10 = (u64::from(fx) * u64::from(inv_y)) >> 16;
let w01 = (u64::from(inv_x) * u64::from(fy)) >> 16;
let w11 = (u64::from(fx) * u64::from(fy)) >> 16;
let res =
(u64::from(p00) * w00 + u64::from(p10) * w10 + u64::from(p01) * w01 + u64::from(p11) * w11)
>> 16;
res as u8
}
#[must_use]
pub(crate) fn erf_approx(x: f64) -> f64 {
if x == 0.0 {
return 0.0;
}
let sign = if x < 0.0 { -1.0 } else { 1.0 };
let x = x.abs();
let a1 = 0.254_829_592;
let a2 = -0.284_496_736;
let a3 = 1.421_413_741;
let a4 = -1.453_152_027;
let a5 = 1.061_405_429;
let p = 0.327_591_1;
let t = 1.0 / (1.0 + p * x);
let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
sign * y
}
#[cfg(all(
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
))]
#[must_use]
pub(crate) unsafe fn erf_approx_v4(x: std::arch::x86_64::__m256d) -> std::arch::x86_64::__m256d {
use std::arch::x86_64::*;
let sign_mask = _mm256_set1_pd(-0.0);
let sign_bits = _mm256_and_pd(x, sign_mask);
let abs_x = _mm256_andnot_pd(sign_mask, x);
let a1 = _mm256_set1_pd(0.254_829_592);
let a2 = _mm256_set1_pd(-0.284_496_736);
let a3 = _mm256_set1_pd(1.421_413_741);
let a4 = _mm256_set1_pd(-1.453_152_027);
let a5 = _mm256_set1_pd(1.061_405_429);
let p = _mm256_set1_pd(0.327_591_1);
let one = _mm256_set1_pd(1.0);
let t = _mm256_div_pd(one, _mm256_fmadd_pd(p, abs_x, one));
let poly = _mm256_fmadd_pd(a5, t, a4);
let poly = _mm256_fmadd_pd(poly, t, a3);
let poly = _mm256_fmadd_pd(poly, t, a2);
let poly = _mm256_fmadd_pd(poly, t, a1);
let neg_x2 = _mm256_mul_pd(abs_x, abs_x);
let neg_x2 = _mm256_xor_pd(neg_x2, sign_mask);
let neg_x2_arr: [f64; 4] = std::mem::transmute(neg_x2);
let exp_vals = _mm256_set_pd(
neg_x2_arr[3].exp(),
neg_x2_arr[2].exp(),
neg_x2_arr[1].exp(),
neg_x2_arr[0].exp(),
);
let y = _mm256_fnmadd_pd(_mm256_mul_pd(poly, t), exp_vals, one);
_mm256_or_pd(y, sign_bits)
}
#[cfg(not(all(
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
)))]
#[must_use]
#[allow(dead_code)]
pub(crate) fn erf_approx_v4(x: [f64; 4]) -> [f64; 4] {
[
erf_approx(x[0]),
erf_approx(x[1]),
erf_approx(x[2]),
erf_approx(x[3]),
]
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::float_cmp, clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_rcp_nr_precision() {
let values = [1.0, 2.0, 10.0, 0.5, 123.456];
for &w in &values {
let expected = 1.0 / w;
let actual = rcp_nr(w);
let diff = (expected - actual).abs();
assert!(
diff < 1e-4,
"rcp_nr({w}) failed: expected {expected}, got {actual}, diff {diff}"
);
}
}
#[test]
fn test_erf_approx_properties() {
assert_eq!(erf_approx(0.0), 0.0);
for x in [0.1, 0.5, 1.0, 2.0, 5.0] {
assert!((erf_approx(-x) + erf_approx(x)).abs() < 1e-15);
}
assert!((erf_approx(10.0) - 1.0).abs() < 1e-7);
assert!((erf_approx(-10.0) + 1.0).abs() < 1e-7);
assert!((erf_approx(100.0) - 1.0).abs() < 1e-15);
}
#[test]
fn test_erf_approx_accuracy() {
let cases = [
(0.5, 0.520_499_877_813_046_5),
(1.0, 0.842_700_792_949_714_8),
(2.0, 0.995_322_265_018_952_7),
];
for (x, expected) in cases {
let actual = erf_approx(x);
let diff = (actual - expected).abs();
assert!(
diff < 1.5e-7,
"erf_approx({x}) error {diff} exceeds tolerance 1.5e-7"
);
}
}
#[test]
fn test_erf_approx_v4_matches_scalar() {
let inputs = [0.5, -1.0, 2.0, -0.3];
#[cfg(all(
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
))]
{
use std::arch::x86_64::*;
unsafe {
let v = _mm256_set_pd(inputs[3], inputs[2], inputs[1], inputs[0]);
let result = erf_approx_v4(v);
let result_arr: [f64; 4] = std::mem::transmute(result);
for i in 0..4 {
let scalar = erf_approx(inputs[i]);
let diff = (result_arr[i] - scalar).abs();
assert!(
diff < 1e-15,
"erf_approx_v4 lane {i}: expected {scalar}, got {}, diff {diff}",
result_arr[i]
);
}
}
}
#[cfg(not(all(
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
)))]
{
let result = erf_approx_v4(inputs);
for i in 0..4 {
let scalar = erf_approx(inputs[i]);
let diff = (result[i] - scalar).abs();
assert!(
diff < 1e-15,
"erf_approx_v4 lane {i}: expected {scalar}, got {}, diff {diff}",
result[i]
);
}
}
}
#[test]
fn test_bilinear_fixed() {
assert_eq!(
bilinear_interpolate_fixed(0.5, 0.5, 100, 200, 100, 200),
150
);
assert_eq!(bilinear_interpolate_fixed(0.0, 0.0, 100, 200, 50, 250), 100);
assert_eq!(
bilinear_interpolate_fixed(0.999, 0.999, 100, 200, 50, 250),
249
); }
}