use crate::common::f_fmla;
use crate::cube_roots::cbrtf::halley_refine_d;
use crate::double_double::DoubleDouble;
use crate::exponents::fast_ldexp;
pub(crate) trait CbrtBackend {
fn fma(&self, x: f64, y: f64, z: f64) -> f64;
fn polyeval4(&self, x: f64, a0: f64, a1: f64, a2: f64, a3: f64) -> f64;
fn halley(&self, x: f64, a: f64) -> f64;
fn exact_mul(&self, x: f64, y: f64) -> DoubleDouble;
fn quick_mult_f64(&self, x: DoubleDouble, y: f64) -> DoubleDouble;
}
pub(crate) struct GenericCbrtBackend {}
impl CbrtBackend for GenericCbrtBackend {
#[inline(always)]
fn fma(&self, x: f64, y: f64, z: f64) -> f64 {
f_fmla(x, y, z)
}
#[inline(always)]
fn polyeval4(&self, x: f64, a0: f64, a1: f64, a2: f64, a3: f64) -> f64 {
use crate::polyeval::f_polyeval4;
f_polyeval4(x, a0, a1, a2, a3)
}
#[inline(always)]
fn halley(&self, x: f64, a: f64) -> f64 {
halley_refine_d(x, a)
}
#[inline(always)]
fn exact_mul(&self, x: f64, y: f64) -> DoubleDouble {
DoubleDouble::from_exact_mult(x, y)
}
#[inline(always)]
fn quick_mult_f64(&self, x: DoubleDouble, y: f64) -> DoubleDouble {
DoubleDouble::quick_mult_f64(x, y)
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
pub(crate) struct FmaCbrtBackend {}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
impl CbrtBackend for FmaCbrtBackend {
#[inline(always)]
fn fma(&self, x: f64, y: f64, z: f64) -> f64 {
f64::mul_add(x, y, z)
}
#[inline(always)]
fn polyeval4(&self, x: f64, a0: f64, a1: f64, a2: f64, a3: f64) -> f64 {
use crate::polyeval::d_polyeval4;
d_polyeval4(x, a0, a1, a2, a3)
}
#[inline(always)]
fn halley(&self, x: f64, a: f64) -> f64 {
use crate::cube_roots::cbrtf::halley_refine_d_fma;
halley_refine_d_fma(x, a)
}
#[inline(always)]
fn exact_mul(&self, x: f64, y: f64) -> DoubleDouble {
DoubleDouble::from_exact_mult_fma(x, y)
}
#[inline(always)]
fn quick_mult_f64(&self, x: DoubleDouble, y: f64) -> DoubleDouble {
DoubleDouble::quick_mult_f64_fma(x, y)
}
}
#[inline(always)]
fn cbrt_gen_impl<B: CbrtBackend>(x: f64, backend: B) -> f64 {
static ESCALE: [f64; 3] = [
1.0,
f64::from_bits(0x3ff428a2f98d728b),
f64::from_bits(0x3ff965fea53d6e3d),
];
let bits = x.to_bits();
let mut exp = ((bits >> 52) & 0x7ff) as i32;
let mut mant = bits & ((1u64 << 52) - 1);
if exp == 0x7ff || x == 0.0 {
return x + x;
}
if exp == 0 && x != 0.0 {
let norm = x * f64::from_bits(0x4350000000000000); let norm_bits = norm.to_bits();
mant = norm_bits & ((1u64 << 52) - 1);
exp = ((norm_bits >> 52) & 0x7ff) as i32 - 54;
}
exp -= 1023;
mant |= 0x3ff << 52;
let m = f64::from_bits(mant);
let p = backend.polyeval4(
m,
f64::from_bits(0x3fe1b0babceeaafa),
f64::from_bits(0x3fe2c9a3e8e06a3c),
f64::from_bits(0xbfc4dc30afb71885),
f64::from_bits(0x3f97a8d3e05458e4),
);
let q = exp.div_euclid(3);
let rem_scale = exp.rem_euclid(3);
let z = p * ESCALE[rem_scale as usize];
let mm = fast_ldexp(m, rem_scale);
let r = 1.0 / mm;
let y0 = backend.halley(z, mm);
let d2y = backend.exact_mul(y0, y0);
let d3y = backend.quick_mult_f64(d2y, y0);
let h = ((d3y.hi - mm) + d3y.lo) * r;
let y = backend.fma(-f64::from_bits(0x3fd5555555555555), y0 * h, y0);
f64::copysign(fast_ldexp(y, q), x)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx", enable = "fma")]
unsafe fn cbrt_fma_impl(x: f64) -> f64 {
cbrt_gen_impl(x, FmaCbrtBackend {})
}
pub fn f_cbrt(x: f64) -> f64 {
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
{
cbrt_gen_impl(x, GenericCbrtBackend {})
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
use std::sync::OnceLock;
static EXECUTOR: OnceLock<unsafe fn(f64) -> f64> = OnceLock::new();
let q = EXECUTOR.get_or_init(|| {
if std::arch::is_x86_feature_detected!("avx")
&& std::arch::is_x86_feature_detected!("fma")
{
cbrt_fma_impl
} else {
fn def_cbrt(x: f64) -> f64 {
cbrt_gen_impl(x, GenericCbrtBackend {})
}
def_cbrt
}
});
unsafe { q(x) }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cbrt() {
assert_eq!(f_cbrt(0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000005432309223745),
0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000017579026781511548);
assert_eq!(f_cbrt(1.225158611559834), 1.0700336588124544);
assert_eq!(f_cbrt(0.00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000139491540182158), 1.1173329935611586e-103);
assert_eq!(f_cbrt(27.0), 3.0);
assert_eq!(f_cbrt(64.0), 4.0);
assert_eq!(f_cbrt(125.0), 5.0);
assert_eq!(f_cbrt(216.0), 6.0);
assert_eq!(f_cbrt(343.0), 7.0);
assert_eq!(f_cbrt(512.0), 8.0);
assert_eq!(f_cbrt(729.0), 9.0);
assert_eq!(f_cbrt(-729.0), -9.0);
assert_eq!(f_cbrt(-512.0), -8.0);
assert_eq!(f_cbrt(-343.0), -7.0);
assert_eq!(f_cbrt(-216.0), -6.0);
assert_eq!(f_cbrt(-125.0), -5.0);
assert_eq!(f_cbrt(-64.0), -4.0);
assert_eq!(f_cbrt(-27.0), -3.0);
assert_eq!(f_cbrt(0.0), 0.0);
assert_eq!(f_cbrt(f64::INFINITY), f64::INFINITY);
assert_eq!(f_cbrt(f64::NEG_INFINITY), f64::NEG_INFINITY);
assert!(f_cbrt(f64::NAN).is_nan());
}
}