Skip to main content

lib_q_ring/
poly.rs

1//! Coefficient (`Poly`) vs NTT (`NttPoly`) newtypes.
2
3use subtle::{
4    Choice,
5    ConditionallySelectable,
6    ConstantTimeGreater,
7};
8use zeroize::{
9    Zeroize,
10    ZeroizeOnDrop,
11};
12
13use crate::coeff::{
14    COEFFICIENTS_IN_SIMD_UNIT,
15    Coefficients,
16    FieldElement,
17    SIMD_UNITS_IN_RING_ELEMENT,
18};
19use crate::constants::{
20    COEFFICIENTS_IN_RING_ELEMENT,
21    FIELD_MODULUS,
22};
23use crate::field::{
24    add_coeffs,
25    reduce_element,
26    reduce_poly_simd,
27    subtract_coeffs,
28};
29use crate::ntt::{
30    intt_montgomery,
31    ntt_forward_simd,
32    ntt_multiply_montgomery,
33};
34
35#[inline]
36fn ct_gt_i32(a: i32, b: i32) -> Choice {
37    let flip = 1u32 << 31;
38    let a_u = (a as u32) ^ flip;
39    let b_u = (b as u32) ^ flip;
40    a_u.ct_gt(&b_u)
41}
42
43#[inline]
44fn centered_abs_i32(coefficient: i32) -> i32 {
45    let sign = coefficient >> 31;
46    coefficient - (sign & (coefficient << 1))
47}
48
49/// Polynomial in the time (coefficient) domain, canonical representatives mod `q`.
50#[derive(Clone, Debug, Eq, PartialEq, Hash, Zeroize, ZeroizeOnDrop)]
51pub struct Poly {
52    /// Coefficients `c[0] + c[1] X + … + c[255] X^{255}`.
53    pub coeffs: [FieldElement; COEFFICIENTS_IN_RING_ELEMENT],
54}
55
56impl Poly {
57    /// Zero polynomial.
58    #[must_use]
59    pub const fn zero() -> Self {
60        Self {
61            coeffs: [0; COEFFICIENTS_IN_RING_ELEMENT],
62        }
63    }
64
65    /// Construct from canonical coefficients (already reduced mod `q` is recommended).
66    #[must_use]
67    pub const fn from_coeffs(coeffs: [FieldElement; COEFFICIENTS_IN_RING_ELEMENT]) -> Self {
68        Self { coeffs }
69    }
70
71    /// Coefficient-wise addition mod `q` (Barrett reduction).
72    pub fn add_assign(&mut self, rhs: &Self) {
73        for i in 0..COEFFICIENTS_IN_RING_ELEMENT {
74            self.coeffs[i] = reduce_element(self.coeffs[i] + rhs.coeffs[i]);
75        }
76    }
77
78    /// Coefficient-wise subtraction mod `q`.
79    pub fn sub_assign(&mut self, rhs: &Self) {
80        for i in 0..COEFFICIENTS_IN_RING_ELEMENT {
81            self.coeffs[i] = reduce_element(self.coeffs[i] - rhs.coeffs[i]);
82        }
83    }
84
85    /// Multiply every coefficient by a small integer, then reduce mod `q`.
86    pub fn scalar_mul_assign(&mut self, k: i32) {
87        let q = FIELD_MODULUS as i64;
88        for c in &mut self.coeffs {
89            let wide = *c as i64 * k as i64;
90            *c = reduce_element(wide.rem_euclid(q) as i32);
91        }
92    }
93
94    /// Negacyclic convolution mod `(X^256 + 1)` via schoolbook \(O(n^2)\) (test / reference).
95    #[must_use]
96    pub fn mul_negacyclic(&self, rhs: &Self) -> Self {
97        let mut acc = [0i64; COEFFICIENTS_IN_RING_ELEMENT];
98        let q = FIELD_MODULUS as i64;
99        for i in 0..COEFFICIENTS_IN_RING_ELEMENT {
100            for j in 0..COEFFICIENTS_IN_RING_ELEMENT {
101                let k = i + j;
102                let prod = (self.coeffs[i] as i64).wrapping_mul(rhs.coeffs[j] as i64);
103                if k < COEFFICIENTS_IN_RING_ELEMENT {
104                    acc[k] += prod;
105                } else {
106                    let idx = k - COEFFICIENTS_IN_RING_ELEMENT;
107                    acc[idx] -= prod;
108                }
109            }
110        }
111        let mut out = [0i32; COEFFICIENTS_IN_RING_ELEMENT];
112        for (o, a) in out.iter_mut().zip(acc) {
113            let mut r = a % q;
114            if r < 0 {
115                r += q;
116            }
117            *o = reduce_element(r as i32);
118        }
119        Self { coeffs: out }
120    }
121
122    /// Infinity norm on absolute representatives in \([-q/2, q/2]\)-style range.
123    ///
124    /// Branch-free over coefficient values (ML-DSA portable `infinity_norm_exceeds` model):
125    /// leaking which coefficient exceeds a bound is acceptable on verify paths; the sign of the
126    /// centered representative must not leak via control flow.
127    #[must_use]
128    pub fn infinity_norm(&self) -> i32 {
129        let half = FIELD_MODULUS / 2;
130        let mut m = 0i32;
131        for &c in &self.coeffs {
132            let gt_half = ct_gt_i32(c, half);
133            let centered = i32::conditional_select(&c, &c.wrapping_sub(FIELD_MODULUS), gt_half);
134            let abs = centered_abs_i32(centered);
135            let gt_max = ct_gt_i32(abs, m);
136            m = i32::conditional_select(&m, &abs, gt_max);
137        }
138        m
139    }
140
141    /// Returns `1` iff [`Self::infinity_norm`] is at most `bound` (inclusive).
142    #[must_use]
143    pub fn norm_within_bound(&self, bound: i32) -> Choice {
144        let exceeds = ct_gt_i32(self.infinity_norm(), bound);
145        exceeds ^ Choice::from(1u8)
146    }
147
148    /// Map every coefficient into canonical `[0, q)` via Barrett reduction, then branch-free
149    /// non-negative fixup: `v + ((v >> 31) & q)`.
150    pub fn normalize_mod_q_assign(&mut self) {
151        let q = FIELD_MODULUS;
152        for c in &mut self.coeffs {
153            *c = reduce_element(*c);
154            let sign = *c >> 31;
155            *c += sign & q;
156        }
157    }
158
159    /// Multiply every coefficient by `scalar` (mod `q`) using wide multiply + Barrett reduction.
160    #[must_use]
161    pub fn scalar_mul_by_u32_mod_q(&self, scalar: u32) -> Poly {
162        let q = FIELD_MODULUS as i64;
163        let r = (scalar % FIELD_MODULUS as u32) as i64;
164        let mut out = self.clone();
165        for c in &mut out.coeffs {
166            let v = (*c as i64 * r).rem_euclid(q) as i32;
167            *c = reduce_element(v);
168        }
169        out
170    }
171
172    /// SIMD lane layout (ML-DSA coefficient order).
173    #[must_use]
174    pub fn to_simd(&self) -> [Coefficients; SIMD_UNITS_IN_RING_ELEMENT] {
175        let mut s = [Coefficients::default(); SIMD_UNITS_IN_RING_ELEMENT];
176        for (i, lane) in s.iter_mut().enumerate() {
177            let base = i * COEFFICIENTS_IN_SIMD_UNIT;
178            lane.values
179                .copy_from_slice(&self.coeffs[base..base + COEFFICIENTS_IN_SIMD_UNIT]);
180        }
181        s
182    }
183
184    fn from_simd(simd: &[Coefficients; SIMD_UNITS_IN_RING_ELEMENT]) -> Self {
185        let mut coeffs = [0i32; COEFFICIENTS_IN_RING_ELEMENT];
186        for (i, lane) in simd.iter().enumerate() {
187            let base = i * COEFFICIENTS_IN_SIMD_UNIT;
188            coeffs[base..base + COEFFICIENTS_IN_SIMD_UNIT].copy_from_slice(&lane.values);
189        }
190        Self { coeffs }
191    }
192
193    /// Map to the NTT/Montgomery SIMD representation used by ML-DSA.
194    #[must_use]
195    pub fn to_ntt(&self) -> NttPoly {
196        let mut simd = self.to_simd();
197        ntt_forward_simd(&mut simd);
198        NttPoly { simd }
199    }
200}
201
202/// Polynomial in the NTT domain (Montgomery-lane representation).
203#[derive(Clone, Debug, PartialEq, Eq, Hash)]
204pub struct NttPoly {
205    pub(crate) simd: [Coefficients; SIMD_UNITS_IN_RING_ELEMENT],
206}
207
208impl NttPoly {
209    /// Zero polynomial in NTT form.
210    #[must_use]
211    pub fn zero() -> Self {
212        Self {
213            simd: [Coefficients::default(); SIMD_UNITS_IN_RING_ELEMENT],
214        }
215    }
216
217    /// Coefficients in SIMD lane order (Montgomery NTT domain) without inverse transform.
218    #[must_use]
219    pub fn packed_ntt_coefficients(&self) -> [FieldElement; COEFFICIENTS_IN_RING_ELEMENT] {
220        let mut c = [0i32; COEFFICIENTS_IN_RING_ELEMENT];
221        for i in 0..SIMD_UNITS_IN_RING_ELEMENT {
222            let base = i * COEFFICIENTS_IN_SIMD_UNIT;
223            c[base..base + COEFFICIENTS_IN_SIMD_UNIT].copy_from_slice(&self.simd[i].values);
224        }
225        c
226    }
227
228    /// Borrow the internal SIMD lanes (read-only).
229    #[must_use]
230    pub fn as_simd(&self) -> &[Coefficients; SIMD_UNITS_IN_RING_ELEMENT] {
231        &self.simd
232    }
233
234    /// Mutable SIMD lanes (expert use).
235    pub fn as_simd_mut(&mut self) -> &mut [Coefficients; SIMD_UNITS_IN_RING_ELEMENT] {
236        &mut self.simd
237    }
238
239    /// Pointwise Montgomery multiply `self *= rhs` in the NTT domain.
240    pub fn pointwise_mul_assign(&mut self, rhs: &Self) {
241        for i in 0..SIMD_UNITS_IN_RING_ELEMENT {
242            ntt_multiply_montgomery(&mut self.simd[i], &rhs.simd[i]);
243        }
244    }
245
246    /// Inverse NTT into coefficient domain with canonical reduction.
247    #[must_use]
248    pub fn to_poly(mut self) -> Poly {
249        intt_montgomery(&mut self.simd);
250        reduce_poly_simd(&mut self.simd);
251        Poly::from_simd(&self.simd)
252    }
253
254    /// Add two NTT polynomials lane-wise (no modular reduction between adds; use before INTT as in ML-DSA accumulators).
255    pub fn add_assign(&mut self, rhs: &Self) {
256        for i in 0..SIMD_UNITS_IN_RING_ELEMENT {
257            add_coeffs(&mut self.simd[i], &rhs.simd[i]);
258        }
259    }
260
261    /// Subtract lane-wise.
262    pub fn sub_assign(&mut self, rhs: &Self) {
263        for i in 0..SIMD_UNITS_IN_RING_ELEMENT {
264            subtract_coeffs(&mut self.simd[i], &rhs.simd[i]);
265        }
266    }
267}
268
269/// Fill SIMD layout from the first 256 coefficients of `buf` (ML-DSA `from_i32_array` order).
270#[must_use]
271pub fn simd_from_i256(
272    buf: &[i32; COEFFICIENTS_IN_RING_ELEMENT],
273) -> [Coefficients; SIMD_UNITS_IN_RING_ELEMENT] {
274    Poly::from_coeffs(*buf).to_simd()
275}
276
277/// Returns `1` iff every polynomial in `polys` has infinity norm at most `bound`.
278#[must_use]
279pub fn polys_norm_within_bound(polys: &[Poly], bound: i32) -> Choice {
280    let mut acc = Choice::from(1u8);
281    for p in polys {
282        acc &= p.norm_within_bound(bound);
283    }
284    acc
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    fn lcg_step(state: &mut u64) -> u32 {
292        *state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
293        (*state >> 32) as u32
294    }
295
296    fn small_poly(state: &mut u64, bound: i32) -> Poly {
297        let mut coeffs = [0i32; COEFFICIENTS_IN_RING_ELEMENT];
298        let width = (2 * bound + 1) as u32;
299        for c in &mut coeffs {
300            let v = (lcg_step(state) % width) as i32;
301            *c = v - bound;
302        }
303        Poly::from_coeffs(coeffs)
304    }
305
306    #[test]
307    fn ntt_inverse_has_expected_linear_scale() {
308        let mut one = [0i32; COEFFICIENTS_IN_RING_ELEMENT];
309        one[0] = 1;
310        let scale = Poly::from_coeffs(one).to_ntt().to_poly().coeffs[0];
311
312        let mut st = 0xC0DEC0DE_u64;
313        for _ in 0..16 {
314            let p = small_poly(&mut st, 8);
315            let back = p.clone().to_ntt().to_poly();
316            for (orig, got) in p.coeffs.iter().zip(back.coeffs.iter()) {
317                let expected = reduce_element((*orig as i64 * scale as i64) as i32);
318                assert_eq!(expected, *got);
319            }
320        }
321    }
322
323    #[test]
324    fn ntt_pointwise_matches_schoolbook_for_small_coeffs() {
325        let mut st = 0xDEADBEEF_u64;
326        for _ in 0..4 {
327            let a = small_poly(&mut st, 8);
328            let b = small_poly(&mut st, 8);
329            let schoolbook = a.mul_negacyclic(&b);
330
331            let mut ntt = a.to_ntt();
332            let b_ntt = b.to_ntt();
333            ntt.pointwise_mul_assign(&b_ntt);
334            let back = ntt.to_poly();
335
336            assert_eq!(schoolbook, back);
337        }
338    }
339
340    fn infinity_norm_branchy_reference(p: &Poly) -> i32 {
341        let half = FIELD_MODULUS / 2;
342        let mut m = 0i32;
343        for &c in &p.coeffs {
344            let v = if c > half { c - FIELD_MODULUS } else { c };
345            m = m.max(v.abs());
346        }
347        m
348    }
349
350    #[test]
351    fn infinity_norm_matches_branchy_reference() {
352        let q = FIELD_MODULUS;
353        let mut st = 0xA11CE_u64;
354        for _ in 0..256 {
355            let mut coeffs = [0i32; COEFFICIENTS_IN_RING_ELEMENT];
356            for c in &mut coeffs {
357                *c = (lcg_step(&mut st) as i32) % q;
358            }
359            let p = Poly::from_coeffs(coeffs);
360            assert_eq!(p.infinity_norm(), infinity_norm_branchy_reference(&p));
361        }
362        for &edge in &[0, 1, q / 2, q / 2 + 1, q - 1] {
363            let mut p = Poly::zero();
364            p.coeffs[0] = edge;
365            p.coeffs[1] = -edge;
366            assert_eq!(p.infinity_norm(), infinity_norm_branchy_reference(&p));
367        }
368    }
369
370    #[test]
371    fn normalize_mod_q_and_scalar_mul_smoke() {
372        let mut p = Poly::zero();
373        p.coeffs[0] = FIELD_MODULUS + 5;
374        p.normalize_mod_q_assign();
375        assert!((0..FIELD_MODULUS).contains(&p.coeffs[0]));
376        p.coeffs[1] = -3;
377        p.normalize_mod_q_assign();
378        assert!((0..FIELD_MODULUS).contains(&p.coeffs[1]));
379        let scaled = p.scalar_mul_by_u32_mod_q(3);
380        assert_eq!(scaled.coeffs[0], reduce_element(p.coeffs[0] * 3));
381    }
382}