use super::prime::NttRootTable;
use super::scalar::compute_shoup;
use alloc::vec;
use alloc::vec::Vec;
#[derive(Debug, Clone)]
pub struct Ntt32Context {
pub n: usize,
pub log_n: u32,
pub q: u32,
pub two_q: u32,
pub root_powers: Vec<u32>,
pub root_powers_shoup: Vec<u32>,
#[cfg(target_arch = "aarch64")]
pub root_powers_qmulh: Vec<i32>,
pub inv_root_powers: Vec<u32>,
pub inv_root_powers_shoup: Vec<u32>,
#[cfg(target_arch = "aarch64")]
pub inv_root_powers_qmulh: Vec<i32>,
pub n_inv: u32,
pub n_inv_shoup: u32,
}
impl Ntt32Context {
pub fn try_new(n: usize, q: u32) -> Result<Self, crate::NttError> {
if n < 2 || !n.is_power_of_two() {
return Err(crate::NttError::InvalidSize(n));
}
if q >= (1u32 << 28) {
return Err(crate::NttError::PrimeTooLarge(q as u64));
}
if !super::prime::is_prime_32(q) {
return Err(crate::NttError::NotPrime(q as u64));
}
if !((q - 1) as usize).is_multiple_of(2 * n) {
return Err(crate::NttError::NotNttFriendly { q: q as u64, n });
}
let base = NttRootTable::new(n, q);
let root_powers_shoup: Vec<u32> = base
.root_powers
.iter()
.map(|&w| compute_shoup(w, q))
.collect();
let inv_root_powers_shoup: Vec<u32> = base
.inv_root_powers
.iter()
.map(|&w| compute_shoup(w, q))
.collect();
let n_inv_shoup = compute_shoup(base.n_inv, q);
#[cfg(target_arch = "aarch64")]
let root_powers_qmulh: Vec<i32> = base
.root_powers
.iter()
.map(|&w| ((w as u64 * (1u64 << 31)) / q as u64) as i32)
.collect();
#[cfg(target_arch = "aarch64")]
let inv_root_powers_qmulh: Vec<i32> = base
.inv_root_powers
.iter()
.map(|&w| ((w as u64 * (1u64 << 31)) / q as u64) as i32)
.collect();
Ok(Self {
n,
log_n: base.log_n,
q,
two_q: 2 * q,
root_powers: base.root_powers,
root_powers_shoup,
#[cfg(target_arch = "aarch64")]
root_powers_qmulh,
inv_root_powers: base.inv_root_powers,
inv_root_powers_shoup,
#[cfg(target_arch = "aarch64")]
inv_root_powers_qmulh,
n_inv: base.n_inv,
n_inv_shoup,
})
}
pub fn new(n: usize, q: u32) -> Self {
Self::try_new(n, q).expect("Invalid NTT parameters")
}
#[inline]
pub fn forward(&self, data: &mut [u32]) {
#[cfg(target_arch = "aarch64")]
{
super::neon::ntt_fwd_neon(data, self);
}
#[cfg(not(target_arch = "aarch64"))]
{
super::scalar::ntt_forward_scalar(data, self);
}
}
#[inline]
pub fn inverse(&self, data: &mut [u32]) {
#[cfg(target_arch = "aarch64")]
{
super::neon::ntt_inv_neon(data, self);
}
#[cfg(not(target_arch = "aarch64"))]
{
super::scalar::ntt_inverse_scalar(data, self);
}
}
#[inline]
pub fn inverse_lazy(&self, data: &mut [u32]) {
#[cfg(target_arch = "aarch64")]
{
super::neon::ntt_inv_neon_lazy(data, self);
}
#[cfg(not(target_arch = "aarch64"))]
{
super::scalar::ntt_inverse_scalar_lazy(data, self);
}
}
#[inline]
pub fn n_inv(&self) -> u32 {
self.n_inv
}
#[inline]
pub fn n_inv_shoup(&self) -> u32 {
self.n_inv_shoup
}
pub fn pointwise_mul(&self, a: &[u32], b: &[u32], result: &mut [u32]) {
super::scalar::ntt_pointwise_mul_scalar(a, b, result, self.q, self.n);
}
pub fn negacyclic_mul(&self, a: &[u32], b: &[u32]) -> Vec<u32> {
let n = self.n;
assert_eq!(a.len(), n, "negacyclic_mul: a.len() must be N");
assert_eq!(b.len(), n, "negacyclic_mul: b.len() must be N");
let mut a_buf = a.to_vec();
let mut b_buf = b.to_vec();
let mut result = vec![0u32; n];
self.negacyclic_mul_into(&mut a_buf, &mut b_buf, &mut result);
result
}
pub fn negacyclic_mul_into(&self, a_buf: &mut [u32], b_buf: &mut [u32], result: &mut [u32]) {
let n = self.n;
assert_eq!(a_buf.len(), n, "a_buf.len()={} != N={n}", a_buf.len());
assert_eq!(b_buf.len(), n, "b_buf.len()={} != N={n}", b_buf.len());
assert_eq!(result.len(), n, "result.len()={} != N={n}", result.len());
self.forward(a_buf);
self.forward(b_buf);
self.pointwise_mul(a_buf, b_buf, result);
self.inverse(result);
}
}
#[cfg(test)]
#[allow(unused_variables, clippy::needless_range_loop, dead_code)]
mod tests {
use super::*;
use crate::ntt32::prime::generate_primes_28;
fn test_prime(n: usize) -> u32 {
generate_primes_28(n, 1)[0]
}
fn make_test_data(n: usize, q: u32) -> Vec<u32> {
(0..n)
.map(|i| ((i as u64 * 314_159_265 + 271_828_182) % q as u64) as u32)
.collect()
}
#[test]
fn test_roundtrip_n2() {
let q = 5u32; let ctx = Ntt32Context::new(2, q);
let original = vec![1u32, 3];
let mut data = original.clone();
ctx.forward(&mut data);
assert_ne!(data, original, "NTT forward did nothing for N=2");
ctx.inverse(&mut data);
assert_eq!(data, original, "NTT roundtrip failed for N=2");
}
#[test]
fn test_roundtrip_n4() {
let q = 17u32; let ctx = Ntt32Context::new(4, q);
let original = vec![1u32, 5, 9, 13];
let mut data = original.clone();
ctx.forward(&mut data);
assert_ne!(data, original, "NTT forward did nothing for N=4");
ctx.inverse(&mut data);
assert_eq!(data, original, "NTT roundtrip failed for N=4");
}
#[test]
fn test_roundtrip_n16() {
let n = 16;
let q = test_prime(n);
let ctx = Ntt32Context::new(n, q);
let original = make_test_data(n, q);
let mut data = original.clone();
ctx.forward(&mut data);
assert_ne!(data, original, "NTT forward did nothing for N={n}");
ctx.inverse(&mut data);
assert_eq!(data, original, "NTT roundtrip failed for N={n}");
}
#[test]
fn test_roundtrip_n64() {
let n = 64;
let q = test_prime(n);
let ctx = Ntt32Context::new(n, q);
let original = make_test_data(n, q);
let mut data = original.clone();
ctx.forward(&mut data);
ctx.inverse(&mut data);
assert_eq!(data, original, "NTT roundtrip failed for N={n}");
}
#[test]
fn test_roundtrip_n1024() {
let n = 1024;
let q = test_prime(n);
let ctx = Ntt32Context::new(n, q);
let original = make_test_data(n, q);
let mut data = original.clone();
ctx.forward(&mut data);
ctx.inverse(&mut data);
assert_eq!(data, original, "NTT roundtrip failed for N={n}");
}
#[test]
fn test_roundtrip_n32768() {
let n = 32768;
let q = test_prime(n);
let ctx = Ntt32Context::new(n, q);
let original = make_test_data(n, q);
let mut data = original.clone();
ctx.forward(&mut data);
ctx.inverse(&mut data);
assert_eq!(data, original, "NTT roundtrip failed for N=32768");
}
#[test]
fn test_roundtrip_zeros() {
let n = 64;
let q = test_prime(n);
let ctx = Ntt32Context::new(n, q);
let mut data = vec![0u32; n];
ctx.forward(&mut data);
ctx.inverse(&mut data);
assert_eq!(data, vec![0u32; n]);
}
#[test]
fn test_constant_polynomial() {
let n = 64;
let q = test_prime(n);
let ctx = Ntt32Context::new(n, q);
let c = 42u32;
let mut data = vec![0u32; n];
data[0] = c;
ctx.forward(&mut data);
for (i, &v) in data.iter().enumerate() {
assert_eq!(v, c, "NTT of constant: data[{i}]={v}, expected {c}");
}
}
#[test]
fn test_negacyclic_mul_identity() {
let n = 64;
let q = test_prime(n);
let ctx = Ntt32Context::new(n, q);
let a: Vec<u32> = (0..n)
.map(|i| ((i as u64 * 17 + 5) % q as u64) as u32)
.collect();
let mut one = vec![0u32; n];
one[0] = 1;
let result = ctx.negacyclic_mul(&a, &one);
assert_eq!(result, a, "Multiply by 1 is not identity");
}
#[test]
fn test_negacyclic_mul_n16() {
let n = 16;
let q = test_prime(n);
let ctx = Ntt32Context::new(n, q);
let a: Vec<u32> = (0..n).map(|i| (i as u32 + 1) % q).collect();
let b: Vec<u32> = vec![1u32; n];
let mut expected = vec![0u32; n];
for i in 0..n {
for j in 0..n {
let prod = (a[i] as u64 * b[j] as u64) % q as u64;
if i + j < n {
expected[i + j] = ((expected[i + j] as u64 + prod) % q as u64) as u32;
} else {
let idx = i + j - n;
expected[idx] = ((expected[idx] as u64 + q as u64 - prod) % q as u64) as u32;
}
}
}
let result = ctx.negacyclic_mul(&a, &b);
assert_eq!(result, expected, "Negacyclic multiplication mismatch");
}
#[test]
fn test_inverse_lazy_no_normalization() {
let n = 256;
let q = test_prime(n);
let ctx = Ntt32Context::new(n, q);
let original = make_test_data(n, q);
let mut data = original.clone();
ctx.forward(&mut data);
ctx.inverse_lazy(&mut data);
assert_ne!(
data, original,
"inverse_lazy should not match original (no N^{{-1}})"
);
let n_inv = ctx.n_inv();
for x in data.iter_mut() {
*x = ((*x as u64 * n_inv as u64) % q as u64) as u32;
}
assert_eq!(
data, original,
"inverse_lazy + manual N^{{-1}} should match original"
);
}
#[test]
fn test_inverse_lazy_matches_concrete_ntt_style() {
let n = 1024;
let q = test_prime(n);
let ctx = Ntt32Context::new(n, q);
let original = make_test_data(n, q);
let mut data_full = original.clone();
let mut data_lazy = original.clone();
ctx.forward(&mut data_full);
ctx.forward(&mut data_lazy);
ctx.inverse(&mut data_full);
ctx.inverse_lazy(&mut data_lazy);
let n_inv = ctx.n_inv();
let data_lazy_normalized: Vec<u32> = data_lazy
.iter()
.map(|&x| ((x as u64 * n_inv as u64) % q as u64) as u32)
.collect();
assert_eq!(data_full, data_lazy_normalized);
}
#[test]
fn test_negacyclic_mul_into_matches_negacyclic_mul() {
let n = 256;
let q = test_prime(n);
let ctx = Ntt32Context::new(n, q);
let a: Vec<u32> = (0..n)
.map(|i| ((i as u64 * 17 + 3) % q as u64) as u32)
.collect();
let b: Vec<u32> = (0..n)
.map(|i| ((i as u64 * 31 + 7) % q as u64) as u32)
.collect();
let result_alloc = ctx.negacyclic_mul(&a, &b);
let mut a_buf = a.clone();
let mut b_buf = b.clone();
let mut result_inplace = vec![0u32; n];
ctx.negacyclic_mul_into(&mut a_buf, &mut b_buf, &mut result_inplace);
assert_eq!(
result_alloc, result_inplace,
"negacyclic_mul_into must match negacyclic_mul"
);
}
#[test]
fn test_negacyclic_mul_into_reusable_buffers() {
let n = 64;
let q = test_prime(n);
let ctx = Ntt32Context::new(n, q);
let mut a_buf = vec![0u32; n];
let mut b_buf = vec![0u32; n];
let mut result = vec![0u32; n];
for round in 0..3u32 {
for i in 0..n {
a_buf[i] = ((i as u64 * (round as u64 + 17) + 3) % q as u64) as u32;
b_buf[i] = ((i as u64 * (round as u64 + 31) + 7) % q as u64) as u32;
}
let expected = ctx.negacyclic_mul(&a_buf, &b_buf);
ctx.negacyclic_mul_into(&mut a_buf, &mut b_buf, &mut result);
assert_eq!(
result, expected,
"Reusable buffer mismatch at round {round}"
);
}
}
#[test]
fn test_pq_mldsa_roundtrip() {
let q: u32 = 8_380_417;
let n = 256;
assert_eq!((q - 1) % (2 * n as u32), 0, "q-1 must be divisible by 2N");
let ctx = Ntt32Context::new(n, q);
let original = make_test_data(n, q);
let mut data = original.clone();
ctx.forward(&mut data);
assert_ne!(data, original, "Forward NTT should change data");
ctx.inverse(&mut data);
assert_eq!(data, original, "ML-DSA roundtrip failed");
}
#[test]
fn test_pq_mldsa_negacyclic_mul() {
let q: u32 = 8_380_417;
let n = 256;
let ctx = Ntt32Context::new(n, q);
let a: Vec<u32> = (0..n)
.map(|i| ((i as u64 * 17 + 5) % q as u64) as u32)
.collect();
let mut one = vec![0u32; n];
one[0] = 1;
let result = ctx.negacyclic_mul(&a, &one);
assert_eq!(result, a, "ML-DSA: multiply by 1 is not identity");
}
#[test]
fn test_pq_falcon512_roundtrip() {
let q: u32 = 12_289;
let n = 512;
assert_eq!((q - 1) % (2 * n as u32), 0, "q-1 must be divisible by 2N");
let ctx = Ntt32Context::new(n, q);
let original = make_test_data(n, q);
let mut data = original.clone();
ctx.forward(&mut data);
ctx.inverse(&mut data);
assert_eq!(data, original, "Falcon-512 roundtrip failed");
}
#[test]
fn test_pq_falcon1024_roundtrip() {
let q: u32 = 12_289;
let n = 1024;
assert_eq!((q - 1) % (2 * n as u32), 0, "q-1 must be divisible by 2N");
let ctx = Ntt32Context::new(n, q);
let original = make_test_data(n, q);
let mut data = original.clone();
ctx.forward(&mut data);
ctx.inverse(&mut data);
assert_eq!(data, original, "Falcon-1024 roundtrip failed");
}
#[test]
fn test_pq_falcon_negacyclic_mul() {
let q: u32 = 12_289;
let n = 512;
let ctx = Ntt32Context::new(n, q);
let a: Vec<u32> = (0..n)
.map(|i| ((i as u64 * 17 + 5) % q as u64) as u32)
.collect();
let mut one = vec![0u32; n];
one[0] = 1;
let result = ctx.negacyclic_mul(&a, &one);
assert_eq!(result, a, "Falcon: multiply by 1 is not identity");
}
#[test]
fn test_pq_mlkem_proxy_roundtrip() {
let q: u32 = 3_329;
let n = 128;
assert_eq!((q - 1) % (2 * n as u32), 0, "q-1 must be divisible by 2N");
let ctx = Ntt32Context::new(n, q);
let original = make_test_data(n, q);
let mut data = original.clone();
ctx.forward(&mut data);
ctx.inverse(&mut data);
assert_eq!(data, original, "ML-KEM proxy roundtrip failed");
}
#[test]
fn test_pq_mlkem_negacyclic_mul() {
let q: u32 = 3_329;
let n = 128;
let ctx = Ntt32Context::new(n, q);
let a: Vec<u32> = (0..n).map(|i| (i as u32 + 1) % q).collect();
let b: Vec<u32> = vec![1u32; n];
let mut expected = vec![0u32; n];
for i in 0..n {
for j in 0..n {
let prod = (a[i] as u64 * b[j] as u64) % q as u64;
if i + j < n {
expected[i + j] = ((expected[i + j] as u64 + prod) % q as u64) as u32;
} else {
let idx = i + j - n;
expected[idx] = ((expected[idx] as u64 + q as u64 - prod) % q as u64) as u32;
}
}
}
let result = ctx.negacyclic_mul(&a, &b);
assert_eq!(
result, expected,
"ML-KEM negacyclic multiplication mismatch"
);
}
const _: () = {
fn assert_send<T: Send>() {}
fn assert_sync<T: Sync>() {}
fn check() {
assert_send::<super::Ntt32Context>();
assert_sync::<super::Ntt32Context>();
}
};
}