use crate::int::algos::mul::mul_karatsuba::{
limb_add_assign, limb_sub_assign, schoolbook_rec_limb,
};
use crate::int::types::compute_limbs::Limb;
pub(crate) const TOOM3_SCRATCH_LIMBS: usize = 1024;
pub(crate) const TOOM3_BASE_THRESHOLD: usize = 9;
pub(crate) const fn toom3_scratch_needed(n: usize, threshold: usize) -> usize {
if n < threshold {
return 0;
}
let k = n.div_ceil(3) + 1;
26 * k + toom3_scratch_needed(k, threshold)
}
pub(crate) fn mul_toom3(a: &[u64], b: &[u64], out: &mut [u64]) {
debug_assert_eq!(a.len(), b.len());
debug_assert!(out.len() >= 2 * a.len());
debug_assert!(
toom3_scratch_needed(a.len(), TOOM3_BASE_THRESHOLD) <= TOOM3_SCRATCH_LIMBS,
"Toom-3 scratch overflow: n={} needs {} limbs, have {}",
a.len(),
toom3_scratch_needed(a.len(), TOOM3_BASE_THRESHOLD),
TOOM3_SCRATCH_LIMBS,
);
let mut scratch = [0u64; TOOM3_SCRATCH_LIMBS];
toom3_rec_limb::<u64>(a, b, out, &mut scratch, TOOM3_BASE_THRESHOLD, inv3_of::<u64>());
}
#[cfg(any(test, feature = "bench-alt"))]
pub(crate) fn mul_toom3_limb<const N: usize, L: Limb>(a: &[u64; N], b: &[u64; N], out: &mut [u64]) {
let h = L::packed_len(N);
debug_assert!(h > 0 && h <= N);
let mut ap = [L::ZERO; N];
let mut bp = [L::ZERO; N];
L::pack(a, &mut ap[..h]);
L::pack(b, &mut bp[..h]);
let ratio: usize = if h < N { 2 } else { 1 };
let threshold_packed = (TOOM3_BASE_THRESHOLD / ratio).max(3);
const BENCH_PROD: usize = 2 * 256;
const BENCH_SCRATCH: usize = 3584; let mut prod = [L::ZERO; BENCH_PROD];
let mut scratch = [L::ZERO; BENCH_SCRATCH];
debug_assert!(2 * h <= prod.len());
debug_assert!(
toom3_scratch_needed(h, threshold_packed) <= BENCH_SCRATCH,
"Toom-3 limb scratch overflow: h={} threshold={} needs {} > {}",
h,
threshold_packed,
toom3_scratch_needed(h, threshold_packed),
BENCH_SCRATCH,
);
for v in prod[..2 * h].iter_mut() {
*v = L::ZERO;
}
toom3_rec_limb::<L>(
&ap[..h],
&bp[..h],
&mut prod[..2 * h],
&mut scratch,
threshold_packed,
inv3_of::<L>(),
);
L::unpack(&prod[..2 * h], &mut out[..2 * N]);
}
#[cfg(test)]
pub(crate) fn mul_toom3_with_threshold(a: &[u64], b: &[u64], out: &mut [u64], threshold: usize) {
debug_assert_eq!(a.len(), b.len());
debug_assert!(out.len() >= 2 * a.len());
let need = toom3_scratch_needed(a.len(), threshold);
let mut scratch = vec![0u64; need.max(1)];
toom3_rec_limb::<u64>(a, b, out, &mut scratch, threshold, inv3_of::<u64>());
}
fn toom3_rec_limb<L: Limb>(
a: &[L],
b: &[L],
out: &mut [L],
scratch: &mut [L],
threshold: usize,
inv: L,
) {
let n = a.len();
debug_assert_eq!(n, b.len());
debug_assert!(out.len() >= 2 * n);
debug_assert!(threshold >= 3);
if n < threshold {
schoolbook_rec_limb::<L>(a, b, out);
return;
}
let k = n.div_ceil(3);
let k2 = (2 * k).min(n);
let a0 = &a[..k];
let a1 = &a[k..k2];
let a2 = &a[k2..];
let b0 = &b[..k];
let b1 = &b[k..k2];
let b2 = &b[k2..];
let ew = k + 1; let pw = 2 * ew;
let level_need = 10 * ew + 8 * pw;
let (level_buf, child_scratch) = scratch.split_at_mut(level_need);
for v in level_buf.iter_mut() {
*v = L::ZERO;
}
let (pa0, r) = level_buf.split_at_mut(ew);
let (pa1, r) = r.split_at_mut(ew);
let (pam, r) = r.split_at_mut(ew);
let (pa2, r) = r.split_at_mut(ew);
let (pai, r) = r.split_at_mut(ew);
let (pb0, r) = r.split_at_mut(ew);
let (pb1, r) = r.split_at_mut(ew);
let (pbm, r) = r.split_at_mut(ew);
let (pb2, r) = r.split_at_mut(ew);
let (pbi, r) = r.split_at_mut(ew);
let (w0, r) = r.split_at_mut(pw);
let (w1, r) = r.split_at_mut(pw);
let (wm, r) = r.split_at_mut(pw);
let (w2, r) = r.split_at_mut(pw);
let (wi, r) = r.split_at_mut(pw);
let (t1, r) = r.split_at_mut(pw);
let (t2, r) = r.split_at_mut(pw);
let (tmp, _) = r.split_at_mut(pw);
pa0[..a0.len()].copy_from_slice(a0);
pai[..a2.len()].copy_from_slice(a2);
pa1[..a0.len()].copy_from_slice(a0);
let _ = limb_add_assign(pa1, a1);
let _ = limb_add_assign(pa1, a2);
pam[..a0.len()].copy_from_slice(a0);
let _ = limb_add_assign(pam, a2);
let am_neg = signed_sub_limb(pam, a1);
pa2[..a0.len()].copy_from_slice(a0);
let _ = limb_add_assign(pa2, a1);
let _ = limb_add_assign(pa2, a1); tmp[..a2.len()].copy_from_slice(a2);
shl_inplace_limb(&mut tmp[..ew], 2); let _ = limb_add_assign(pa2, &tmp[..ew]);
pb0[..b0.len()].copy_from_slice(b0);
pbi[..b2.len()].copy_from_slice(b2);
pb1[..b0.len()].copy_from_slice(b0);
let _ = limb_add_assign(pb1, b1);
let _ = limb_add_assign(pb1, b2);
pbm[..b0.len()].copy_from_slice(b0);
let _ = limb_add_assign(pbm, b2);
let bm_neg = signed_sub_limb(pbm, b1);
pb2[..b0.len()].copy_from_slice(b0);
let _ = limb_add_assign(pb2, b1);
let _ = limb_add_assign(pb2, b1); for v in tmp.iter_mut() {
*v = L::ZERO;
}
tmp[..b2.len()].copy_from_slice(b2);
shl_inplace_limb(&mut tmp[..ew], 2); let _ = limb_add_assign(pb2, &tmp[..ew]);
toom3_rec_limb::<L>(pa0, pb0, w0, child_scratch, threshold, inv);
toom3_rec_limb::<L>(pa1, pb1, w1, child_scratch, threshold, inv);
toom3_rec_limb::<L>(pam, pbm, wm, child_scratch, threshold, inv);
toom3_rec_limb::<L>(pa2, pb2, w2, child_scratch, threshold, inv);
toom3_rec_limb::<L>(pai, pbi, wi, child_scratch, threshold, inv);
let wm_neg = am_neg ^ bm_neg;
let t1_neg;
let t2_neg;
if !wm_neg {
t1[..pw].copy_from_slice(w1);
let _ = limb_add_assign(t1, wm);
t1_neg = false;
t2[..pw].copy_from_slice(w1);
t2_neg = signed_sub_limb(t2, wm); } else {
t1[..pw].copy_from_slice(w1);
t1_neg = signed_sub_limb(t1, wm); t2[..pw].copy_from_slice(w1);
let _ = limb_add_assign(t2, wm);
t2_neg = false;
}
shr1_limb(t1); shr1_limb(t2); let t1_neg = signed_sub_signed_limb(t1, t1_neg, w0, false);
let t1_neg = signed_sub_signed_limb(t1, t1_neg, wi, false);
debug_assert!(!t1_neg, "c2 (t1) negative -- interpolation error (step A)");
debug_assert!(!t2_neg, "t2=c1+c3 negative -- interpolation error (step A)");
let went_neg = signed_sub_limb(w2, w0);
debug_assert!(!went_neg, "Toom-3: r2 < r0 -- invariant violated for unsigned inputs");
shr1_limb(w2); tmp[..pw].copy_from_slice(t1); shl_inplace_limb(tmp, 1); let s_neg = signed_sub_signed_limb(w2, false, tmp, t1_neg);
tmp[..pw].copy_from_slice(wi);
shl_inplace_limb(tmp, 3); let s_neg = signed_sub_signed_limb(w2, s_neg, tmp, false);
debug_assert!(!s_neg, "s=c1+4c3 negative -- interpolation error (step B)");
tmp[..pw].copy_from_slice(w2); let c3_neg = signed_sub_signed_limb(tmp, s_neg, t2, t2_neg); debug_assert!(!c3_neg, "3*c3 negative -- interpolation error (step C)");
let c3_neg = div3_limb(tmp, c3_neg, inv);
w2[..pw].copy_from_slice(t2); let c1_neg = signed_sub_signed_limb(w2, t2_neg, tmp, c3_neg); debug_assert!(!c1_neg, "c1 negative -- interpolation error (step D)");
debug_assert!(!t1_neg, "c2 negative -- final check");
debug_assert!(!c3_neg, "c3 negative -- final check");
add_into_out_limb(out, 0, w0); add_into_out_limb(out, k, w2); add_into_out_limb(out, 2 * k, t1); add_into_out_limb(out, 3 * k, tmp); add_into_out_limb(out, 4 * k, wi); }
#[inline]
fn cmp_cross_limb<L: Limb>(a: &[L], b: &[L]) -> i32 {
let la = a.len();
let lb = b.len();
let top = la.max(lb);
let mut i = top;
while i > 0 {
i -= 1;
let av = if i < la { a[i] } else { L::ZERO };
let bv = if i < lb { b[i] } else { L::ZERO };
if av > bv {
return 1;
}
if av < bv {
return -1;
}
}
0
}
#[inline]
fn rsub_assign_limb<L: Limb>(dst: &mut [L], src: &[L]) {
let mut borrow = false;
let mut i = 0;
while i < dst.len() {
let sv = if i < src.len() { src[i] } else { L::ZERO };
let (d1, b1) = sv.overflowing_sub(dst[i]);
let (d2, b2) = d1.overflowing_sub(if borrow { L::ONE } else { L::ZERO });
dst[i] = d2;
borrow = b1 | b2;
i += 1;
}
}
#[inline]
fn signed_sub_limb<L: Limb>(dst: &mut [L], src: &[L]) -> bool {
if cmp_cross_limb(dst, src) >= 0 {
let _ = limb_sub_assign(dst, src);
false
} else {
rsub_assign_limb(dst, src); true
}
}
#[inline]
fn signed_sub_signed_limb<L: Limb>(a: &mut [L], a_neg: bool, b: &[L], b_neg: bool) -> bool {
signed_add_signed_limb(a, a_neg, b, !b_neg)
}
#[inline]
fn signed_add_signed_limb<L: Limb>(a: &mut [L], a_neg: bool, b: &[L], b_neg: bool) -> bool {
if a_neg == b_neg {
let _ = limb_add_assign(a, b);
a_neg
} else if cmp_cross_limb(a, b) >= 0 {
let _ = limb_sub_assign(a, b);
a_neg
} else {
rsub_assign_limb(a, b); b_neg
}
}
#[inline]
fn shl_inplace_limb<L: Limb>(a: &mut [L], shift: u32) {
if shift == 0 {
return;
}
let rshift = L::BITS - shift;
let mut carry = L::ZERO;
for limb in a.iter_mut() {
let out_carry = limb.wrapping_shr(rshift);
*limb = limb.wrapping_shl(shift).overflowing_add(carry).0;
carry = out_carry;
}
}
#[inline]
fn shr1_limb<L: Limb>(a: &mut [L]) {
let mut carry = L::ZERO; for limb in a.iter_mut().rev() {
let new_carry = limb.wrapping_shl(L::BITS - 1); *limb = limb.wrapping_shr(1).overflowing_add(carry).0;
carry = new_carry;
}
debug_assert!(carry == L::ZERO, "shr1: odd value -- must be even for Toom-3");
}
#[inline]
fn inv3_of<L: Limb>() -> L {
let three = L::ONE.overflowing_add(L::ONE).0.overflowing_add(L::ONE).0;
let two = L::ONE.overflowing_add(L::ONE).0;
let mut inv = three; let mut step = 0;
while step < 6 {
let t = two.overflowing_sub(three.widening_mul(inv).0).0;
inv = inv.widening_mul(t).0;
step += 1;
}
inv
}
#[inline]
fn div3_limb<L: Limb>(a: &mut [L], neg: bool, inv: L) -> bool {
let three = L::ONE.overflowing_add(L::ONE).0.overflowing_add(L::ONE).0;
let mut borrow = L::ZERO;
for limb in a.iter_mut() {
let (s, under) = limb.overflowing_sub(borrow);
let q = s.widening_mul(inv).0; *limb = q;
let hi = q.widening_mul(three).1;
borrow = hi.overflowing_add(if under { L::ONE } else { L::ZERO }).0;
}
debug_assert!(borrow == L::ZERO, "div3: not divisible by 3 -- Toom-3 interpolation error");
neg
}
#[inline]
fn add_into_out_limb<L: Limb>(out: &mut [L], offset: usize, src: &[L]) {
if offset < out.len() {
let _ = limb_add_assign(&mut out[offset..], src);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::int::algos::mul::mul_schoolbook::mul_schoolbook;
fn fill(n: usize, seed: u64) -> Vec<u64> {
let mut out = vec![0u64; n];
let mut state = seed;
for x in out.iter_mut() {
state = state.wrapping_add(0x9E37_79B9_7F4A_7C15);
let mut z = state;
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
*x = z ^ (z >> 31);
}
out
}
fn schoolbook_ref(a: &[u64], b: &[u64]) -> Vec<u64> {
let n = a.len().max(b.len());
let mut a2 = vec![0u64; n];
let mut b2 = vec![0u64; n];
a2[..a.len()].copy_from_slice(a);
b2[..b.len()].copy_from_slice(b);
let mut out = vec![0u64; 2 * n];
mul_schoolbook(&a2, &b2, &mut out);
out
}
fn edge_cases(n: usize) -> Vec<Vec<u64>> {
let mut v = vec![vec![0u64; n], vec![u64::MAX; n]];
let mut lo = vec![0u64; n];
lo[0] = u64::MAX;
v.push(lo);
if n > 1 {
let mut hi = vec![0u64; n];
hi[n - 1] = 1;
v.push(hi);
}
if n > 2 {
let mut mid = vec![0u64; n];
mid[n / 2] = u64::MAX;
v.push(mid);
}
v
}
#[test]
fn toom3_bit_identical_to_schoolbook() {
const WIDTHS: &[usize] = &[
3, 4, 5, 6, 7, 8, 9, 10, 12, 15, 16, 18, 21, 24, 27, 32, 33, 36, 48, 51, 60, 63, 64,
];
const THRESHOLDS: &[usize] = &[3, 6, 9];
for &n in WIDTHS {
for &th in THRESHOLDS {
for a in edge_cases(n) {
for b in edge_cases(n) {
let expected = schoolbook_ref(&a, &b);
let mut got = vec![0u64; 2 * n];
mul_toom3_with_threshold(&a, &b, &mut got, th);
assert_eq!(got, expected, "mismatch (edge) n={n} th={th}");
let mut got2 = vec![0u64; 2 * n];
mul_toom3_with_threshold(&b, &a, &mut got2, th);
assert_eq!(got2, expected, "not commutative n={n} th={th}");
}
}
for seed in [1u64, 3, 7, 13, 42, 1337, 0xDEAD_BEEF, 0xCAFE_F00D] {
let a = fill(n, seed);
let b = fill(n, seed.wrapping_add(0x1234_5678));
let expected = schoolbook_ref(&a, &b);
let mut got = vec![0u64; 2 * n];
mul_toom3_with_threshold(&a, &b, &mut got, th);
assert_eq!(got, expected, "mismatch (random) n={n} th={th} seed={seed}");
}
}
}
}
#[test]
fn toom3_limb_u128_bit_identical() {
macro_rules! check {
($N:literal) => {{
const N: usize = $N;
let mut ops: Vec<([u64; N], [u64; N])> = Vec::new();
ops.push(([0u64; N], [u64::MAX; N]));
ops.push(([u64::MAX; N], [u64::MAX; N]));
let mut lo = [0u64; N];
lo[0] = u64::MAX;
ops.push((lo, [u64::MAX; N]));
for seed in [1u64, 7, 42, 0xDEAD_BEEF, 0xCAFE_F00D, 1009] {
let a = fill(N, seed);
let b = fill(N, seed.wrapping_add(0x9999));
let mut aa = [0u64; N];
let mut bb = [0u64; N];
aa.copy_from_slice(&a);
bb.copy_from_slice(&b);
ops.push((aa, bb));
}
for (a, b) in ops {
let expected = schoolbook_ref(&a, &b);
let mut g_u64 = vec![0u64; 2 * N];
super::mul_toom3_limb::<N, u64>(&a, &b, &mut g_u64);
assert_eq!(g_u64, expected, "toom3_limb u64 mismatch N={}", N);
let mut g_u128 = vec![0u64; 2 * N];
super::mul_toom3_limb::<N, u128>(&a, &b, &mut g_u128);
assert_eq!(g_u128, expected, "toom3_limb u128 mismatch N={}", N);
}
}};
}
check!(18);
check!(24);
check!(32);
check!(48);
check!(64);
}
#[test]
fn toom3_max_width_fixed_scratch() {
assert!(
toom3_scratch_needed(64, TOOM3_BASE_THRESHOLD) <= TOOM3_SCRATCH_LIMBS,
"scratch too small for n=64 threshold={}",
TOOM3_BASE_THRESHOLD,
);
let a = fill(64, 0xABCD_EF01_2345_6789);
let b = fill(64, 0xFEDC_BA98_7654_3210);
let expected = schoolbook_ref(&a, &b);
let mut got = vec![0u64; 128];
mul_toom3(&a, &b, &mut got);
assert_eq!(got, expected, "n=64 fixed-scratch mismatch");
}
}