Skip to main content

lib_q_ring/
poly.rs

1//! Coefficient (`Poly`) vs NTT (`NttPoly`) newtypes.
2
3use zeroize::{
4    Zeroize,
5    ZeroizeOnDrop,
6};
7
8use crate::coeff::{
9    COEFFICIENTS_IN_SIMD_UNIT,
10    Coefficients,
11    FieldElement,
12    SIMD_UNITS_IN_RING_ELEMENT,
13};
14use crate::constants::{
15    COEFFICIENTS_IN_RING_ELEMENT,
16    FIELD_MODULUS,
17};
18use crate::field::{
19    add_coeffs,
20    reduce_element,
21    reduce_poly_simd,
22    subtract_coeffs,
23};
24use crate::ntt::{
25    intt_montgomery,
26    ntt_forward_simd,
27    ntt_multiply_montgomery,
28};
29
30/// Polynomial in the time (coefficient) domain, canonical representatives mod `q`.
31#[derive(Clone, Debug, Eq, PartialEq, Hash, Zeroize, ZeroizeOnDrop)]
32pub struct Poly {
33    /// Coefficients `c[0] + c[1] X + … + c[255] X^{255}`.
34    pub coeffs: [FieldElement; COEFFICIENTS_IN_RING_ELEMENT],
35}
36
37impl Poly {
38    /// Zero polynomial.
39    #[must_use]
40    pub const fn zero() -> Self {
41        Self {
42            coeffs: [0; COEFFICIENTS_IN_RING_ELEMENT],
43        }
44    }
45
46    /// Construct from canonical coefficients (already reduced mod `q` is recommended).
47    #[must_use]
48    pub const fn from_coeffs(coeffs: [FieldElement; COEFFICIENTS_IN_RING_ELEMENT]) -> Self {
49        Self { coeffs }
50    }
51
52    /// Coefficient-wise addition mod `q` (Barrett reduction).
53    pub fn add_assign(&mut self, rhs: &Self) {
54        for i in 0..COEFFICIENTS_IN_RING_ELEMENT {
55            self.coeffs[i] = reduce_element(self.coeffs[i] + rhs.coeffs[i]);
56        }
57    }
58
59    /// Coefficient-wise subtraction mod `q`.
60    pub fn sub_assign(&mut self, rhs: &Self) {
61        for i in 0..COEFFICIENTS_IN_RING_ELEMENT {
62            self.coeffs[i] = reduce_element(self.coeffs[i] - rhs.coeffs[i]);
63        }
64    }
65
66    /// Multiply every coefficient by a small integer, then reduce mod `q`.
67    pub fn scalar_mul_assign(&mut self, k: i32) {
68        for c in &mut self.coeffs {
69            *c = reduce_element((*c as i64 * k as i64) as i32);
70        }
71    }
72
73    /// Negacyclic convolution mod `(X^256 + 1)` via schoolbook \(O(n^2)\) (test / reference).
74    #[must_use]
75    pub fn mul_negacyclic(&self, rhs: &Self) -> Self {
76        let mut acc = [0i64; COEFFICIENTS_IN_RING_ELEMENT];
77        let q = FIELD_MODULUS as i64;
78        for i in 0..COEFFICIENTS_IN_RING_ELEMENT {
79            for j in 0..COEFFICIENTS_IN_RING_ELEMENT {
80                let k = i + j;
81                let prod = (self.coeffs[i] as i64).wrapping_mul(rhs.coeffs[j] as i64);
82                if k < COEFFICIENTS_IN_RING_ELEMENT {
83                    acc[k] += prod;
84                } else {
85                    let idx = k - COEFFICIENTS_IN_RING_ELEMENT;
86                    acc[idx] -= prod;
87                }
88            }
89        }
90        let mut out = [0i32; COEFFICIENTS_IN_RING_ELEMENT];
91        for (o, a) in out.iter_mut().zip(acc) {
92            let mut r = a % q;
93            if r < 0 {
94                r += q;
95            }
96            *o = reduce_element(r as i32);
97        }
98        Self { coeffs: out }
99    }
100
101    /// Infinity norm on absolute representatives in \([-q/2, q/2]\)-style range.
102    #[must_use]
103    pub fn infinity_norm(&self) -> i32 {
104        let half = FIELD_MODULUS / 2;
105        let mut m = 0i32;
106        for &c in &self.coeffs {
107            let v = if c > half { c - FIELD_MODULUS } else { c };
108            m = m.max(v.abs());
109        }
110        m
111    }
112
113    /// SIMD lane layout (ML-DSA coefficient order).
114    #[must_use]
115    pub fn to_simd(&self) -> [Coefficients; SIMD_UNITS_IN_RING_ELEMENT] {
116        let mut s = [Coefficients::default(); SIMD_UNITS_IN_RING_ELEMENT];
117        for (i, lane) in s.iter_mut().enumerate() {
118            let base = i * COEFFICIENTS_IN_SIMD_UNIT;
119            lane.values
120                .copy_from_slice(&self.coeffs[base..base + COEFFICIENTS_IN_SIMD_UNIT]);
121        }
122        s
123    }
124
125    fn from_simd(simd: &[Coefficients; SIMD_UNITS_IN_RING_ELEMENT]) -> Self {
126        let mut coeffs = [0i32; COEFFICIENTS_IN_RING_ELEMENT];
127        for (i, lane) in simd.iter().enumerate() {
128            let base = i * COEFFICIENTS_IN_SIMD_UNIT;
129            coeffs[base..base + COEFFICIENTS_IN_SIMD_UNIT].copy_from_slice(&lane.values);
130        }
131        Self { coeffs }
132    }
133
134    /// Map to the NTT/Montgomery SIMD representation used by ML-DSA.
135    #[must_use]
136    pub fn to_ntt(&self) -> NttPoly {
137        let mut simd = self.to_simd();
138        ntt_forward_simd(&mut simd);
139        NttPoly { simd }
140    }
141}
142
143/// Polynomial in the NTT domain (Montgomery-lane representation).
144#[derive(Clone, Debug, PartialEq, Eq, Hash)]
145pub struct NttPoly {
146    pub(crate) simd: [Coefficients; SIMD_UNITS_IN_RING_ELEMENT],
147}
148
149impl NttPoly {
150    /// Zero polynomial in NTT form.
151    #[must_use]
152    pub fn zero() -> Self {
153        Self {
154            simd: [Coefficients::default(); SIMD_UNITS_IN_RING_ELEMENT],
155        }
156    }
157
158    /// Coefficients in SIMD lane order (Montgomery NTT domain) without inverse transform.
159    #[must_use]
160    pub fn packed_ntt_coefficients(&self) -> [FieldElement; COEFFICIENTS_IN_RING_ELEMENT] {
161        let mut c = [0i32; COEFFICIENTS_IN_RING_ELEMENT];
162        for i in 0..SIMD_UNITS_IN_RING_ELEMENT {
163            let base = i * COEFFICIENTS_IN_SIMD_UNIT;
164            c[base..base + COEFFICIENTS_IN_SIMD_UNIT].copy_from_slice(&self.simd[i].values);
165        }
166        c
167    }
168
169    /// Borrow the internal SIMD lanes (read-only).
170    #[must_use]
171    pub fn as_simd(&self) -> &[Coefficients; SIMD_UNITS_IN_RING_ELEMENT] {
172        &self.simd
173    }
174
175    /// Mutable SIMD lanes (expert use).
176    pub fn as_simd_mut(&mut self) -> &mut [Coefficients; SIMD_UNITS_IN_RING_ELEMENT] {
177        &mut self.simd
178    }
179
180    /// Pointwise Montgomery multiply `self *= rhs` in the NTT domain.
181    pub fn pointwise_mul_assign(&mut self, rhs: &Self) {
182        for i in 0..SIMD_UNITS_IN_RING_ELEMENT {
183            ntt_multiply_montgomery(&mut self.simd[i], &rhs.simd[i]);
184        }
185    }
186
187    /// Inverse NTT into coefficient domain with canonical reduction.
188    #[must_use]
189    pub fn to_poly(mut self) -> Poly {
190        intt_montgomery(&mut self.simd);
191        reduce_poly_simd(&mut self.simd);
192        Poly::from_simd(&self.simd)
193    }
194
195    /// Add two NTT polynomials lane-wise (no modular reduction between adds; use before INTT as in ML-DSA accumulators).
196    pub fn add_assign(&mut self, rhs: &Self) {
197        for i in 0..SIMD_UNITS_IN_RING_ELEMENT {
198            add_coeffs(&mut self.simd[i], &rhs.simd[i]);
199        }
200    }
201
202    /// Subtract lane-wise.
203    pub fn sub_assign(&mut self, rhs: &Self) {
204        for i in 0..SIMD_UNITS_IN_RING_ELEMENT {
205            subtract_coeffs(&mut self.simd[i], &rhs.simd[i]);
206        }
207    }
208}
209
210/// Fill SIMD layout from the first 256 coefficients of `buf` (ML-DSA `from_i32_array` order).
211#[must_use]
212pub fn simd_from_i256(
213    buf: &[i32; COEFFICIENTS_IN_RING_ELEMENT],
214) -> [Coefficients; SIMD_UNITS_IN_RING_ELEMENT] {
215    Poly::from_coeffs(*buf).to_simd()
216}
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221
222    fn lcg_step(state: &mut u64) -> u32 {
223        *state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
224        (*state >> 32) as u32
225    }
226
227    fn small_poly(state: &mut u64, bound: i32) -> Poly {
228        let mut coeffs = [0i32; COEFFICIENTS_IN_RING_ELEMENT];
229        let width = (2 * bound + 1) as u32;
230        for c in &mut coeffs {
231            let v = (lcg_step(state) % width) as i32;
232            *c = v - bound;
233        }
234        Poly::from_coeffs(coeffs)
235    }
236
237    #[test]
238    fn ntt_inverse_has_expected_linear_scale() {
239        let mut one = [0i32; COEFFICIENTS_IN_RING_ELEMENT];
240        one[0] = 1;
241        let scale = Poly::from_coeffs(one).to_ntt().to_poly().coeffs[0];
242
243        let mut st = 0xC0DEC0DE_u64;
244        for _ in 0..16 {
245            let p = small_poly(&mut st, 8);
246            let back = p.clone().to_ntt().to_poly();
247            for (orig, got) in p.coeffs.iter().zip(back.coeffs.iter()) {
248                let expected = reduce_element((*orig as i64 * scale as i64) as i32);
249                assert_eq!(expected, *got);
250            }
251        }
252    }
253
254    #[test]
255    fn ntt_pointwise_matches_schoolbook_for_small_coeffs() {
256        let mut st = 0xDEADBEEF_u64;
257        for _ in 0..4 {
258            let a = small_poly(&mut st, 8);
259            let b = small_poly(&mut st, 8);
260            let schoolbook = a.mul_negacyclic(&b);
261
262            let mut ntt = a.to_ntt();
263            let b_ntt = b.to_ntt();
264            ntt.pointwise_mul_assign(&b_ntt);
265            let back = ntt.to_poly();
266
267            assert_eq!(schoolbook, back);
268        }
269    }
270}