macro_rules! wide_lit {
($T:ty, $s:literal) => {
match <$T>::from_str_radix($s, 10) {
::core::result::Result::Ok(v) => v,
::core::result::Result::Err(_) => {
panic!(concat!("wide_roots: invalid base-10 literal ", $s))
}
}
};
}
macro_rules! decl_wide_roots {
($Type:ident, $Storage:ty, $SqrtWide:ty, $CbrtWide:ty) => {
impl<const SCALE: u32> $Type<SCALE> {
#[inline]
#[must_use]
pub fn sqrt_strict(self) -> Self {
let raw = self.to_bits();
if raw <= $crate::macros::wide_roots::wide_lit!($Storage, "0") {
return Self::ZERO;
}
let one = $crate::macros::wide_roots::wide_lit!($SqrtWide, "1");
let ten = $crate::macros::wide_roots::wide_lit!($SqrtWide, "10");
let n: $SqrtWide = raw.resize::<$SqrtWide>() * ten.pow(SCALE);
let q = n.isqrt();
let diff = n - q * q;
let q = if diff > q { q + one } else { q };
Self::from_bits(q.resize::<$Storage>())
}
#[inline]
#[must_use]
pub fn cbrt_strict(self) -> Self {
let raw = self.to_bits();
let storage_zero = $crate::macros::wide_roots::wide_lit!($Storage, "0");
if raw == storage_zero {
return Self::ZERO;
}
let zero = $crate::macros::wide_roots::wide_lit!($CbrtWide, "0");
let one = $crate::macros::wide_roots::wide_lit!($CbrtWide, "1");
let three = $crate::macros::wide_roots::wide_lit!($CbrtWide, "3");
let ten = $crate::macros::wide_roots::wide_lit!($CbrtWide, "10");
let widened = raw.resize::<$CbrtWide>();
let negative = widened < zero;
let mag = if negative { -widened } else { widened };
let n: $CbrtWide = mag * ten.pow(2 * SCALE);
let sig_bits = <$CbrtWide>::BITS - n.leading_zeros();
let mut x = one << sig_bits.div_ceil(3);
loop {
let y = (x + x + n / (x * x)) / three;
if y >= x {
break;
}
x = y;
}
let q = x;
let eight_n = n << 3u32;
let t = q + q + one;
let q = if eight_n >= t * t * t { q + one } else { q };
let signed = if negative { -q } else { q };
Self::from_bits(signed.resize::<$Storage>())
}
#[cfg(all(feature = "strict", not(feature = "fast")))]
#[inline]
#[must_use]
pub fn sqrt(self) -> Self {
self.sqrt_strict()
}
#[cfg(all(feature = "strict", not(feature = "fast")))]
#[inline]
#[must_use]
pub fn cbrt(self) -> Self {
self.cbrt_strict()
}
#[inline]
#[must_use]
pub fn hypot_strict(self, other: Self) -> Self {
let a = self.abs();
let b = other.abs();
let (large, small) = if a >= b { (a, b) } else { (b, a) };
if large == Self::ZERO {
Self::ZERO
} else {
let ratio = small / large;
let one_plus_sq = Self::ONE + ratio * ratio;
large * one_plus_sq.sqrt_strict()
}
}
}
};
}
pub(crate) use {decl_wide_roots, wide_lit};
#[cfg(all(test, not(feature = "fast")))]
mod tests {
use crate::{D38, D76, D153, D307};
#[test]
fn sqrt_perfect_squares_are_exact() {
assert_eq!(D76::<6>::from_int(4).sqrt_strict(), D76::<6>::from_int(2));
assert_eq!(D76::<6>::from_int(9).sqrt_strict(), D76::<6>::from_int(3));
assert_eq!(
D76::<6>::from_int(144).sqrt_strict(),
D76::<6>::from_int(12)
);
assert_eq!(D153::<6>::from_int(25).sqrt_strict(), D153::<6>::from_int(5));
assert_eq!(
D307::<6>::from_int(81).sqrt_strict(),
D307::<6>::from_int(9)
);
}
#[test]
fn sqrt_zero_and_negative_saturate() {
assert_eq!(D76::<6>::ZERO.sqrt_strict(), D76::<6>::ZERO);
assert_eq!(D76::<6>::from_int(-4).sqrt_strict(), D76::<6>::ZERO);
assert_eq!(D307::<6>::from_int(-1).sqrt_strict(), D307::<6>::ZERO);
}
#[test]
fn cbrt_perfect_cubes_are_exact() {
assert_eq!(D76::<6>::from_int(8).cbrt_strict(), D76::<6>::from_int(2));
assert_eq!(
D76::<6>::from_int(27).cbrt_strict(),
D76::<6>::from_int(3)
);
assert_eq!(
D76::<6>::from_int(-8).cbrt_strict(),
D76::<6>::from_int(-2)
);
assert_eq!(
D153::<6>::from_int(125).cbrt_strict(),
D153::<6>::from_int(5)
);
assert_eq!(
D307::<6>::from_int(-64).cbrt_strict(),
D307::<6>::from_int(-4)
);
}
#[test]
fn cbrt_zero_is_zero() {
assert_eq!(D76::<6>::ZERO.cbrt_strict(), D76::<6>::ZERO);
assert_eq!(D153::<6>::ZERO.cbrt_strict(), D153::<6>::ZERO);
assert_eq!(D307::<6>::ZERO.cbrt_strict(), D307::<6>::ZERO);
}
#[test]
fn wide_roots_match_d38() {
for raw in [2i64, 3, 5, 7, 10, 123, 1_000, 999_983] {
let narrow = D38::<6>::from_int(raw);
let wide: D76<6> = narrow.into();
let narrow_sqrt: D76<6> = narrow.sqrt_strict().into();
assert_eq!(wide.sqrt_strict(), narrow_sqrt, "sqrt mismatch for {raw}");
let narrow_cbrt: D76<6> = narrow.cbrt_strict().into();
assert_eq!(wide.cbrt_strict(), narrow_cbrt, "cbrt mismatch for {raw}");
}
}
#[test]
fn sqrt_cbrt_at_wide_only_scale() {
assert_eq!(
D76::<50>::from_int(4).sqrt_strict(),
D76::<50>::from_int(2)
);
assert_eq!(
D76::<50>::from_int(8).cbrt_strict(),
D76::<50>::from_int(2)
);
assert_eq!(
D307::<150>::from_int(9).sqrt_strict(),
D307::<150>::from_int(3)
);
assert_eq!(
D307::<150>::from_int(27).cbrt_strict(),
D307::<150>::from_int(3)
);
}
}