Skip to main content

vaea_ntt/
poly.rs

1// Copyright (C) 2024-2026 Vaea SAS
2// SPDX-License-Identifier: AGPL-3.0-or-later
3//
4// This file is part of VaeaNTT.
5//
6// VaeaNTT is free software: you can redistribute it and/or modify it under
7// the terms of the GNU Affero General Public License as published by the
8// Free Software Foundation, either version 3 of the License, or (at your
9// option) any later version.
10//
11// VaeaNTT is distributed in the hope that it will be useful, but WITHOUT
12// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13// FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public
14// License for more details.
15//
16// You should have received a copy of the GNU Affero General Public License
17// along with VaeaNTT. If not, see <https://www.gnu.org/licenses/>.
18
19//! # Polynomial over Z_q\[X\]/(X^N + 1)
20//!
21//! Polynomials are stored in coefficient domain by default.
22//! Use [`Poly64::forward_ntt`] / [`Poly64::inverse_ntt`] to switch between
23//! coefficient and NTT (evaluation) domains.
24//!
25//! In NTT domain, multiplication is pointwise O(N) instead of O(N²).
26
27use crate::ntt64::arith::{mod_mul_barrett, Ntt64Arith};
28use crate::ntt64::context::{ntt_forward, ntt_inverse, Ntt64Context};
29use alloc::vec;
30use alloc::vec::Vec;
31#[cfg(feature = "rand")]
32use rand::Rng;
33#[cfg(feature = "rand")]
34use rand_distr::{Distribution, Normal};
35
36// ---------------------------------------------------------------------------
37// Poly64 — polynomial in Z_q\[X\]/(X^N+1)
38// ---------------------------------------------------------------------------
39
40/// Polynomial in R_q = Z_q\[X\]/(X^N + 1) with 64-bit coefficients.
41///
42/// Tracks whether the data is in coefficient domain or NTT (evaluation) domain.
43/// In NTT domain, multiplication is pointwise (O(N) instead of O(N²)).
44#[derive(Clone, Debug)]
45pub struct Poly64 {
46    /// Coefficients (coefficient domain) or evaluations (NTT domain).
47    pub data: Vec<u64>,
48    /// `true` if the polynomial is in NTT (evaluation) domain.
49    pub is_ntt: bool,
50}
51
52impl Poly64 {
53    // -------------------------------------------------------------------
54    // Constructors
55    // -------------------------------------------------------------------
56
57    /// Creates the zero polynomial with N coefficients.
58    #[inline]
59    pub fn new_zero(n: usize) -> Self {
60        Self {
61            data: vec![0u64; n],
62            is_ntt: false,
63        }
64    }
65
66    /// Creates a polynomial with uniform random coefficients in [0, q).
67    ///
68    /// Requires the `rand` feature.
69    #[cfg(feature = "rand")]
70    pub fn new_random(n: usize, arith: &Ntt64Arith) -> Self {
71        let mut rng = rand::thread_rng();
72        let q = arith.modulus;
73        let data: Vec<u64> = (0..n).map(|_| rng.gen_range(0..q)).collect();
74        Self {
75            data,
76            is_ntt: false,
77        }
78    }
79
80    /// Creates a ternary polynomial with coefficients in {0, 1, q−1}.
81    ///
82    /// q−1 represents −1 mod q. The distribution is uniform over {−1, 0, 1}.
83    /// Used for secret keys in CKKS/BFV.
84    ///
85    /// Requires the `rand` feature.
86    #[cfg(feature = "rand")]
87    pub fn new_ternary(n: usize, arith: &Ntt64Arith) -> Self {
88        let mut rng = rand::thread_rng();
89        let q = arith.modulus;
90        let data: Vec<u64> = (0..n)
91            .map(|_| match rng.gen_range(0u32..3) {
92                0 => 0,
93                1 => 1,
94                _ => q - 1,
95            })
96            .collect();
97        Self {
98            data,
99            is_ntt: false,
100        }
101    }
102
103    /// Creates a polynomial with discrete Gaussian noise.
104    ///
105    /// Each coefficient is drawn from N(0, σ²), rounded to the nearest integer,
106    /// then reduced mod q. Negative values are represented as q + value.
107    ///
108    /// Requires the `rand` feature.
109    #[cfg(feature = "rand")]
110    pub fn new_gaussian(n: usize, sigma: f64, arith: &Ntt64Arith) -> Self {
111        let mut rng = rand::thread_rng();
112        let q = arith.modulus;
113        let normal = Normal::new(0.0, sigma).expect("sigma must be > 0");
114        let data: Vec<u64> = (0..n)
115            .map(|_| {
116                let sample: f64 = normal.sample(&mut rng);
117                let rounded = sample.round() as i64;
118                if rounded >= 0 {
119                    (rounded as u64) % q
120                } else {
121                    let abs_val = (-rounded) as u64;
122                    let r = abs_val % q;
123                    if r == 0 {
124                        0
125                    } else {
126                        q - r
127                    }
128                }
129            })
130            .collect();
131        Self {
132            data,
133            is_ntt: false,
134        }
135    }
136
137    // -------------------------------------------------------------------
138    // NTT transforms
139    // -------------------------------------------------------------------
140
141    /// Converts from coefficient domain to NTT domain (in-place).
142    ///
143    /// # Panics
144    /// Panics if the polynomial is already in NTT domain.
145    pub fn forward_ntt(&mut self, ntt_ctx: &Ntt64Context) {
146        assert!(!self.is_ntt, "polynomial is already in NTT domain");
147        ntt_forward(&mut self.data, ntt_ctx);
148        self.is_ntt = true;
149    }
150
151    /// Converts from NTT domain to coefficient domain (in-place).
152    ///
153    /// # Panics
154    /// Panics if the polynomial is not in NTT domain.
155    pub fn inverse_ntt(&mut self, ntt_ctx: &Ntt64Context) {
156        assert!(self.is_ntt, "polynomial is not in NTT domain");
157        ntt_inverse(&mut self.data, ntt_ctx);
158        self.is_ntt = false;
159    }
160
161    // -------------------------------------------------------------------
162    // Arithmetic
163    // -------------------------------------------------------------------
164
165    /// Pointwise addition: `self += other (mod q)`.
166    ///
167    /// Both polynomials must be in the same domain (NTT or coefficient).
168    ///
169    /// # Panics
170    /// Panics if domains or sizes don't match.
171    pub fn add_assign(&mut self, other: &Poly64, arith: &Ntt64Arith) {
172        assert_eq!(
173            self.is_ntt, other.is_ntt,
174            "polynomials must be in the same domain"
175        );
176        assert_eq!(
177            self.data.len(),
178            other.data.len(),
179            "polynomials must have the same size"
180        );
181        let q = arith.modulus;
182        for (a, &b) in self.data.iter_mut().zip(other.data.iter()) {
183            let sum = *a + b;
184            // Branchless via overflowing_sub
185            let (sub, borrow) = sum.overflowing_sub(q);
186            *a = if borrow { sum } else { sub };
187        }
188    }
189
190    /// Pointwise subtraction: `self -= other (mod q)`.
191    ///
192    /// Both polynomials must be in the same domain.
193    ///
194    /// # Panics
195    /// Panics if domains or sizes don't match.
196    pub fn sub_assign(&mut self, other: &Poly64, arith: &Ntt64Arith) {
197        assert_eq!(
198            self.is_ntt, other.is_ntt,
199            "polynomials must be in the same domain"
200        );
201        assert_eq!(
202            self.data.len(),
203            other.data.len(),
204            "polynomials must have the same size"
205        );
206        let q = arith.modulus;
207        for (a, &b) in self.data.iter_mut().zip(other.data.iter()) {
208            let (sub, borrow) = (*a).overflowing_sub(b);
209            *a = if borrow { sub.wrapping_add(q) } else { sub };
210        }
211    }
212
213    /// Pointwise multiplication: `self *= other (mod q)`.
214    ///
215    /// **Both polynomials must be in NTT domain** so that pointwise multiplication
216    /// corresponds to negacyclic convolution in coefficient domain.
217    ///
218    /// # Panics
219    /// Panics if polynomials are not in NTT domain or have different sizes.
220    pub fn mul_assign(&mut self, other: &Poly64, arith: &Ntt64Arith) {
221        assert!(
222            self.is_ntt && other.is_ntt,
223            "both polynomials must be in NTT domain for multiplication"
224        );
225        assert_eq!(
226            self.data.len(),
227            other.data.len(),
228            "polynomials must have the same size"
229        );
230        for (a, &b) in self.data.iter_mut().zip(other.data.iter()) {
231            *a = mod_mul_barrett(*a, b, arith);
232        }
233    }
234
235    /// Scalar multiplication: `self *= scalar (mod q)`.
236    pub fn scalar_mul(&mut self, scalar: u64, arith: &Ntt64Arith) {
237        for a in self.data.iter_mut() {
238            *a = mod_mul_barrett(*a, scalar, arith);
239        }
240    }
241
242    /// Negation: `self = −self (mod q)`, i.e. `self[i] = q − self[i]`.
243    pub fn negate(&mut self, arith: &Ntt64Arith) {
244        let q = arith.modulus;
245        for a in self.data.iter_mut() {
246            // Branchless: mask = (a != 0) as u64 * u64::MAX, then q & mask - *a ...
247            // but the branch here is on public data (coefficients), not secrets,
248            // and the branch predictor handles it well. Keep it simple.
249            *a = if *a == 0 { 0 } else { q - *a };
250        }
251    }
252
253    // -------------------------------------------------------------------
254    // Utilities
255    // -------------------------------------------------------------------
256
257    /// Number of coefficients (= max degree + 1).
258    #[inline]
259    pub fn len(&self) -> usize {
260        self.data.len()
261    }
262
263    /// Whether the polynomial has zero length.
264    #[inline]
265    pub fn is_empty(&self) -> bool {
266        self.data.is_empty()
267    }
268}
269
270// ---------------------------------------------------------------------------
271// Naive polynomial multiplication (test-only)
272// ---------------------------------------------------------------------------
273
274/// Naive polynomial multiplication in Z_q\[X\]/(X^N+1).
275///
276/// O(N²) complexity. Used only in tests to verify NTT-based multiplication.
277#[cfg(test)]
278#[allow(unused_variables, clippy::needless_range_loop, dead_code)]
279fn naive_poly_mul(a: &[u64], b: &[u64], q: u64) -> Vec<u64> {
280    let n = a.len();
281    assert_eq!(n, b.len());
282    let mut result = vec![0u64; n];
283
284    for i in 0..n {
285        for j in 0..n {
286            let prod = (a[i] as u128) * (b[j] as u128);
287            let idx = i + j;
288            if idx < n {
289                let val = (result[idx] as u128 + prod) % (q as u128);
290                result[idx] = val as u64;
291            } else {
292                let wrapped_idx = idx - n;
293                let val = (result[wrapped_idx] as u128 + (q as u128) - (prod % (q as u128)))
294                    % (q as u128);
295                result[wrapped_idx] = val as u64;
296            }
297        }
298    }
299    result
300}
301
302// ---------------------------------------------------------------------------
303// Tests
304// ---------------------------------------------------------------------------
305
306#[cfg(test)]
307#[allow(unused_variables, clippy::needless_range_loop, dead_code)]
308mod tests {
309    use super::*;
310    use crate::ntt64::arith::Ntt64Arith;
311    use crate::ntt64::context::Ntt64Context;
312
313    // Small NTT-friendly prime for N=256: q = 7681 = 15·512+1
314    const TEST_Q: u64 = 7681;
315    const TEST_N: usize = 256;
316
317    fn test_arith() -> Ntt64Arith {
318        Ntt64Arith::new(TEST_Q)
319    }
320
321    fn test_ntt_ctx() -> Ntt64Context {
322        Ntt64Context::new(TEST_N, test_arith())
323    }
324
325    #[test]
326    fn test_poly_add_sub() {
327        let arith = test_arith();
328        let a = Poly64::new_random(TEST_N, &arith);
329        let b = Poly64::new_random(TEST_N, &arith);
330
331        let mut c = a.clone();
332        c.add_assign(&b, &arith);
333        c.sub_assign(&b, &arith);
334
335        for i in 0..TEST_N {
336            assert_eq!(c.data[i], a.data[i], "add/sub roundtrip fails at index {i}");
337        }
338    }
339
340    #[test]
341    fn test_poly_add_commutative() {
342        let arith = test_arith();
343        let a = Poly64::new_random(TEST_N, &arith);
344        let b = Poly64::new_random(TEST_N, &arith);
345
346        let mut ab = a.clone();
347        ab.add_assign(&b, &arith);
348
349        let mut ba = b.clone();
350        ba.add_assign(&a, &arith);
351
352        for i in 0..TEST_N {
353            assert_eq!(ab.data[i], ba.data[i], "add not commutative at index {i}");
354        }
355    }
356
357    #[test]
358    fn test_poly_negate() {
359        let arith = test_arith();
360        let a = Poly64::new_random(TEST_N, &arith);
361
362        let mut neg_a = a.clone();
363        neg_a.negate(&arith);
364
365        let mut sum = a.clone();
366        sum.add_assign(&neg_a, &arith);
367
368        for i in 0..TEST_N {
369            assert_eq!(sum.data[i], 0, "a + (-a) != 0 at index {i}");
370        }
371    }
372
373    #[test]
374    fn test_poly_scalar_mul() {
375        let arith = test_arith();
376        let a = Poly64::new_random(TEST_N, &arith);
377
378        let mut doubled = a.clone();
379        doubled.scalar_mul(2, &arith);
380
381        let mut sum = a.clone();
382        sum.add_assign(&a, &arith);
383
384        for i in 0..TEST_N {
385            assert_eq!(doubled.data[i], sum.data[i], "2*a != a+a at index {i}");
386        }
387    }
388
389    #[test]
390    fn test_poly_mul_ntt() {
391        let arith = test_arith();
392        let ntt_ctx = test_ntt_ctx();
393
394        let mut a = Poly64::new_zero(TEST_N);
395        a.data[0] = 1;
396        a.data[1] = 1;
397
398        let mut b = Poly64::new_zero(TEST_N);
399        b.data[0] = 1;
400        b.data[2] = 1;
401
402        let expected = naive_poly_mul(&a.data, &b.data, TEST_Q);
403
404        a.forward_ntt(&ntt_ctx);
405        b.forward_ntt(&ntt_ctx);
406        a.mul_assign(&b, &arith);
407        a.inverse_ntt(&ntt_ctx);
408
409        for i in 0..TEST_N {
410            assert_eq!(a.data[i], expected[i], "NTT mul != naive at index {i}");
411        }
412    }
413
414    #[test]
415    fn test_poly_mul_random_ntt() {
416        let arith = test_arith();
417        let ntt_ctx = test_ntt_ctx();
418
419        let a_orig = Poly64::new_random(TEST_N, &arith);
420        let b_orig = Poly64::new_random(TEST_N, &arith);
421
422        let expected = naive_poly_mul(&a_orig.data, &b_orig.data, TEST_Q);
423
424        let mut a = a_orig.clone();
425        let mut b = b_orig.clone();
426        a.forward_ntt(&ntt_ctx);
427        b.forward_ntt(&ntt_ctx);
428        a.mul_assign(&b, &arith);
429        a.inverse_ntt(&ntt_ctx);
430
431        for i in 0..TEST_N {
432            assert_eq!(a.data[i], expected[i], "NTT mul != naive at index {i}");
433        }
434    }
435
436    #[test]
437    fn test_ternary_distribution() {
438        let arith = test_arith();
439        let poly = Poly64::new_ternary(1024, &arith);
440
441        for (i, &coeff) in poly.data.iter().enumerate() {
442            assert!(
443                coeff == 0 || coeff == 1 || coeff == TEST_Q - 1,
444                "invalid ternary coefficient at index {i}: {coeff}"
445            );
446        }
447
448        let count_zero = poly.data.iter().filter(|&&c| c == 0).count();
449        let count_one = poly.data.iter().filter(|&&c| c == 1).count();
450        let count_neg = poly.data.iter().filter(|&&c| c == TEST_Q - 1).count();
451
452        assert!(count_zero > 0);
453        assert!(count_one > 0);
454        assert!(count_neg > 0);
455    }
456
457    #[test]
458    fn test_gaussian_distribution() {
459        let arith = test_arith();
460        let sigma = 3.2;
461        let n = 8192;
462        let poly = Poly64::new_gaussian(n, sigma, &arith);
463
464        let q = TEST_Q as f64;
465        let half_q = q / 2.0;
466        let centered: Vec<f64> = poly
467            .data
468            .iter()
469            .map(|&c| {
470                let c = c as f64;
471                if c > half_q {
472                    c - q
473                } else {
474                    c
475                }
476            })
477            .collect();
478
479        let mean = centered.iter().sum::<f64>() / n as f64;
480        assert!(mean.abs() < 0.5, "mean too far from 0: {mean}");
481
482        let variance = centered.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
483        let std_dev = variance.sqrt();
484        assert!(
485            (std_dev - sigma).abs() < 1.0,
486            "stddev too far from {sigma}: {std_dev}"
487        );
488    }
489
490    #[test]
491    fn test_ntt_roundtrip() {
492        let arith = test_arith();
493        let ntt_ctx = test_ntt_ctx();
494        let original = Poly64::new_random(TEST_N, &arith);
495
496        let mut poly = original.clone();
497        poly.forward_ntt(&ntt_ctx);
498        assert!(poly.is_ntt);
499        poly.inverse_ntt(&ntt_ctx);
500        assert!(!poly.is_ntt);
501
502        for i in 0..TEST_N {
503            assert_eq!(
504                poly.data[i], original.data[i],
505                "NTT roundtrip fails at index {i}"
506            );
507        }
508    }
509
510    #[test]
511    fn test_new_zero() {
512        let poly = Poly64::new_zero(64);
513        assert_eq!(poly.len(), 64);
514        assert!(!poly.is_ntt);
515        for &c in &poly.data {
516            assert_eq!(c, 0);
517        }
518    }
519
520    #[test]
521    #[should_panic(expected = "already in NTT domain")]
522    fn test_double_forward_ntt_panics() {
523        let arith = test_arith();
524        let ntt_ctx = test_ntt_ctx();
525        let mut poly = Poly64::new_random(TEST_N, &arith);
526        poly.forward_ntt(&ntt_ctx);
527        poly.forward_ntt(&ntt_ctx);
528    }
529
530    #[test]
531    #[should_panic(expected = "not in NTT domain")]
532    fn test_inverse_ntt_without_forward_panics() {
533        let arith = test_arith();
534        let ntt_ctx = test_ntt_ctx();
535        let mut poly = Poly64::new_random(TEST_N, &arith);
536        poly.inverse_ntt(&ntt_ctx);
537    }
538}