1#![cfg_attr(not(feature = "std"), no_std)]
4
5#[cfg(feature = "alloc")]
6extern crate alloc;
7#[cfg(feature = "alloc")]
8use alloc::vec::Vec;
9
10use super::ntt::montgomery_reduce;
11use super::params::{Modulus, NttModulus}; use crate::error::{Error, Result};
13use core::marker::PhantomData;
14use core::ops::{Add, Neg, Sub};
15use zeroize::Zeroize;
16
17#[inline(always)]
19fn to_montgomery<M: NttModulus>(val: u32) -> u32 {
20 ((val as u64 * M::MONT_R as u64) % M::Q as u64) as u32
21}
22
23#[derive(Debug, Clone, PartialEq, Eq)]
25pub struct Polynomial<M: Modulus> {
26 #[cfg(feature = "alloc")]
28 pub coeffs: Vec<u32>,
29 #[cfg(not(feature = "alloc"))]
31 pub coeffs: [u32; 256], _marker: PhantomData<M>,
33}
34
35impl<M: Modulus> Zeroize for Polynomial<M> {
37 fn zeroize(&mut self) {
38 #[cfg(feature = "alloc")]
40 {
41 for coeff in self.coeffs.iter_mut() {
42 coeff.zeroize();
43 }
44 }
45 #[cfg(not(feature = "alloc"))]
46 {
47 self.coeffs.zeroize();
48 }
49 }
50}
51
52impl<M: Modulus> Polynomial<M> {
53 pub fn zero() -> Self {
55 Self {
56 coeffs: vec![0; M::N], _marker: PhantomData,
58 }
59 }
60
61 pub fn from_coeffs(coeffs_slice: &[u32]) -> Result<Self> {
63 if coeffs_slice.len() != M::N {
64 return Err(Error::Parameter {
65 name: "coeffs_slice".into(),
66 reason: "Incorrect number of coefficients for polynomial degree N".into(),
67 });
68 }
69
70 #[cfg(feature = "alloc")]
71 let coeffs = coeffs_slice.to_vec();
72
73 #[cfg(not(feature = "alloc"))]
74 let mut coeffs = [0u32; 256];
75 #[cfg(not(feature = "alloc"))]
76 coeffs[..M::N].copy_from_slice(coeffs_slice);
77
78 Ok(Self {
79 coeffs,
80 _marker: PhantomData,
81 })
82 }
83
84 pub fn degree() -> usize {
86 M::N
87 }
88
89 pub fn modulus_q() -> u32 {
91 M::Q
92 }
93
94 pub fn as_coeffs_slice(&self) -> &[u32] {
96 &self.coeffs[..M::N]
97 }
98
99 pub fn as_mut_coeffs_slice(&mut self) -> &mut [u32] {
101 &mut self.coeffs[..M::N]
102 }
103
104 #[inline(always)]
106 fn reduce_coefficient(a: u32) -> u32 {
107 let q = M::Q;
109 let mask = ((a >= q) as u32).wrapping_neg();
110 a.wrapping_sub(q & mask)
111 }
112
113 #[inline(always)]
116 fn conditional_sub_q(a: i64) -> u32 {
117 let q = M::Q as i64;
118 a.rem_euclid(q) as u32
120 }
121
122 pub fn add(&self, other: &Self) -> Self {
124 let mut result = Self::zero();
125 for i in 0..M::N {
126 let sum = self.coeffs[i].wrapping_add(other.coeffs[i]);
127 result.coeffs[i] = Self::reduce_coefficient(sum);
128 }
129 result
130 }
131
132 pub fn sub(&self, other: &Self) -> Self {
134 let mut result = Self::zero();
135 for i in 0..M::N {
136 let diff = (self.coeffs[i] as i64) - (other.coeffs[i] as i64);
137 result.coeffs[i] = Self::conditional_sub_q(diff);
138 }
139 result
140 }
141
142 pub fn neg(&self) -> Self {
144 let mut result = Self::zero();
145 for i in 0..M::N {
146 let mask = ((self.coeffs[i] != 0) as u32).wrapping_neg();
148 result.coeffs[i] = (M::Q - self.coeffs[i]) & mask;
149 }
150 result
151 }
152
153 pub fn scalar_mul(&self, scalar: u32) -> Self {
155 let mut result = Self::zero();
156 for i in 0..M::N {
157 let prod = (self.coeffs[i] as u64) * (scalar as u64);
158 result.coeffs[i] = (prod % M::Q as u64) as u32;
159 }
160 result
161 }
162
163 pub fn schoolbook_mul(&self, other: &Self) -> Self {
166 let mut result = Self::zero();
167 let n = M::N;
168 let q = M::Q as u64;
169
170 let mut tmp = vec![0u64; 2 * n];
173
174 for (i, &ai_u32) in self.coeffs.iter().enumerate().take(n) {
177 let ai = ai_u32 as u64;
178 for (j, &bj_u32) in other.coeffs.iter().enumerate().take(n) {
179 let bj = bj_u32 as u64;
180 tmp[i + j] = tmp[i + j].wrapping_add(ai * bj);
181 }
182 }
183
184 for k in n..(2 * n) {
187 let upper_val = tmp[k] % q;
190 if upper_val > 0 {
191 tmp[k - n] = (tmp[k - n] + q - upper_val) % q;
193 }
194 }
195
196 #[allow(clippy::needless_range_loop)]
198 for i in 0..n {
200 result.coeffs[i] = (tmp[i] % q) as u32;
201 }
202
203 result
204 }
205
206 pub fn reduce_coeffs(&mut self) {
208 for i in 0..M::N {
209 self.coeffs[i] = Self::reduce_coefficient(self.coeffs[i]);
210 }
211 }
212}
213
214pub trait PolynomialNttExt<M: NttModulus> {
218 fn scalar_mul_montgomery(&self, scalar: u32) -> Polynomial<M>;
221}
222
223impl<M: NttModulus> PolynomialNttExt<M> for Polynomial<M> {
224 fn scalar_mul_montgomery(&self, scalar: u32) -> Polynomial<M> {
226 let mut result = Polynomial::<M>::zero();
227 let scalar_mont = to_montgomery::<M>(scalar);
229 for i in 0..M::N {
230 let prod = (self.coeffs[i] as u64) * (scalar_mont as u64);
232 result.coeffs[i] = montgomery_reduce::<M>(prod);
233 }
234 result
235 }
236}
237
238#[inline(always)]
240pub fn barrett_reduce<M: Modulus>(a: u32) -> u32 {
241 a % M::Q
244}
245
246impl<M: Modulus> Add for &Polynomial<M> {
249 type Output = Polynomial<M>;
250
251 fn add(self, other: Self) -> Self::Output {
252 self.add(other)
253 }
254}
255
256impl<M: Modulus> Sub for &Polynomial<M> {
257 type Output = Polynomial<M>;
258
259 fn sub(self, other: Self) -> Self::Output {
260 self.sub(other)
261 }
262}
263
264impl<M: Modulus> Neg for &Polynomial<M> {
265 type Output = Polynomial<M>;
266
267 fn neg(self) -> Self::Output {
268 self.neg()
269 }
270}
271
272impl<M: Modulus> Add for Polynomial<M> {
274 type Output = Self;
275
276 fn add(self, other: Self) -> Self::Output {
277 &self + &other
278 }
279}
280
281impl<M: Modulus> Sub for Polynomial<M> {
282 type Output = Self;
283
284 fn sub(self, other: Self) -> Self::Output {
285 &self - &other
286 }
287}
288
289impl<M: Modulus> Neg for Polynomial<M> {
290 type Output = Self;
291
292 fn neg(self) -> Self::Output {
293 -&self
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300
301 #[derive(Clone)]
303 struct TestModulus;
304 impl Modulus for TestModulus {
305 const Q: u32 = 3329; const N: usize = 4; }
308
309 #[test]
310 fn test_polynomial_creation() {
311 let poly = Polynomial::<TestModulus>::zero();
312 assert_eq!(poly.as_coeffs_slice(), &[0, 0, 0, 0]);
313
314 let coeffs = vec![1, 2, 3, 4];
315 let poly = Polynomial::<TestModulus>::from_coeffs(&coeffs).unwrap();
316 assert_eq!(poly.as_coeffs_slice(), &[1, 2, 3, 4]);
317 }
318
319 #[test]
320 fn test_polynomial_addition() {
321 let a = Polynomial::<TestModulus>::from_coeffs(&[1, 2, 3, 4]).unwrap();
322 let b = Polynomial::<TestModulus>::from_coeffs(&[5, 6, 7, 8]).unwrap();
323 let c = a + b;
325 assert_eq!(c.as_coeffs_slice(), &[6, 8, 10, 12]);
326 }
327
328 #[test]
329 fn test_polynomial_subtraction() {
330 let a = Polynomial::<TestModulus>::from_coeffs(&[10, 20, 30, 40]).unwrap();
331 let b = Polynomial::<TestModulus>::from_coeffs(&[5, 6, 7, 8]).unwrap();
332 let c = a - b;
334 assert_eq!(c.as_coeffs_slice(), &[5, 14, 23, 32]);
335 }
336
337 #[test]
338 fn test_polynomial_negation() {
339 let a = Polynomial::<TestModulus>::from_coeffs(&[1, 2, 0, 4]).unwrap();
340 let neg_a = -a;
342 assert_eq!(neg_a.as_coeffs_slice(), &[3328, 3327, 0, 3325]);
343 }
344
345 #[test]
346 fn test_modular_reduction() {
347 let a = Polynomial::<TestModulus>::from_coeffs(&[3330, 3331, 3328, 0]).unwrap();
348 let mut b = a.clone();
349 b.reduce_coeffs();
350 assert_eq!(b.as_coeffs_slice(), &[1, 2, 3328, 0]);
351 }
352
353 #[test]
354 fn test_zeroization() {
355 let mut poly = Polynomial::<TestModulus>::from_coeffs(&[1, 2, 3, 4]).unwrap();
356 poly.zeroize();
357 assert_eq!(poly.as_coeffs_slice(), &[0, 0, 0, 0]);
358 assert_eq!(poly.coeffs.len(), 4); }
360
361 #[test]
362 fn test_schoolbook_mul_negacyclic() {
363 let mut x_cubed = Polynomial::<TestModulus>::zero();
366 x_cubed.coeffs[3] = 1; let mut x = Polynomial::<TestModulus>::zero();
369 x.coeffs[1] = 1; let result = x_cubed.schoolbook_mul(&x);
372 assert_eq!(result.coeffs[0], TestModulus::Q - 1);
374 assert_eq!(result.coeffs[1], 0);
375 assert_eq!(result.coeffs[2], 0);
376 assert_eq!(result.coeffs[3], 0);
377
378 let a = Polynomial::<TestModulus>::from_coeffs(&[1, 2, 3, 4]).unwrap();
380 let b = Polynomial::<TestModulus>::from_coeffs(&[5, 6, 7, 8]).unwrap();
381 let c = a.schoolbook_mul(&b);
382
383 let expected_0 = ((5i32 - 61i32).rem_euclid(TestModulus::Q as i32)) as u32;
407 let expected_1 = ((16i32 - 52i32).rem_euclid(TestModulus::Q as i32)) as u32;
408 let expected_2 = ((34i32 - 32i32).rem_euclid(TestModulus::Q as i32)) as u32;
409 let expected_3 = 60u32;
410
411 assert_eq!(c.coeffs[0], expected_0);
412 assert_eq!(c.coeffs[1], expected_1);
413 assert_eq!(c.coeffs[2], expected_2);
414 assert_eq!(c.coeffs[3], expected_3);
415 }
416
417 #[test]
418 fn test_dilithium_negacyclic() {
419 #[derive(Clone)]
421 struct DilithiumTestModulus;
422 impl Modulus for DilithiumTestModulus {
423 const Q: u32 = 8380417; const N: usize = 4; }
426
427 let mut x_to_n_minus_1 = Polynomial::<DilithiumTestModulus>::zero();
429 x_to_n_minus_1.coeffs[3] = 1; let mut x = Polynomial::<DilithiumTestModulus>::zero();
432 x.coeffs[1] = 1; let result = x_to_n_minus_1.schoolbook_mul(&x);
435 assert_eq!(result.coeffs[0], DilithiumTestModulus::Q - 1);
437 assert_eq!(result.coeffs[1], 0);
438 assert_eq!(result.coeffs[2], 0);
439 assert_eq!(result.coeffs[3], 0);
440
441 let mut sparse = Polynomial::<DilithiumTestModulus>::zero();
443 sparse.coeffs[0] = 1; sparse.coeffs[2] = DilithiumTestModulus::Q - 1; let dense = Polynomial::<DilithiumTestModulus>::from_coeffs(&[100, 200, 300, 400]).unwrap();
447 let result = sparse.schoolbook_mul(&dense);
448
449 assert_eq!(result.coeffs[0], 400);
457 assert_eq!(result.coeffs[1], 600);
458 assert_eq!(result.coeffs[2], 200);
459 assert_eq!(result.coeffs[3], 200);
460 }
461}