use super::mod_q::DEFAULT_Q;
#[derive(Clone)]
pub struct NttContext {
n: usize,
moduli: Vec<u64>,
q_inv_neg: Vec<u64>,
r_squared: Vec<u64>,
psi_powers: Vec<Vec<u64>>,
psi_inv_powers: Vec<Vec<u64>>,
n_inv: Vec<u64>,
}
impl NttContext {
pub fn new(n: usize, q: u64) -> Self {
Self::with_moduli(n, &[q])
}
pub fn with_moduli(n: usize, moduli: &[u64]) -> Self {
assert!(n.is_power_of_two(), "n must be a power of two");
assert!(!moduli.is_empty(), "moduli must be non-empty");
let mut q_inv_neg = Vec::with_capacity(moduli.len());
let mut r_squared = Vec::with_capacity(moduli.len());
let mut psi_powers = Vec::with_capacity(moduli.len());
let mut psi_inv_powers = Vec::with_capacity(moduli.len());
let mut n_inv = Vec::with_capacity(moduli.len());
for &q in moduli {
assert!(q % (2 * n as u64) == 1, "q must be ≡ 1 (mod 2n)");
let q_inv = Self::compute_q_inv_neg(q);
let r2 = Self::compute_r_squared(q);
let psi = Self::find_primitive_root(2 * n as u64, q);
let psi_mont = Self::to_montgomery(psi, q, r2, q_inv);
let psi_pow = Self::compute_twiddle_factors(n, psi_mont, q, q_inv, r2);
let psi_inv = Self::mod_pow(psi, q - 2, q);
let psi_inv_mont = Self::to_montgomery(psi_inv, q, r2, q_inv);
let psi_inv_pow = Self::compute_twiddle_factors(n, psi_inv_mont, q, q_inv, r2);
let n_inv_val = Self::mod_pow(n as u64, q - 2, q);
let n_inv_mont = Self::to_montgomery(n_inv_val, q, r2, q_inv);
q_inv_neg.push(q_inv);
r_squared.push(r2);
psi_powers.push(psi_pow);
psi_inv_powers.push(psi_inv_pow);
n_inv.push(n_inv_mont);
}
Self {
n,
moduli: moduli.to_vec(),
q_inv_neg,
r_squared,
psi_powers,
psi_inv_powers,
n_inv,
}
}
pub fn with_default_q(n: usize) -> Self {
Self::new(n, DEFAULT_Q)
}
pub fn dimension(&self) -> usize {
self.n
}
pub fn modulus(&self) -> u64 {
self.moduli
.iter()
.copied()
.fold(1u64, |acc, m| acc.saturating_mul(m))
}
pub fn moduli(&self) -> &[u64] {
&self.moduli
}
pub fn crt_count(&self) -> usize {
self.moduli.len()
}
pub fn forward(&self, coeffs: &mut [u64]) {
assert_eq!(
coeffs.len(),
self.n * self.crt_count(),
"Input length must match dimension * crt_count"
);
for (idx, _) in self.moduli.iter().enumerate() {
let start = idx * self.n;
let end = start + self.n;
for c in coeffs[start..end].iter_mut() {
*c = Self::to_montgomery_at(
*c,
self.moduli[idx],
self.r_squared[idx],
self.q_inv_neg[idx],
);
}
self.forward_inplace_at(&mut coeffs[start..end], idx);
}
}
pub fn forward_inplace(&self, coeffs: &mut [u64]) {
assert_eq!(
coeffs.len(),
self.n * self.crt_count(),
"Input length must match dimension * crt_count"
);
for (idx, _) in self.moduli.iter().enumerate() {
let start = idx * self.n;
let end = start + self.n;
self.forward_inplace_at(&mut coeffs[start..end], idx);
}
}
fn forward_inplace_at(&self, coeffs: &mut [u64], idx: usize) {
let n = self.n;
let q = self.moduli[idx];
let psi_powers = &self.psi_powers[idx];
let mut t = n;
let mut m = 1;
while m < n {
t >>= 1;
for i in 0..m {
let j1 = 2 * i * t;
let j2 = j1 + t;
let w = psi_powers[m + i];
for j in j1..j2 {
let u = coeffs[j];
let v = self.montgomery_mul_at(coeffs[j + t], w, idx);
coeffs[j] = if u + v >= q { u + v - q } else { u + v };
coeffs[j + t] = if u >= v { u - v } else { q - v + u };
}
}
m <<= 1;
}
}
pub fn inverse(&self, coeffs: &mut [u64]) {
assert_eq!(
coeffs.len(),
self.n * self.crt_count(),
"Input length must match dimension * crt_count"
);
self.inverse_inplace(coeffs);
for (idx, _) in self.moduli.iter().enumerate() {
let start = idx * self.n;
let end = start + self.n;
for c in coeffs[start..end].iter_mut() {
*c = self.montgomery_mul_at(*c, 1, idx);
}
}
}
pub fn inverse_inplace(&self, coeffs: &mut [u64]) {
assert_eq!(
coeffs.len(),
self.n * self.crt_count(),
"Input length must match dimension * crt_count"
);
for (idx, _) in self.moduli.iter().enumerate() {
let start = idx * self.n;
let end = start + self.n;
self.inverse_inplace_at(&mut coeffs[start..end], idx);
}
}
fn inverse_inplace_at(&self, coeffs: &mut [u64], idx: usize) {
let n = self.n;
let q = self.moduli[idx];
let psi_inv_powers = &self.psi_inv_powers[idx];
let mut t = 1;
let mut m = n;
while m > 1 {
m >>= 1;
let j1 = 0;
for i in 0..m {
let j2 = j1 + i * 2 * t;
let w = psi_inv_powers[m + i];
for j in j2..(j2 + t) {
let u = coeffs[j];
let v = coeffs[j + t];
coeffs[j] = if u + v >= q { u + v - q } else { u + v };
let diff = if u >= v { u - v } else { q - v + u };
coeffs[j + t] = self.montgomery_mul_at(diff, w, idx);
}
}
t <<= 1;
}
for c in coeffs.iter_mut() {
*c = self.montgomery_mul_at(*c, self.n_inv[idx], idx);
}
}
pub fn pointwise_mul(&self, a: &[u64], b: &[u64], result: &mut [u64]) {
assert_eq!(
a.len(),
self.n * self.crt_count(),
"Input length must match dimension * crt_count"
);
assert_eq!(
b.len(),
self.n * self.crt_count(),
"Input length must match dimension * crt_count"
);
assert_eq!(
result.len(),
self.n * self.crt_count(),
"Input length must match dimension * crt_count"
);
for idx in 0..self.crt_count() {
let start = idx * self.n;
for i in 0..self.n {
result[start + i] = self.montgomery_mul_at(a[start + i], b[start + i], idx);
}
}
}
#[inline]
pub fn pointwise_mul_single(&self, a: u64, b: u64) -> u64 {
self.montgomery_mul_at(a, b, 0)
}
#[inline]
pub fn pointwise_mul_single_at(&self, a: u64, b: u64, idx: usize) -> u64 {
self.montgomery_mul_at(a, b, idx)
}
pub fn to_mont(&self, a: u64) -> u64 {
Self::to_montgomery(a, self.moduli[0], self.r_squared[0], self.q_inv_neg[0])
}
pub fn from_mont(&self, a: u64) -> u64 {
self.montgomery_mul_at(a, 1, 0)
}
fn montgomery_mul_at(&self, a: u64, b: u64, idx: usize) -> u64 {
let q = self.moduli[idx];
let q_inv_neg = self.q_inv_neg[idx];
let ab = (a as u128) * (b as u128);
let m = ((ab as u64).wrapping_mul(q_inv_neg)) as u128;
let t = ((ab + m * (q as u128)) >> 64) as u64;
if t >= q {
t - q
} else {
t
}
}
fn to_montgomery(a: u64, q: u64, r_squared: u64, q_inv_neg: u64) -> u64 {
let ab = (a as u128) * (r_squared as u128);
let m = ((ab as u64).wrapping_mul(q_inv_neg)) as u128;
let t = ((ab + m * (q as u128)) >> 64) as u64;
if t >= q {
t - q
} else {
t
}
}
#[inline]
fn to_montgomery_at(a: u64, q: u64, r_squared: u64, q_inv_neg: u64) -> u64 {
Self::to_montgomery(a, q, r_squared, q_inv_neg)
}
fn compute_q_inv_neg(q: u64) -> u64 {
let mut y: u64 = 1;
for i in 1..64 {
let yi = y.wrapping_mul(q) & (1u64 << i);
y |= yi;
}
y.wrapping_neg()
}
fn compute_r_squared(q: u64) -> u64 {
let r_mod_q = (1u128 << 64) % (q as u128);
((r_mod_q * r_mod_q) % (q as u128)) as u64
}
fn mod_pow(mut base: u64, mut exp: u64, m: u64) -> u64 {
let mut result = 1u64;
base %= m;
while exp > 0 {
if exp & 1 == 1 {
result = ((result as u128 * base as u128) % m as u128) as u64;
}
exp >>= 1;
base = ((base as u128 * base as u128) % m as u128) as u64;
}
result
}
fn find_primitive_root(n: u64, q: u64) -> u64 {
let exp = (q - 1) / n;
for g in 2..q {
let candidate = Self::mod_pow(g, exp, q);
if Self::mod_pow(candidate, n, q) == 1 && Self::mod_pow(candidate, n / 2, q) != 1 {
return candidate;
}
}
panic!("No primitive root found (should not happen for valid parameters)");
}
fn compute_twiddle_factors(
n: usize,
psi: u64,
q: u64,
q_inv_neg: u64,
r_squared: u64,
) -> Vec<u64> {
let mut factors = vec![0u64; n];
factors[1] = Self::to_montgomery(1, q, r_squared, q_inv_neg);
for m in 1..n {
if m.is_power_of_two() {
let exp = n / (2 * m);
let mut pow = Self::to_montgomery(1, q, r_squared, q_inv_neg);
for _ in 0..exp {
let ab = (pow as u128) * (psi as u128);
let mm = ((ab as u64).wrapping_mul(q_inv_neg)) as u128;
pow = ((ab + mm * (q as u128)) >> 64) as u64;
if pow >= q {
pow -= q;
}
}
factors[m] = pow;
} else {
let prev_idx = m & (m - 1); let step_idx = m & (!m + 1);
let ab = (factors[prev_idx] as u128) * (factors[step_idx] as u128);
let mm = ((ab as u64).wrapping_mul(q_inv_neg)) as u128;
let t = ((ab + mm * (q as u128)) >> 64) as u64;
factors[m] = if t >= q { t - q } else { t };
}
}
factors
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ntt_inverse_roundtrip_small() {
let n = 16;
let ctx = NttContext::with_default_q(n);
let original: Vec<u64> = (0..n as u64).collect();
let mut coeffs = original.clone();
ctx.forward(&mut coeffs);
ctx.inverse(&mut coeffs);
assert_eq!(coeffs, original);
}
#[test]
fn test_ntt_inverse_roundtrip_1024() {
let n = 1024;
let ctx = NttContext::with_default_q(n);
let original: Vec<u64> = (0..n as u64).collect();
let mut coeffs = original.clone();
ctx.forward(&mut coeffs);
ctx.inverse(&mut coeffs);
assert_eq!(coeffs, original);
}
#[test]
fn test_ntt_inverse_roundtrip_2048() {
let n = 2048;
let ctx = NttContext::with_default_q(n);
let original: Vec<u64> = (0..n as u64).map(|i| i * 1000 % DEFAULT_Q).collect();
let mut coeffs = original.clone();
ctx.forward(&mut coeffs);
ctx.inverse(&mut coeffs);
assert_eq!(coeffs, original);
}
#[test]
fn test_ntt_inverse_roundtrip_4096() {
let n = 4096;
let ctx = NttContext::with_default_q(n);
let original: Vec<u64> = (0..n as u64).map(|i| (i * 12345) % DEFAULT_Q).collect();
let mut coeffs = original.clone();
ctx.forward(&mut coeffs);
ctx.inverse(&mut coeffs);
assert_eq!(coeffs, original);
}
#[test]
fn test_ntt_zero_polynomial() {
let n = 256;
let ctx = NttContext::with_default_q(n);
let mut coeffs = vec![0u64; n];
ctx.forward(&mut coeffs);
assert!(coeffs.iter().all(|&c| c == 0));
ctx.inverse(&mut coeffs);
assert!(coeffs.iter().all(|&c| c == 0));
}
#[test]
fn test_ntt_constant_polynomial() {
let n = 256;
let ctx = NttContext::with_default_q(n);
let mut coeffs = vec![0u64; n];
coeffs[0] = 42;
let original = coeffs.clone();
ctx.forward(&mut coeffs);
ctx.inverse(&mut coeffs);
assert_eq!(coeffs, original);
}
#[test]
fn test_pointwise_multiplication() {
let n = 256;
let ctx = NttContext::with_default_q(n);
let mut a = vec![0u64; n];
let mut b = vec![0u64; n];
a[0] = 1;
b[0] = 1;
ctx.forward(&mut a);
ctx.forward(&mut b);
let mut result = vec![0u64; n];
ctx.pointwise_mul(&a, &b, &mut result);
ctx.inverse(&mut result);
assert_eq!(result[0], 1);
assert!(result[1..].iter().all(|&c| c == 0));
}
#[test]
fn test_negacyclic_convolution() {
let n = 256;
let q = DEFAULT_Q;
let ctx = NttContext::with_default_q(n);
let mut a = vec![0u64; n];
a[1] = 1;
let mut b = vec![0u64; n];
b[n - 1] = 1;
ctx.forward(&mut a);
ctx.forward(&mut b);
let mut result = vec![0u64; n];
ctx.pointwise_mul(&a, &b, &mut result);
ctx.inverse(&mut result);
assert_eq!(result[0], q - 1);
assert!(result[1..].iter().all(|&c| c == 0));
}
#[test]
fn test_linearity() {
let n = 256;
let ctx = NttContext::with_default_q(n);
let q = DEFAULT_Q;
let a: Vec<u64> = (0..n as u64).collect();
let b: Vec<u64> = (0..n as u64).map(|i| (i * 2) % q).collect();
let mut a_ntt = a.clone();
let mut b_ntt = b.clone();
ctx.forward(&mut a_ntt);
ctx.forward(&mut b_ntt);
let mut sum: Vec<u64> = a.iter().zip(b.iter()).map(|(&x, &y)| (x + y) % q).collect();
ctx.forward(&mut sum);
for i in 0..n {
let expected = (a_ntt[i] + b_ntt[i]) % q;
assert_eq!(sum[i], expected);
}
}
}