use crate::support::{
CastFrom, CastInto, DInt, Float, FpResult, HInt, Int, IntTy, MinInt, Round, Status, cold_path,
};
#[inline]
pub fn sqrt<F>(x: F) -> F
where
F: Float + SqrtHelper,
F::Int: HInt,
F::Int: From<u8>,
F::Int: From<F::ISet2>,
F::Int: CastInto<F::ISet1>,
F::Int: CastInto<F::ISet2>,
u32: CastInto<F::Int>,
{
sqrt_round(x, Round::Nearest).val
}
#[inline]
pub fn sqrt_round<F>(x: F, _round: Round) -> FpResult<F>
where
F: Float + SqrtHelper,
F::Int: HInt,
F::Int: From<u8>,
F::Int: From<F::ISet2>,
F::Int: CastInto<F::ISet1>,
F::Int: CastInto<F::ISet2>,
u32: CastInto<F::Int>,
{
let zero = IntTy::<F>::ZERO;
let one = IntTy::<F>::ONE;
let mut ix = x.to_bits();
let noshift = F::BITS <= u32::BITS;
let (mut top, special_case) = if noshift {
let exp_lsb = one << F::SIG_BITS;
let special_case = ix.wrapping_sub(exp_lsb) >= F::EXP_MASK - exp_lsb;
(Exp::NoShift(()), special_case)
} else {
let top = u32::cast_from(ix >> F::SIG_BITS);
let special_case = top.wrapping_sub(1) >= F::EXP_SAT - 1;
(Exp::Shifted(top), special_case)
};
if special_case {
cold_path();
if ix << 1 == zero {
return FpResult::ok(x);
}
if ix == F::EXP_MASK {
return FpResult::ok(x);
}
if ix > F::EXP_MASK {
return FpResult::new(F::NAN, Status::INVALID);
}
let scaled = x * F::from_parts(false, F::SIG_BITS + F::EXP_BIAS, zero);
ix = scaled.to_bits();
match top {
Exp::Shifted(ref mut v) => {
*v = scaled.ex();
*v = (*v).wrapping_sub(F::SIG_BITS);
}
Exp::NoShift(()) => {
ix = ix.wrapping_sub((F::SIG_BITS << F::SIG_BITS).cast());
}
}
}
let (m_u2, exp) = match top {
Exp::Shifted(top) => {
let mut e = top;
let mut m_u2 = (ix | F::IMPLICIT_BIT) << F::EXP_BITS;
let even = (e & 1) != 0;
if even {
m_u2 >>= 1;
}
e = (e.wrapping_add(F::EXP_SAT >> 1)) >> 1;
(m_u2, Exp::Shifted(e))
}
Exp::NoShift(()) => {
let even = ix & (one << F::SIG_BITS) != zero;
let mut e_noshift = ix >> 1;
e_noshift += (F::EXP_MASK ^ (F::SIGN_MASK >> 1)) >> 1;
e_noshift &= F::EXP_MASK;
let m1 = (ix << F::EXP_BITS) | F::SIGN_MASK;
let m0 = (ix << (F::EXP_BITS - 1)) & !F::SIGN_MASK;
let m_u2 = if even { m0 } else { m1 };
(m_u2, Exp::NoShift(e_noshift))
}
};
let i = usize::cast_from(ix >> (F::SIG_BITS - 6)) & 0b1111111;
let r1_u0: F::ISet1 = F::ISet1::cast_from(RSQRT_TAB[i]) << (F::ISet1::BITS - 16);
let s1_u2: F::ISet1 = ((m_u2) >> (F::BITS - F::ISet1::BITS)).cast();
let (r1_u0, _s1_u2) = goldschmidt::<F, F::ISet1>(r1_u0, s1_u2, F::SET1_ROUNDS, false);
let r2_u0: F::ISet2 = F::ISet2::from(r1_u0) << (F::ISet2::BITS - F::ISet1::BITS);
let s2_u2: F::ISet2 = ((m_u2) >> (F::BITS - F::ISet2::BITS)).cast();
let (r2_u0, _s2_u2) = goldschmidt::<F, F::ISet2>(r2_u0, s2_u2, F::SET2_ROUNDS, false);
let r_u0: F::Int = F::Int::from(r2_u0) << (F::BITS - F::ISet2::BITS);
let s_u2: F::Int = m_u2;
let (_r_u0, s_u2) = goldschmidt::<F, F::Int>(r_u0, s_u2, F::FINAL_ROUNDS, true);
let mut m = s_u2 >> (F::EXP_BITS - 2);
let shift = 2 * F::SIG_BITS - (F::BITS - 2);
let d0 = (m_u2 << shift).wrapping_sub(m.wrapping_mul(m));
let d1 = m.wrapping_sub(d0);
m += d1 >> (F::BITS - 1);
m &= F::SIG_MASK;
match exp {
Exp::Shifted(e) => m |= IntTy::<F>::cast_from(e) << F::SIG_BITS,
Exp::NoShift(e) => m |= e,
};
let mut y = F::from_bits(m);
if F::BITS > 16 {
let d2 = d1.wrapping_add(m).wrapping_add(one);
let mut tiny = if d2 == zero {
cold_path();
zero
} else {
F::IMPLICIT_BIT
};
tiny |= (d1 ^ d2) & F::SIGN_MASK;
let t = F::from_bits(tiny);
y = y + t;
}
FpResult::ok(y)
}
fn wmulh<I: HInt>(a: I, b: I) -> I {
a.widen_mul(b).hi()
}
#[inline]
fn goldschmidt<F, I>(mut r_u0: I, mut s_u2: I, count: u32, final_set: bool) -> (I, I)
where
F: SqrtHelper,
I: HInt + From<u8>,
{
let three_u2 = I::from(0b11u8) << (I::BITS - 2);
let mut u_u0 = r_u0;
for i in 0..count {
s_u2 = wmulh(s_u2, u_u0);
if i > 0 && (!final_set || i + 1 < count) {
s_u2 <<= 1;
}
let d_u2 = wmulh(s_u2, r_u0);
u_u0 = three_u2.wrapping_sub(d_u2);
r_u0 = wmulh(r_u0, u_u0) << 1;
}
(r_u0, s_u2)
}
enum Exp<T> {
Shifted(u32),
NoShift(T),
}
pub trait SqrtHelper: Float {
type ISet1: HInt + Into<Self::ISet2> + CastFrom<Self::Int> + From<u8>;
type ISet2: HInt + From<Self::ISet1> + From<u8>;
const SET1_ROUNDS: u32 = 0;
const SET2_ROUNDS: u32 = 0;
const FINAL_ROUNDS: u32;
}
#[cfg(f16_enabled)]
impl SqrtHelper for f16 {
type ISet1 = u16; type ISet2 = u16;
const FINAL_ROUNDS: u32 = 2;
}
impl SqrtHelper for f32 {
type ISet1 = u32; type ISet2 = u32;
const FINAL_ROUNDS: u32 = 3;
}
impl SqrtHelper for f64 {
type ISet1 = u32; type ISet2 = u32;
const SET2_ROUNDS: u32 = 2;
const FINAL_ROUNDS: u32 = 2;
}
#[cfg(f128_enabled)]
impl SqrtHelper for f128 {
type ISet1 = u32;
type ISet2 = u64;
const SET1_ROUNDS: u32 = 1;
const SET2_ROUNDS: u32 = 2;
const FINAL_ROUNDS: u32 = 2;
}
#[rustfmt::skip]
static RSQRT_TAB: [u16; 128] = [
0xb451, 0xb2f0, 0xb196, 0xb044, 0xaef9, 0xadb6, 0xac79, 0xab43,
0xaa14, 0xa8eb, 0xa7c8, 0xa6aa, 0xa592, 0xa480, 0xa373, 0xa26b,
0xa168, 0xa06a, 0x9f70, 0x9e7b, 0x9d8a, 0x9c9d, 0x9bb5, 0x9ad1,
0x99f0, 0x9913, 0x983a, 0x9765, 0x9693, 0x95c4, 0x94f8, 0x9430,
0x936b, 0x92a9, 0x91ea, 0x912e, 0x9075, 0x8fbe, 0x8f0a, 0x8e59,
0x8daa, 0x8cfe, 0x8c54, 0x8bac, 0x8b07, 0x8a64, 0x89c4, 0x8925,
0x8889, 0x87ee, 0x8756, 0x86c0, 0x862b, 0x8599, 0x8508, 0x8479,
0x83ec, 0x8361, 0x82d8, 0x8250, 0x81c9, 0x8145, 0x80c2, 0x8040,
0xff02, 0xfd0e, 0xfb25, 0xf947, 0xf773, 0xf5aa, 0xf3ea, 0xf234,
0xf087, 0xeee3, 0xed47, 0xebb3, 0xea27, 0xe8a3, 0xe727, 0xe5b2,
0xe443, 0xe2dc, 0xe17a, 0xe020, 0xdecb, 0xdd7d, 0xdc34, 0xdaf1,
0xd9b3, 0xd87b, 0xd748, 0xd61a, 0xd4f1, 0xd3cd, 0xd2ad, 0xd192,
0xd07b, 0xcf69, 0xce5b, 0xcd51, 0xcc4a, 0xcb48, 0xca4a, 0xc94f,
0xc858, 0xc764, 0xc674, 0xc587, 0xc49d, 0xc3b7, 0xc2d4, 0xc1f4,
0xc116, 0xc03c, 0xbf65, 0xbe90, 0xbdbe, 0xbcef, 0xbc23, 0xbb59,
0xba91, 0xb9cc, 0xb90a, 0xb84a, 0xb78c, 0xb6d0, 0xb617, 0xb560,
];
#[cfg(test)]
mod tests {
use super::*;
fn spec_test<F>()
where
F: Float + SqrtHelper,
F::Int: HInt,
F::Int: From<u8>,
F::Int: From<F::ISet2>,
F::Int: CastInto<F::ISet1>,
F::Int: CastInto<F::ISet2>,
u32: CastInto<F::Int>,
{
let nan = [F::NEG_INFINITY, F::NEG_ONE, F::NAN, F::MIN];
let roundtrip = [F::ZERO, F::NEG_ZERO, F::INFINITY];
for x in nan {
let FpResult { val, status } = sqrt_round(x, Round::Nearest);
assert!(val.is_nan());
assert!(status == Status::INVALID);
}
for x in roundtrip {
let FpResult { val, status } = sqrt_round(x, Round::Nearest);
assert_biteq!(val, x);
assert!(status == Status::OK);
}
}
#[test]
#[cfg(f16_enabled)]
fn sanity_check_f16() {
assert_biteq!(sqrt(100.0f16), 10.0);
assert_biteq!(sqrt(4.0f16), 2.0);
}
#[test]
#[cfg(f16_enabled)]
fn spec_tests_f16() {
spec_test::<f16>();
}
#[test]
#[cfg(f16_enabled)]
#[allow(clippy::approx_constant)]
fn conformance_tests_f16() {
let cases = [
(f16::PI, 0x3f17_u16),
(f16::from_bits(0x70e2), 0x5640_u16),
(f16::from_bits(0x0000000f), 0x13bf_u16),
(f16::INFINITY, f16::INFINITY.to_bits()),
];
for (input, output) in cases {
assert_biteq!(
sqrt(input),
f16::from_bits(output),
"input: {input:?} ({:#018x})",
input.to_bits()
);
}
}
#[test]
fn sanity_check_f32() {
assert_biteq!(sqrt(100.0f32), 10.0);
assert_biteq!(sqrt(4.0f32), 2.0);
}
#[test]
fn spec_tests_f32() {
spec_test::<f32>();
}
#[test]
#[allow(clippy::approx_constant)]
fn conformance_tests_f32() {
let cases = [
(f32::PI, 0x3fe2dfc5_u32),
(10000.0f32, 0x42c80000_u32),
(f32::from_bits(0x0000000f), 0x1b2f456f_u32),
(f32::INFINITY, f32::INFINITY.to_bits()),
];
for (input, output) in cases {
assert_biteq!(
sqrt(input),
f32::from_bits(output),
"input: {input:?} ({:#018x})",
input.to_bits()
);
}
}
#[test]
fn sanity_check_f64() {
assert_biteq!(sqrt(100.0f64), 10.0);
assert_biteq!(sqrt(4.0f64), 2.0);
}
#[test]
fn spec_tests_f64() {
spec_test::<f64>();
}
#[test]
#[allow(clippy::approx_constant)]
fn conformance_tests_f64() {
let cases = [
(f64::PI, 0x3ffc5bf891b4ef6a_u64),
(10000.0, 0x4059000000000000_u64),
(f64::from_bits(0x0000000f), 0x1e7efbdeb14f4eda_u64),
(f64::INFINITY, f64::INFINITY.to_bits()),
];
for (input, output) in cases {
assert_biteq!(
sqrt(input),
f64::from_bits(output),
"input: {input:?} ({:#018x})",
input.to_bits()
);
}
}
#[test]
#[cfg(f128_enabled)]
fn sanity_check_f128() {
assert_biteq!(sqrt(100.0f128), 10.0);
assert_biteq!(sqrt(4.0f128), 2.0);
}
#[test]
#[cfg(f128_enabled)]
fn spec_tests_f128() {
spec_test::<f128>();
}
#[test]
#[cfg(f128_enabled)]
#[allow(clippy::approx_constant)]
fn conformance_tests_f128() {
let cases = [
(f128::PI, 0x3fffc5bf891b4ef6aa79c3b0520d5db9_u128),
(
f128::from_bits(0x400c3880000000000000000000000000),
0x40059000000000000000000000000000_u128,
),
(
f128::from_bits(0x0000000f),
0x1fc9efbdeb14f4ed9b17ae807907e1e9_u128,
),
(f128::INFINITY, f128::INFINITY.to_bits()),
];
for (input, output) in cases {
assert_biteq!(
sqrt(input),
f128::from_bits(output),
"input: {input:?} ({:#018x})",
input.to_bits()
);
}
}
}