#![cfg_attr(not(feature = "std"), no_std)]
use super::params::{Modulus, NttModulus, PostInvNtt};
use super::polynomial::Polynomial;
use crate::error::{Error, Result};
#[inline(always)]
fn pow_mod<M: Modulus>(mut base: u32, mut exp: u32) -> u32 {
let mut acc: u32 = 1;
while exp != 0 {
if (exp & 1) == 1 {
acc = ((acc as u64 * base as u64) % M::Q as u64) as u32;
}
base = ((base as u64 * base as u64) % M::Q as u64) as u32;
exp >>= 1;
}
acc
}
pub trait NttOperator<M: NttModulus> {
fn ntt(poly: &mut Polynomial<M>) -> Result<()>;
}
pub trait InverseNttOperator<M: NttModulus> {
fn inv_ntt(poly: &mut Polynomial<M>) -> Result<()>;
}
pub struct CooleyTukeyNtt;
#[inline(always)]
pub fn montgomery_reduce<M: NttModulus>(a: u64) -> u32 {
let q = M::Q as u64;
let neg_qinv = M::NEG_QINV as u64;
let m = ((a as u32) as u64).wrapping_mul(neg_qinv) & 0xFFFFFFFF;
let t = a.wrapping_add(m.wrapping_mul(q)) >> 32;
let result = t as u32;
let mask = ((result >= M::Q) as u32).wrapping_neg();
result.wrapping_sub(M::Q & mask)
}
#[inline]
fn reduce_to_q<M: Modulus>(x: u32) -> u32 {
let mut y = x;
y -= M::Q & ((y >= M::Q) as u32).wrapping_neg();
y -= M::Q & ((y >= M::Q) as u32).wrapping_neg();
if y < M::Q {
return y;
}
let (mu, k) = if M::BARRETT_MU != 0 {
(M::BARRETT_MU, M::BARRETT_K)
} else {
let log_q = 64 - (M::Q as u64).leading_zeros(); let k = log_q + 32;
let mu = (1u128 << k) / M::Q as u128; (mu, k)
};
let x_wide = y as u128;
let q = ((x_wide * mu) >> k) as u32;
let mut r = y.wrapping_sub(q.wrapping_mul(M::Q));
r = r.wrapping_sub(M::Q & ((r >= M::Q) as u32).wrapping_neg());
r
}
#[inline(always)]
fn montgomery_mul<M: NttModulus>(a: u32, b: u32) -> u32 {
montgomery_reduce::<M>((a as u64) * (b as u64))
}
#[inline(always)]
fn add_mod<M: Modulus>(a: u32, b: u32) -> u32 {
((a as u64 + b as u64) % M::Q as u64) as u32
}
#[inline(always)]
fn add_mod_fast<M: Modulus>(a: u32, b: u32) -> u32 {
let s = a + b;
let mask = ((s >= M::Q) as u32).wrapping_neg();
s - (M::Q & mask)
}
#[inline(always)]
fn sub_mod_fast<M: Modulus>(a: u32, b: u32) -> u32 {
let t = a.wrapping_add(M::Q).wrapping_sub(b);
let mask = ((t >= M::Q) as u32).wrapping_neg();
t - (M::Q & mask)
}
#[inline(always)]
fn sub_mod_upto_2q<M: Modulus>(a: u32, b: u32) -> u32 {
a.wrapping_add(M::Q).wrapping_sub(b)
}
#[inline(always)]
fn to_montgomery<M: NttModulus>(val: u32) -> u32 {
((val as u64 * M::MONT_R as u64) % M::Q as u64) as u32
}
impl<M: NttModulus> NttOperator<M> for CooleyTukeyNtt {
fn ntt(poly: &mut Polynomial<M>) -> Result<()> {
let n = M::N;
if n & (n - 1) != 0 {
return Err(Error::Parameter {
name: "NTT".into(),
reason: "Polynomial degree must be a power of 2".into(),
});
}
let coeffs = poly.as_mut_coeffs_slice();
let is_dilithium = !M::ZETAS.is_empty();
if is_dilithium {
let mut k = 0;
let mut len = n / 2;
while len >= 1 {
for start in (0..n).step_by(2 * len) {
let zeta = M::ZETAS[k]; k += 1;
for j in start..start + len {
let a = coeffs[j];
let b = coeffs[j + len];
let t = montgomery_mul::<M>(b, zeta);
coeffs[j] = add_mod::<M>(a, t);
coeffs[j + len] = sub_mod_upto_2q::<M>(a, t);
}
}
len >>= 1;
}
for c in coeffs.iter_mut() {
*c = reduce_to_q::<M>(*c);
}
} else {
for c in coeffs.iter_mut() {
*c = to_montgomery::<M>(*c);
}
let mut len = 1_usize;
while len < n {
let exp = n / (len << 1);
let root_std = pow_mod::<M>(M::ZETA, exp as u32);
let root_mont = to_montgomery::<M>(root_std);
for start in (0..n).step_by(len << 1) {
let mut w_mont = M::MONT_R;
for j in 0..len {
let u = coeffs[start + j];
let v = montgomery_mul::<M>(coeffs[start + j + len], w_mont);
coeffs[start + j] = add_mod_fast::<M>(u, v);
coeffs[start + j + len] = sub_mod_fast::<M>(u, v);
w_mont = montgomery_mul::<M>(w_mont, root_mont);
}
}
len <<= 1;
}
}
Ok(())
}
}
impl<M: NttModulus> InverseNttOperator<M> for CooleyTukeyNtt {
fn inv_ntt(poly: &mut Polynomial<M>) -> Result<()> {
let n = M::N;
if n & (n - 1) != 0 {
return Err(Error::Parameter {
name: "Inverse NTT".into(),
reason: "Polynomial degree must be a power of 2".into(),
});
}
let coeffs = poly.as_mut_coeffs_slice();
let is_dilithium = !M::ZETAS.is_empty();
if is_dilithium {
for c in coeffs.iter_mut() {
*c = reduce_to_q::<M>(*c);
}
let mut k = M::ZETAS.len(); let mut len = 1;
while len < n {
for start in (0..n).step_by(2 * len) {
k -= 1;
let zeta_fwd = M::ZETAS[k];
let zeta = if zeta_fwd == 0 { 0 } else { M::Q - zeta_fwd };
for j in start..start + len {
let t = coeffs[j];
let u = coeffs[j + len];
coeffs[j] = add_mod::<M>(t, u);
let diff = sub_mod_upto_2q::<M>(t, u);
coeffs[j + len] = montgomery_mul::<M>(diff, zeta);
}
}
len <<= 1;
}
for c in coeffs.iter_mut() {
*c = reduce_to_q::<M>(*c);
}
let n_inv_std = pow_mod::<M>(M::N as u32, M::Q - 2);
for c in coeffs.iter_mut() {
*c = ((*c as u64 * n_inv_std as u64) % M::Q as u64) as u32;
}
match M::POST_INVNTT_MODE {
PostInvNtt::Standard => {} PostInvNtt::Montgomery => {
for c in coeffs.iter_mut() {
*c = to_montgomery::<M>(*c);
}
}
}
} else {
let root_inv_std = pow_mod::<M>(M::ZETA, M::Q - 2);
let mut len = n >> 1;
while len >= 1 {
let exp = n / (len << 1);
let root_std = pow_mod::<M>(root_inv_std, exp as u32);
let root_mont = to_montgomery::<M>(root_std);
for start in (0..n).step_by(len << 1) {
let mut w_mont = M::MONT_R;
for j in 0..len {
let u = coeffs[start + j];
let v = coeffs[start + j + len];
coeffs[start + j] = add_mod_fast::<M>(u, v);
coeffs[start + j + len] =
montgomery_mul::<M>(sub_mod_fast::<M>(u, v), w_mont);
w_mont = montgomery_mul::<M>(w_mont, root_mont);
}
}
len >>= 1;
}
for c in coeffs.iter_mut() {
*c = montgomery_mul::<M>(*c, M::N_INV);
}
if M::POST_INVNTT_MODE == PostInvNtt::Standard {
for c in coeffs.iter_mut() {
*c = montgomery_reduce::<M>(*c as u64);
}
}
}
Ok(())
}
}
impl<M: NttModulus> Polynomial<M> {
pub fn ntt_inplace(&mut self) -> Result<()> {
CooleyTukeyNtt::ntt(self)
}
pub fn from_ntt_inplace(&mut self) -> Result<()> {
CooleyTukeyNtt::inv_ntt(self)
}
pub fn ntt_mul(&self, other: &Self) -> Self {
let mut result = Self::zero();
let n = M::N;
let is_dilithium = !M::ZETAS.is_empty();
if is_dilithium {
for i in 0..n {
result.coeffs[i] =
((self.coeffs[i] as u64 * other.coeffs[i] as u64) % M::Q as u64) as u32;
}
} else {
for i in 0..n {
result.coeffs[i] = montgomery_mul::<M>(self.coeffs[i], other.coeffs[i]);
}
}
result
}
}
#[cfg(test)]
mod tests;