use crate::field::{barrett_reduce, fqadd, fqmul, fqsub, montgomery_reduce, Fe, MONT_R};
use crate::params::{N, Q};
pub const ZETAS: [Fe; 128] = [
1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797, 2786, 3260, 569, 1746, 296,
2447, 1339, 1476, 3046, 56, 2240, 1333, 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, 289, 331,
3253, 1756, 1197, 2304, 2277, 2055, 650, 1977, 2513, 632, 2865, 33, 1320, 1915, 2319, 1435,
807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648, 2474, 3110, 1227, 910, 17, 2761, 583,
2649, 1637, 723, 2288, 1100, 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, 1703, 1651, 2789,
1789, 1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641, 1584, 2298, 2037,
3220, 375, 2549, 2090, 1645, 1063, 319, 2773, 757, 2099, 561, 2466, 2594, 2804, 1092, 403,
1026, 1143, 2150, 2775, 886, 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154,
];
pub const ZETAS_MONT: [Fe; 128] = {
let mut z = [0u16; 128];
let mut i = 0;
while i < 128 {
z[i] = ((ZETAS[i] as u32 * MONT_R as u32) % Q as u32) as u16;
i += 1;
}
z
};
pub const F_MONT: Fe = 512;
pub const GAMMAS: [Fe; 128] = {
let mut g = [0u16; 128];
let mut i = 0;
while i < 128 {
let zpr = ZETAS[i] as u64;
g[i] = ((zpr * zpr * 17) % crate::params::Q as u64) as u16;
i += 1;
}
g
};
pub const GAMMAS_MONT: [Fe; 128] = {
let mut g = [0u16; 128];
let mut i = 0;
while i < 128 {
g[i] = ((GAMMAS[i] as u32 * MONT_R as u32) % Q as u32) as u16;
i += 1;
}
g
};
pub fn ntt_forward(a: &mut [Fe; N]) {
let mut k: usize = 1;
for &len in &[128usize, 64, 32, 16, 8, 4, 2] {
let mut start = 0;
while start < 256 {
let zeta_mont = ZETAS_MONT[k];
k += 1;
for j in start..(start + len) {
let t = montgomery_reduce(a[j + len] as i32 * zeta_mont as i32);
a[j + len] = fqsub(a[j], t);
a[j] = fqadd(a[j], t);
}
start += 2 * len;
}
}
}
pub fn ntt_inverse(a: &mut [Fe; N]) {
let mut k: usize = 127;
for &len in &[2usize, 4, 8, 16, 32, 64, 128] {
let mut start = 0;
while start < 256 {
let zeta_mont = ZETAS_MONT[k];
k = k.wrapping_sub(1);
for j in start..(start + len) {
let tmp = a[j];
a[j] = barrett_reduce(tmp as i32 + a[j + len] as i32);
a[j + len] = fqsub(a[j + len], tmp);
a[j + len] = montgomery_reduce(a[j + len] as i32 * zeta_mont as i32);
}
start += 2 * len;
}
}
for coef in a.iter_mut() {
*coef = montgomery_reduce(*coef as i32 * F_MONT as i32);
}
}
#[inline]
pub fn basemul(a0: Fe, a1: Fe, b0: Fe, b1: Fe, gamma_mont: Fe) -> (Fe, Fe) {
let prod_a1b1 = fqmul(a1, b1);
let gamma_term = montgomery_reduce(prod_a1b1 as i32 * gamma_mont as i32);
let c0 = gamma_term as i32 + fqmul(a0, b0) as i32;
let c0 = barrett_reduce(c0);
let c1 = fqmul(a0, b1) as i32 + fqmul(a1, b0) as i32;
let c1 = barrett_reduce(c1);
(c0, c1)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::params::Q;
#[test]
fn ntt_roundtrip() {
let mut a: [Fe; N] = [0; N];
for (i, x) in a.iter_mut().enumerate() {
*x = (i as u16) % 3329;
}
let orig = a;
ntt_forward(&mut a);
ntt_inverse(&mut a);
for i in 0..N {
assert_eq!(a[i], orig[i], "mismatch at {i}");
}
}
fn naive_mul(a: &[Fe; N], b: &[Fe; N]) -> [Fe; N] {
let mut out = [0u32; 2 * N];
for i in 0..N {
for j in 0..N {
out[i + j] = (out[i + j] + (a[i] as u32) * (b[j] as u32)) % Q as u32;
}
}
let mut r = [0u16; N];
for i in 0..N {
let v = (out[i] + Q as u32 - (out[N + i] % Q as u32)) % Q as u32;
r[i] = v as u16;
}
r
}
#[test]
fn ntt_mul_matches_naive() {
let mut a = [0u16; N];
let mut b = [0u16; N];
for i in 0..N {
a[i] = ((i * 7 + 3) % Q as usize) as u16;
b[i] = ((i * 13 + 5) % Q as usize) as u16;
}
let expected = naive_mul(&a, &b);
let mut af = a;
let mut bf = b;
ntt_forward(&mut af);
ntt_forward(&mut bf);
use crate::poly::PolyNtt;
let ap = PolyNtt(af);
let bp = PolyNtt(bf);
let cp = ap.basemul(&bp);
let mut cc = cp.0;
ntt_inverse(&mut cc);
for i in 0..N {
assert_eq!(cc[i], expected[i], "mismatch at {i}");
}
}
}