fhe_math/ntt/
native.rs

1use crate::zq::Modulus;
2use itertools::Itertools;
3use rand::{Rng, SeedableRng};
4use rand_chacha::ChaCha8Rng;
5use std::iter::successors;
6
7/// Number-Theoretic Transform operator.
8#[derive(Debug, Clone, PartialEq, Eq)]
9pub struct NttOperator {
10    p: Modulus,
11    p_twice: u64,
12    size: usize,
13    omegas: Box<[u64]>,
14    omegas_shoup: Box<[u64]>,
15    zetas_inv: Box<[u64]>,
16    zetas_inv_shoup: Box<[u64]>,
17    size_inv: u64,
18    size_inv_shoup: u64,
19}
20
21impl NttOperator {
22    /// Create an NTT operator given a modulus for a specific size.
23    ///
24    /// Aborts if the size is not a power of 2 that is >= 8 in debug mode.
25    /// Returns None if the modulus does not support the NTT for this specific
26    /// size.
27    pub fn new(p: &Modulus, size: usize) -> Option<Self> {
28        if !super::supports_ntt(p.p, size) {
29            None
30        } else {
31            let size_inv = p.inv(size as u64)?;
32
33            let omega = Self::primitive_root(size, p);
34            let omega_inv = p.inv(omega)?;
35
36            let powers = successors(Some(1u64), |n| Some(p.mul(*n, omega)))
37                .take(size)
38                .collect_vec();
39            let powers_inv = successors(Some(omega_inv), |n| Some(p.mul(*n, omega_inv)))
40                .take(size)
41                .collect_vec();
42
43            let (omegas, zetas_inv): (Vec<u64>, Vec<u64>) = (0..size)
44                .map(|i| {
45                    let j = i.reverse_bits() >> (size.leading_zeros() + 1);
46                    (powers[j], powers_inv[j])
47                })
48                .unzip();
49
50            let omegas_shoup = p.shoup_vec(&omegas);
51            let zetas_inv_shoup = p.shoup_vec(&zetas_inv);
52
53            Some(Self {
54                p: p.clone(),
55                p_twice: p.p * 2,
56                size,
57                omegas: omegas.into_boxed_slice(),
58                omegas_shoup: omegas_shoup.into_boxed_slice(),
59                zetas_inv: zetas_inv.into_boxed_slice(),
60                zetas_inv_shoup: zetas_inv_shoup.into_boxed_slice(),
61                size_inv,
62                size_inv_shoup: p.shoup(size_inv),
63            })
64        }
65    }
66
67    /// Compute the forward NTT in place.
68    /// Aborts if a is not of the size handled by the operator.
69    pub fn forward(&self, a: &mut [u64]) {
70        debug_assert_eq!(a.len(), self.size);
71
72        let mut l = self.size >> 1;
73        let mut k = 1;
74        while l > 0 {
75            for chunk in a.chunks_exact_mut(2 * l) {
76                let omega = self.omegas[k];
77                let omega_shoup = self.omegas_shoup[k];
78                k += 1;
79
80                let (left, right) = chunk.split_at_mut(l);
81                if l == 1 {
82                    // The last level should reduce the output
83                    self.butterfly(&mut left[0], &mut right[0], omega, omega_shoup);
84                    left[0] = self.reduce3(left[0]);
85                    right[0] = self.reduce3(right[0]);
86                } else {
87                    for (x, y) in left.iter_mut().zip(right.iter_mut()) {
88                        self.butterfly(x, y, omega, omega_shoup);
89                    }
90                }
91            }
92            l >>= 1;
93        }
94    }
95
96    /// Compute the backward NTT in place.
97    /// Aborts if a is not of the size handled by the operator.
98    pub fn backward(&self, a: &mut [u64]) {
99        debug_assert_eq!(a.len(), self.size);
100
101        let mut k = 0;
102        let mut l = 1;
103
104        while l < self.size {
105            for chunk in a.chunks_exact_mut(2 * l) {
106                let zeta_inv = self.zetas_inv[k];
107                let zeta_inv_shoup = self.zetas_inv_shoup[k];
108                k += 1;
109
110                let (left, right) = chunk.split_at_mut(l);
111                if l == 1 {
112                    self.inv_butterfly(&mut left[0], &mut right[0], zeta_inv, zeta_inv_shoup);
113                } else {
114                    for (x, y) in left.iter_mut().zip(right.iter_mut()) {
115                        self.inv_butterfly(x, y, zeta_inv, zeta_inv_shoup);
116                    }
117                }
118            }
119            l <<= 1;
120        }
121
122        a.iter_mut()
123            .for_each(|ai| *ai = self.p.mul_shoup(*ai, self.size_inv, self.size_inv_shoup));
124    }
125
126    /// Compute the forward NTT in place in variable time in a lazily fashion.
127    /// This means that the output coefficients may be up to 4 times the
128    /// modulus.
129    ///
130    /// # Safety
131    /// This function assumes that a_ptr points to at least `size` elements.
132    /// This function is not constant time and its timing may reveal information
133    /// about the value being reduced.
134    pub(crate) unsafe fn forward_vt_lazy(&self, a_ptr: *mut u64) {
135        let mut l = self.size >> 1;
136        let mut m = 1;
137        let mut k = 1;
138        while l > 0 {
139            for i in 0..m {
140                let omega = *self.omegas.get_unchecked(k);
141                let omega_shoup = *self.omegas_shoup.get_unchecked(k);
142                k += 1;
143
144                let s = 2 * i * l;
145                match l {
146                    1 => {
147                        self.butterfly_vt(
148                            &mut *a_ptr.add(s),
149                            &mut *a_ptr.add(s + l),
150                            omega,
151                            omega_shoup,
152                        );
153                    }
154                    _ => {
155                        for j in s..(s + l) {
156                            self.butterfly_vt(
157                                &mut *a_ptr.add(j),
158                                &mut *a_ptr.add(j + l),
159                                omega,
160                                omega_shoup,
161                            );
162                        }
163                    }
164                }
165            }
166            l >>= 1;
167            m <<= 1;
168        }
169    }
170
171    /// Compute the forward NTT in place in variable time.
172    ///
173    /// # Safety
174    /// This function assumes that a_ptr points to at least `size` elements.
175    /// This function is not constant time and its timing may reveal information
176    /// about the value being reduced.
177    pub unsafe fn forward_vt(&self, a_ptr: *mut u64) {
178        self.forward_vt_lazy(a_ptr);
179        for i in 0..self.size {
180            *a_ptr.add(i) = self.reduce3_vt(*a_ptr.add(i))
181        }
182    }
183
184    /// Compute the backward NTT in place in variable time.
185    ///
186    /// # Safety
187    /// This function assumes that a_ptr points to at least `size` elements.
188    /// This function is not constant time and its timing may reveal information
189    /// about the value being reduced.
190    pub unsafe fn backward_vt(&self, a_ptr: *mut u64) {
191        let mut k = 0;
192        let mut m = self.size >> 1;
193        let mut l = 1;
194        while m > 0 {
195            for i in 0..m {
196                let s = 2 * i * l;
197                let zeta_inv = *self.zetas_inv.get_unchecked(k);
198                let zeta_inv_shoup = *self.zetas_inv_shoup.get_unchecked(k);
199                k += 1;
200                match l {
201                    1 => {
202                        self.inv_butterfly_vt(
203                            &mut *a_ptr.add(s),
204                            &mut *a_ptr.add(s + l),
205                            zeta_inv,
206                            zeta_inv_shoup,
207                        );
208                    }
209                    _ => {
210                        for j in s..(s + l) {
211                            self.inv_butterfly_vt(
212                                &mut *a_ptr.add(j),
213                                &mut *a_ptr.add(j + l),
214                                zeta_inv,
215                                zeta_inv_shoup,
216                            );
217                        }
218                    }
219                }
220            }
221            l <<= 1;
222            m >>= 1;
223        }
224
225        for i in 0..self.size as isize {
226            *a_ptr.offset(i) =
227                self.p
228                    .mul_shoup(*a_ptr.offset(i), self.size_inv, self.size_inv_shoup)
229        }
230    }
231
232    /// Reduce a modulo p.
233    ///
234    /// Aborts if a >= 4 * p.
235    const fn reduce3(&self, a: u64) -> u64 {
236        debug_assert!(a < 4 * self.p.p);
237
238        let y = Modulus::reduce1(a, self.p_twice);
239        Modulus::reduce1(y, self.p.p)
240    }
241
242    /// Reduce a modulo p in variable time.
243    ///
244    /// Aborts if a >= 4 * p.
245    const unsafe fn reduce3_vt(&self, a: u64) -> u64 {
246        debug_assert!(a < 4 * self.p.p);
247
248        let y = Modulus::reduce1_vt(a, self.p_twice);
249        Modulus::reduce1_vt(y, self.p.p)
250    }
251
252    /// NTT Butterfly.
253    fn butterfly(&self, x: &mut u64, y: &mut u64, w: u64, w_shoup: u64) {
254        debug_assert!(*x < 4 * self.p.p);
255        debug_assert!(*y < 4 * self.p.p);
256        debug_assert!(w < self.p.p);
257        debug_assert_eq!(self.p.shoup(w), w_shoup);
258
259        *x = Modulus::reduce1(*x, self.p_twice);
260        let t = self.p.lazy_mul_shoup(*y, w, w_shoup);
261        *y = *x + self.p_twice - t;
262        *x += t;
263
264        debug_assert!(*x < 4 * self.p.p);
265        debug_assert!(*y < 4 * self.p.p);
266    }
267
268    /// NTT Butterfly in variable time.
269    unsafe fn butterfly_vt(&self, x: &mut u64, y: &mut u64, w: u64, w_shoup: u64) {
270        debug_assert!(*x < 4 * self.p.p);
271        debug_assert!(*y < 4 * self.p.p);
272        debug_assert!(w < self.p.p);
273        debug_assert_eq!(self.p.shoup(w), w_shoup);
274
275        *x = Modulus::reduce1_vt(*x, self.p_twice);
276        let t = self.p.lazy_mul_shoup(*y, w, w_shoup);
277        *y = *x + self.p_twice - t;
278        *x += t;
279
280        debug_assert!(*x < 4 * self.p.p);
281        debug_assert!(*y < 4 * self.p.p);
282    }
283
284    /// Inverse NTT butterfly.
285    fn inv_butterfly(&self, x: &mut u64, y: &mut u64, z: u64, z_shoup: u64) {
286        debug_assert!(*x < self.p_twice);
287        debug_assert!(*y < self.p_twice);
288        debug_assert!(z < self.p.p);
289        debug_assert_eq!(self.p.shoup(z), z_shoup);
290
291        let t = *x;
292        *x = Modulus::reduce1(*y + t, self.p_twice);
293        *y = self.p.lazy_mul_shoup(self.p_twice + t - *y, z, z_shoup);
294
295        debug_assert!(*x < self.p_twice);
296        debug_assert!(*y < self.p_twice);
297    }
298
299    /// Inverse NTT butterfly in variable time
300    unsafe fn inv_butterfly_vt(&self, x: &mut u64, y: &mut u64, z: u64, z_shoup: u64) {
301        debug_assert!(*x < self.p_twice);
302        debug_assert!(*y < self.p_twice);
303        debug_assert!(z < self.p.p);
304        debug_assert_eq!(self.p.shoup(z), z_shoup);
305
306        let t = *x;
307        *x = Modulus::reduce1_vt(*y + t, self.p_twice);
308        *y = self.p.lazy_mul_shoup(self.p_twice + t - *y, z, z_shoup);
309
310        debug_assert!(*x < self.p_twice);
311        debug_assert!(*y < self.p_twice);
312    }
313
314    /// Returns a 2n-th primitive root modulo p.
315    ///
316    /// Aborts if p is not prime or n is not a power of 2 that is >= 8.
317    fn primitive_root(n: usize, p: &Modulus) -> u64 {
318        debug_assert!(super::supports_ntt(p.p, n));
319
320        let lambda = (p.p - 1) / (2 * n as u64);
321
322        let mut rng: ChaCha8Rng = SeedableRng::seed_from_u64(0);
323        for _ in 0..100 {
324            let mut root = rng.random_range(0..p.p);
325            root = p.pow(root, lambda);
326            if Self::is_primitive_root(root, 2 * n, p) {
327                return root;
328            }
329        }
330
331        debug_assert!(false, "Couldn't find primitive root");
332        0
333    }
334
335    /// Returns whether a is a n-th primitive root of unity.
336    ///
337    /// Aborts if a >= p in debug mode.
338    fn is_primitive_root(a: u64, n: usize, p: &Modulus) -> bool {
339        debug_assert!(a < p.p);
340        debug_assert!(super::supports_ntt(p.p, n >> 1)); // TODO: This is not exactly the right condition here.
341
342        // A primitive root of unity is such that x^n = 1 mod p, and x^(n/p) != 1 mod p
343        // for all prime p dividing n.
344        (p.pow(a, n as u64) == 1) && (p.pow(a, (n / 2) as u64) != 1)
345    }
346}