#![cfg_attr(not(feature = "std"), no_std)]
#[cfg(feature = "alloc")]
extern crate alloc;
#[cfg(feature = "alloc")]
use alloc::vec::Vec;
use super::ntt::montgomery_reduce;
use super::params::{Modulus, NttModulus}; use crate::error::{Error, Result};
use core::marker::PhantomData;
use core::ops::{Add, Neg, Sub};
use zeroize::Zeroize;
#[inline(always)]
fn to_montgomery<M: NttModulus>(val: u32) -> u32 {
((val as u64 * M::MONT_R as u64) % M::Q as u64) as u32
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Polynomial<M: Modulus> {
#[cfg(feature = "alloc")]
pub coeffs: Vec<u32>,
#[cfg(not(feature = "alloc"))]
pub coeffs: [u32; 256], _marker: PhantomData<M>,
}
impl<M: Modulus> Zeroize for Polynomial<M> {
fn zeroize(&mut self) {
#[cfg(feature = "alloc")]
{
for coeff in self.coeffs.iter_mut() {
coeff.zeroize();
}
}
#[cfg(not(feature = "alloc"))]
{
self.coeffs.zeroize();
}
}
}
impl<M: Modulus> Polynomial<M> {
pub fn zero() -> Self {
Self {
coeffs: vec![0; M::N], _marker: PhantomData,
}
}
pub fn from_coeffs(coeffs_slice: &[u32]) -> Result<Self> {
if coeffs_slice.len() != M::N {
return Err(Error::Parameter {
name: "coeffs_slice".into(),
reason: "Incorrect number of coefficients for polynomial degree N".into(),
});
}
#[cfg(feature = "alloc")]
let coeffs = coeffs_slice.to_vec();
#[cfg(not(feature = "alloc"))]
let mut coeffs = [0u32; 256];
#[cfg(not(feature = "alloc"))]
coeffs[..M::N].copy_from_slice(coeffs_slice);
Ok(Self {
coeffs,
_marker: PhantomData,
})
}
pub fn degree() -> usize {
M::N
}
pub fn modulus_q() -> u32 {
M::Q
}
pub fn as_coeffs_slice(&self) -> &[u32] {
&self.coeffs[..M::N]
}
pub fn as_mut_coeffs_slice(&mut self) -> &mut [u32] {
&mut self.coeffs[..M::N]
}
#[inline(always)]
fn reduce_coefficient(a: u32) -> u32 {
let q = M::Q;
let mask = ((a >= q) as u32).wrapping_neg();
a.wrapping_sub(q & mask)
}
#[inline(always)]
fn conditional_sub_q(a: i64) -> u32 {
let q = M::Q as i64;
a.rem_euclid(q) as u32
}
pub fn add(&self, other: &Self) -> Self {
let mut result = Self::zero();
for i in 0..M::N {
let sum = self.coeffs[i].wrapping_add(other.coeffs[i]);
result.coeffs[i] = Self::reduce_coefficient(sum);
}
result
}
pub fn sub(&self, other: &Self) -> Self {
let mut result = Self::zero();
for i in 0..M::N {
let diff = (self.coeffs[i] as i64) - (other.coeffs[i] as i64);
result.coeffs[i] = Self::conditional_sub_q(diff);
}
result
}
pub fn neg(&self) -> Self {
let mut result = Self::zero();
for i in 0..M::N {
let mask = ((self.coeffs[i] != 0) as u32).wrapping_neg();
result.coeffs[i] = (M::Q - self.coeffs[i]) & mask;
}
result
}
pub fn scalar_mul(&self, scalar: u32) -> Self {
let mut result = Self::zero();
for i in 0..M::N {
let prod = (self.coeffs[i] as u64) * (scalar as u64);
result.coeffs[i] = (prod % M::Q as u64) as u32;
}
result
}
pub fn schoolbook_mul(&self, other: &Self) -> Self {
let mut result = Self::zero();
let n = M::N;
let q = M::Q as u64;
let mut tmp = vec![0u64; 2 * n];
for (i, &ai_u32) in self.coeffs.iter().enumerate().take(n) {
let ai = ai_u32 as u64;
for (j, &bj_u32) in other.coeffs.iter().enumerate().take(n) {
let bj = bj_u32 as u64;
tmp[i + j] = tmp[i + j].wrapping_add(ai * bj);
}
}
for k in n..(2 * n) {
let upper_val = tmp[k] % q;
if upper_val > 0 {
tmp[k - n] = (tmp[k - n] + q - upper_val) % q;
}
}
#[allow(clippy::needless_range_loop)]
for i in 0..n {
result.coeffs[i] = (tmp[i] % q) as u32;
}
result
}
pub fn reduce_coeffs(&mut self) {
for i in 0..M::N {
self.coeffs[i] = Self::reduce_coefficient(self.coeffs[i]);
}
}
}
pub trait PolynomialNttExt<M: NttModulus> {
fn scalar_mul_montgomery(&self, scalar: u32) -> Polynomial<M>;
}
impl<M: NttModulus> PolynomialNttExt<M> for Polynomial<M> {
fn scalar_mul_montgomery(&self, scalar: u32) -> Polynomial<M> {
let mut result = Polynomial::<M>::zero();
let scalar_mont = to_montgomery::<M>(scalar);
for i in 0..M::N {
let prod = (self.coeffs[i] as u64) * (scalar_mont as u64);
result.coeffs[i] = montgomery_reduce::<M>(prod);
}
result
}
}
#[inline(always)]
pub fn barrett_reduce<M: Modulus>(a: u32) -> u32 {
a % M::Q
}
impl<M: Modulus> Add for &Polynomial<M> {
type Output = Polynomial<M>;
fn add(self, other: Self) -> Self::Output {
self.add(other)
}
}
impl<M: Modulus> Sub for &Polynomial<M> {
type Output = Polynomial<M>;
fn sub(self, other: Self) -> Self::Output {
self.sub(other)
}
}
impl<M: Modulus> Neg for &Polynomial<M> {
type Output = Polynomial<M>;
fn neg(self) -> Self::Output {
self.neg()
}
}
impl<M: Modulus> Add for Polynomial<M> {
type Output = Self;
fn add(self, other: Self) -> Self::Output {
&self + &other
}
}
impl<M: Modulus> Sub for Polynomial<M> {
type Output = Self;
fn sub(self, other: Self) -> Self::Output {
&self - &other
}
}
impl<M: Modulus> Neg for Polynomial<M> {
type Output = Self;
fn neg(self) -> Self::Output {
-&self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Clone)]
struct TestModulus;
impl Modulus for TestModulus {
const Q: u32 = 3329; const N: usize = 4; }
#[test]
fn test_polynomial_creation() {
let poly = Polynomial::<TestModulus>::zero();
assert_eq!(poly.as_coeffs_slice(), &[0, 0, 0, 0]);
let coeffs = vec![1, 2, 3, 4];
let poly = Polynomial::<TestModulus>::from_coeffs(&coeffs).unwrap();
assert_eq!(poly.as_coeffs_slice(), &[1, 2, 3, 4]);
}
#[test]
fn test_polynomial_addition() {
let a = Polynomial::<TestModulus>::from_coeffs(&[1, 2, 3, 4]).unwrap();
let b = Polynomial::<TestModulus>::from_coeffs(&[5, 6, 7, 8]).unwrap();
let c = a + b;
assert_eq!(c.as_coeffs_slice(), &[6, 8, 10, 12]);
}
#[test]
fn test_polynomial_subtraction() {
let a = Polynomial::<TestModulus>::from_coeffs(&[10, 20, 30, 40]).unwrap();
let b = Polynomial::<TestModulus>::from_coeffs(&[5, 6, 7, 8]).unwrap();
let c = a - b;
assert_eq!(c.as_coeffs_slice(), &[5, 14, 23, 32]);
}
#[test]
fn test_polynomial_negation() {
let a = Polynomial::<TestModulus>::from_coeffs(&[1, 2, 0, 4]).unwrap();
let neg_a = -a;
assert_eq!(neg_a.as_coeffs_slice(), &[3328, 3327, 0, 3325]);
}
#[test]
fn test_modular_reduction() {
let a = Polynomial::<TestModulus>::from_coeffs(&[3330, 3331, 3328, 0]).unwrap();
let mut b = a.clone();
b.reduce_coeffs();
assert_eq!(b.as_coeffs_slice(), &[1, 2, 3328, 0]);
}
#[test]
fn test_zeroization() {
let mut poly = Polynomial::<TestModulus>::from_coeffs(&[1, 2, 3, 4]).unwrap();
poly.zeroize();
assert_eq!(poly.as_coeffs_slice(), &[0, 0, 0, 0]);
assert_eq!(poly.coeffs.len(), 4); }
#[test]
fn test_schoolbook_mul_negacyclic() {
let mut x_cubed = Polynomial::<TestModulus>::zero();
x_cubed.coeffs[3] = 1;
let mut x = Polynomial::<TestModulus>::zero();
x.coeffs[1] = 1;
let result = x_cubed.schoolbook_mul(&x);
assert_eq!(result.coeffs[0], TestModulus::Q - 1);
assert_eq!(result.coeffs[1], 0);
assert_eq!(result.coeffs[2], 0);
assert_eq!(result.coeffs[3], 0);
let a = Polynomial::<TestModulus>::from_coeffs(&[1, 2, 3, 4]).unwrap();
let b = Polynomial::<TestModulus>::from_coeffs(&[5, 6, 7, 8]).unwrap();
let c = a.schoolbook_mul(&b);
let expected_0 = ((5i32 - 61i32).rem_euclid(TestModulus::Q as i32)) as u32;
let expected_1 = ((16i32 - 52i32).rem_euclid(TestModulus::Q as i32)) as u32;
let expected_2 = ((34i32 - 32i32).rem_euclid(TestModulus::Q as i32)) as u32;
let expected_3 = 60u32;
assert_eq!(c.coeffs[0], expected_0);
assert_eq!(c.coeffs[1], expected_1);
assert_eq!(c.coeffs[2], expected_2);
assert_eq!(c.coeffs[3], expected_3);
}
#[test]
fn test_dilithium_negacyclic() {
#[derive(Clone)]
struct DilithiumTestModulus;
impl Modulus for DilithiumTestModulus {
const Q: u32 = 8380417; const N: usize = 4; }
let mut x_to_n_minus_1 = Polynomial::<DilithiumTestModulus>::zero();
x_to_n_minus_1.coeffs[3] = 1;
let mut x = Polynomial::<DilithiumTestModulus>::zero();
x.coeffs[1] = 1;
let result = x_to_n_minus_1.schoolbook_mul(&x);
assert_eq!(result.coeffs[0], DilithiumTestModulus::Q - 1);
assert_eq!(result.coeffs[1], 0);
assert_eq!(result.coeffs[2], 0);
assert_eq!(result.coeffs[3], 0);
let mut sparse = Polynomial::<DilithiumTestModulus>::zero();
sparse.coeffs[0] = 1; sparse.coeffs[2] = DilithiumTestModulus::Q - 1;
let dense = Polynomial::<DilithiumTestModulus>::from_coeffs(&[100, 200, 300, 400]).unwrap();
let result = sparse.schoolbook_mul(&dense);
assert_eq!(result.coeffs[0], 400);
assert_eq!(result.coeffs[1], 600);
assert_eq!(result.coeffs[2], 200);
assert_eq!(result.coeffs[3], 200);
}
}