qfall_math/integer_mod_q/ntt_basis_polynomial_ring_zq/ntt.rs
1// Copyright © 2025 Marvin Beckmann
2//
3// This file is part of qFALL-math.
4//
5// qFALL-math is free software: you can redistribute it and/or modify it under
6// the terms of the Mozilla Public License Version 2.0 as published by the
7// Mozilla Foundation. See <https://mozilla.org/en-US/MPL/2.0/>.
8
9//! This module contains implementations to perform the NTT for [`PolyOverZq`]
10//! in the respective polynomial ring.
11//! The implementation mostly follows the description in <https://higashi.blog/2023/06/23/ntt-02/>.
12//!
13//! The explicit functions contain the documentation.
14
15use super::{NTTBasisPolynomialRingZq, from::ConvolutionType};
16use crate::{
17 integer::Z,
18 integer_mod_q::{Modulus, PolyOverZq},
19 traits::GetCoefficient,
20 utils::index::bit_reverse_permutation,
21};
22use flint_sys::fmpz_mod::{fmpz_mod_add, fmpz_mod_ctx, fmpz_mod_mul, fmpz_mod_sub};
23
24impl NTTBasisPolynomialRingZq {
25 /// For a given polynomial viewed in the polynomial ring defined by the `self`, it computes the NTT.
26 ///
27 /// It computes the iterative Cooley-Tukey transformation algorithm to compute the transform.
28 /// Polynomials of degree smaller than `n-1` are with `0` coefficients.
29 ///
30 /// Parameters:
31 /// - `poly`: The polynomial for which it is desired to compute the NTT
32 ///
33 /// Returns the NTT of a polynomial in the context of the polynomial ring.
34 ///
35 /// # Examples
36 /// ```
37 /// use qfall_math::integer::Z;
38 /// use qfall_math::integer_mod_q::{Modulus, PolyOverZq, NTTBasisPolynomialRingZq, ConvolutionType};
39 /// use std::str::FromStr;
40 ///
41 /// let g_poly = PolyOverZq::from_str("4 1 2 3 4 mod 7681").unwrap();
42 /// let modulus = Modulus::from(7681);
43 ///
44 /// let ntt_basis =
45 /// NTTBasisPolynomialRingZq::init(4, 3383, &modulus, ConvolutionType::Cyclic);
46 ///
47 /// let ghat = ntt_basis.ntt(&g_poly);
48 ///
49 /// let cmp_ghat = vec![
50 /// Z::from(10),
51 /// Z::from(913),
52 /// Z::from(7679),
53 /// Z::from(6764),
54 /// ];
55 /// assert_eq!(cmp_ghat, ghat);
56 /// ```
57 /// ```
58 /// use qfall_math::integer::Z;
59 /// use qfall_math::integer_mod_q::{Modulus, PolyOverZq, NTTBasisPolynomialRingZq, ConvolutionType};
60 /// use std::str::FromStr;
61 ///
62 /// let g_poly = PolyOverZq::from_str("4 1 2 3 4 mod 7681").unwrap();
63 /// let modulus = Modulus::from(7681);
64 ///
65 /// let ntt_basis =
66 /// NTTBasisPolynomialRingZq::init(4, 1925, &modulus, ConvolutionType::Negacyclic);
67 ///
68 /// let ghat = ntt_basis.ntt(&g_poly);
69 ///
70 /// let cmp_ghat = vec![
71 /// Z::from(1467),
72 /// Z::from(2807),
73 /// Z::from(3471),
74 /// Z::from(7621),
75 /// ];
76 /// assert_eq!(cmp_ghat, ghat);
77 /// ```
78 ///
79 /// # Panics if ...
80 /// - it is not reduced, i.e. has a coefficient of degree > n
81 /// - the modulus differs from the modulus over which we view the polynomial
82 ///
83 /// The algorithm is based on the Cooley-Tukey algorithm:
84 /// -\[1\] Cooley, James W., and John W. Tukey.
85 /// "An algorithm for the machine calculation of complex Fourier series."
86 /// Mathematics of computation 19.90 (1965): 297-301.
87 pub fn ntt(&self, poly: &PolyOverZq) -> Vec<Z> {
88 assert!(poly.get_degree() < self.n);
89 assert_eq!(poly.get_mod(), self.modulus);
90 // we use the unsafe getter, because we know that all indices are in the range
91 // and no error can occur here
92 let mut poly_coeffs: Vec<Z> = (0..self.n)
93 .map(|i| unsafe { poly.get_coeff_unchecked(i) })
94 .collect();
95 for _ in poly_coeffs.len()..(self.n as usize) {
96 poly_coeffs.push(Z::default());
97 }
98
99 // Negacyclic: perform preprocessing
100 if self.convolution_type == ConvolutionType::Negacyclic {
101 for (i, x) in poly_coeffs.iter_mut().enumerate() {
102 unsafe {
103 fmpz_mod_mul(
104 &mut x.value,
105 &x.value,
106 &self.powers_of_psi[i].value,
107 self.modulus.get_fmpz_mod_ctx_struct(),
108 );
109 }
110 }
111 }
112
113 iterative_ntt(poly_coeffs, &self.powers_of_omega, &self.modulus)
114 }
115}
116
117/// This function essentially computes the included butterfly computations for each provided
118/// chunk.
119/// The chunk is double the size of the stride.
120/// The computation currently performs the standard butterfly operation from Cooley-Tukey
121unsafe fn ntt_stride_steps(
122 chunk: &mut [Z],
123 stride: usize,
124 power_pointer: i64,
125 modulus_pointer: &fmpz_mod_ctx,
126 powers_of_omega_pointers: &[Z],
127) {
128 for i in 0..stride {
129 // compute power of the current level
130 let current_power = &powers_of_omega_pointers[2_usize.pow(power_pointer as u32) * (i)];
131
132 // CT butterfly
133 // by using Z, we can manage not to initialize additional modulus objects in this part
134 // and save runtime
135 let mut temp = Z::default();
136
137 unsafe {
138 fmpz_mod_mul(
139 &mut temp.value,
140 ¤t_power.value,
141 &chunk[i + stride].value,
142 modulus_pointer,
143 );
144 fmpz_mod_sub(
145 &mut chunk[i + stride].value,
146 &chunk[i].value,
147 &temp.value,
148 modulus_pointer,
149 );
150 fmpz_mod_add(
151 &mut chunk[i].value,
152 &chunk[i].value,
153 &temp.value,
154 modulus_pointer,
155 )
156 }
157 }
158}
159
160/// This algorithm performs an iterative version of the Cooley-Tukey algorithm.
161/// It takes in the coefficients of the polynomial and the precomputed powers of the
162/// root of unity.
163/// Here, we assume that precomputation has been done, i.e.: if the algorithm is
164/// called for negatively wrapped convolution, then this has been accounted for in the previous algorithm.
165///
166/// The algorithm possesses the option to be multi-threaded, but benchmarking has shown,
167/// that it makes the algorithm less efficient, so we turned it off.
168fn iterative_ntt(coefficients: Vec<Z>, powers_of_omega: &[Z], modulus: &Modulus) -> Vec<Z> {
169 let n = coefficients.len();
170 let nr_iterations = n.ilog2() as i64;
171
172 // compute the bit reversed order of the coefficients
173 let mut res = coefficients;
174 bit_reverse_permutation(&mut res);
175 let modulus_pointer = modulus.get_fmpz_mod_ctx_struct();
176
177 let mut power_pointer: i64 = nr_iterations - 1;
178 let mut stride = 1;
179 // iterate through all layers
180 while stride < n {
181 // split into strides and perform action for each respective stride
182 res.chunks_mut(2 * stride).for_each(|chunk| unsafe {
183 ntt_stride_steps(
184 chunk,
185 stride,
186 power_pointer,
187 modulus_pointer,
188 powers_of_omega,
189 )
190 });
191 stride *= 2;
192 power_pointer -= 1;
193 }
194 res
195}
196
197#[cfg(test)]
198mod test_ntt {
199 use crate::{
200 integer::Z,
201 integer_mod_q::{ConvolutionType, Modulus, NTTBasisPolynomialRingZq, PolyOverZq},
202 };
203 use std::str::FromStr;
204
205 /// This example is taken from: https://eprint.iacr.org/2024/585.pdf Example 3.4
206 #[test]
207 fn example_34_multiplication_with_ntt() {
208 let g_poly = PolyOverZq::from_str("4 1 2 3 4 mod 7681").unwrap();
209 let modulus = Modulus::from(7681);
210
211 let ntt_basis = NTTBasisPolynomialRingZq::init(4, 3383, &modulus, ConvolutionType::Cyclic);
212
213 let ghat = ntt_basis.ntt(&g_poly);
214 let cmp_ghat = vec![Z::from(10), Z::from(913), Z::from(7679), Z::from(6764)];
215 assert_eq!(cmp_ghat, ghat);
216 }
217
218 /// Ensure that NTT panics, if the degree of the polynomial is too high, i.e. not reduced.
219 #[test]
220 #[should_panic]
221 fn degree_too_high() {
222 let g_poly = PolyOverZq::from_str("5 1 2 3 4 5 mod 7681").unwrap();
223 let modulus = Modulus::from(7681);
224
225 let ntt_basis = NTTBasisPolynomialRingZq::init(4, 3383, &modulus, ConvolutionType::Cyclic);
226
227 let _ = ntt_basis.ntt(&g_poly);
228 }
229
230 /// Ensure that NTT panics, if the modulus of the polynomial is different
231 #[test]
232 #[should_panic]
233 fn different_modulus() {
234 let g_poly = PolyOverZq::from_str("4 1 2 3 4 mod 7681").unwrap();
235 let modulus = Modulus::from(7682);
236
237 let ntt_basis = NTTBasisPolynomialRingZq::init(4, 3383, &modulus, ConvolutionType::Cyclic);
238
239 let _ = ntt_basis.ntt(&g_poly);
240 }
241
242 /// Ensure that NTT works for smaller degree polynomials
243 #[test]
244 fn small_degree() {
245 let g_poly = PolyOverZq::from_str("2 1 2 mod 7681").unwrap();
246 let modulus = Modulus::from(7681);
247
248 let ntt_basis =
249 NTTBasisPolynomialRingZq::init(4, 1925, &modulus, ConvolutionType::Negacyclic);
250
251 let ghat = ntt_basis.ntt(&g_poly);
252 let cmp_ghat = vec![Z::from(3851), Z::from(5256), Z::from(3832), Z::from(2427)];
253 assert_eq!(cmp_ghat, ghat);
254 }
255}