const KARA_THRESHOLD: usize = 48;
fn poly_mul_schoolbook(c: &mut [u16], a: &[u16], b: &[u16]) {
debug_assert_eq!(c.len(), a.len() + b.len() - 1);
for slot in c.iter_mut() {
*slot = 0;
}
for (i, &ai) in a.iter().enumerate() {
for (j, &bj) in b.iter().enumerate() {
c[i + j] = c[i + j].wrapping_add(ai.wrapping_mul(bj));
}
}
}
fn poly_mul_kara(c: &mut [u16], a: &[u16], b: &[u16]) {
let n = a.len();
debug_assert_eq!(b.len(), n);
debug_assert_eq!(c.len(), 2 * n - 1);
if n <= KARA_THRESHOLD {
poly_mul_schoolbook(c, a, b);
return;
}
let m = n.div_ceil(2); let (a_lo, a_hi) = a.split_at(m);
let (b_lo, b_hi) = b.split_at(m);
let h = a_hi.len();
let mut a_sum = vec![0u16; m];
let mut b_sum = vec![0u16; m];
a_sum.copy_from_slice(a_lo);
b_sum.copy_from_slice(b_lo);
for i in 0..h {
a_sum[i] = a_sum[i].wrapping_add(a_hi[i]);
b_sum[i] = b_sum[i].wrapping_add(b_hi[i]);
}
let mut z0 = vec![0u16; 2 * m - 1];
let mut z2 = vec![0u16; 2 * h - 1];
let mut z1 = vec![0u16; 2 * m - 1];
poly_mul_kara(&mut z0, a_lo, b_lo);
poly_mul_kara(&mut z2, a_hi, b_hi);
poly_mul_kara(&mut z1, &a_sum, &b_sum);
for i in 0..z0.len() {
z1[i] = z1[i].wrapping_sub(z0[i]);
}
for i in 0..z2.len() {
z1[i] = z1[i].wrapping_sub(z2[i]);
}
for slot in c.iter_mut() {
*slot = 0;
}
for i in 0..z0.len() {
c[i] = c[i].wrapping_add(z0[i]);
}
for i in 0..z1.len() {
c[i + m] = c[i + m].wrapping_add(z1[i]);
}
for i in 0..z2.len() {
c[i + 2 * m] = c[i + 2 * m].wrapping_add(z2[i]);
}
}
pub(crate) fn poly_mul_cyclic(r: &mut [u16], a: &[u16], b: &[u16]) {
let n = a.len();
debug_assert_eq!(b.len(), n);
debug_assert_eq!(r.len(), n);
debug_assert!(n >= 2);
let mut c = vec![0u16; 2 * n - 1];
poly_mul_kara(&mut c, a, b);
for i in 0..n - 1 {
r[i] = c[i].wrapping_add(c[i + n]);
}
r[n - 1] = c[n - 1];
}
#[cfg(test)]
mod tests {
use super::*;
fn poly_mul_cyclic_ref(r: &mut [u16], a: &[u16], b: &[u16]) {
let n = a.len();
for slot in r.iter_mut() {
*slot = 0;
}
for k in 0..n {
let mut acc: u16 = 0;
for i in 1..n - k {
acc = acc.wrapping_add(a[k + i].wrapping_mul(b[n - i]));
}
for i in 0..=k {
acc = acc.wrapping_add(a[k - i].wrapping_mul(b[i]));
}
r[k] = acc;
}
}
fn check(n: usize, seed: u32) {
let mut s: u32 = seed.wrapping_add(0x9E3779B9);
let mut next = || {
s = s.wrapping_mul(1664525).wrapping_add(1013904223);
(s >> 16) as u16
};
let a: Vec<u16> = (0..n).map(|_| next()).collect();
let b: Vec<u16> = (0..n).map(|_| next()).collect();
let mut got = vec![0u16; n];
let mut want = vec![0u16; n];
poly_mul_cyclic(&mut got, &a, &b);
poly_mul_cyclic_ref(&mut want, &a, &b);
assert_eq!(got, want, "mismatch at n = {n}");
}
#[test]
fn matches_reference_at_threshold_and_above() {
for n in [2, 7, 32, 47, 48, 49, 64, 100, 509, 677, 701, 821] {
check(n, n as u32);
}
}
#[test]
fn handles_zero_input() {
let n = 256;
let a = vec![0u16; n];
let b: Vec<u16> = (0..n as u16).collect();
let mut got = vec![0u16; n];
poly_mul_cyclic(&mut got, &a, &b);
assert!(got.iter().all(|&c| c == 0));
}
}