qfall_math/integer_mod_q/ntt_basis_polynomial_ring_zq/
ntt.rs

1// Copyright © 2025 Marvin Beckmann
2//
3// This file is part of qFALL-math.
4//
5// qFALL-math is free software: you can redistribute it and/or modify it under
6// the terms of the Mozilla Public License Version 2.0 as published by the
7// Mozilla Foundation. See <https://mozilla.org/en-US/MPL/2.0/>.
8
9//! This module contains implementations to perform the NTT for [`PolyOverZq`]
10//! in the respective polynomial ring.
11//! The implementation mostly follows the description in <https://higashi.blog/2023/06/23/ntt-02/>.
12//!
13//! The explicit functions contain the documentation.
14
15use super::{NTTBasisPolynomialRingZq, from::ConvolutionType};
16use crate::{
17    integer::Z,
18    integer_mod_q::{Modulus, PolyOverZq},
19    traits::GetCoefficient,
20    utils::index::bit_reverse_permutation,
21};
22use flint_sys::fmpz_mod::{fmpz_mod_add, fmpz_mod_ctx, fmpz_mod_mul, fmpz_mod_sub};
23
24impl NTTBasisPolynomialRingZq {
25    /// For a given polynomial viewed in the polynomial ring defined by the `self`, it computes the NTT.
26    ///
27    /// It computes the iterative Cooley-Tukey transformation algorithm to compute the transform.
28    /// Polynomials of degree smaller than `n-1` are with `0` coefficients.
29    ///
30    /// Parameters:
31    /// - `poly`: The polynomial for which it is desired to compute the NTT
32    ///
33    /// Returns the NTT of a polynomial in the context of the polynomial ring.
34    ///
35    /// # Examples
36    /// ```
37    /// use qfall_math::integer::Z;
38    /// use qfall_math::integer_mod_q::{Modulus, PolyOverZq, NTTBasisPolynomialRingZq, ConvolutionType};
39    /// use std::str::FromStr;
40    ///
41    /// let g_poly = PolyOverZq::from_str("4  1 2 3 4 mod 7681").unwrap();
42    /// let modulus = Modulus::from(7681);
43    ///
44    /// let ntt_basis =
45    ///     NTTBasisPolynomialRingZq::init(4, 3383, &modulus, ConvolutionType::Cyclic);
46    ///
47    /// let ghat = ntt_basis.ntt(&g_poly);
48    ///
49    /// let cmp_ghat = vec![
50    ///     Z::from(10),
51    ///     Z::from(913),
52    ///     Z::from(7679),
53    ///     Z::from(6764),
54    /// ];
55    /// assert_eq!(cmp_ghat, ghat);
56    /// ```
57    /// ```
58    /// use qfall_math::integer::Z;
59    /// use qfall_math::integer_mod_q::{Modulus, PolyOverZq, NTTBasisPolynomialRingZq, ConvolutionType};
60    /// use std::str::FromStr;
61    ///
62    /// let g_poly = PolyOverZq::from_str("4  1 2 3 4 mod 7681").unwrap();
63    /// let modulus = Modulus::from(7681);
64    ///
65    /// let ntt_basis =
66    ///     NTTBasisPolynomialRingZq::init(4, 1925, &modulus, ConvolutionType::Negacyclic);
67    ///
68    /// let ghat = ntt_basis.ntt(&g_poly);
69    ///
70    /// let cmp_ghat = vec![
71    ///     Z::from(1467),
72    ///     Z::from(2807),
73    ///     Z::from(3471),
74    ///     Z::from(7621),
75    /// ];
76    /// assert_eq!(cmp_ghat, ghat);
77    /// ```
78    ///
79    /// # Panics if ...
80    /// - it is not reduced, i.e. has a coefficient of degree > n
81    /// - the modulus differs from the modulus over which we view the polynomial
82    ///
83    /// The algorithm is based on the Cooley-Tukey algorithm:
84    /// -\[1\] Cooley, James W., and John W. Tukey.
85    ///     "An algorithm for the machine calculation of complex Fourier series."
86    ///     Mathematics of computation 19.90 (1965): 297-301.
87    pub fn ntt(&self, poly: &PolyOverZq) -> Vec<Z> {
88        assert!(poly.get_degree() < self.n);
89        assert_eq!(poly.get_mod(), self.modulus);
90        // we use the unsafe getter, because we know that all indices are in the range
91        // and no error can occur here
92        let mut poly_coeffs: Vec<Z> = (0..self.n)
93            .map(|i| unsafe { poly.get_coeff_unchecked(i) })
94            .collect();
95        for _ in poly_coeffs.len()..(self.n as usize) {
96            poly_coeffs.push(Z::default());
97        }
98
99        // Negacyclic: perform preprocessing
100        if self.convolution_type == ConvolutionType::Negacyclic {
101            for (i, x) in poly_coeffs.iter_mut().enumerate() {
102                unsafe {
103                    fmpz_mod_mul(
104                        &mut x.value,
105                        &x.value,
106                        &self.powers_of_psi[i].value,
107                        self.modulus.get_fmpz_mod_ctx_struct(),
108                    );
109                }
110            }
111        }
112
113        iterative_ntt(poly_coeffs, &self.powers_of_omega, &self.modulus)
114    }
115}
116
117/// This function essentially computes the included butterfly computations for each provided
118/// chunk.
119/// The chunk is double the size of the stride.
120/// The computation currently performs the standard butterfly operation from Cooley-Tukey
121unsafe fn ntt_stride_steps(
122    chunk: &mut [Z],
123    stride: usize,
124    power_pointer: i64,
125    modulus_pointer: &fmpz_mod_ctx,
126    powers_of_omega_pointers: &[Z],
127) {
128    for i in 0..stride {
129        // compute power of the current level
130        let current_power = &powers_of_omega_pointers[2_usize.pow(power_pointer as u32) * (i)];
131
132        // CT butterfly
133        // by using Z, we can manage not to initialize additional modulus objects in this part
134        // and save runtime
135        let mut temp = Z::default();
136
137        unsafe {
138            fmpz_mod_mul(
139                &mut temp.value,
140                &current_power.value,
141                &chunk[i + stride].value,
142                modulus_pointer,
143            );
144            fmpz_mod_sub(
145                &mut chunk[i + stride].value,
146                &chunk[i].value,
147                &temp.value,
148                modulus_pointer,
149            );
150            fmpz_mod_add(
151                &mut chunk[i].value,
152                &chunk[i].value,
153                &temp.value,
154                modulus_pointer,
155            )
156        }
157    }
158}
159
160/// This algorithm performs an iterative version of the Cooley-Tukey algorithm.
161/// It takes in the coefficients of the polynomial and the precomputed powers of the
162/// root of unity.
163/// Here, we assume that precomputation has been done, i.e.: if the algorithm is
164/// called for negatively wrapped convolution, then this has been accounted for in the previous algorithm.
165///
166/// The algorithm possesses the option to be multi-threaded, but benchmarking has shown,
167/// that it makes the algorithm less efficient, so we turned it off.
168fn iterative_ntt(coefficients: Vec<Z>, powers_of_omega: &[Z], modulus: &Modulus) -> Vec<Z> {
169    let n = coefficients.len();
170    let nr_iterations = n.ilog2() as i64;
171
172    // compute the bit reversed order of the coefficients
173    let mut res = coefficients;
174    bit_reverse_permutation(&mut res);
175    let modulus_pointer = modulus.get_fmpz_mod_ctx_struct();
176
177    let mut power_pointer: i64 = nr_iterations - 1;
178    let mut stride = 1;
179    // iterate through all layers
180    while stride < n {
181        // split into strides and perform action for each respective stride
182        res.chunks_mut(2 * stride).for_each(|chunk| unsafe {
183            ntt_stride_steps(
184                chunk,
185                stride,
186                power_pointer,
187                modulus_pointer,
188                powers_of_omega,
189            )
190        });
191        stride *= 2;
192        power_pointer -= 1;
193    }
194    res
195}
196
197#[cfg(test)]
198mod test_ntt {
199    use crate::{
200        integer::Z,
201        integer_mod_q::{ConvolutionType, Modulus, NTTBasisPolynomialRingZq, PolyOverZq},
202    };
203    use std::str::FromStr;
204
205    /// This example is taken from: https://eprint.iacr.org/2024/585.pdf Example 3.4
206    #[test]
207    fn example_34_multiplication_with_ntt() {
208        let g_poly = PolyOverZq::from_str("4  1 2 3 4 mod 7681").unwrap();
209        let modulus = Modulus::from(7681);
210
211        let ntt_basis = NTTBasisPolynomialRingZq::init(4, 3383, &modulus, ConvolutionType::Cyclic);
212
213        let ghat = ntt_basis.ntt(&g_poly);
214        let cmp_ghat = vec![Z::from(10), Z::from(913), Z::from(7679), Z::from(6764)];
215        assert_eq!(cmp_ghat, ghat);
216    }
217
218    /// Ensure that NTT panics, if the degree of the polynomial is too high, i.e. not reduced.
219    #[test]
220    #[should_panic]
221    fn degree_too_high() {
222        let g_poly = PolyOverZq::from_str("5  1 2 3 4 5 mod 7681").unwrap();
223        let modulus = Modulus::from(7681);
224
225        let ntt_basis = NTTBasisPolynomialRingZq::init(4, 3383, &modulus, ConvolutionType::Cyclic);
226
227        let _ = ntt_basis.ntt(&g_poly);
228    }
229
230    /// Ensure that NTT panics, if the modulus of the polynomial is different
231    #[test]
232    #[should_panic]
233    fn different_modulus() {
234        let g_poly = PolyOverZq::from_str("4  1 2 3 4 mod 7681").unwrap();
235        let modulus = Modulus::from(7682);
236
237        let ntt_basis = NTTBasisPolynomialRingZq::init(4, 3383, &modulus, ConvolutionType::Cyclic);
238
239        let _ = ntt_basis.ntt(&g_poly);
240    }
241
242    /// Ensure that NTT works for smaller degree polynomials
243    #[test]
244    fn small_degree() {
245        let g_poly = PolyOverZq::from_str("2  1 2 mod 7681").unwrap();
246        let modulus = Modulus::from(7681);
247
248        let ntt_basis =
249            NTTBasisPolynomialRingZq::init(4, 1925, &modulus, ConvolutionType::Negacyclic);
250
251        let ghat = ntt_basis.ntt(&g_poly);
252        let cmp_ghat = vec![Z::from(3851), Z::from(5256), Z::from(3832), Z::from(2427)];
253        assert_eq!(cmp_ghat, ghat);
254    }
255}