use crate::int::types::compute_limbs::Limb;
#[inline]
pub(crate) fn sqr_low_limb<const N: usize, L: Limb>(x: &[u64; N], out: &mut [u64; N]) {
let h = L::packed_len(N);
let mut xp = [L::ZERO; N];
L::pack(x, &mut xp[..h]);
let mut acc = [L::ZERO; N];
let mut i = 0;
while i < h {
let ai = xp[i];
if ai != L::ZERO {
let mut carry = L::ZERO;
let mut j = i + 1;
while i + j < h {
let idx = i + j;
let (lo, hi) = ai.widening_mul(xp[j]);
let (s1, c1) = acc[idx].overflowing_add(lo);
let (s2, c2) = s1.overflowing_add(carry);
acc[idx] = s2;
carry = hi.add_carries(c1, c2);
j += 1;
}
}
i += 1;
}
let mut dcarry = L::ZERO;
let mut k = 0;
while k < h {
let (s1, c1) = acc[k].overflowing_add(acc[k]);
let (s2, c2) = s1.overflowing_add(dcarry);
acc[k] = s2;
dcarry = if c1 || c2 { L::ONE } else { L::ZERO };
k += 1;
}
let mut i = 0;
while 2 * i < h {
let pos = 2 * i;
let (lo, hi) = xp[i].widening_mul(xp[i]);
let (s, mut prop) = acc[pos].overflowing_add(lo);
acc[pos] = s;
let mut p = pos + 1;
if p < h {
let (s1, c1) = acc[p].overflowing_add(hi);
let (s2, c2) = s1.overflowing_add(if prop { L::ONE } else { L::ZERO });
acc[p] = s2;
prop = c1 || c2;
p += 1;
while prop && p < h {
let (s3, c3) = acc[p].overflowing_add(L::ONE);
acc[p] = s3;
prop = c3;
p += 1;
}
}
i += 1;
}
L::unpack(&acc[..h], out);
}
#[cfg(test)]
mod tests {
use super::sqr_low_limb;
use crate::int::algos::sqr::sqr_low_fixed::sqr_low_fixed;
fn diff_at<const N: usize>(seeds: &[u64]) {
for &seed in seeds {
let mut x = [0u64; N];
let mut s = seed;
for limb in x.iter_mut() {
s = s.wrapping_add(0x9E37_79B9_7F4A_7C15);
let mut z = s;
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
*limb = z ^ (z >> 31);
}
let mut want = [0u64; N];
sqr_low_fixed::<N>(&x, &mut want);
let mut got_u64 = [0u64; N];
sqr_low_limb::<N, u64>(&x, &mut got_u64);
assert_eq!(got_u64, want, "u64 N={N} seed={seed:#x}");
let mut got_u128 = [0u64; N];
sqr_low_limb::<N, u128>(&x, &mut got_u128);
assert_eq!(got_u128, want, "u128 N={N} seed={seed:#x}");
}
}
fn all_ones_at<const N: usize>() {
let x = [u64::MAX; N];
let mut want = [0u64; N];
sqr_low_fixed::<N>(&x, &mut want);
let mut got_u128 = [0u64; N];
sqr_low_limb::<N, u128>(&x, &mut got_u128);
assert_eq!(got_u128, want, "u128 all-ones N={N}");
let mut got_u64 = [0u64; N];
sqr_low_limb::<N, u64>(&x, &mut got_u64);
assert_eq!(got_u64, want, "u64 all-ones N={N}");
}
#[test]
fn sqr_low_limb_matches_sqr_low_fixed_even_widths() {
let seeds: [u64; 8] = [0, 1, 2, 3, 0xDEAD_BEEF, 0xFFFF_FFFF_FFFF_FFFF, 7, 0x1357_9BDF];
diff_at::<2>(&seeds);
diff_at::<4>(&seeds);
diff_at::<6>(&seeds);
diff_at::<8>(&seeds);
diff_at::<16>(&seeds);
diff_at::<32>(&seeds);
diff_at::<64>(&seeds);
all_ones_at::<2>();
all_ones_at::<4>();
all_ones_at::<8>();
all_ones_at::<16>();
all_ones_at::<32>();
all_ones_at::<64>();
}
}