qfall_math/integer_mod_q/ntt_basis_polynomial_ring_zq/
from.rs

1// Copyright © 2025 Marvin Beckmann
2//
3// This file is part of qFALL-math.
4//
5// qFALL-math is free software: you can redistribute it and/or modify it under
6// the terms of the Mozilla Public License Version 2.0 as published by the
7// Mozilla Foundation. See <https://mozilla.org/en-US/MPL/2.0/>.
8
9//! This module contains implementations to construct [`NTTBasisPolynomialRingZq`] for the
10//! NTT-transform for [`PolyOverZq`](crate::integer_mod_q::PolyOverZq) objects in the polynomialring.
11//!
12//! The explicit functions contain the documentation.
13
14use super::NTTBasisPolynomialRingZq;
15use crate::{
16    integer::Z,
17    integer_mod_q::{Modulus, Zq},
18    traits::Pow,
19};
20
21impl NTTBasisPolynomialRingZq {
22    /// This function allows to initialize a [`NTTBasisPolynomialRingZq`]
23    /// object.
24    /// We currently only allow for two kinds of moduli to accompany the construction:
25    /// It must be either cyclic (`X^n - 1`) or negacyclic (`X^n + 1`) convoluted wrapping (also submitted in the algorithm)
26    /// and the degree of the polynomial must be a power of two.
27    /// Only then can we compute a full-split of the polynomial ring.
28    /// Accordingly, the `root_of_unity` must be a `n`th root of unity or respectively a `2n`th root of unity.
29    ///
30    /// This function does not check if the provided root of unity is actually a root of unity!
31    ///
32    /// Parameters:
33    /// - `n`: the degree of the polynomial
34    /// - `root_of_unity`: the `n`-th or `2n`-th root of unity
35    /// - `q`: the modulus of the polynomial
36    /// - `convolution_type`: defines whether convolution is cyclic or negacyclic
37    ///
38    /// # Examples
39    /// ```
40    /// use qfall_math::integer_mod_q::Modulus;
41    /// use qfall_math::integer_mod_q::NTTBasisPolynomialRingZq;
42    /// use qfall_math::integer_mod_q::ConvolutionType;
43    ///
44    /// let modulus = Modulus::from(7681);
45    ///
46    /// // Initializes the NTT for `X^4 - 1 mod 7681`
47    /// let ntt_basis =
48    ///     NTTBasisPolynomialRingZq::init(4, 3383, &modulus, ConvolutionType::Cyclic);
49    ///
50    /// // Initializes the NTT for `X^4 + 1 mod 7681`
51    /// let ntt_basis =
52    ///     NTTBasisPolynomialRingZq::init(4, 1925, &modulus, ConvolutionType::Negacyclic);
53    /// ```
54    ///
55    /// # Panics...
56    /// - if `n` is not a power of two.
57    pub fn init(
58        n: usize,
59        root_of_unity: impl Into<Z>,
60        modulus: impl Into<Modulus>,
61        convolution_type: ConvolutionType,
62    ) -> Self {
63        assert_eq!(n.next_power_of_two(), n);
64        let n = n as i64;
65        let root_of_unity = Zq::from((root_of_unity, modulus));
66        let modulus = root_of_unity.get_mod();
67
68        let n_inv = Zq::from((n, modulus)).inverse().unwrap();
69        let root_of_unity_inv = root_of_unity.inverse().unwrap();
70
71        // map the input to the `n`th root of unity and prepare power computation
72        let (psi, psi_inv, omega, omega_inv) = match convolution_type {
73            ConvolutionType::Cyclic => (None, None, root_of_unity.clone(), root_of_unity_inv),
74            ConvolutionType::Negacyclic => (
75                Some(&root_of_unity),
76                Some(&root_of_unity_inv),
77                root_of_unity.pow(2).unwrap(),
78                root_of_unity.pow(-2).unwrap(),
79            ),
80        };
81
82        // precompute powers of `n`th root of unity
83        let powers_of_omega = (0..n)
84            .map(|i| {
85                omega
86                    .pow(i)
87                    .unwrap()
88                    .get_representative_least_nonnegative_residue()
89            })
90            .collect();
91        let powers_of_omega_inv = (0..n)
92            .map(|i| {
93                omega_inv
94                    .pow(i)
95                    .unwrap()
96                    .get_representative_least_nonnegative_residue()
97            })
98            .collect();
99
100        // precompute powers of `2n`th root of unity
101        let powers_of_psi = match convolution_type {
102            ConvolutionType::Cyclic => Vec::new(),
103            ConvolutionType::Negacyclic => (0..n)
104                .map(|i| {
105                    psi.unwrap()
106                        .pow(i)
107                        .unwrap()
108                        .get_representative_least_nonnegative_residue()
109                })
110                .collect(),
111        };
112        let powers_of_psi_inv = match convolution_type {
113            ConvolutionType::Cyclic => Vec::new(),
114            ConvolutionType::Negacyclic => (0..n)
115                .map(|i| {
116                    psi_inv
117                        .unwrap()
118                        .pow(i)
119                        .unwrap()
120                        .get_representative_least_nonnegative_residue()
121                })
122                .collect(),
123        };
124
125        Self {
126            n,
127            n_inv: n_inv.get_representative_least_nonnegative_residue(),
128            powers_of_omega,
129            powers_of_omega_inv,
130            powers_of_psi,
131            powers_of_psi_inv,
132            modulus: root_of_unity.get_mod(),
133            convolution_type: convolution_type.clone(),
134        }
135    }
136}
137
138/// This enum only serves the purpose of distinguishing between cycic or negacyclic wrapping
139/// in polynomial rings, and more specificially, for the purpose of distinguishing them when utilizing NTT
140/// for the polynomial rings.
141#[derive(Debug, Clone, PartialEq)]
142pub enum ConvolutionType {
143    Cyclic,
144    Negacyclic,
145}
146
147#[cfg(test)]
148mod test_init {
149    use super::ConvolutionType;
150    use crate::{
151        integer::Z,
152        integer_mod_q::{Modulus, NTTBasisPolynomialRingZq},
153    };
154
155    /// Our algorithm only supports complete splits as of right now, so other inputs should be prohibited for now.
156    #[test]
157    #[should_panic]
158    fn n_not_power_2() {
159        let _ = NTTBasisPolynomialRingZq::init(12315, 1, 2, ConvolutionType::Cyclic);
160    }
161
162    /// Ensure that the algorithm sets the set of values as expected
163    #[test]
164    fn set_values_correctly_cyclic() {
165        let ntt_basis = NTTBasisPolynomialRingZq::init(4, 3383, 7681, ConvolutionType::Cyclic);
166
167        assert_eq!(ConvolutionType::Cyclic, ntt_basis.convolution_type);
168        assert_eq!(Modulus::from(7681), ntt_basis.modulus);
169        assert_eq!(4, ntt_basis.n);
170        assert_eq!(Z::from(5761), ntt_basis.n_inv);
171        assert!(ntt_basis.powers_of_psi.is_empty());
172        assert!(ntt_basis.powers_of_psi_inv.is_empty());
173        assert_eq!(
174            vec![Z::from(1), Z::from(3383), Z::from(7680), Z::from(4298)],
175            ntt_basis.powers_of_omega
176        );
177        assert_eq!(
178            vec![Z::from(1), Z::from(4298), Z::from(7680), Z::from(3383)],
179            ntt_basis.powers_of_omega_inv
180        );
181    }
182
183    /// Ensure that the algorithm sets the set of values as expected
184    #[test]
185    fn set_values_correctly_negacyclic() {
186        let ntt_basis = NTTBasisPolynomialRingZq::init(4, 1925, 7681, ConvolutionType::Negacyclic);
187
188        assert_eq!(ConvolutionType::Negacyclic, ntt_basis.convolution_type);
189        assert_eq!(Modulus::from(7681), ntt_basis.modulus);
190        assert_eq!(4, ntt_basis.n);
191        assert_eq!(Z::from(5761), ntt_basis.n_inv);
192        assert_eq!(
193            vec![Z::from(1), Z::from(1925), Z::from(3383), Z::from(6468)],
194            ntt_basis.powers_of_psi
195        );
196        assert_eq!(
197            vec![Z::from(1), Z::from(1213), Z::from(4298), Z::from(5756)],
198            ntt_basis.powers_of_psi_inv
199        );
200        assert_eq!(
201            vec![Z::from(1), Z::from(3383), Z::from(7680), Z::from(4298)],
202            ntt_basis.powers_of_omega
203        );
204        assert_eq!(
205            vec![Z::from(1), Z::from(4298), Z::from(7680), Z::from(3383)],
206            ntt_basis.powers_of_omega_inv
207        );
208    }
209}