use crate::algo_x_support::seed::sqrt_seed;
use crate::int::types::traits::BigInt;
use crate::int::types::Int;
use crate::support::rounding::RoundingMode;
#[inline]
fn isqrt_w_seeded<const W: usize>(n: Int<W>) -> Int<W> {
let bits = n.bit_length();
let mag = n.unsigned_abs();
let mut seed_limbs = [0u64; W];
sqrt_seed(mag.as_limbs(), bits, &mut seed_limbs);
let x0 = Int::<W>::from_mag_limbs(&seed_limbs, false);
let x0 = if x0 <= Int::<W>::ZERO { Int::<W>::ONE } else { x0 };
let two = Int::<W>::from_i128(2);
let mut x = x0;
loop {
let y = (x + n / x) / two;
if y >= x {
break x;
}
x = y;
}
}
#[inline]
#[must_use]
pub(crate) fn sqrt_native<const N: usize, const W: usize>(
raw: Int<N>,
pow10_scale: Int<W>,
mode: RoundingMode,
) -> Int<N> {
if raw <= Int::<N>::ZERO {
return Int::<N>::ZERO;
}
let zero = Int::<W>::ZERO;
let one = Int::<W>::ONE;
let widened: Int<W> = raw.resize_to::<Int<W>>();
let n: Int<W> = widened * pow10_scale;
let q = isqrt_w_seeded::<W>(n);
let qsq = q * q;
let diff = n - qsq;
let halfway_round_up = diff > q;
let diff_nonzero = diff != zero;
let bump = match mode {
RoundingMode::HalfToEven
| RoundingMode::HalfAwayFromZero
| RoundingMode::HalfTowardZero => halfway_round_up,
RoundingMode::Trunc | RoundingMode::Floor => false,
RoundingMode::Ceiling => diff_nonzero,
};
let q = if bump { q + one } else { q };
q.resize_to::<Int<N>>()
}
#[cfg(all(test, feature = "std"))]
mod tests {
use super::sqrt_native;
use crate::algos::sqrt::sqrt_newton::sqrt_newton;
use crate::int::types::Int;
use crate::support::rounding::RoundingMode;
const ALL_MODES: [RoundingMode; 6] = [
RoundingMode::HalfToEven,
RoundingMode::HalfAwayFromZero,
RoundingMode::HalfTowardZero,
RoundingMode::Trunc,
RoundingMode::Floor,
RoundingMode::Ceiling,
];
fn check_cell<const N: usize, const W: usize>(scale: u32, raws: &[i128])
where
crate::int::types::compute_limbs::Limbs<N>: crate::int::types::compute_limbs::ComputeLimbs,
{
for &r in raws {
let raw = Int::<N>::from_i128(r);
for mode in ALL_MODES {
let got = sqrt_native::<N, W>(raw, Int::<W>::TEN.pow(scale), mode);
let want = sqrt_newton::<N>(raw, scale, mode);
assert_eq!(got, want, "N={N} W={W} scale={scale} raw={r} mode={mode:?}");
}
}
}
#[test]
fn sqrt_native_matches_generic_newton_d76_s35() {
let raws: [i128; 8] = [
0,
1,
-5, 400_000_000_000_000_000_000_000_000_000_000_000, 150_000_000_000_000_000_000_000_000_000_000_000,
(1i128 << 100) | 0xBEEF,
(1i128 << 120) | 0x1357,
i128::MAX,
];
check_cell::<4, 6>(35, &raws);
}
#[cfg(feature = "_wide-support")]
#[test]
fn sqrt_native_matches_generic_newton_d153_s75() {
let raws: [i128; 6] = [
0,
1,
(1i128 << 64),
(1i128 << 100) | 0xABCD,
(1i128 << 126) | 0x99,
i128::MAX,
];
check_cell::<8, 11>(75, &raws);
}
#[cfg(any(feature = "x-wide", feature = "xx-wide"))]
#[test]
fn sqrt_native_matches_generic_newton_d307_s150() {
let raws: [i128; 5] = [
0,
1,
(1i128 << 64) | 7,
(1i128 << 120) | 0x1,
i128::MAX,
];
check_cell::<16, 21>(150, &raws);
}
#[test]
fn sqrt_native_perfect_square_four_is_two() {
let four = Int::<4>::from_i128(4) * Int::<4>::from_i128(10).pow(35);
let two = Int::<4>::from_i128(2) * Int::<4>::from_i128(10).pow(35);
for mode in ALL_MODES {
assert_eq!(sqrt_native::<4, 6>(four, Int::<6>::TEN.pow(35), mode), two, "mode {mode:?}");
}
}
#[cfg(feature = "_wide-support")]
fn near_max<const N: usize>() -> Int<N> {
let mut mag = [0u64; N];
for m in mag.iter_mut() {
*m = u64::MAX;
}
mag[N - 1] = u64::MAX >> 1;
Int::<N>::from_mag_limbs(&mag, false)
}
#[cfg(any(feature = "x-wide", feature = "xx-wide"))]
#[test]
fn sqrt_native_near_max_magnitude_all_cells() {
for mode in ALL_MODES {
assert_eq!(sqrt_native::<4, 6>(near_max::<4>(), Int::<6>::TEN.pow(35), mode), sqrt_newton::<4>(near_max::<4>(), 35, mode), "D76 mode {mode:?}");
assert_eq!(sqrt_native::<6, 9>(near_max::<6>(), Int::<9>::TEN.pow(57), mode), sqrt_newton::<6>(near_max::<6>(), 57, mode), "D115 mode {mode:?}");
assert_eq!(sqrt_native::<8, 12>(near_max::<8>(), Int::<12>::TEN.pow(75), mode), sqrt_newton::<8>(near_max::<8>(), 75, mode), "D153s75 mode {mode:?}");
assert_eq!(sqrt_native::<8, 12>(near_max::<8>(), Int::<12>::TEN.pow(76), mode), sqrt_newton::<8>(near_max::<8>(), 76, mode), "D153s76 mode {mode:?}");
assert_eq!(sqrt_native::<12, 19>(near_max::<12>(), Int::<19>::TEN.pow(115), mode), sqrt_newton::<12>(near_max::<12>(), 115, mode), "D230 mode {mode:?}");
assert_eq!(sqrt_native::<16, 24>(near_max::<16>(), Int::<24>::TEN.pow(150), mode), sqrt_newton::<16>(near_max::<16>(), 150, mode), "D307 mode {mode:?}");
}
}
#[cfg(feature = "_wide-support")]
#[test]
fn sqrt_native_routed_2n_widths_near_max() {
for mode in ALL_MODES {
for &s in &[0u32, 20, 28, 57] {
assert_eq!(sqrt_native::<3, 6>(near_max::<3>(), Int::<6>::TEN.pow(s), mode), sqrt_newton::<3>(near_max::<3>(), s, mode), "D57 W=6 s={s} mode {mode:?}");
}
for &s in &[0u32, 20, 35, 76] {
assert_eq!(sqrt_native::<4, 8>(near_max::<4>(), Int::<8>::TEN.pow(s), mode), sqrt_newton::<4>(near_max::<4>(), s, mode), "D76 W=8 s={s} mode {mode:?}");
}
for &s in &[0u32, 25, 57, 115] {
assert_eq!(sqrt_native::<6, 12>(near_max::<6>(), Int::<12>::TEN.pow(s), mode), sqrt_newton::<6>(near_max::<6>(), s, mode), "D115 W=12 s={s} mode {mode:?}");
}
for &s in &[0u32, 25, 75, 153] {
assert_eq!(sqrt_native::<8, 16>(near_max::<8>(), Int::<16>::TEN.pow(s), mode), sqrt_newton::<8>(near_max::<8>(), s, mode), "D153 W=16 s={s} mode {mode:?}");
}
}
}
}