Skip to main content

lib_q_ring/
poly.rs

1//! Coefficient (`Poly`) vs NTT (`NttPoly`) newtypes.
2
3use subtle::{
4    Choice,
5    ConditionallySelectable,
6    ConstantTimeGreater,
7};
8use zeroize::{
9    Zeroize,
10    ZeroizeOnDrop,
11};
12
13use crate::coeff::{
14    COEFFICIENTS_IN_SIMD_UNIT,
15    Coefficients,
16    FieldElement,
17    SIMD_UNITS_IN_RING_ELEMENT,
18};
19use crate::constants::{
20    COEFFICIENTS_IN_RING_ELEMENT,
21    FIELD_MODULUS,
22};
23use crate::field::{
24    add_coeffs,
25    reduce_element,
26    reduce_poly_simd,
27    subtract_coeffs,
28};
29use crate::ntt::{
30    intt_montgomery,
31    ntt_forward_simd,
32    ntt_multiply_montgomery,
33};
34
35#[inline]
36fn ct_gt_i32(a: i32, b: i32) -> Choice {
37    let flip = 1u32 << 31;
38    let a_u = (a as u32) ^ flip;
39    let b_u = (b as u32) ^ flip;
40    a_u.ct_gt(&b_u)
41}
42
43#[inline]
44fn centered_abs_i32(coefficient: i32) -> i32 {
45    let sign = coefficient >> 31;
46    coefficient - (sign & (coefficient << 1))
47}
48
49/// Polynomial in the time (coefficient) domain, canonical representatives mod `q`.
50#[derive(Clone, Debug, Eq, PartialEq, Hash, Zeroize, ZeroizeOnDrop)]
51pub struct Poly {
52    /// Coefficients `c[0] + c[1] X + … + c[255] X^{255}`.
53    pub coeffs: [FieldElement; COEFFICIENTS_IN_RING_ELEMENT],
54}
55
56impl Poly {
57    /// Zero polynomial.
58    #[must_use]
59    pub const fn zero() -> Self {
60        Self {
61            coeffs: [0; COEFFICIENTS_IN_RING_ELEMENT],
62        }
63    }
64
65    /// Construct from canonical coefficients (already reduced mod `q` is recommended).
66    #[must_use]
67    pub const fn from_coeffs(coeffs: [FieldElement; COEFFICIENTS_IN_RING_ELEMENT]) -> Self {
68        Self { coeffs }
69    }
70
71    /// Coefficient-wise addition mod `q` (Barrett reduction).
72    pub fn add_assign(&mut self, rhs: &Self) {
73        for i in 0..COEFFICIENTS_IN_RING_ELEMENT {
74            self.coeffs[i] = reduce_element(self.coeffs[i] + rhs.coeffs[i]);
75        }
76    }
77
78    /// Coefficient-wise subtraction mod `q`.
79    pub fn sub_assign(&mut self, rhs: &Self) {
80        for i in 0..COEFFICIENTS_IN_RING_ELEMENT {
81            self.coeffs[i] = reduce_element(self.coeffs[i] - rhs.coeffs[i]);
82        }
83    }
84
85    /// Multiply every coefficient by a small integer, then reduce mod `q`.
86    pub fn scalar_mul_assign(&mut self, k: i32) {
87        for c in &mut self.coeffs {
88            *c = reduce_element((*c as i64 * k as i64) as i32);
89        }
90    }
91
92    /// Negacyclic convolution mod `(X^256 + 1)` via schoolbook \(O(n^2)\) (test / reference).
93    #[must_use]
94    pub fn mul_negacyclic(&self, rhs: &Self) -> Self {
95        let mut acc = [0i64; COEFFICIENTS_IN_RING_ELEMENT];
96        let q = FIELD_MODULUS as i64;
97        for i in 0..COEFFICIENTS_IN_RING_ELEMENT {
98            for j in 0..COEFFICIENTS_IN_RING_ELEMENT {
99                let k = i + j;
100                let prod = (self.coeffs[i] as i64).wrapping_mul(rhs.coeffs[j] as i64);
101                if k < COEFFICIENTS_IN_RING_ELEMENT {
102                    acc[k] += prod;
103                } else {
104                    let idx = k - COEFFICIENTS_IN_RING_ELEMENT;
105                    acc[idx] -= prod;
106                }
107            }
108        }
109        let mut out = [0i32; COEFFICIENTS_IN_RING_ELEMENT];
110        for (o, a) in out.iter_mut().zip(acc) {
111            let mut r = a % q;
112            if r < 0 {
113                r += q;
114            }
115            *o = reduce_element(r as i32);
116        }
117        Self { coeffs: out }
118    }
119
120    /// Infinity norm on absolute representatives in \([-q/2, q/2]\)-style range.
121    ///
122    /// Branch-free over coefficient values (ML-DSA portable `infinity_norm_exceeds` model):
123    /// leaking which coefficient exceeds a bound is acceptable on verify paths; the sign of the
124    /// centered representative must not leak via control flow.
125    #[must_use]
126    pub fn infinity_norm(&self) -> i32 {
127        let half = FIELD_MODULUS / 2;
128        let mut m = 0i32;
129        for &c in &self.coeffs {
130            let gt_half = ct_gt_i32(c, half);
131            let centered = i32::conditional_select(&c, &c.wrapping_sub(FIELD_MODULUS), gt_half);
132            let abs = centered_abs_i32(centered);
133            let gt_max = ct_gt_i32(abs, m);
134            m = i32::conditional_select(&m, &abs, gt_max);
135        }
136        m
137    }
138
139    /// Returns `1` iff [`Self::infinity_norm`] is at most `bound` (inclusive).
140    #[must_use]
141    pub fn norm_within_bound(&self, bound: i32) -> Choice {
142        let exceeds = ct_gt_i32(self.infinity_norm(), bound);
143        exceeds ^ Choice::from(1u8)
144    }
145
146    /// Map every coefficient into canonical `[0, q)` via Barrett reduction, then branch-free
147    /// non-negative fixup: `v + ((v >> 31) & q)`.
148    pub fn normalize_mod_q_assign(&mut self) {
149        let q = FIELD_MODULUS;
150        for c in &mut self.coeffs {
151            *c = reduce_element(*c);
152            let sign = *c >> 31;
153            *c += sign & q;
154        }
155    }
156
157    /// Multiply every coefficient by `scalar` (mod `q`) using wide multiply + Barrett reduction.
158    #[must_use]
159    pub fn scalar_mul_by_u32_mod_q(&self, scalar: u32) -> Poly {
160        let q = FIELD_MODULUS as i64;
161        let r = (scalar % FIELD_MODULUS as u32) as i64;
162        let mut out = self.clone();
163        for c in &mut out.coeffs {
164            let v = (*c as i64 * r).rem_euclid(q) as i32;
165            *c = reduce_element(v);
166        }
167        out
168    }
169
170    /// SIMD lane layout (ML-DSA coefficient order).
171    #[must_use]
172    pub fn to_simd(&self) -> [Coefficients; SIMD_UNITS_IN_RING_ELEMENT] {
173        let mut s = [Coefficients::default(); SIMD_UNITS_IN_RING_ELEMENT];
174        for (i, lane) in s.iter_mut().enumerate() {
175            let base = i * COEFFICIENTS_IN_SIMD_UNIT;
176            lane.values
177                .copy_from_slice(&self.coeffs[base..base + COEFFICIENTS_IN_SIMD_UNIT]);
178        }
179        s
180    }
181
182    fn from_simd(simd: &[Coefficients; SIMD_UNITS_IN_RING_ELEMENT]) -> Self {
183        let mut coeffs = [0i32; COEFFICIENTS_IN_RING_ELEMENT];
184        for (i, lane) in simd.iter().enumerate() {
185            let base = i * COEFFICIENTS_IN_SIMD_UNIT;
186            coeffs[base..base + COEFFICIENTS_IN_SIMD_UNIT].copy_from_slice(&lane.values);
187        }
188        Self { coeffs }
189    }
190
191    /// Map to the NTT/Montgomery SIMD representation used by ML-DSA.
192    #[must_use]
193    pub fn to_ntt(&self) -> NttPoly {
194        let mut simd = self.to_simd();
195        ntt_forward_simd(&mut simd);
196        NttPoly { simd }
197    }
198}
199
200/// Polynomial in the NTT domain (Montgomery-lane representation).
201#[derive(Clone, Debug, PartialEq, Eq, Hash)]
202pub struct NttPoly {
203    pub(crate) simd: [Coefficients; SIMD_UNITS_IN_RING_ELEMENT],
204}
205
206impl NttPoly {
207    /// Zero polynomial in NTT form.
208    #[must_use]
209    pub fn zero() -> Self {
210        Self {
211            simd: [Coefficients::default(); SIMD_UNITS_IN_RING_ELEMENT],
212        }
213    }
214
215    /// Coefficients in SIMD lane order (Montgomery NTT domain) without inverse transform.
216    #[must_use]
217    pub fn packed_ntt_coefficients(&self) -> [FieldElement; COEFFICIENTS_IN_RING_ELEMENT] {
218        let mut c = [0i32; COEFFICIENTS_IN_RING_ELEMENT];
219        for i in 0..SIMD_UNITS_IN_RING_ELEMENT {
220            let base = i * COEFFICIENTS_IN_SIMD_UNIT;
221            c[base..base + COEFFICIENTS_IN_SIMD_UNIT].copy_from_slice(&self.simd[i].values);
222        }
223        c
224    }
225
226    /// Borrow the internal SIMD lanes (read-only).
227    #[must_use]
228    pub fn as_simd(&self) -> &[Coefficients; SIMD_UNITS_IN_RING_ELEMENT] {
229        &self.simd
230    }
231
232    /// Mutable SIMD lanes (expert use).
233    pub fn as_simd_mut(&mut self) -> &mut [Coefficients; SIMD_UNITS_IN_RING_ELEMENT] {
234        &mut self.simd
235    }
236
237    /// Pointwise Montgomery multiply `self *= rhs` in the NTT domain.
238    pub fn pointwise_mul_assign(&mut self, rhs: &Self) {
239        for i in 0..SIMD_UNITS_IN_RING_ELEMENT {
240            ntt_multiply_montgomery(&mut self.simd[i], &rhs.simd[i]);
241        }
242    }
243
244    /// Inverse NTT into coefficient domain with canonical reduction.
245    #[must_use]
246    pub fn to_poly(mut self) -> Poly {
247        intt_montgomery(&mut self.simd);
248        reduce_poly_simd(&mut self.simd);
249        Poly::from_simd(&self.simd)
250    }
251
252    /// Add two NTT polynomials lane-wise (no modular reduction between adds; use before INTT as in ML-DSA accumulators).
253    pub fn add_assign(&mut self, rhs: &Self) {
254        for i in 0..SIMD_UNITS_IN_RING_ELEMENT {
255            add_coeffs(&mut self.simd[i], &rhs.simd[i]);
256        }
257    }
258
259    /// Subtract lane-wise.
260    pub fn sub_assign(&mut self, rhs: &Self) {
261        for i in 0..SIMD_UNITS_IN_RING_ELEMENT {
262            subtract_coeffs(&mut self.simd[i], &rhs.simd[i]);
263        }
264    }
265}
266
267/// Fill SIMD layout from the first 256 coefficients of `buf` (ML-DSA `from_i32_array` order).
268#[must_use]
269pub fn simd_from_i256(
270    buf: &[i32; COEFFICIENTS_IN_RING_ELEMENT],
271) -> [Coefficients; SIMD_UNITS_IN_RING_ELEMENT] {
272    Poly::from_coeffs(*buf).to_simd()
273}
274
275/// Returns `1` iff every polynomial in `polys` has infinity norm at most `bound`.
276#[must_use]
277pub fn polys_norm_within_bound(polys: &[Poly], bound: i32) -> Choice {
278    let mut acc = Choice::from(1u8);
279    for p in polys {
280        acc &= p.norm_within_bound(bound);
281    }
282    acc
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288
289    fn lcg_step(state: &mut u64) -> u32 {
290        *state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
291        (*state >> 32) as u32
292    }
293
294    fn small_poly(state: &mut u64, bound: i32) -> Poly {
295        let mut coeffs = [0i32; COEFFICIENTS_IN_RING_ELEMENT];
296        let width = (2 * bound + 1) as u32;
297        for c in &mut coeffs {
298            let v = (lcg_step(state) % width) as i32;
299            *c = v - bound;
300        }
301        Poly::from_coeffs(coeffs)
302    }
303
304    #[test]
305    fn ntt_inverse_has_expected_linear_scale() {
306        let mut one = [0i32; COEFFICIENTS_IN_RING_ELEMENT];
307        one[0] = 1;
308        let scale = Poly::from_coeffs(one).to_ntt().to_poly().coeffs[0];
309
310        let mut st = 0xC0DEC0DE_u64;
311        for _ in 0..16 {
312            let p = small_poly(&mut st, 8);
313            let back = p.clone().to_ntt().to_poly();
314            for (orig, got) in p.coeffs.iter().zip(back.coeffs.iter()) {
315                let expected = reduce_element((*orig as i64 * scale as i64) as i32);
316                assert_eq!(expected, *got);
317            }
318        }
319    }
320
321    #[test]
322    fn ntt_pointwise_matches_schoolbook_for_small_coeffs() {
323        let mut st = 0xDEADBEEF_u64;
324        for _ in 0..4 {
325            let a = small_poly(&mut st, 8);
326            let b = small_poly(&mut st, 8);
327            let schoolbook = a.mul_negacyclic(&b);
328
329            let mut ntt = a.to_ntt();
330            let b_ntt = b.to_ntt();
331            ntt.pointwise_mul_assign(&b_ntt);
332            let back = ntt.to_poly();
333
334            assert_eq!(schoolbook, back);
335        }
336    }
337
338    fn infinity_norm_branchy_reference(p: &Poly) -> i32 {
339        let half = FIELD_MODULUS / 2;
340        let mut m = 0i32;
341        for &c in &p.coeffs {
342            let v = if c > half { c - FIELD_MODULUS } else { c };
343            m = m.max(v.abs());
344        }
345        m
346    }
347
348    #[test]
349    fn infinity_norm_matches_branchy_reference() {
350        let q = FIELD_MODULUS;
351        let mut st = 0xA11CE_u64;
352        for _ in 0..256 {
353            let mut coeffs = [0i32; COEFFICIENTS_IN_RING_ELEMENT];
354            for c in &mut coeffs {
355                *c = (lcg_step(&mut st) as i32) % q;
356            }
357            let p = Poly::from_coeffs(coeffs);
358            assert_eq!(p.infinity_norm(), infinity_norm_branchy_reference(&p));
359        }
360        for &edge in &[0, 1, q / 2, q / 2 + 1, q - 1] {
361            let mut p = Poly::zero();
362            p.coeffs[0] = edge;
363            p.coeffs[1] = -edge;
364            assert_eq!(p.infinity_norm(), infinity_norm_branchy_reference(&p));
365        }
366    }
367
368    #[test]
369    fn normalize_mod_q_and_scalar_mul_smoke() {
370        let mut p = Poly::zero();
371        p.coeffs[0] = FIELD_MODULUS + 5;
372        p.normalize_mod_q_assign();
373        assert!((0..FIELD_MODULUS).contains(&p.coeffs[0]));
374        p.coeffs[1] = -3;
375        p.normalize_mod_q_assign();
376        assert!((0..FIELD_MODULUS).contains(&p.coeffs[1]));
377        let scaled = p.scalar_mul_by_u32_mod_q(3);
378        assert_eq!(scaled.coeffs[0], reduce_element(p.coeffs[0] * 3));
379    }
380}