Skip to main content

vaea_ntt/
poly.rs

1// Copyright (C) 2024-2026 Vaea SAS
2// SPDX-License-Identifier: AGPL-3.0-or-later
3//
4// This file is part of VaeaNTT.
5//
6// VaeaNTT is free software: you can redistribute it and/or modify it under
7// the terms of the GNU Affero General Public License as published by the
8// Free Software Foundation, either version 3 of the License, or (at your
9// option) any later version.
10//
11// VaeaNTT is distributed in the hope that it will be useful, but WITHOUT
12// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13// FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public
14// License for more details.
15//
16// You should have received a copy of the GNU Affero General Public License
17// along with VaeaNTT. If not, see <https://www.gnu.org/licenses/>.
18
19
20//! # Polynomial over Z_q\[X\]/(X^N + 1)
21//!
22//! Polynomials are stored in coefficient domain by default.
23//! Use [`Poly64::forward_ntt`] / [`Poly64::inverse_ntt`] to switch between
24//! coefficient and NTT (evaluation) domains.
25//!
26//! In NTT domain, multiplication is pointwise O(N) instead of O(N²).
27
28use crate::ntt64::arith::{mod_mul_barrett, Ntt64Arith};
29use crate::ntt64::context::{ntt_forward, ntt_inverse, Ntt64Context};
30use alloc::vec;
31use alloc::vec::Vec;
32#[cfg(feature = "rand")]
33use rand::Rng;
34#[cfg(feature = "rand")]
35use rand_distr::{Distribution, Normal};
36
37// ---------------------------------------------------------------------------
38// Poly64 — polynomial in Z_q\[X\]/(X^N+1)
39// ---------------------------------------------------------------------------
40
41/// Polynomial in R_q = Z_q\[X\]/(X^N + 1) with 64-bit coefficients.
42///
43/// Tracks whether the data is in coefficient domain or NTT (evaluation) domain.
44/// In NTT domain, multiplication is pointwise (O(N) instead of O(N²)).
45#[derive(Clone, Debug)]
46pub struct Poly64 {
47    /// Coefficients (coefficient domain) or evaluations (NTT domain).
48    pub data: Vec<u64>,
49    /// `true` if the polynomial is in NTT (evaluation) domain.
50    pub is_ntt: bool,
51}
52
53impl Poly64 {
54    // -------------------------------------------------------------------
55    // Constructors
56    // -------------------------------------------------------------------
57
58    /// Creates the zero polynomial with N coefficients.
59    #[inline]
60    pub fn new_zero(n: usize) -> Self {
61        Self {
62            data: vec![0u64; n],
63            is_ntt: false,
64        }
65    }
66
67    /// Creates a polynomial with uniform random coefficients in [0, q).
68    ///
69    /// Requires the `rand` feature.
70    #[cfg(feature = "rand")]
71    pub fn new_random(n: usize, arith: &Ntt64Arith) -> Self {
72        let mut rng = rand::thread_rng();
73        let q = arith.modulus;
74        let data: Vec<u64> = (0..n).map(|_| rng.gen_range(0..q)).collect();
75        Self {
76            data,
77            is_ntt: false,
78        }
79    }
80
81    /// Creates a ternary polynomial with coefficients in {0, 1, q−1}.
82    ///
83    /// q−1 represents −1 mod q. The distribution is uniform over {−1, 0, 1}.
84    /// Used for secret keys in CKKS/BFV.
85    ///
86    /// Requires the `rand` feature.
87    #[cfg(feature = "rand")]
88    pub fn new_ternary(n: usize, arith: &Ntt64Arith) -> Self {
89        let mut rng = rand::thread_rng();
90        let q = arith.modulus;
91        let data: Vec<u64> = (0..n)
92            .map(|_| match rng.gen_range(0u32..3) {
93                0 => 0,
94                1 => 1,
95                _ => q - 1,
96            })
97            .collect();
98        Self {
99            data,
100            is_ntt: false,
101        }
102    }
103
104    /// Creates a polynomial with discrete Gaussian noise.
105    ///
106    /// Each coefficient is drawn from N(0, σ²), rounded to the nearest integer,
107    /// then reduced mod q. Negative values are represented as q + value.
108    ///
109    /// Requires the `rand` feature.
110    #[cfg(feature = "rand")]
111    pub fn new_gaussian(n: usize, sigma: f64, arith: &Ntt64Arith) -> Self {
112        let mut rng = rand::thread_rng();
113        let q = arith.modulus;
114        let normal = Normal::new(0.0, sigma).expect("sigma must be > 0");
115        let data: Vec<u64> = (0..n)
116            .map(|_| {
117                let sample: f64 = normal.sample(&mut rng);
118                let rounded = sample.round() as i64;
119                if rounded >= 0 {
120                    (rounded as u64) % q
121                } else {
122                    let abs_val = (-rounded) as u64;
123                    let r = abs_val % q;
124                    if r == 0 {
125                        0
126                    } else {
127                        q - r
128                    }
129                }
130            })
131            .collect();
132        Self {
133            data,
134            is_ntt: false,
135        }
136    }
137
138    // -------------------------------------------------------------------
139    // NTT transforms
140    // -------------------------------------------------------------------
141
142    /// Converts from coefficient domain to NTT domain (in-place).
143    ///
144    /// # Panics
145    /// Panics if the polynomial is already in NTT domain.
146    pub fn forward_ntt(&mut self, ntt_ctx: &Ntt64Context) {
147        assert!(!self.is_ntt, "polynomial is already in NTT domain");
148        ntt_forward(&mut self.data, ntt_ctx);
149        self.is_ntt = true;
150    }
151
152    /// Converts from NTT domain to coefficient domain (in-place).
153    ///
154    /// # Panics
155    /// Panics if the polynomial is not in NTT domain.
156    pub fn inverse_ntt(&mut self, ntt_ctx: &Ntt64Context) {
157        assert!(self.is_ntt, "polynomial is not in NTT domain");
158        ntt_inverse(&mut self.data, ntt_ctx);
159        self.is_ntt = false;
160    }
161
162    // -------------------------------------------------------------------
163    // Arithmetic
164    // -------------------------------------------------------------------
165
166    /// Pointwise addition: `self += other (mod q)`.
167    ///
168    /// Both polynomials must be in the same domain (NTT or coefficient).
169    ///
170    /// # Panics
171    /// Panics if domains or sizes don't match.
172    pub fn add_assign(&mut self, other: &Poly64, arith: &Ntt64Arith) {
173        assert_eq!(
174            self.is_ntt, other.is_ntt,
175            "polynomials must be in the same domain"
176        );
177        assert_eq!(
178            self.data.len(),
179            other.data.len(),
180            "polynomials must have the same size"
181        );
182        let q = arith.modulus;
183        for (a, &b) in self.data.iter_mut().zip(other.data.iter()) {
184            let sum = *a + b;
185            // Branchless via overflowing_sub
186            let (sub, borrow) = sum.overflowing_sub(q);
187            *a = if borrow { sum } else { sub };
188        }
189    }
190
191    /// Pointwise subtraction: `self -= other (mod q)`.
192    ///
193    /// Both polynomials must be in the same domain.
194    ///
195    /// # Panics
196    /// Panics if domains or sizes don't match.
197    pub fn sub_assign(&mut self, other: &Poly64, arith: &Ntt64Arith) {
198        assert_eq!(
199            self.is_ntt, other.is_ntt,
200            "polynomials must be in the same domain"
201        );
202        assert_eq!(
203            self.data.len(),
204            other.data.len(),
205            "polynomials must have the same size"
206        );
207        let q = arith.modulus;
208        for (a, &b) in self.data.iter_mut().zip(other.data.iter()) {
209            let (sub, borrow) = (*a).overflowing_sub(b);
210            *a = if borrow { sub.wrapping_add(q) } else { sub };
211        }
212    }
213
214    /// Pointwise multiplication: `self *= other (mod q)`.
215    ///
216    /// **Both polynomials must be in NTT domain** so that pointwise multiplication
217    /// corresponds to negacyclic convolution in coefficient domain.
218    ///
219    /// # Panics
220    /// Panics if polynomials are not in NTT domain or have different sizes.
221    pub fn mul_assign(&mut self, other: &Poly64, arith: &Ntt64Arith) {
222        assert!(
223            self.is_ntt && other.is_ntt,
224            "both polynomials must be in NTT domain for multiplication"
225        );
226        assert_eq!(
227            self.data.len(),
228            other.data.len(),
229            "polynomials must have the same size"
230        );
231        for (a, &b) in self.data.iter_mut().zip(other.data.iter()) {
232            *a = mod_mul_barrett(*a, b, arith);
233        }
234    }
235
236    /// Scalar multiplication: `self *= scalar (mod q)`.
237    pub fn scalar_mul(&mut self, scalar: u64, arith: &Ntt64Arith) {
238        for a in self.data.iter_mut() {
239            *a = mod_mul_barrett(*a, scalar, arith);
240        }
241    }
242
243    /// Negation: `self = −self (mod q)`, i.e. `self[i] = q − self[i]`.
244    pub fn negate(&mut self, arith: &Ntt64Arith) {
245        let q = arith.modulus;
246        for a in self.data.iter_mut() {
247            // Branchless: mask = (a != 0) as u64 * u64::MAX, then q & mask - *a ...
248            // but the branch here is on public data (coefficients), not secrets,
249            // and the branch predictor handles it well. Keep it simple.
250            *a = if *a == 0 { 0 } else { q - *a };
251        }
252    }
253
254    // -------------------------------------------------------------------
255    // Utilities
256    // -------------------------------------------------------------------
257
258    /// Number of coefficients (= max degree + 1).
259    #[inline]
260    pub fn len(&self) -> usize {
261        self.data.len()
262    }
263
264    /// Whether the polynomial has zero length.
265    #[inline]
266    pub fn is_empty(&self) -> bool {
267        self.data.is_empty()
268    }
269}
270
271// ---------------------------------------------------------------------------
272// Naive polynomial multiplication (test-only)
273// ---------------------------------------------------------------------------
274
275/// Naive polynomial multiplication in Z_q\[X\]/(X^N+1).
276///
277/// O(N²) complexity. Used only in tests to verify NTT-based multiplication.
278#[cfg(test)]
279fn naive_poly_mul(a: &[u64], b: &[u64], q: u64) -> Vec<u64> {
280    let n = a.len();
281    assert_eq!(n, b.len());
282    let mut result = vec![0u64; n];
283
284    for i in 0..n {
285        for j in 0..n {
286            let prod = (a[i] as u128) * (b[j] as u128);
287            let idx = i + j;
288            if idx < n {
289                let val = (result[idx] as u128 + prod) % (q as u128);
290                result[idx] = val as u64;
291            } else {
292                let wrapped_idx = idx - n;
293                let val = (result[wrapped_idx] as u128 + (q as u128) - (prod % (q as u128)))
294                    % (q as u128);
295                result[wrapped_idx] = val as u64;
296            }
297        }
298    }
299    result
300}
301
302// ---------------------------------------------------------------------------
303// Tests
304// ---------------------------------------------------------------------------
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309    use crate::ntt64::arith::Ntt64Arith;
310    use crate::ntt64::context::Ntt64Context;
311
312    // Small NTT-friendly prime for N=256: q = 7681 = 15·512+1
313    const TEST_Q: u64 = 7681;
314    const TEST_N: usize = 256;
315
316    fn test_arith() -> Ntt64Arith {
317        Ntt64Arith::new(TEST_Q)
318    }
319
320    fn test_ntt_ctx() -> Ntt64Context {
321        Ntt64Context::new(TEST_N, test_arith())
322    }
323
324    #[test]
325    fn test_poly_add_sub() {
326        let arith = test_arith();
327        let a = Poly64::new_random(TEST_N, &arith);
328        let b = Poly64::new_random(TEST_N, &arith);
329
330        let mut c = a.clone();
331        c.add_assign(&b, &arith);
332        c.sub_assign(&b, &arith);
333
334        for i in 0..TEST_N {
335            assert_eq!(c.data[i], a.data[i], "add/sub roundtrip fails at index {i}");
336        }
337    }
338
339    #[test]
340    fn test_poly_add_commutative() {
341        let arith = test_arith();
342        let a = Poly64::new_random(TEST_N, &arith);
343        let b = Poly64::new_random(TEST_N, &arith);
344
345        let mut ab = a.clone();
346        ab.add_assign(&b, &arith);
347
348        let mut ba = b.clone();
349        ba.add_assign(&a, &arith);
350
351        for i in 0..TEST_N {
352            assert_eq!(ab.data[i], ba.data[i], "add not commutative at index {i}");
353        }
354    }
355
356    #[test]
357    fn test_poly_negate() {
358        let arith = test_arith();
359        let a = Poly64::new_random(TEST_N, &arith);
360
361        let mut neg_a = a.clone();
362        neg_a.negate(&arith);
363
364        let mut sum = a.clone();
365        sum.add_assign(&neg_a, &arith);
366
367        for i in 0..TEST_N {
368            assert_eq!(sum.data[i], 0, "a + (-a) != 0 at index {i}");
369        }
370    }
371
372    #[test]
373    fn test_poly_scalar_mul() {
374        let arith = test_arith();
375        let a = Poly64::new_random(TEST_N, &arith);
376
377        let mut doubled = a.clone();
378        doubled.scalar_mul(2, &arith);
379
380        let mut sum = a.clone();
381        sum.add_assign(&a, &arith);
382
383        for i in 0..TEST_N {
384            assert_eq!(doubled.data[i], sum.data[i], "2*a != a+a at index {i}");
385        }
386    }
387
388    #[test]
389    fn test_poly_mul_ntt() {
390        let arith = test_arith();
391        let ntt_ctx = test_ntt_ctx();
392
393        let mut a = Poly64::new_zero(TEST_N);
394        a.data[0] = 1;
395        a.data[1] = 1;
396
397        let mut b = Poly64::new_zero(TEST_N);
398        b.data[0] = 1;
399        b.data[2] = 1;
400
401        let expected = naive_poly_mul(&a.data, &b.data, TEST_Q);
402
403        a.forward_ntt(&ntt_ctx);
404        b.forward_ntt(&ntt_ctx);
405        a.mul_assign(&b, &arith);
406        a.inverse_ntt(&ntt_ctx);
407
408        for i in 0..TEST_N {
409            assert_eq!(a.data[i], expected[i], "NTT mul != naive at index {i}");
410        }
411    }
412
413    #[test]
414    fn test_poly_mul_random_ntt() {
415        let arith = test_arith();
416        let ntt_ctx = test_ntt_ctx();
417
418        let a_orig = Poly64::new_random(TEST_N, &arith);
419        let b_orig = Poly64::new_random(TEST_N, &arith);
420
421        let expected = naive_poly_mul(&a_orig.data, &b_orig.data, TEST_Q);
422
423        let mut a = a_orig.clone();
424        let mut b = b_orig.clone();
425        a.forward_ntt(&ntt_ctx);
426        b.forward_ntt(&ntt_ctx);
427        a.mul_assign(&b, &arith);
428        a.inverse_ntt(&ntt_ctx);
429
430        for i in 0..TEST_N {
431            assert_eq!(a.data[i], expected[i], "NTT mul != naive at index {i}");
432        }
433    }
434
435    #[test]
436    fn test_ternary_distribution() {
437        let arith = test_arith();
438        let poly = Poly64::new_ternary(1024, &arith);
439
440        for (i, &coeff) in poly.data.iter().enumerate() {
441            assert!(
442                coeff == 0 || coeff == 1 || coeff == TEST_Q - 1,
443                "invalid ternary coefficient at index {i}: {coeff}"
444            );
445        }
446
447        let count_zero = poly.data.iter().filter(|&&c| c == 0).count();
448        let count_one = poly.data.iter().filter(|&&c| c == 1).count();
449        let count_neg = poly.data.iter().filter(|&&c| c == TEST_Q - 1).count();
450
451        assert!(count_zero > 0);
452        assert!(count_one > 0);
453        assert!(count_neg > 0);
454    }
455
456    #[test]
457    fn test_gaussian_distribution() {
458        let arith = test_arith();
459        let sigma = 3.2;
460        let n = 8192;
461        let poly = Poly64::new_gaussian(n, sigma, &arith);
462
463        let q = TEST_Q as f64;
464        let half_q = q / 2.0;
465        let centered: Vec<f64> = poly
466            .data
467            .iter()
468            .map(|&c| {
469                let c = c as f64;
470                if c > half_q {
471                    c - q
472                } else {
473                    c
474                }
475            })
476            .collect();
477
478        let mean = centered.iter().sum::<f64>() / n as f64;
479        assert!(mean.abs() < 0.5, "mean too far from 0: {mean}");
480
481        let variance = centered.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
482        let std_dev = variance.sqrt();
483        assert!(
484            (std_dev - sigma).abs() < 1.0,
485            "stddev too far from {sigma}: {std_dev}"
486        );
487    }
488
489    #[test]
490    fn test_ntt_roundtrip() {
491        let arith = test_arith();
492        let ntt_ctx = test_ntt_ctx();
493        let original = Poly64::new_random(TEST_N, &arith);
494
495        let mut poly = original.clone();
496        poly.forward_ntt(&ntt_ctx);
497        assert!(poly.is_ntt);
498        poly.inverse_ntt(&ntt_ctx);
499        assert!(!poly.is_ntt);
500
501        for i in 0..TEST_N {
502            assert_eq!(
503                poly.data[i], original.data[i],
504                "NTT roundtrip fails at index {i}"
505            );
506        }
507    }
508
509    #[test]
510    fn test_new_zero() {
511        let poly = Poly64::new_zero(64);
512        assert_eq!(poly.len(), 64);
513        assert!(!poly.is_ntt);
514        for &c in &poly.data {
515            assert_eq!(c, 0);
516        }
517    }
518
519    #[test]
520    #[should_panic(expected = "already in NTT domain")]
521    fn test_double_forward_ntt_panics() {
522        let arith = test_arith();
523        let ntt_ctx = test_ntt_ctx();
524        let mut poly = Poly64::new_random(TEST_N, &arith);
525        poly.forward_ntt(&ntt_ctx);
526        poly.forward_ntt(&ntt_ctx);
527    }
528
529    #[test]
530    #[should_panic(expected = "not in NTT domain")]
531    fn test_inverse_ntt_without_forward_panics() {
532        let arith = test_arith();
533        let ntt_ctx = test_ntt_ctx();
534        let mut poly = Poly64::new_random(TEST_N, &arith);
535        poly.inverse_ntt(&ntt_ctx);
536    }
537}