dcrypt_algorithms/poly/
polynomial.rs

1//! polynomial.rs - Enhanced implementation with arithmetic operations
2
3#![cfg_attr(not(feature = "std"), no_std)]
4
5#[cfg(feature = "alloc")]
6extern crate alloc;
7#[cfg(feature = "alloc")]
8use alloc::vec::Vec;
9
10use super::ntt::montgomery_reduce;
11use super::params::{Modulus, NttModulus}; // FIXED: Import NttModulus from params
12use crate::error::{Error, Result};
13use core::marker::PhantomData;
14use core::ops::{Add, Neg, Sub};
15use zeroize::Zeroize;
16
17/// Convert a value from standard domain to Montgomery domain
18#[inline(always)]
19fn to_montgomery<M: NttModulus>(val: u32) -> u32 {
20    ((val as u64 * M::MONT_R as u64) % M::Q as u64) as u32
21}
22
23/// A polynomial in a ring `R_Q = Z_Q[X]/(X^N + 1)`
24#[derive(Debug, Clone, PartialEq, Eq)]
25pub struct Polynomial<M: Modulus> {
26    /// Coefficients of the polynomial, stored in standard representation
27    #[cfg(feature = "alloc")]
28    pub coeffs: Vec<u32>,
29    /// Coefficients of the polynomial, stored in standard representation
30    #[cfg(not(feature = "alloc"))]
31    pub coeffs: [u32; 256], // Will need const generics for proper support
32    _marker: PhantomData<M>,
33}
34
35// Custom Zeroize implementation that preserves vector length
36impl<M: Modulus> Zeroize for Polynomial<M> {
37    fn zeroize(&mut self) {
38        // Zero all coefficients without changing the length
39        #[cfg(feature = "alloc")]
40        {
41            for coeff in self.coeffs.iter_mut() {
42                coeff.zeroize();
43            }
44        }
45        #[cfg(not(feature = "alloc"))]
46        {
47            self.coeffs.zeroize();
48        }
49    }
50}
51
52impl<M: Modulus> Polynomial<M> {
53    /// Creates a new polynomial with all coefficients set to zero
54    pub fn zero() -> Self {
55        Self {
56            coeffs: vec![0; M::N], // length = 256, every coeff = 0
57            _marker: PhantomData,
58        }
59    }
60
61    /// Creates a polynomial from a slice of coefficients
62    pub fn from_coeffs(coeffs_slice: &[u32]) -> Result<Self> {
63        if coeffs_slice.len() != M::N {
64            return Err(Error::Parameter {
65                name: "coeffs_slice".into(),
66                reason: "Incorrect number of coefficients for polynomial degree N".into(),
67            });
68        }
69
70        #[cfg(feature = "alloc")]
71        let coeffs = coeffs_slice.to_vec();
72
73        #[cfg(not(feature = "alloc"))]
74        let mut coeffs = [0u32; 256];
75        #[cfg(not(feature = "alloc"))]
76        coeffs[..M::N].copy_from_slice(coeffs_slice);
77
78        Ok(Self {
79            coeffs,
80            _marker: PhantomData,
81        })
82    }
83
84    /// Returns the degree N of the polynomial
85    pub fn degree() -> usize {
86        M::N
87    }
88
89    /// Returns the modulus Q for coefficient arithmetic
90    pub fn modulus_q() -> u32 {
91        M::Q
92    }
93
94    /// Returns a slice view of the coefficients
95    pub fn as_coeffs_slice(&self) -> &[u32] {
96        &self.coeffs[..M::N]
97    }
98
99    /// Returns a mutable slice view of the coefficients
100    pub fn as_mut_coeffs_slice(&mut self) -> &mut [u32] {
101        &mut self.coeffs[..M::N]
102    }
103
104    /// Branch-free modular reduction of a single coefficient
105    #[inline(always)]
106    fn reduce_coefficient(a: u32) -> u32 {
107        // Branch-free reduction: a - Q if a >= Q else a
108        let q = M::Q;
109        let mask = ((a >= q) as u32).wrapping_neg();
110        a.wrapping_sub(q & mask)
111    }
112
113    /// Branch-free conditional subtraction for signed results
114    /// FIXED: Simplified to use rem_euclid for proper modular arithmetic
115    #[inline(always)]
116    fn conditional_sub_q(a: i64) -> u32 {
117        let q = M::Q as i64;
118        // Use rem_euclid for proper modular arithmetic
119        a.rem_euclid(q) as u32
120    }
121
122    /// Polynomial addition modulo Q
123    pub fn add(&self, other: &Self) -> Self {
124        let mut result = Self::zero();
125        for i in 0..M::N {
126            let sum = self.coeffs[i].wrapping_add(other.coeffs[i]);
127            result.coeffs[i] = Self::reduce_coefficient(sum);
128        }
129        result
130    }
131
132    /// Polynomial subtraction modulo Q
133    pub fn sub(&self, other: &Self) -> Self {
134        let mut result = Self::zero();
135        for i in 0..M::N {
136            let diff = (self.coeffs[i] as i64) - (other.coeffs[i] as i64);
137            result.coeffs[i] = Self::conditional_sub_q(diff);
138        }
139        result
140    }
141
142    /// Polynomial negation modulo Q
143    pub fn neg(&self) -> Self {
144        let mut result = Self::zero();
145        for i in 0..M::N {
146            // Mask is 0xFFFF_FFFF when coeff ≠ 0, 0 otherwise
147            let mask = ((self.coeffs[i] != 0) as u32).wrapping_neg();
148            result.coeffs[i] = (M::Q - self.coeffs[i]) & mask;
149        }
150        result
151    }
152
153    /// Scalar multiplication
154    pub fn scalar_mul(&self, scalar: u32) -> Self {
155        let mut result = Self::zero();
156        for i in 0..M::N {
157            let prod = (self.coeffs[i] as u64) * (scalar as u64);
158            result.coeffs[i] = (prod % M::Q as u64) as u32;
159        }
160        result
161    }
162
163    /// Schoolbook polynomial multiplication with NEGACYCLIC reduction for Dilithium
164    /// In ring `R_q[x]/(x^N + 1)`, when degree >= N, we have `x^N ≡ -1`
165    pub fn schoolbook_mul(&self, other: &Self) -> Self {
166        let mut result = Self::zero();
167        let n = M::N;
168        let q = M::Q as u64;
169
170        // Use a temporary array to accumulate products without modular reduction
171        // This prevents overflow: max value is n * (q-1)^2 < 2^64 for Dilithium
172        let mut tmp = vec![0u64; 2 * n];
173
174        // Step 1: Compute full convolution without modular reduction
175        // FIXED: Use iterator instead of indexing
176        for (i, &ai_u32) in self.coeffs.iter().enumerate().take(n) {
177            let ai = ai_u32 as u64;
178            for (j, &bj_u32) in other.coeffs.iter().enumerate().take(n) {
179                let bj = bj_u32 as u64;
180                tmp[i + j] = tmp[i + j].wrapping_add(ai * bj);
181            }
182        }
183
184        // Step 2: Apply negacyclic reduction (x^N = -1)
185        // Fold upper half back with negation
186        for k in n..(2 * n) {
187            // When reducing x^k where k >= n, we use x^n = -1
188            // So x^k = -x^(k-n)
189            let upper_val = tmp[k] % q;
190            if upper_val > 0 {
191                // Subtract from lower coefficient (equivalent to adding the negative)
192                tmp[k - n] = (tmp[k - n] + q - upper_val) % q;
193            }
194        }
195
196        // Step 3: Final reduction to [0, q)
197        #[allow(clippy::needless_range_loop)]
198        // We need indexed access here to match tmp and result.coeffs indices
199        for i in 0..n {
200            result.coeffs[i] = (tmp[i] % q) as u32;
201        }
202
203        result
204    }
205
206    /// In-place coefficient reduction to ensure all coefficients are < Q
207    pub fn reduce_coeffs(&mut self) {
208        for i in 0..M::N {
209            self.coeffs[i] = Self::reduce_coefficient(self.coeffs[i]);
210        }
211    }
212}
213
214// NTT operations are implemented in ntt.rs as extension methods
215
216/// Extension trait for polynomials with NTT-enabled modulus
217pub trait PolynomialNttExt<M: NttModulus> {
218    // FIXED: Now uses params::NttModulus
219    /// Fast scalar multiplication using Montgomery reduction
220    fn scalar_mul_montgomery(&self, scalar: u32) -> Polynomial<M>;
221}
222
223impl<M: NttModulus> PolynomialNttExt<M> for Polynomial<M> {
224    // FIXED: Now uses params::NttModulus
225    fn scalar_mul_montgomery(&self, scalar: u32) -> Polynomial<M> {
226        let mut result = Polynomial::<M>::zero();
227        // FIXED: Convert scalar to Montgomery form before multiplication
228        let scalar_mont = to_montgomery::<M>(scalar);
229        for i in 0..M::N {
230            // Now both operands are in Montgomery form, so the result stays in Montgomery form
231            let prod = (self.coeffs[i] as u64) * (scalar_mont as u64);
232            result.coeffs[i] = montgomery_reduce::<M>(prod);
233        }
234        result
235    }
236}
237
238/// Barrett reduction for fast modular arithmetic
239#[inline(always)]
240pub fn barrett_reduce<M: Modulus>(a: u32) -> u32 {
241    // Simplified Barrett reduction
242    // In production, would use precomputed Barrett constant
243    a % M::Q
244}
245
246// Implement standard ops traits for ergonomic usage
247// Define reference implementations first
248impl<M: Modulus> Add for &Polynomial<M> {
249    type Output = Polynomial<M>;
250
251    fn add(self, other: Self) -> Self::Output {
252        self.add(other)
253    }
254}
255
256impl<M: Modulus> Sub for &Polynomial<M> {
257    type Output = Polynomial<M>;
258
259    fn sub(self, other: Self) -> Self::Output {
260        self.sub(other)
261    }
262}
263
264impl<M: Modulus> Neg for &Polynomial<M> {
265    type Output = Polynomial<M>;
266
267    fn neg(self) -> Self::Output {
268        self.neg()
269    }
270}
271
272// Now owned implementations can use the reference implementations
273impl<M: Modulus> Add for Polynomial<M> {
274    type Output = Self;
275
276    fn add(self, other: Self) -> Self::Output {
277        &self + &other
278    }
279}
280
281impl<M: Modulus> Sub for Polynomial<M> {
282    type Output = Self;
283
284    fn sub(self, other: Self) -> Self::Output {
285        &self - &other
286    }
287}
288
289impl<M: Modulus> Neg for Polynomial<M> {
290    type Output = Self;
291
292    fn neg(self) -> Self::Output {
293        -&self
294    }
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300
301    // Test modulus for unit tests
302    #[derive(Clone)]
303    struct TestModulus;
304    impl Modulus for TestModulus {
305        const Q: u32 = 3329; // Kyber's Q
306        const N: usize = 4; // Small for testing
307    }
308
309    #[test]
310    fn test_polynomial_creation() {
311        let poly = Polynomial::<TestModulus>::zero();
312        assert_eq!(poly.as_coeffs_slice(), &[0, 0, 0, 0]);
313
314        let coeffs = vec![1, 2, 3, 4];
315        let poly = Polynomial::<TestModulus>::from_coeffs(&coeffs).unwrap();
316        assert_eq!(poly.as_coeffs_slice(), &[1, 2, 3, 4]);
317    }
318
319    #[test]
320    fn test_polynomial_addition() {
321        let a = Polynomial::<TestModulus>::from_coeffs(&[1, 2, 3, 4]).unwrap();
322        let b = Polynomial::<TestModulus>::from_coeffs(&[5, 6, 7, 8]).unwrap();
323        // Use the + operator directly to avoid explicit borrows
324        let c = a + b;
325        assert_eq!(c.as_coeffs_slice(), &[6, 8, 10, 12]);
326    }
327
328    #[test]
329    fn test_polynomial_subtraction() {
330        let a = Polynomial::<TestModulus>::from_coeffs(&[10, 20, 30, 40]).unwrap();
331        let b = Polynomial::<TestModulus>::from_coeffs(&[5, 6, 7, 8]).unwrap();
332        // Use the - operator directly to avoid explicit borrows
333        let c = a - b;
334        assert_eq!(c.as_coeffs_slice(), &[5, 14, 23, 32]);
335    }
336
337    #[test]
338    fn test_polynomial_negation() {
339        let a = Polynomial::<TestModulus>::from_coeffs(&[1, 2, 0, 4]).unwrap();
340        // Use the - operator directly to avoid explicit borrows
341        let neg_a = -a;
342        assert_eq!(neg_a.as_coeffs_slice(), &[3328, 3327, 0, 3325]);
343    }
344
345    #[test]
346    fn test_modular_reduction() {
347        let a = Polynomial::<TestModulus>::from_coeffs(&[3330, 3331, 3328, 0]).unwrap();
348        let mut b = a.clone();
349        b.reduce_coeffs();
350        assert_eq!(b.as_coeffs_slice(), &[1, 2, 3328, 0]);
351    }
352
353    #[test]
354    fn test_zeroization() {
355        let mut poly = Polynomial::<TestModulus>::from_coeffs(&[1, 2, 3, 4]).unwrap();
356        poly.zeroize();
357        assert_eq!(poly.as_coeffs_slice(), &[0, 0, 0, 0]);
358        assert_eq!(poly.coeffs.len(), 4); // Length preserved
359    }
360
361    #[test]
362    fn test_schoolbook_mul_negacyclic() {
363        // Test negacyclic property: x^N = -1
364        // For N=4, x^4 = -1, so x^3 * x = -1
365        let mut x_cubed = Polynomial::<TestModulus>::zero();
366        x_cubed.coeffs[3] = 1; // x^3
367
368        let mut x = Polynomial::<TestModulus>::zero();
369        x.coeffs[1] = 1; // x
370
371        let result = x_cubed.schoolbook_mul(&x);
372        // x^3 * x = x^4 = -1 mod q = q-1
373        assert_eq!(result.coeffs[0], TestModulus::Q - 1);
374        assert_eq!(result.coeffs[1], 0);
375        assert_eq!(result.coeffs[2], 0);
376        assert_eq!(result.coeffs[3], 0);
377
378        // Test a more complex example
379        let a = Polynomial::<TestModulus>::from_coeffs(&[1, 2, 3, 4]).unwrap();
380        let b = Polynomial::<TestModulus>::from_coeffs(&[5, 6, 7, 8]).unwrap();
381        let c = a.schoolbook_mul(&b);
382
383        // Manually compute expected result with negacyclic reduction
384        // (1 + 2x + 3x^2 + 4x^3)(5 + 6x + 7x^2 + 8x^3)
385        //
386        // Full expansion (before reduction):
387        // 1*5 = 5
388        // 1*6x + 2*5x = 6x + 10x = 16x
389        // 1*7x^2 + 2*6x^2 + 3*5x^2 = 7x^2 + 12x^2 + 15x^2 = 34x^2
390        // 1*8x^3 + 2*7x^3 + 3*6x^3 + 4*5x^3 = 8x^3 + 14x^3 + 18x^3 + 20x^3 = 60x^3
391        // 2*8x^4 + 3*7x^4 + 4*6x^4 = 16x^4 + 21x^4 + 24x^4 = 61x^4
392        // 3*8x^5 + 4*7x^5 = 24x^5 + 28x^5 = 52x^5
393        // 4*8x^6 = 32x^6
394        //
395        // Now apply x^4 = -1:
396        // x^4 = -1
397        // x^5 = -x
398        // x^6 = -x^2
399        //
400        // So:
401        // Constant: 5 - 61 = -56 → 3329 - 56 = 3273
402        // x: 16 - 52 = -36 → 3329 - 36 = 3293
403        // x^2: 34 - 32 = 2
404        // x^3: 60
405
406        let expected_0 = ((5i32 - 61i32).rem_euclid(TestModulus::Q as i32)) as u32;
407        let expected_1 = ((16i32 - 52i32).rem_euclid(TestModulus::Q as i32)) as u32;
408        let expected_2 = ((34i32 - 32i32).rem_euclid(TestModulus::Q as i32)) as u32;
409        let expected_3 = 60u32;
410
411        assert_eq!(c.coeffs[0], expected_0);
412        assert_eq!(c.coeffs[1], expected_1);
413        assert_eq!(c.coeffs[2], expected_2);
414        assert_eq!(c.coeffs[3], expected_3);
415    }
416
417    #[test]
418    fn test_dilithium_negacyclic() {
419        // Test with Dilithium-like parameters
420        #[derive(Clone)]
421        struct DilithiumTestModulus;
422        impl Modulus for DilithiumTestModulus {
423            const Q: u32 = 8380417; // Dilithium's Q
424            const N: usize = 4; // Small for testing, but same negacyclic property
425        }
426
427        // Test that x^N = -1 in the ring
428        let mut x_to_n_minus_1 = Polynomial::<DilithiumTestModulus>::zero();
429        x_to_n_minus_1.coeffs[3] = 1; // x^3
430
431        let mut x = Polynomial::<DilithiumTestModulus>::zero();
432        x.coeffs[1] = 1; // x
433
434        let result = x_to_n_minus_1.schoolbook_mul(&x);
435        // x^3 * x = x^4 = -1 mod q = q-1
436        assert_eq!(result.coeffs[0], DilithiumTestModulus::Q - 1);
437        assert_eq!(result.coeffs[1], 0);
438        assert_eq!(result.coeffs[2], 0);
439        assert_eq!(result.coeffs[3], 0);
440
441        // Test with sparse polynomial (like challenge polynomial c)
442        let mut sparse = Polynomial::<DilithiumTestModulus>::zero();
443        sparse.coeffs[0] = 1; // +1
444        sparse.coeffs[2] = DilithiumTestModulus::Q - 1; // -1
445
446        let dense = Polynomial::<DilithiumTestModulus>::from_coeffs(&[100, 200, 300, 400]).unwrap();
447        let result = sparse.schoolbook_mul(&dense);
448
449        // (1 - x^2) * (100 + 200x + 300x^2 + 400x^3)
450        // = 100 + 200x + 300x^2 + 400x^3 - 100x^2 - 200x^3 - 300x^4 - 400x^5
451        // With x^4 = -1, x^5 = -x:
452        // = 100 + 200x + (300-100)x^2 + (400-200)x^3 + 300 + 400x
453        // = (100+300) + (200+400)x + 200x^2 + 200x^3
454        // = 400 + 600x + 200x^2 + 200x^3
455
456        assert_eq!(result.coeffs[0], 400);
457        assert_eq!(result.coeffs[1], 600);
458        assert_eq!(result.coeffs[2], 200);
459        assert_eq!(result.coeffs[3], 200);
460    }
461}