use crate::int::algos::hypot::hypot_pythagoras::hypot_pythagoras;
use crate::int::algos::support::limbs::{fit_k, fit_one};
use crate::int::types::compute_limbs::{ComputeLimbs, Limbs};
use crate::int::types::Int;
use crate::support::rounding::RoundingMode;
#[inline]
fn isqrt_u128(n: u128) -> (u128, u128) {
if n == 0 {
return (0, 0);
}
let bits = 128 - n.leading_zeros();
let mut x: u128 = crate::algo_x_support::seed::sqrt_seed_u128(n, bits);
loop {
if let Some(xx) = x.checked_mul(x) {
if xx <= n {
return (x, n - xx);
}
}
let y = (x + n / x) >> 1;
x = if y >= x { x - 1 } else { y };
}
}
#[inline]
fn sq_u256(x: u128) -> [u64; 4] {
let xw = [x as u64, (x >> 64) as u64, 0, 0];
let mut out = [0u64; 4];
crate::int::algos::sqr::sqr_low_fixed::sqr_low_fixed::<4>(&xw, &mut out);
out
}
#[inline]
fn add_u256(a: [u64; 4], b: [u64; 4]) -> ([u64; 4], bool) {
let mut out = [0u64; 4];
let mut carry = 0u128;
let mut i = 0;
while i < 4 {
let s = a[i] as u128 + b[i] as u128 + carry;
out[i] = s as u64;
carry = s >> 64;
i += 1;
}
(out, carry != 0)
}
#[inline]
fn div_u256_by_u128(n: [u64; 4], d: u128) -> u128 {
let d_hi = (d >> 64) as u64;
if d_hi == 0 {
let d0 = d as u64; let mut q = [0u64; 4];
let mut rem: u64 = 0;
let mut i = 4;
while i > 0 {
i -= 1;
let cur = ((rem as u128) << 64) | n[i] as u128; q[i] = (cur / d0 as u128) as u64;
rem = (cur % d0 as u128) as u64;
}
(q[0] as u128) | ((q[1] as u128) << 64)
} else {
knuth_d_256_by_128(n, d)
}
}
#[inline]
fn knuth_d_256_by_128(n: [u64; 4], d: u128) -> u128 {
let d_hi = (d >> 64) as u64;
let shift = d_hi.leading_zeros(); let dn: u128 = d << shift;
let v1 = (dn >> 64) as u64; let v0 = dn as u64;
let mut u = [0u64; 5];
if shift == 0 {
u[0] = n[0];
u[1] = n[1];
u[2] = n[2];
u[3] = n[3];
} else {
let s = shift;
u[0] = n[0] << s;
u[1] = (n[1] << s) | (n[0] >> (64 - s));
u[2] = (n[2] << s) | (n[1] >> (64 - s));
u[3] = (n[3] << s) | (n[2] >> (64 - s));
u[4] = n[3] >> (64 - s);
}
let mut q = [0u64; 2];
let mut j = 2; while j > 0 {
j -= 1;
let num = ((u[j + 2] as u128) << 64) | u[j + 1] as u128;
let mut qhat = num / v1 as u128;
let mut rhat = num % v1 as u128;
if qhat >= (1u128 << 64) {
qhat = (1u128 << 64) - 1;
rhat = num - qhat * v1 as u128;
}
while rhat < (1u128 << 64) && qhat * v0 as u128 > (rhat << 64) + u[j] as u128 {
qhat -= 1;
rhat += v1 as u128;
}
let qh = qhat as u64;
let m0 = qh as u128 * v0 as u128; let m1 = qh as u128 * v1 as u128; let p_lo = m0 as u64;
let p_mid = (m0 >> 64) + (m1 as u64 as u128); let p_hi = (p_mid >> 64) + (m1 >> 64); let p_mid = p_mid as u64;
let p_hi = p_hi as u64;
let mut borrow: i128 = 0;
let s0 = u[j] as i128 - p_lo as i128 - borrow;
u[j] = s0 as u64;
borrow = if s0 < 0 { 1 } else { 0 };
let s1 = u[j + 1] as i128 - p_mid as i128 - borrow;
u[j + 1] = s1 as u64;
borrow = if s1 < 0 { 1 } else { 0 };
let s2 = u[j + 2] as i128 - p_hi as i128 - borrow;
u[j + 2] = s2 as u64;
borrow = if s2 < 0 { 1 } else { 0 };
let mut qd = qh;
if borrow != 0 {
qd = qd.wrapping_sub(1);
let a0 = u[j] as u128 + v0 as u128;
u[j] = a0 as u64;
let a1 = u[j + 1] as u128 + v1 as u128 + (a0 >> 64);
u[j + 1] = a1 as u64;
u[j + 2] = u[j + 2].wrapping_add((a1 >> 64) as u64);
}
q[j] = qd;
}
(q[0] as u128) | ((q[1] as u128) << 64)
}
#[inline]
fn isqrt_u256(n: [u64; 4]) -> (u128, [u64; 4]) {
let bits: u32 = if n[3] != 0 {
192 + (64 - n[3].leading_zeros())
} else if n[2] != 0 {
128 + (64 - n[2].leading_zeros())
} else if n[1] != 0 {
64 + (64 - n[1].leading_zeros())
} else {
64 - n[0].leading_zeros()
};
if bits == 0 {
return (0, [0u64; 4]);
}
let mut seed_out = [0u64; 3];
crate::algo_x_support::seed::sqrt_seed(&n, bits, &mut seed_out);
let mut x: u128 = if seed_out[2] != 0 {
u128::MAX
} else {
(seed_out[0] as u128) | ((seed_out[1] as u128) << 64)
};
if x == 0 {
x = 1;
}
debug_assert!(
x == u128::MAX || cmp_u256(sq_u256(x), n) >= 0,
"isqrt_u256 seed under-estimate: seed={x} n={n:?}"
);
loop {
let xsq = sq_u256(x);
if cmp_u256(xsq, n) <= 0 {
let (rem, _borrow) = sub_u256(n, xsq); return (x, rem);
}
let nx = div_u256_by_u128(n, x);
let y = avg_u128(x, nx);
x = if y >= x { x - 1 } else { y };
}
}
#[inline]
fn avg_u128(a: u128, b: u128) -> u128 {
(a & b) + ((a ^ b) >> 1)
}
#[inline]
fn cmp_u256(a: [u64; 4], b: [u64; 4]) -> i32 {
let mut i = 4;
while i > 0 {
i -= 1;
if a[i] != b[i] {
return if a[i] < b[i] { -1 } else { 1 };
}
}
0
}
#[inline]
#[must_use]
#[allow(dead_code)]
pub(crate) fn hypot_u128_fast<const N: usize>(a: Int<N>, b: Int<N>, mode: RoundingMode) -> Option<Int<N>>
where
Limbs<N>: ComputeLimbs,
{
let ma = a.unsigned_abs();
let mb = b.unsigned_abs();
let la = ma.as_limbs();
let lb = mb.as_limbs();
if fit_one(la) && fit_one(lb) {
let av = la[0] as u128;
let bv = lb[0] as u128;
let asq = av * av;
let bsq = bv * bv;
if let Some(n) = asq.checked_add(bsq) {
if n == 0 {
return Some(Int::<N>::ZERO);
}
let (q, rem) = isqrt_u128(n); return finish::<N>(q, rem != 0, rem > q, mode);
}
}
if fit_k(la, 2) && fit_k(lb, 2) {
let av = (la[0] as u128) | ((la[1] as u128) << 64);
let bv = (lb[0] as u128) | ((lb[1] as u128) << 64);
let (n, carry) = add_u256(sq_u256(av), sq_u256(bv));
if !carry {
if n == [0u64; 4] {
return Some(Int::<N>::ZERO);
}
let (q, rem) = isqrt_u256(n);
let rem_nonzero = rem != [0u64; 4];
let rem_gt_q = cmp_u256_u128(rem, q) > 0;
return finish::<N>(q, rem_nonzero, rem_gt_q, mode);
}
}
hypot_pythagoras::<N>(a, b, mode)
}
#[inline]
fn sub_u256(a: [u64; 4], b: [u64; 4]) -> ([u64; 4], bool) {
let mut out = [0u64; 4];
let mut borrow = 0i128;
let mut i = 0;
while i < 4 {
let d = a[i] as i128 - b[i] as i128 - borrow;
if d < 0 {
out[i] = (d + (1i128 << 64)) as u64;
borrow = 1;
} else {
out[i] = d as u64;
borrow = 0;
}
i += 1;
}
(out, borrow != 0)
}
#[inline]
fn cmp_u256_u128(a: [u64; 4], q: u128) -> i32 {
if a[3] != 0 || a[2] != 0 {
return 1; }
let lo = (a[0] as u128) | ((a[1] as u128) << 64);
if lo < q {
-1
} else if lo > q {
1
} else {
0
}
}
#[inline]
fn finish<const N: usize>(q: u128, diff_nonzero: bool, halfway_round_up: bool, mode: RoundingMode) -> Option<Int<N>> {
let bump = match mode {
RoundingMode::HalfToEven
| RoundingMode::HalfAwayFromZero
| RoundingMode::HalfTowardZero => halfway_round_up,
RoundingMode::Trunc | RoundingMode::Floor => false,
RoundingMode::Ceiling => diff_nonzero,
};
let (result, carried) = q.overflowing_add(bump as u128);
let hi = (result >> 64) as u64;
let lo = result as u64;
let top_overflow = carried as u128; match N {
1 => {
if hi != 0 || (lo >> 63) != 0 {
return None;
}
let mut out = [0u64; N];
out[0] = lo;
Some(Int::<N>::from_limbs(out))
}
2 => {
if top_overflow != 0 || (hi >> 63) != 0 {
return None;
}
let mut out = [0u64; N];
out[0] = lo;
out[1] = hi;
Some(Int::<N>::from_limbs(out))
}
_ => {
let mut out = [0u64; N];
out[0] = lo;
out[1] = hi;
if top_overflow != 0 {
out[2] = 1;
}
Some(Int::<N>::from_limbs(out))
}
}
}
#[cfg(test)]
mod tests {
use super::{cmp_u256, hypot_u128_fast, isqrt_u256, sq_u256};
use crate::int::algos::hypot::hypot_pythagoras::hypot_pythagoras;
use crate::int::types::Int;
use crate::support::rounding::RoundingMode;
const ALL_MODES: [RoundingMode; 6] = [
RoundingMode::HalfToEven,
RoundingMode::HalfAwayFromZero,
RoundingMode::HalfTowardZero,
RoundingMode::Trunc,
RoundingMode::Floor,
RoundingMode::Ceiling,
];
fn mix(s: &mut u64) -> u64 {
*s = s.wrapping_add(0x9E37_79B9_7F4A_7C15);
let mut z = *s;
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
z ^ (z >> 31)
}
#[test]
fn sq_u256_matches_reference() {
let mut s = 0x1234_5678_9ABC_DEF0u64;
assert_eq!(sq_u256(0), [0, 0, 0, 0]);
assert_eq!(sq_u256(1), [1, 0, 0, 0]);
assert_eq!(sq_u256(u64::MAX as u128), [1, u64::MAX - 1, 0, 0]);
assert_eq!(sq_u256(u128::MAX), [1, 0, u64::MAX - 1, u64::MAX]);
for _ in 0..5000 {
let x = (mix(&mut s) as u128) | ((mix(&mut s) as u128) << 64);
let q = sq_u256(x);
let low = (q[0] as u128) | ((q[1] as u128) << 64);
assert_eq!(low, x.wrapping_mul(x), "low128 x={x}");
assert_eq!(isqrt_u256(q), (x, [0u64; 4]), "roundtrip x={x}");
}
}
#[test]
fn isqrt_u256_exact_floor() {
let mut s = 0xDEAD_F00D_u64;
for _ in 0..3000 {
let r = ((mix(&mut s) as u128) | ((mix(&mut s) as u128) << 64)) >> (mix(&mut s) % 64);
if r == 0 {
continue;
}
let rsq = sq_u256(r);
assert_eq!(isqrt_u256(rsq), (r, [0u64; 4]), "perfect square r={r}");
let (rsq_plus, carry) = super::add_u256(rsq, [r as u64, (r >> 64) as u64, 0, 0]);
if !carry {
assert_eq!(
isqrt_u256(rsq_plus),
(r, [r as u64, (r >> 64) as u64, 0, 0]),
"r²+r, r={r}"
);
}
let (rsq_minus, borrow) = super::sub_u256(rsq, [1, 0, 0, 0]);
if !borrow {
assert_eq!(isqrt_u256(rsq_minus).0, r - 1, "r²-1, r={r}");
}
}
}
fn isqrt_u256_ref(n: [u64; 4]) -> u128 {
let mut root: u128 = 0;
let mut bit = 127i32;
while bit >= 0 {
let trial = root | (1u128 << bit);
if cmp_u256(sq_u256(trial), n) <= 0 {
root = trial;
}
bit -= 1;
}
root
}
#[test]
fn isqrt_u256_adversarial_perfect_square_tops() {
fn seed_of(n: [u64; 4], bits: u32) -> u128 {
let mut out = [0u64; 3];
crate::algo_x_support::seed::sqrt_seed(&n, bits, &mut out);
if out[2] != 0 {
u128::MAX
} else {
(out[0] as u128) | ((out[1] as u128) << 64)
}
}
fn bit_len_u256(n: [u64; 4]) -> u32 {
if n[3] != 0 {
192 + (64 - n[3].leading_zeros())
} else if n[2] != 0 {
128 + (64 - n[2].leading_zeros())
} else if n[1] != 0 {
64 + (64 - n[1].leading_zeros())
} else {
64 - n[0].leading_zeros()
}
}
let mut s = 0x5EED_A11A_C0DE_u64;
let mut ks: Vec<u64> = vec![1u64 << 31, (1u64 << 32) - 1];
for _ in 0..32 {
let k = (mix(&mut s) % (1u64 << 31)) + (1u64 << 31); ks.push(k);
}
ks.extend_from_slice(&[(1u64 << 31) + 1, (1u64 << 31) + 7, (1u64 << 32) - 2]);
for &k in &ks {
let top = (k as u128) * (k as u128); for shift in 64u32..=191 {
let mut n = [0u64; 4];
let limb = (shift / 64) as usize;
let off = shift % 64;
let lo = (top << off) as u128; let hi = if off == 0 { 0u128 } else { top >> (128 - off) };
let pieces = [lo as u64, (lo >> 64) as u64, hi as u64];
for (idx, &pc) in pieces.iter().enumerate() {
if limb + idx < 4 {
n[limb + idx] = pc;
}
}
let full_limbs = (shift / 64) as usize;
for l in n.iter_mut().take(full_limbs) {
*l = u64::MAX;
}
if off != 0 && full_limbs < 4 {
n[full_limbs] |= (1u64 << off) - 1;
}
let bits = bit_len_u256(n);
let seed = seed_of(n, bits);
assert!(
cmp_u256(sq_u256(seed), n) >= 0,
"seed under-estimate: k={k} shift={shift} seed={seed} n={n:?}"
);
let got = isqrt_u256(n).0;
let want = isqrt_u256_ref(n);
assert_eq!(got, want, "floor mismatch: k={k} shift={shift} n={n:?}");
}
}
}
#[test]
fn isqrt_u256_saturation_zone_floor_2pow128_minus_1() {
let max = u128::MAX; let maxsq = sq_u256(max); for d in [0u128, 1, 2, 3, 7, 1000, max] {
let (n, carry) = super::add_u256(maxsq, [d as u64, (d >> 64) as u64, 0, 0]);
assert!(!carry, "n must stay < 2^256 (d={d})");
let got = isqrt_u256(n).0;
let want = isqrt_u256_ref(n);
assert_eq!(want, max, "reference floor must be 2^128-1 (d={d})");
assert_eq!(got, want, "isqrt_u256 saturation-zone floor (d={d}) n={n:?}");
}
let mut s = 0xCAFE_5A7Eu64;
for _ in 0..500 {
let n = [mix(&mut s), mix(&mut s), mix(&mut s), mix(&mut s) | (1u64 << 63)];
assert_eq!(isqrt_u256(n).0, isqrt_u256_ref(n), "near-2^256 floor n={n:?}");
}
let big = {
let mut l = [0u64; 3];
l[0] = u64::MAX;
l[1] = u64::MAX; Int::<3>::from_limbs(l)
};
let one = Int::<3>::from_i64(1);
for mode in ALL_MODES {
assert_eq!(
hypot_u128_fast::<3>(big, one, mode),
hypot_pythagoras::<3>(big, one, mode),
"hypot saturation-zone mismatch mode={mode:?}"
);
}
}
fn diff_at<const N: usize>()
where
crate::int::types::compute_limbs::Limbs<N>: crate::int::types::compute_limbs::ComputeLimbs,
{
let mut s = 0xDEAD_BEEF_CAFE_F00D_u64 ^ (N as u64);
for _ in 0..400 {
let mut la = [0u64; N];
let mut lb = [0u64; N];
let shape = mix(&mut s) % 5;
match shape {
0 => {
la[0] = mix(&mut s) & 0xFFFF_FFFF;
lb[0] = mix(&mut s) & 0xFFFF_FFFF;
}
1 => {
la[0] = mix(&mut s);
lb[0] = mix(&mut s);
}
2 if N >= 2 => {
la[0] = mix(&mut s);
la[1] = mix(&mut s);
lb[0] = mix(&mut s);
lb[1] = mix(&mut s);
}
3 if N >= 2 => {
la[0] = mix(&mut s);
la[1] = mix(&mut s) | (1u64 << 63);
lb[0] = mix(&mut s);
lb[1] = mix(&mut s) | (1u64 << 63);
}
_ => {
for k in 0..N {
la[k] = mix(&mut s);
lb[k] = mix(&mut s);
}
la[N - 1] &= i64::MAX as u64;
lb[N - 1] &= i64::MAX as u64;
}
}
let a = Int::<N>::from_limbs(la);
let b = Int::<N>::from_limbs(lb);
for mode in ALL_MODES {
assert_eq!(
hypot_u128_fast::<N>(a, b, mode),
hypot_pythagoras::<N>(a, b, mode),
"N={N} mode={mode:?} a={:?} b={:?}",
a.as_limbs(),
b.as_limbs()
);
}
}
let checks: [(i64, i64); 6] = [(3, 4), (5, 12), (8, 15), (1, 1), (0, 0), (0, 42)];
for (av, bv) in checks {
let a = Int::<N>::from_i64(av);
let b = Int::<N>::from_i64(bv);
for mode in ALL_MODES {
assert_eq!(
hypot_u128_fast::<N>(a, b, mode),
hypot_pythagoras::<N>(a, b, mode),
"explicit N={N} mode={mode:?} a={av} b={bv}"
);
}
}
}
#[cfg(feature = "_wide-support")]
#[test]
fn hypot_u128_fast_matches_pythagoras() {
diff_at::<1>();
diff_at::<2>();
diff_at::<3>();
diff_at::<4>();
}
#[test]
fn hypot_u128_fast_two_limb_matches_pythagoras_n2() {
diff_at::<2>();
}
#[test]
fn hypot_u128_fast_perfect_square() {
let a = Int::<2>::from_i64(3);
let b = Int::<2>::from_i64(4);
for mode in ALL_MODES {
assert_eq!(hypot_u128_fast::<2>(a, b, mode).unwrap().as_i128(), 5);
}
}
}