use std::marker::PhantomData;
use crate::reference::ntt120::primes::PrimeSet;
use poulpy_hal::alloc_aligned;
#[derive(Clone, Debug)]
pub struct NttStepMeta {
pub q2bs: [u64; 4],
pub bs: u64,
pub half_bs: u64,
pub mask: u64,
pub reduce: bool,
}
#[derive(Clone, Debug)]
pub struct NttReducMeta {
pub modulo_red_cst: [u64; 4],
pub mask: u64,
pub h: u64,
}
pub struct NttTable<P: PrimeSet> {
pub n: usize,
pub level_metadata: Vec<NttStepMeta>,
pub powomega: Vec<u64>,
pub reduc_metadata: NttReducMeta,
pub input_bit_size: u64,
pub output_bit_size: u64,
_phantom: PhantomData<P>,
}
pub struct NttTableInv<P: PrimeSet> {
pub n: usize,
pub level_metadata: Vec<NttStepMeta>,
pub powomega: Vec<u64>,
pub reduc_metadata: NttReducMeta,
pub input_bit_size: u64,
pub output_bit_size: u64,
_phantom: PhantomData<P>,
}
pub fn modq_pow(x: u32, n: i64, q: u32) -> u32 {
let qm1 = (q - 1) as i64;
let np = ((n % qm1) + qm1) % qm1;
let mut np = np as u64;
let mut val_pow = x as u64;
let q64 = q as u64;
let mut res: u64 = 1;
while np != 0 {
if np & 1 != 0 {
res = (res * val_pow) % q64;
}
val_pow = (val_pow * val_pow) % q64;
np >>= 1;
}
res as u32
}
fn fill_omegas<P: PrimeSet>(n: usize) -> [u32; 4] {
debug_assert!((1..=(1 << 16)).contains(&n), "n must be a power of two in [1, 2^16], got {n}");
std::array::from_fn(|k| modq_pow(P::OMEGA[k], (1i64 << 16) / n as i64, P::Q[k]))
}
fn fill_reduction_meta<P: PrimeSet>(bs_start: u64) -> (NttReducMeta, u64) {
let mut bs_after_reduc = u64::MAX;
let mut min_h = bs_start / 2;
for h in bs_start / 2..bs_start {
let mut t = 0u64;
for k in 0..4 {
let q = P::Q[k] as u64;
let pow_h_mod_q = pow2_mod(h, q);
let pow_h_bs = if pow_h_mod_q <= 1 { 0u64 } else { ceil_log2_u64(pow_h_mod_q) };
let t1 = bs_start - h + pow_h_bs;
let t2 = 1 + t1.max(h);
if t < t2 {
t = t2;
}
}
if t < bs_after_reduc {
min_h = h;
bs_after_reduc = t;
}
}
let mask = (1u64 << min_h) - 1;
let modulo_red_cst: [u64; 4] = std::array::from_fn(|k| pow2_mod(min_h, P::Q[k] as u64));
(
NttReducMeta {
modulo_red_cst,
mask,
h: min_h,
},
bs_after_reduc,
)
}
#[inline(always)]
fn pack_omega(t: u64, half_bs: u64, q: u64) -> u64 {
let t1 = (t << half_bs) % q;
(t1 << 32) | t
}
impl<P: PrimeSet> NttTable<P> {
pub fn new(n: usize) -> Self {
assert!(
n.is_power_of_two() && n <= (1 << 16),
"NTT size must be a power of two ≤ 2^16, got {n}"
);
let omega_vec = fill_omegas::<P>(n);
let log_q = P::LOG_Q;
let mut bs = 64u64;
let (reduc_metadata, bs_after_reduc) = fill_reduction_meta::<P>(bs);
let input_bit_size = bs;
let powomega_capacity = alloc_aligned::<u64>(4 * 2 * n.max(1));
let mut powomega: Vec<u64> = powomega_capacity;
let mut po_ptr = 0usize;
let mut level_metadata: Vec<NttStepMeta> = Vec::new();
if n == 1 {
return Self {
n,
level_metadata,
powomega,
reduc_metadata,
input_bit_size,
output_bit_size: bs,
_phantom: PhantomData,
};
}
{
let half_bs = bs.div_ceil(2);
let mask = (1u64 << half_bs) - 1;
bs = half_bs + log_q + 1;
level_metadata.push(NttStepMeta {
q2bs: [0; 4],
bs,
half_bs,
mask,
reduce: false,
});
}
let half_bs_0 = level_metadata[0].half_bs;
{
let mut pow_om: [u64; 4] = [1; 4]; for i in 0..n {
for k in 0..4 {
let q = P::Q[k] as u64;
powomega[po_ptr + 4 * i + k] = pack_omega(pow_om[k], half_bs_0, q);
}
for k in 0..4 {
pow_om[k] = (pow_om[k] * omega_vec[k] as u64) % P::Q[k] as u64;
}
}
po_ptr += 4 * n;
}
let mut nn = n;
while nn >= 2 {
let halfnn = nn / 2;
let do_reduce = bs == 64;
if do_reduce {
bs = bs_after_reduc;
}
let q2bs: [u64; 4] = std::array::from_fn(|k| (P::Q[k] as u64) << (bs - log_q));
let (new_bs, half_bs, mask) = if nn >= 4 {
let bs1 = bs + 1; let half_bs = bs1.div_ceil(2);
let bs2 = half_bs + log_q + 1; let new_bs = bs1.max(bs2);
assert!(new_bs <= 64, "NTT bit-size overflow at level nn={nn}");
(new_bs, half_bs, (1u64 << half_bs) - 1)
} else {
let new_bs = bs + 1;
(new_bs, 0, 0)
};
level_metadata.push(NttStepMeta {
q2bs,
bs: new_bs,
half_bs,
mask,
reduce: do_reduce,
});
bs = new_bs;
if halfnn > 1 {
let m = n / halfnn; let omega_m: [u64; 4] = std::array::from_fn(|k| modq_pow(omega_vec[k], m as i64, P::Q[k]) as u64);
let mut pow_om: [u64; 4] = omega_m; let half_bs_level = level_metadata.last().unwrap().half_bs;
for i in 0..halfnn - 1 {
for k in 0..4 {
let q = P::Q[k] as u64;
powomega[po_ptr + 4 * i + k] = pack_omega(pow_om[k], half_bs_level, q);
}
for k in 0..4 {
pow_om[k] = (pow_om[k] * omega_m[k]) % P::Q[k] as u64;
}
}
po_ptr += 4 * (halfnn - 1);
}
nn /= 2;
}
let output_bit_size = bs;
powomega.truncate(po_ptr);
Self {
n,
level_metadata,
powomega,
reduc_metadata,
input_bit_size,
output_bit_size,
_phantom: PhantomData,
}
}
}
impl<P: PrimeSet> NttTableInv<P> {
pub fn new(n: usize) -> Self {
assert!(
n.is_power_of_two() && n <= (1 << 16),
"iNTT size must be a power of two ≤ 2^16, got {n}"
);
let omega_vec = fill_omegas::<P>(n);
let log_q = P::LOG_Q;
let mut bs = 64u64;
let (reduc_metadata, bs_after_reduc) = fill_reduction_meta::<P>(bs);
let input_bit_size = bs;
let powomega_capacity = alloc_aligned::<u64>(4 * 2 * n.max(1));
let mut powomega: Vec<u64> = powomega_capacity;
let mut po_ptr = 0usize;
let mut level_metadata: Vec<NttStepMeta> = Vec::new();
if n == 1 {
return Self {
n,
level_metadata,
powomega,
reduc_metadata,
input_bit_size,
output_bit_size: bs,
_phantom: PhantomData,
};
}
{
let do_reduce = bs == 64;
if do_reduce {
bs = bs_after_reduc;
}
let q2bs: [u64; 4] = std::array::from_fn(|k| (P::Q[k] as u64) << (bs - log_q));
let new_bs = bs + 1;
level_metadata.push(NttStepMeta {
q2bs,
bs: new_bs,
half_bs: 0,
mask: 0,
reduce: do_reduce,
});
bs = new_bs;
}
let mut nn = 4usize;
while nn <= n {
let halfnn = nn / 2;
let do_reduce = bs == 64;
if do_reduce {
bs = bs_after_reduc;
}
let half_bs = bs.div_ceil(2);
let bs_mult = half_bs + log_q + 1; let new_bs = 1 + bs.max(bs_mult); assert!(new_bs <= 64, "iNTT bit-size overflow at level nn={nn}");
let q2bs: [u64; 4] = std::array::from_fn(|k| (P::Q[k] as u64) << (bs_mult - log_q));
let mask = (1u64 << half_bs) - 1;
level_metadata.push(NttStepMeta {
q2bs,
bs: new_bs,
half_bs,
mask,
reduce: do_reduce,
});
bs = new_bs;
let m = n / halfnn;
let omega_inv_m: [u64; 4] = std::array::from_fn(|k| modq_pow(omega_vec[k], -(m as i64), P::Q[k]) as u64);
let mut pow_om: [u64; 4] = omega_inv_m; let half_bs_level = level_metadata.last().unwrap().half_bs;
for i in 0..halfnn - 1 {
for k in 0..4 {
let q = P::Q[k] as u64;
powomega[po_ptr + 4 * i + k] = pack_omega(pow_om[k], half_bs_level, q);
}
for k in 0..4 {
pow_om[k] = (pow_om[k] * omega_inv_m[k]) % P::Q[k] as u64;
}
}
po_ptr += 4 * (halfnn - 1);
nn *= 2;
}
{
let do_reduce = bs == 64;
if do_reduce {
bs = bs_after_reduc;
}
let half_bs = bs.div_ceil(2);
let new_bs = half_bs + log_q + 1;
assert!(new_bs <= 64, "iNTT bit-size overflow at last level");
let q2bs: [u64; 4] = std::array::from_fn(|k| (P::Q[k] as u64) << (new_bs - log_q));
let mask = (1u64 << half_bs) - 1;
level_metadata.push(NttStepMeta {
q2bs,
bs: new_bs,
half_bs,
mask,
reduce: do_reduce,
});
bs = new_bs;
for k in 0..4 {
let q = P::Q[k] as u64;
let inv_n = modq_pow(n as u32, -1, P::Q[k]) as u64;
let omega_inv = modq_pow(omega_vec[k], -1, P::Q[k]) as u64;
let mut pow_om = inv_n; for i in 0..n {
powomega[po_ptr + 4 * i + k] = pack_omega(pow_om, half_bs, q);
pow_om = (pow_om * omega_inv) % q;
}
}
po_ptr += 4 * n;
}
let output_bit_size = bs;
powomega.truncate(po_ptr);
Self {
n,
level_metadata,
powomega,
reduc_metadata,
input_bit_size,
output_bit_size,
_phantom: PhantomData,
}
}
}
#[inline(always)]
pub fn split_precompmul(inp: u64, powomega_packed: u64, half_bs: u64, mask: u64) -> u64 {
let inp_low = inp & mask;
let t = powomega_packed & 0xFFFF_FFFF; let t1 = powomega_packed >> 32; inp_low.wrapping_mul(t).wrapping_add((inp >> half_bs).wrapping_mul(t1))
}
#[inline(always)]
pub fn modq_red(x: u64, h: u64, mask: u64, cst: u64) -> u64 {
(x & mask).wrapping_add((x >> h).wrapping_mul(cst))
}
pub fn ntt_ref<P: PrimeSet>(table: &NttTable<P>, data: &mut [u64]) {
let n = table.n;
if n == 1 {
return;
}
debug_assert!(data.len() >= 4 * n);
let mut po_off = 0usize; let mut meta_idx = 0usize;
{
let meta = &table.level_metadata[meta_idx];
let h = meta.half_bs;
let mask = meta.mask;
for i in 0..n {
for k in 0..4 {
let x = data[4 * i + k];
let po = table.powomega[po_off + 4 * i + k];
data[4 * i + k] = split_precompmul(x, po, h, mask);
}
}
po_off += 4 * n;
meta_idx += 1;
}
let mut nn = n;
while nn >= 2 {
let halfnn = nn / 2;
let meta = &table.level_metadata[meta_idx];
let mut blk = 0;
while blk < n {
ntt_butterfly_block(data, blk, halfnn, meta, &table.reduc_metadata, &table.powomega, po_off);
blk += nn;
}
po_off += 4 * halfnn.saturating_sub(1);
meta_idx += 1;
nn /= 2;
}
}
pub fn intt_ref<P: PrimeSet>(table: &NttTableInv<P>, data: &mut [u64]) {
let n = table.n;
if n == 1 {
return;
}
debug_assert!(data.len() >= 4 * n);
let mut po_off = 0usize;
let mut meta_idx = 0usize;
let log_n = n.trailing_zeros() as usize;
{
let meta = &table.level_metadata[meta_idx];
let mut blk = 0;
while blk < n {
intt_butterfly_block(data, blk, 1, meta, &table.reduc_metadata, &table.powomega, po_off);
blk += 2;
}
meta_idx += 1;
}
let mut nn = 4usize;
for _ in 1..log_n {
let halfnn = nn / 2;
let meta = &table.level_metadata[meta_idx];
let mut blk = 0;
while blk < n {
intt_butterfly_block(data, blk, halfnn, meta, &table.reduc_metadata, &table.powomega, po_off);
blk += nn;
}
po_off += 4 * (halfnn - 1);
meta_idx += 1;
nn *= 2;
}
{
let meta = &table.level_metadata[meta_idx];
let h = meta.half_bs;
let mask = meta.mask;
let do_reduce = meta.reduce;
for i in 0..n {
for k in 0..4 {
let x = if do_reduce {
modq_red(
data[4 * i + k],
table.reduc_metadata.h,
table.reduc_metadata.mask,
table.reduc_metadata.modulo_red_cst[k],
)
} else {
data[4 * i + k]
};
let po = table.powomega[po_off + 4 * i + k];
data[4 * i + k] = split_precompmul(x, po, h, mask);
}
}
}
}
#[inline(always)]
fn ntt_butterfly_block(
data: &mut [u64],
blk: usize, halfnn: usize, meta: &NttStepMeta,
reduc: &NttReducMeta,
powomega: &[u64],
po_off: usize, ) {
let q2bs = meta.q2bs;
let h = meta.half_bs;
let mask = meta.mask;
let do_reduce = meta.reduce;
for (k, &q2bs_k) in q2bs.iter().enumerate() {
let idx_a = 4 * blk + k;
let idx_b = 4 * (blk + halfnn) + k;
let a = if do_reduce {
modq_red(data[idx_a], reduc.h, reduc.mask, reduc.modulo_red_cst[k])
} else {
data[idx_a]
};
let b = if do_reduce {
modq_red(data[idx_b], reduc.h, reduc.mask, reduc.modulo_red_cst[k])
} else {
data[idx_b]
};
data[idx_a] = a.wrapping_add(b);
data[idx_b] = a.wrapping_add(q2bs_k).wrapping_sub(b);
}
if halfnn > 1 {
for i in 1..halfnn {
for k in 0..4 {
let idx_a = 4 * (blk + i) + k;
let idx_b = 4 * (blk + halfnn + i) + k;
let a = if do_reduce {
modq_red(data[idx_a], reduc.h, reduc.mask, reduc.modulo_red_cst[k])
} else {
data[idx_a]
};
let b = if do_reduce {
modq_red(data[idx_b], reduc.h, reduc.mask, reduc.modulo_red_cst[k])
} else {
data[idx_b]
};
data[idx_a] = a.wrapping_add(b);
let b1 = a.wrapping_add(q2bs[k]).wrapping_sub(b);
data[idx_b] = split_precompmul(b1, powomega[po_off + 4 * (i - 1) + k], h, mask);
}
}
}
}
#[inline(always)]
fn intt_butterfly_block(
data: &mut [u64],
blk: usize,
halfnn: usize,
meta: &NttStepMeta,
reduc: &NttReducMeta,
powomega: &[u64],
po_off: usize,
) {
let q2bs = meta.q2bs;
let h = meta.half_bs;
let mask = meta.mask;
let do_reduce = meta.reduce;
for (k, &q2bs_k) in q2bs.iter().enumerate() {
let idx_a = 4 * blk + k;
let idx_b = 4 * (blk + halfnn) + k;
let a = if do_reduce {
modq_red(data[idx_a], reduc.h, reduc.mask, reduc.modulo_red_cst[k])
} else {
data[idx_a]
};
let bo = if do_reduce {
modq_red(data[idx_b], reduc.h, reduc.mask, reduc.modulo_red_cst[k])
} else {
data[idx_b]
};
data[idx_a] = a.wrapping_add(bo);
data[idx_b] = a.wrapping_add(q2bs_k).wrapping_sub(bo);
}
if halfnn > 1 {
for i in 1..halfnn {
for k in 0..4 {
let idx_a = 4 * (blk + i) + k;
let idx_b = 4 * (blk + halfnn + i) + k;
let a = if do_reduce {
modq_red(data[idx_a], reduc.h, reduc.mask, reduc.modulo_red_cst[k])
} else {
data[idx_a]
};
let b_raw = if do_reduce {
modq_red(data[idx_b], reduc.h, reduc.mask, reduc.modulo_red_cst[k])
} else {
data[idx_b]
};
let bo = split_precompmul(b_raw, powomega[po_off + 4 * (i - 1) + k], h, mask);
data[idx_a] = a.wrapping_add(bo);
data[idx_b] = a.wrapping_add(q2bs[k]).wrapping_sub(bo);
}
}
}
}
use super::pow2_mod;
fn ceil_log2_u64(x: u64) -> u64 {
if x <= 1 {
return 0;
}
let floor_log2 = 63 - x.leading_zeros() as u64;
if x.is_power_of_two() { floor_log2 } else { floor_log2 + 1 }
}
#[cfg(test)]
mod tests {
use super::*;
use crate::reference::ntt120::{
arithmetic::{b_from_znx64_ref, b_to_znx128_ref},
primes::Primes30,
};
#[test]
fn ntt_intt_identity() {
for log_n in 1..=8usize {
let n = 1 << log_n;
let fwd = NttTable::<Primes30>::new(n);
let inv = NttTableInv::<Primes30>::new(n);
let coeffs: Vec<i64> = (0..n as i64).map(|i| (i * 7 + 3) % 201 - 100).collect();
let mut data = vec![0u64; 4 * n];
b_from_znx64_ref::<Primes30>(n, &mut data, &coeffs);
let data_orig = data.clone();
ntt_ref::<Primes30>(&fwd, &mut data);
intt_ref::<Primes30>(&inv, &mut data);
for i in 0..n {
for k in 0..4 {
let orig = data_orig[4 * i + k] % Primes30::Q[k] as u64;
let got = data[4 * i + k] % Primes30::Q[k] as u64;
assert_eq!(orig, got, "n={n} i={i} k={k}: mismatch after NTT+iNTT round-trip");
}
}
}
}
#[test]
fn ntt_convolution() {
let n = 8usize;
let fwd = NttTable::<Primes30>::new(n);
let inv = NttTableInv::<Primes30>::new(n);
let a: Vec<i64> = [1, 2, 0, 0, 0, 0, 0, 0].to_vec();
let b: Vec<i64> = [3, 4, 0, 0, 0, 0, 0, 0].to_vec();
let mut da = vec![0u64; 4 * n];
let mut db = vec![0u64; 4 * n];
b_from_znx64_ref::<Primes30>(n, &mut da, &a);
b_from_znx64_ref::<Primes30>(n, &mut db, &b);
ntt_ref::<Primes30>(&fwd, &mut da);
ntt_ref::<Primes30>(&fwd, &mut db);
let mut dc = vec![0u64; 4 * n];
for i in 0..n {
for k in 0..4 {
let q = Primes30::Q[k] as u64;
dc[4 * i + k] = (da[4 * i + k] % q * (db[4 * i + k] % q)) % q;
}
}
intt_ref::<Primes30>(&inv, &mut dc);
let mut result = vec![0i128; n];
b_to_znx128_ref::<Primes30>(n, &mut result, &dc);
let expected: Vec<i128> = [3, 10, 8, 0, 0, 0, 0, 0].to_vec();
assert_eq!(result, expected, "NTT convolution mismatch");
}
}