use crate::int::algos::support::limbs::{add_assign, sub_assign};
use crate::int::algos::mul::mul_schoolbook::mul_schoolbook;
use crate::int::types::compute_limbs::{ComputeLimbs, Limb, Limbs};
const KARATSUBA_MAX_WIDTH: usize = 256;
pub(crate) const KARATSUBA_SCRATCH_LIMBS: usize =
karatsuba_scratch_needed_th(KARATSUBA_MAX_WIDTH, 4);
pub(crate) fn mul_karatsuba(a: &[u64], b: &[u64], out: &mut [u64], threshold: usize) {
debug_assert_eq!(a.len(), b.len());
debug_assert!(out.len() >= 2 * a.len());
debug_assert!(
karatsuba_scratch_needed_th(a.len(), threshold) <= KARATSUBA_SCRATCH_LIMBS,
"Karatsuba scratch overflow: n={} needs {} limbs, have {}",
a.len(),
karatsuba_scratch_needed_th(a.len(), threshold),
KARATSUBA_SCRATCH_LIMBS,
);
let mut scratch = [0u64; KARATSUBA_SCRATCH_LIMBS];
karatsuba_rec(a, b, out, &mut scratch, threshold);
}
#[cfg(feature = "bench-alt")]
pub(crate) fn mul_karatsuba_forced(a: &[u64], b: &[u64], out: &mut [u64], threshold: usize) {
debug_assert_eq!(a.len(), b.len());
debug_assert!(out.len() >= 2 * a.len());
debug_assert!(
karatsuba_scratch_needed_th(a.len(), threshold) <= KARATSUBA_SCRATCH_LIMBS,
"Karatsuba scratch overflow in forced bench entry"
);
for o in out.iter_mut() { *o = 0; }
let mut scratch = [0u64; KARATSUBA_SCRATCH_LIMBS];
karatsuba_rec(a, b, out, &mut scratch, threshold);
}
#[cfg(test)]
pub(crate) fn mul_karatsuba_with_threshold(
a: &[u64],
b: &[u64],
out: &mut [u64],
threshold: usize,
) {
debug_assert_eq!(a.len(), b.len());
debug_assert!(out.len() >= 2 * a.len());
let need = karatsuba_scratch_needed_th(a.len(), threshold);
let mut scratch = vec![0u64; need];
karatsuba_rec(a, b, out, &mut scratch, threshold);
}
pub(crate) fn mul_karatsuba_limb<const N: usize, L: Limb>(
a: &[u64; N],
b: &[u64; N],
out: &mut [u64],
threshold: usize,
) where
Limbs<N>: ComputeLimbs,
{
let h = L::packed_len(N);
debug_assert!(h > 0 && h <= N);
let mut ap = [L::ZERO; N];
let mut bp = [L::ZERO; N];
L::pack(a, &mut ap[..h]);
L::pack(b, &mut bp[..h]);
let ratio: usize = if h < N { 2 } else { 1 };
let threshold_packed = (threshold / ratio).max(4);
let mut work = L::karatsuba::<Limbs<N>>();
let work = work.as_mut();
let (prod, scratch) = work.split_at_mut(2 * h);
debug_assert!(
scratch.len() >= karatsuba_scratch_needed_th(h, threshold_packed),
"Karatsuba scratch overflow: h={}, threshold_packed={}, need={}, have={}",
h,
threshold_packed,
karatsuba_scratch_needed_th(h, threshold_packed),
scratch.len(),
);
karatsuba_rec_limb::<L>(&ap[..h], &bp[..h], &mut *prod, scratch, threshold_packed);
L::unpack(&prod[..2 * h], &mut out[..2 * N]);
}
pub(crate) const fn karatsuba_scratch_needed_th(n: usize, threshold: usize) -> usize {
if n < threshold {
return 0;
}
let h = n / 2;
let hi = n - h;
let level = 2 * h + 2 * hi + (hi + 1) + (hi + 1) + 2 * (hi + 1);
level + karatsuba_scratch_needed_th(hi + 1, threshold)
}
fn karatsuba_rec(a: &[u64], b: &[u64], out: &mut [u64], scratch: &mut [u64], threshold: usize) {
debug_assert!(threshold >= 4, "Karatsuba threshold must be >= 4 to terminate");
let n = a.len();
if n < threshold {
mul_schoolbook(a, b, out);
return;
}
let h = n / 2;
let hi = n - h;
let (a_lo, a_hi) = a.split_at(h);
let (b_lo, b_hi) = b.split_at(h);
let (z0, rest) = scratch.split_at_mut(2 * h);
let (z2, rest) = rest.split_at_mut(2 * hi);
let (sa, rest) = rest.split_at_mut(hi + 1);
let (sb, rest) = rest.split_at_mut(hi + 1);
let (z1, tail) = rest.split_at_mut(2 * (hi + 1));
for v in z0.iter_mut() { *v = 0; }
for v in z2.iter_mut() { *v = 0; }
for v in z1.iter_mut() { *v = 0; }
karatsuba_rec(a_lo, b_lo, z0, tail, threshold);
karatsuba_rec_unbalanced(a_hi, b_hi, z2, tail, threshold);
for v in sa.iter_mut() { *v = 0; }
for v in sb.iter_mut() { *v = 0; }
sa[..h].copy_from_slice(a_lo);
sb[..h].copy_from_slice(b_lo);
let _ = add_assign(sa, a_hi);
let _ = add_assign(sb, b_hi);
karatsuba_rec_unbalanced(sa, sb, z1, tail, threshold);
let _ = sub_assign(z1, z0);
let _ = sub_assign(z1, z2);
out[..z0.len()].copy_from_slice(z0);
let _ = add_assign(&mut out[2 * h..], z2);
let _ = add_assign(&mut out[h..], z1);
}
fn karatsuba_rec_unbalanced(
a: &[u64],
b: &[u64],
out: &mut [u64],
scratch: &mut [u64],
threshold: usize,
) {
debug_assert_eq!(a.len(), b.len());
if a.len() >= threshold {
karatsuba_rec(a, b, out, scratch, threshold);
} else {
for v in out.iter_mut() { *v = 0; }
mul_schoolbook(a, b, out);
}
}
#[inline]
pub(crate) fn schoolbook_rec_limb<L: Limb>(a: &[L], b: &[L], out: &mut [L]) {
let na = a.len();
let nb = b.len();
let mut i = 0;
while i < na {
let ai = a[i];
if ai != L::ZERO {
let mut carry = L::ZERO;
let mut j = 0;
while j < nb {
let (lo, hi) = ai.widening_mul(b[j]);
let idx = i + j;
let (s1, c1) = out[idx].overflowing_add(lo);
let (s2, c2) = s1.overflowing_add(carry);
out[idx] = s2;
carry = hi.add_carries(c1, c2);
j += 1;
}
let mut idx = i + nb;
while carry != L::ZERO && idx < out.len() {
let (s, c) = out[idx].overflowing_add(carry);
out[idx] = s;
carry = if c { L::ONE } else { L::ZERO };
idx += 1;
}
}
i += 1;
}
}
#[inline]
pub(crate) fn limb_add_assign<L: Limb>(a: &mut [L], b: &[L]) -> bool {
let mut carry = false;
let mut i = 0;
while i < a.len() {
let bv = if i < b.len() { b[i] } else { L::ZERO };
let (s1, c1) = a[i].overflowing_add(bv);
let (s2, c2) = s1.overflowing_add(if carry { L::ONE } else { L::ZERO });
a[i] = s2;
carry = c1 | c2;
i += 1;
}
carry
}
#[inline]
pub(crate) fn limb_sub_assign<L: Limb>(a: &mut [L], b: &[L]) -> bool {
let mut borrow = false;
let mut i = 0;
while i < a.len() {
let bv = if i < b.len() { b[i] } else { L::ZERO };
let (d1, b1) = a[i].overflowing_sub(bv);
let (d2, b2) = d1.overflowing_sub(if borrow { L::ONE } else { L::ZERO });
a[i] = d2;
borrow = b1 | b2;
i += 1;
}
borrow
}
fn karatsuba_rec_limb_unbalanced<L: Limb>(
a: &[L],
b: &[L],
out: &mut [L],
scratch: &mut [L],
threshold: usize,
) {
debug_assert_eq!(a.len(), b.len());
if a.len() >= threshold {
karatsuba_rec_limb::<L>(a, b, out, scratch, threshold);
} else {
for v in out.iter_mut() { *v = L::ZERO; }
schoolbook_rec_limb::<L>(a, b, out);
}
}
fn karatsuba_rec_limb<L: Limb>(
a: &[L],
b: &[L],
out: &mut [L],
scratch: &mut [L],
threshold: usize,
) {
debug_assert!(threshold >= 4);
let n = a.len();
if n < threshold {
schoolbook_rec_limb::<L>(a, b, out);
return;
}
let h = n / 2;
let hi = n - h;
let (a_lo, a_hi) = a.split_at(h);
let (b_lo, b_hi) = b.split_at(h);
let (z0, rest) = scratch.split_at_mut(2 * h);
let (z2, rest) = rest.split_at_mut(2 * hi);
let (sa, rest) = rest.split_at_mut(hi + 1);
let (sb, rest) = rest.split_at_mut(hi + 1);
let (z1, tail) = rest.split_at_mut(2 * (hi + 1));
for v in z0.iter_mut() { *v = L::ZERO; }
for v in z2.iter_mut() { *v = L::ZERO; }
for v in z1.iter_mut() { *v = L::ZERO; }
karatsuba_rec_limb::<L>(a_lo, b_lo, z0, tail, threshold);
karatsuba_rec_limb_unbalanced::<L>(a_hi, b_hi, z2, tail, threshold);
for v in sa.iter_mut() { *v = L::ZERO; }
for v in sb.iter_mut() { *v = L::ZERO; }
sa[..h].copy_from_slice(a_lo);
sb[..h].copy_from_slice(b_lo);
let _ = limb_add_assign::<L>(sa, a_hi);
let _ = limb_add_assign::<L>(sb, b_hi);
karatsuba_rec_limb_unbalanced::<L>(sa, sb, z1, tail, threshold);
let _ = limb_sub_assign::<L>(z1, z0);
let _ = limb_sub_assign::<L>(z1, z2);
out[..z0.len()].copy_from_slice(z0);
let _ = limb_add_assign::<L>(&mut out[2 * h..], z2);
let _ = limb_add_assign::<L>(&mut out[h..], z1);
}