#![deny(unsafe_code)]
#![deny(missing_docs)]
#![deny(clippy::unwrap_used)]
#![deny(clippy::panic)]
#![allow(clippy::arithmetic_side_effects)]
#![allow(clippy::cast_possible_truncation)]
#![allow(clippy::cast_possible_wrap)]
#![allow(clippy::indexing_slicing)]
#![allow(clippy::unnecessary_wraps)]
use crate::prelude::error::{LatticeArcError, Result};
use crate::primitives::polynomial::arithmetic::{mod_inverse, mod_pow};
pub struct NttProcessor {
forward_twiddles: Vec<i32>,
inverse_twiddles: Vec<i32>,
pub(crate) n: usize,
pub(crate) modulus: i64,
primitive_root: i32,
}
impl NttProcessor {
#[must_use]
pub fn forward_twiddles(&self) -> &[i32] {
&self.forward_twiddles
}
#[must_use]
pub fn inverse_twiddles(&self) -> &[i32] {
&self.inverse_twiddles
}
fn find_primitive_root(n: usize, modulus: i64) -> Result<i32> {
match (n, modulus) {
(256, 3329) => Ok(17), (512, 12289) => Ok(49), (1024, 12289) => Ok(49), _ => Err(LatticeArcError::InvalidInput(format!(
"No known primitive root for N={}, modulus={}",
n, modulus
))),
}
}
fn compute_twiddles(n: usize, primitive_root: i32, modulus: i64) -> Result<Vec<i32>> {
let mut twiddles = vec![0i32; n];
let root_pow = mod_pow(i64::from(primitive_root), (modulus - 1) / n as i64, modulus);
twiddles[0] = 1;
for (i, twiddle) in twiddles.iter_mut().enumerate().skip(1).take(n - 1) {
*twiddle = mod_pow(root_pow, i as i64, modulus) as i32;
}
Ok(twiddles)
}
fn compute_inverse_twiddles(n: usize, primitive_root: i32, modulus: i64) -> Result<Vec<i32>> {
let mut twiddles = vec![0i32; n];
let root_pow = mod_pow(i64::from(primitive_root), (modulus - 1) / n as i64, modulus);
let inv_root_pow = mod_inverse(root_pow, modulus)?;
twiddles[0] = 1;
for (i, twiddle) in twiddles.iter_mut().enumerate().skip(1).take(n - 1) {
*twiddle = mod_pow(inv_root_pow, i as i64, modulus) as i32;
}
Ok(twiddles)
}
pub fn new(n: usize, modulus: i64) -> Result<Self> {
if !n.is_power_of_two() {
return Err(LatticeArcError::InvalidInput("NTT size must be a power of 2".to_string()));
}
if modulus <= 1 {
return Err(LatticeArcError::InvalidInput(
"Modulus must be greater than 1".to_string(),
));
}
let primitive_root = Self::find_primitive_root(n, modulus)?;
let forward_twiddles = Self::compute_twiddles(n, primitive_root, modulus)?;
let inverse_twiddles = Self::compute_inverse_twiddles(n, primitive_root, modulus)?;
Ok(Self { forward_twiddles, inverse_twiddles, n, modulus, primitive_root })
}
pub fn forward(&self, coeffs: &[i32]) -> Result<Vec<i32>> {
if coeffs.len() != self.n {
return Err(LatticeArcError::InvalidInput(format!(
"Input length {} doesn't match NTT size {}",
coeffs.len(),
self.n
)));
}
let mut result = coeffs.to_vec();
self.ntt(&mut result, false)?;
Ok(result)
}
pub fn inverse(&self, evaluations: &[i32]) -> Result<Vec<i32>> {
if evaluations.len() != self.n {
return Err(LatticeArcError::InvalidInput(format!(
"Input length {} doesn't match NTT size {}",
evaluations.len(),
self.n
)));
}
let mut result = evaluations.to_vec();
self.ntt(&mut result, true)?;
let n_i64 = i64::try_from(self.n).map_err(|_e| {
LatticeArcError::InvalidInput("NTT size exceeds i64 range".to_string())
})?;
let n_inv_i64 = mod_inverse(n_i64, self.modulus)?;
let n_inv = i32::try_from(n_inv_i64).map_err(|_e| {
LatticeArcError::InvalidInput("NTT inverse exceeds i32 range".to_string())
})?;
for coeff in &mut result {
*coeff = self.mod_mul(*coeff, n_inv);
}
Ok(result)
}
pub fn multiply(&self, a: &[i32], b: &[i32]) -> Result<Vec<i32>> {
if a.len() != self.n || b.len() != self.n {
return Err(LatticeArcError::InvalidInput(
"Polynomial lengths must match NTT size".to_string(),
));
}
let a_eval = self.forward(a)?;
let b_eval = self.forward(b)?;
let mut c_eval = vec![0i32; self.n];
for i in 0..self.n {
if let (Some(&a_val), Some(&b_val)) = (a_eval.get(i), b_eval.get(i))
&& let Some(c_val) = c_eval.get_mut(i)
{
*c_val = self.mod_mul(a_val, b_val);
}
}
self.inverse(&c_eval)
}
fn mod_mul(&self, a: i32, b: i32) -> i32 {
((i64::from(a) * i64::from(b)) % self.modulus) as i32
}
fn mod_add(&self, a: i32, b: i32) -> i32 {
let sum = i64::from(a) + i64::from(b);
(sum % self.modulus) as i32
}
fn mod_sub(&self, a: i32, b: i32) -> i32 {
let diff = i64::from(a) - i64::from(b);
let result = diff % self.modulus;
if result < 0 { (result + self.modulus) as i32 } else { result as i32 }
}
fn ntt(&self, data: &mut [i32], inverse: bool) -> Result<()> {
let n = data.len();
let mut j = 0;
for i in 1..n {
let mut bit = n >> 1;
while j & bit != 0 {
j ^= bit;
bit >>= 1;
}
j ^= bit;
if i < j {
data.swap(i, j);
}
}
let root = if inverse {
let inv = mod_inverse(i64::from(self.primitive_root), self.modulus)?;
i32::try_from(inv).map_err(|_e| {
LatticeArcError::InvalidInput("Root inverse exceeds i32 range".to_string())
})?
} else {
self.primitive_root
};
let mut length = 2;
while length <= n {
let length_i64 = i64::try_from(length).map_err(|_e| {
LatticeArcError::InvalidInput("NTT length exceeds i64 range".to_string())
})?;
let wlen_i64 = mod_pow(i64::from(root), (self.modulus - 1) / length_i64, self.modulus);
let wlen = i32::try_from(wlen_i64).map_err(|_e| {
LatticeArcError::InvalidInput("Twiddle factor exceeds i32 range".to_string())
})?;
let mut i = 0;
while i < n {
let mut w = 1;
let half_len = length / 2;
let (left, right) = data.split_at_mut(i + half_len);
let u_slice = &mut left[i..];
let v_slice = &mut right[..half_len];
for j in 0..half_len {
if let (Some(&u), Some(&v_data)) = (u_slice.get(j), v_slice.get(j)) {
let v = self.mod_mul(v_data, w);
if let Some(u_out) = u_slice.get_mut(j) {
*u_out = self.mod_add(u, v);
}
if let Some(v_out) = v_slice.get_mut(j) {
*v_out = self.mod_sub(u, v);
}
}
w = self.mod_mul(w, wlen);
}
i += length;
}
length *= 2;
}
Ok(())
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
#[allow(clippy::expect_used)]
#[allow(clippy::arithmetic_side_effects)]
#[allow(clippy::indexing_slicing)]
#[allow(clippy::cast_possible_truncation)]
#[allow(clippy::cast_possible_wrap)]
mod tests {
use super::*;
#[test]
fn test_ntt_new_kyber_params_succeeds_with_correct_size_has_correct_size() {
let ntt = NttProcessor::new(256, 3329).expect("Kyber params should work");
assert_eq!(ntt.n, 256);
assert_eq!(ntt.modulus, 3329);
assert_eq!(ntt.forward_twiddles().len(), 256);
assert_eq!(ntt.inverse_twiddles().len(), 256);
}
#[test]
fn test_ntt_new_dilithium_512_succeeds() {
let ntt = NttProcessor::new(512, 12289).expect("Dilithium params should work");
assert_eq!(ntt.n, 512);
assert_eq!(ntt.modulus, 12289);
}
#[test]
fn test_ntt_new_dilithium_1024_succeeds() {
let ntt = NttProcessor::new(1024, 12289).expect("Dilithium params should work");
assert_eq!(ntt.n, 1024);
}
#[test]
fn test_ntt_new_non_power_of_two_returns_error() {
let result = NttProcessor::new(100, 3329);
assert!(result.is_err());
}
#[test]
fn test_ntt_new_invalid_modulus_returns_error() {
let result = NttProcessor::new(256, 0);
assert!(result.is_err());
let result = NttProcessor::new(256, 1);
assert!(result.is_err());
}
#[test]
fn test_ntt_new_unknown_params_returns_error() {
let result = NttProcessor::new(256, 7681);
assert!(result.is_err());
}
#[test]
fn test_ntt_forward_inverse_roundtrip_kyber_roundtrip() {
let ntt = NttProcessor::new(256, 3329).unwrap();
let mut coeffs = vec![0i32; 256];
coeffs[0] = 1;
coeffs[1] = 2;
coeffs[2] = 3;
coeffs[10] = 42;
let eval = ntt.forward(&coeffs).unwrap();
let recovered = ntt.inverse(&eval).unwrap();
for (i, (&original, &result)) in coeffs.iter().zip(recovered.iter()).enumerate() {
let orig_mod = (i64::from(original) % 3329 + 3329) % 3329;
let res_mod = (i64::from(result) % 3329 + 3329) % 3329;
assert_eq!(orig_mod, res_mod, "Mismatch at index {}", i);
}
}
#[test]
fn test_ntt_forward_inverse_roundtrip_dilithium_roundtrip() {
let ntt = NttProcessor::new(512, 12289).unwrap();
let mut coeffs = vec![0i32; 512];
coeffs[0] = 100;
coeffs[1] = 200;
coeffs[255] = 500;
let eval = ntt.forward(&coeffs).unwrap();
let recovered = ntt.inverse(&eval).unwrap();
for (i, (&original, &result)) in coeffs.iter().zip(recovered.iter()).enumerate() {
let orig_mod = (i64::from(original) % 12289 + 12289) % 12289;
let res_mod = (i64::from(result) % 12289 + 12289) % 12289;
assert_eq!(orig_mod, res_mod, "Mismatch at index {}", i);
}
}
#[test]
fn test_ntt_forward_wrong_size_returns_error() {
let ntt = NttProcessor::new(256, 3329).unwrap();
let too_short = vec![1i32; 128];
assert!(ntt.forward(&too_short).is_err());
let too_long = vec![1i32; 512];
assert!(ntt.forward(&too_long).is_err());
}
#[test]
fn test_ntt_inverse_wrong_size_returns_error() {
let ntt = NttProcessor::new(256, 3329).unwrap();
let wrong = vec![1i32; 100];
assert!(ntt.inverse(&wrong).is_err());
}
#[test]
fn test_ntt_multiply_identity_preserves_polynomial_succeeds() {
let ntt = NttProcessor::new(256, 3329).unwrap();
let mut a = vec![0i32; 256];
a[0] = 5;
a[1] = 3;
let mut identity = vec![0i32; 256];
identity[0] = 1;
let result = ntt.multiply(&a, &identity).unwrap();
let a0_mod = (i64::from(a[0]) % 3329 + 3329) % 3329;
let r0_mod = (i64::from(result[0]) % 3329 + 3329) % 3329;
assert_eq!(a0_mod, r0_mod, "Multiplication by identity should preserve a[0]");
}
#[test]
fn test_ntt_multiply_zero_returns_all_zeros_succeeds() {
let ntt = NttProcessor::new(256, 3329).unwrap();
let mut a = vec![0i32; 256];
a[0] = 42;
let zero = vec![0i32; 256];
let result = ntt.multiply(&a, &zero).unwrap();
for (i, &val) in result.iter().enumerate() {
assert_eq!(val, 0, "a * 0 should be zero at index {}", i);
}
}
#[test]
fn test_ntt_multiply_wrong_size_returns_error() {
let ntt = NttProcessor::new(256, 3329).unwrap();
let a = vec![1i32; 256];
let b = vec![1i32; 128];
assert!(ntt.multiply(&a, &b).is_err());
assert!(ntt.multiply(&b, &a).is_err());
}
#[test]
fn test_twiddle_factors_first_element_is_one_succeeds() {
let ntt = NttProcessor::new(256, 3329).unwrap();
assert_eq!(ntt.forward_twiddles()[0], 1);
assert_eq!(ntt.inverse_twiddles()[0], 1);
}
#[test]
fn test_twiddle_factors_length_matches_ntt_size_has_correct_size() {
let ntt = NttProcessor::new(256, 3329).unwrap();
assert_eq!(ntt.forward_twiddles().len(), 256);
assert_eq!(ntt.inverse_twiddles().len(), 256);
}
#[test]
fn test_ntt_forward_zero_polynomial_returns_all_zeros_succeeds() {
let ntt = NttProcessor::new(256, 3329).unwrap();
let zeros = vec![0i32; 256];
let result = ntt.forward(&zeros).unwrap();
for &val in &result {
assert_eq!(val, 0, "NTT of zero polynomial should be zero");
}
}
#[test]
fn test_ntt_forward_inverse_roundtrip_1024_roundtrip() {
let ntt = NttProcessor::new(1024, 12289).unwrap();
let mut coeffs = vec![0i32; 1024];
coeffs[0] = 50;
coeffs[1] = 100;
coeffs[511] = 200;
coeffs[1023] = 300;
let eval = ntt.forward(&coeffs).unwrap();
let recovered = ntt.inverse(&eval).unwrap();
for (i, (&original, &result)) in coeffs.iter().zip(recovered.iter()).enumerate() {
let orig_mod = (i64::from(original) % 12289 + 12289) % 12289;
let res_mod = (i64::from(result) % 12289 + 12289) % 12289;
assert_eq!(orig_mod, res_mod, "Mismatch at index {}", i);
}
}
#[test]
fn test_ntt_multiply_1024_by_zero_returns_zeros_succeeds() {
let ntt = NttProcessor::new(1024, 12289).unwrap();
let mut a = vec![0i32; 1024];
a[0] = 7;
let zero = vec![0i32; 1024];
let result = ntt.multiply(&a, &zero).unwrap();
for (i, &val) in result.iter().enumerate() {
assert_eq!(val, 0, "a * 0 should be zero at index {}", i);
}
}
#[test]
fn test_ntt_negative_modulus_returns_error() {
let result = NttProcessor::new(256, -1);
assert!(result.is_err());
}
#[test]
fn test_ntt_inverse_zero_polynomial_returns_all_zeros_succeeds() {
let ntt = NttProcessor::new(256, 3329).unwrap();
let zeros = vec![0i32; 256];
let result = ntt.inverse(&zeros).unwrap();
for &val in &result {
assert_eq!(val, 0, "Inverse NTT of zeros should be zeros");
}
}
#[test]
fn test_ntt_multiply_commutativity_holds_succeeds() {
let ntt = NttProcessor::new(256, 3329).unwrap();
let mut a = vec![0i32; 256];
a[0] = 3;
a[1] = 5;
let mut b = vec![0i32; 256];
b[0] = 7;
b[1] = 2;
let ab = ntt.multiply(&a, &b).unwrap();
let ba = ntt.multiply(&b, &a).unwrap();
assert_eq!(ab, ba, "Polynomial multiplication should be commutative");
}
#[test]
fn test_ntt_twiddle_factors_1024_has_correct_length_and_first_element_has_correct_size() {
let ntt = NttProcessor::new(1024, 12289).unwrap();
assert_eq!(ntt.forward_twiddles().len(), 1024);
assert_eq!(ntt.inverse_twiddles().len(), 1024);
assert_eq!(ntt.forward_twiddles()[0], 1);
assert_eq!(ntt.inverse_twiddles()[0], 1);
}
}