phantom_zone/
ntt.rs

1use itertools::{izip, Itertools};
2use rand::{Rng, RngCore, SeedableRng};
3use rand_chacha::ChaCha8Rng;
4
5use crate::{
6    backend::{ArithmeticOps, ModInit, ModularOpsU64, Modulus},
7    utils::{mod_exponent, mod_inverse, ShoupMul},
8};
9
10pub trait NttInit<M> {
11    /// Ntt istance must be compatible across different instances with same `q`
12    /// and `n`
13    fn new(q: &M, n: usize) -> Self;
14}
15
16pub trait Ntt {
17    type Element;
18    fn forward_lazy(&self, v: &mut [Self::Element]);
19    fn forward(&self, v: &mut [Self::Element]);
20    fn backward_lazy(&self, v: &mut [Self::Element]);
21    fn backward(&self, v: &mut [Self::Element]);
22}
23
24/// Forward butterfly routine for Number theoretic transform. Given inputs `x <
25/// 4q` and `y < 4q` mutates x and y in place to equal x' and y' where
26/// x' = x + wy
27/// y' = x - wy
28/// and both x' and y' are \in [0, 4q)
29///
30/// Implements Algorithm 4 of [FASTER ARITHMETIC FOR NUMBER-THEORETIC TRANSFORMS](https://arxiv.org/pdf/1205.2926.pdf)
31pub fn forward_butterly_0_to_4q(
32    mut x: u64,
33    y: u64,
34    w: u64,
35    w_shoup: u64,
36    q: u64,
37    q_twice: u64,
38) -> (u64, u64) {
39    debug_assert!(x < q * 4, "{} >= (4q){}", x, 4 * q);
40    debug_assert!(y < q * 4, "{} >= (4q){}", y, 4 * q);
41
42    if x >= q_twice {
43        x = x - q_twice;
44    }
45
46    let t = ShoupMul::mul(y, w, w_shoup, q);
47
48    (x + t, x + q_twice - t)
49}
50
51pub fn forward_butterly_0_to_2q(
52    mut x: u64,
53    y: u64,
54    w: u64,
55    w_shoup: u64,
56    q: u64,
57    q_twice: u64,
58) -> (u64, u64) {
59    debug_assert!(x < q * 4, "{} >= (4q){}", x, 4 * q);
60    debug_assert!(y < q * 4, "{} >= (4q){}", y, 4 * q);
61
62    if x >= q_twice {
63        x = x - q_twice;
64    }
65
66    let t = ShoupMul::mul(y, w, w_shoup, q);
67
68    let ox = x.wrapping_add(t);
69    let oy = x.wrapping_sub(t);
70
71    (
72        (ox).min(ox.wrapping_sub(q_twice)),
73        oy.min(oy.wrapping_add(q_twice)),
74    )
75}
76
77/// Inverse butterfly routine of Inverse Number theoretic transform. Given
78/// inputs `x < 2q` and `y < 2q` mutates x and y to equal x' and y' where
79/// x'= x + y
80/// y' = w(x - y)
81/// and both x' and y' are \in [0, 2q)
82///
83/// Implements Algorithm 3 of [FASTER ARITHMETIC FOR NUMBER-THEORETIC TRANSFORMS](https://arxiv.org/pdf/1205.2926.pdf)
84pub fn inverse_butterfly_0_to_2q(
85    x: u64,
86    y: u64,
87    w_inv: u64,
88    w_inv_shoup: u64,
89    q: u64,
90    q_twice: u64,
91) -> (u64, u64) {
92    debug_assert!(x < q_twice, "{} >= (2q){q_twice}", x);
93    debug_assert!(y < q_twice, "{} >= (2q){q_twice}", y);
94
95    let mut x_dash = x + y;
96    if x_dash >= q_twice {
97        x_dash -= q_twice
98    }
99
100    let t = x + q_twice - y;
101    let y = ShoupMul::mul(t, w_inv, w_inv_shoup, q);
102
103    (x_dash, y)
104}
105
106/// Number theoretic transform of vector `a` where each element can be in range
107/// [0, 2q). Outputs NTT(a) where each element is in range [0,2q)
108///
109/// Implements Cooley-tukey based forward NTT as given in Algorithm 1 of https://eprint.iacr.org/2016/504.pdf.
110pub fn ntt_lazy(a: &mut [u64], psi: &[u64], psi_shoup: &[u64], q: u64, q_twice: u64) {
111    assert!(a.len() == psi.len());
112
113    let n = a.len();
114    let mut t = n;
115
116    let mut m = 1;
117    while m < n {
118        t >>= 1;
119        let w = &psi[m..];
120        let w_shoup = &psi_shoup[m..];
121
122        if t == 1 {
123            for (a, w, w_shoup) in izip!(a.chunks_mut(2), w.iter(), w_shoup.iter()) {
124                let (ox, oy) = forward_butterly_0_to_2q(a[0], a[1], *w, *w_shoup, q, q_twice);
125                a[0] = ox;
126                a[1] = oy;
127            }
128        } else {
129            for i in 0..m {
130                let a = &mut a[2 * i * t..(2 * (i + 1) * t)];
131                let (left, right) = a.split_at_mut(t);
132
133                for (x, y) in izip!(left.iter_mut(), right.iter_mut()) {
134                    let (ox, oy) = forward_butterly_0_to_4q(*x, *y, w[i], w_shoup[i], q, q_twice);
135                    *x = ox;
136                    *y = oy;
137                }
138            }
139        }
140
141        m <<= 1;
142    }
143}
144
145/// Same as `ntt_lazy` with output in range [0, q)
146pub fn ntt(a: &mut [u64], psi: &[u64], psi_shoup: &[u64], q: u64, q_twice: u64) {
147    assert!(a.len() == psi.len());
148
149    let n = a.len();
150    let mut t = n;
151
152    let mut m = 1;
153    while m < n {
154        t >>= 1;
155        let w = &psi[m..];
156        let w_shoup = &psi_shoup[m..];
157
158        if t == 1 {
159            for (a, w, w_shoup) in izip!(a.chunks_mut(2), w.iter(), w_shoup.iter()) {
160                let (ox, oy) = forward_butterly_0_to_2q(a[0], a[1], *w, *w_shoup, q, q_twice);
161                // reduce from range [0, 2q) to [0, q)
162                a[0] = ox.min(ox.wrapping_sub(q));
163                a[1] = oy.min(oy.wrapping_sub(q));
164            }
165        } else {
166            for i in 0..m {
167                let a = &mut a[2 * i * t..(2 * (i + 1) * t)];
168                let (left, right) = a.split_at_mut(t);
169
170                for (x, y) in izip!(left.iter_mut(), right.iter_mut()) {
171                    let (ox, oy) = forward_butterly_0_to_4q(*x, *y, w[i], w_shoup[i], q, q_twice);
172                    *x = ox;
173                    *y = oy;
174                }
175            }
176        }
177
178        m <<= 1;
179    }
180}
181
182/// Inverse number theoretic transform of input vector `a` with each element can
183/// be in range [0, 2q). Outputs vector INTT(a) with each element in range [0,
184/// 2q)
185///
186/// Implements backward number theorectic transform using GS algorithm as given in Algorithm 2 of https://eprint.iacr.org/2016/504.pdf
187pub fn ntt_inv_lazy(
188    a: &mut [u64],
189    psi_inv: &[u64],
190    psi_inv_shoup: &[u64],
191    n_inv: u64,
192    n_inv_shoup: u64,
193    q: u64,
194    q_twice: u64,
195) {
196    assert!(a.len() == psi_inv.len());
197
198    let mut m = a.len() >> 1;
199    let mut t = 1;
200
201    while m > 0 {
202        if m == 1 {
203            let (left, right) = a.split_at_mut(t);
204
205            for (x, y) in izip!(left.iter_mut(), right.iter_mut()) {
206                let (ox, oy) =
207                    inverse_butterfly_0_to_2q(*x, *y, psi_inv[1], psi_inv_shoup[1], q, q_twice);
208                *x = ShoupMul::mul(ox, n_inv, n_inv_shoup, q);
209                *y = ShoupMul::mul(oy, n_inv, n_inv_shoup, q);
210            }
211        } else {
212            let w_inv = &psi_inv[m..];
213            let w_inv_shoup = &psi_inv_shoup[m..];
214            for i in 0..m {
215                let a = &mut a[2 * i * t..2 * (i + 1) * t];
216                let (left, right) = a.split_at_mut(t);
217
218                for (x, y) in izip!(left.iter_mut(), right.iter_mut()) {
219                    let (ox, oy) =
220                        inverse_butterfly_0_to_2q(*x, *y, w_inv[i], w_inv_shoup[i], q, q_twice);
221                    *x = ox;
222                    *y = oy;
223                }
224            }
225        }
226
227        t *= 2;
228        m >>= 1;
229    }
230}
231
232/// Same as `ntt_inv_lazy` with output in range [0, q)
233pub fn ntt_inv(
234    a: &mut [u64],
235    psi_inv: &[u64],
236    psi_inv_shoup: &[u64],
237    n_inv: u64,
238    n_inv_shoup: u64,
239    q: u64,
240    q_twice: u64,
241) {
242    assert!(a.len() == psi_inv.len());
243
244    let mut m = a.len() >> 1;
245    let mut t = 1;
246
247    while m > 0 {
248        if m == 1 {
249            let (left, right) = a.split_at_mut(t);
250
251            for (x, y) in izip!(left.iter_mut(), right.iter_mut()) {
252                let (ox, oy) =
253                    inverse_butterfly_0_to_2q(*x, *y, psi_inv[1], psi_inv_shoup[1], q, q_twice);
254                let ox = ShoupMul::mul(ox, n_inv, n_inv_shoup, q);
255                let oy = ShoupMul::mul(oy, n_inv, n_inv_shoup, q);
256                *x = ox.min(ox.wrapping_sub(q));
257                *y = oy.min(oy.wrapping_sub(q));
258            }
259        } else {
260            let w_inv = &psi_inv[m..];
261            let w_inv_shoup = &psi_inv_shoup[m..];
262            for i in 0..m {
263                let a = &mut a[2 * i * t..2 * (i + 1) * t];
264                let (left, right) = a.split_at_mut(t);
265
266                for (x, y) in izip!(left.iter_mut(), right.iter_mut()) {
267                    let (ox, oy) =
268                        inverse_butterfly_0_to_2q(*x, *y, w_inv[i], w_inv_shoup[i], q, q_twice);
269                    *x = ox;
270                    *y = oy;
271                }
272            }
273        }
274
275        t *= 2;
276        m >>= 1;
277    }
278}
279
280/// Find n^{th} root of unity in field F_q, if one exists
281///
282/// Note: n^{th} root of unity exists if and only if $q = 1 \mod{n}$
283pub(crate) fn find_primitive_root<R: RngCore>(q: u64, n: u64, rng: &mut R) -> Option<u64> {
284    assert!(n.is_power_of_two(), "{n} is not power of two");
285
286    // n^th root of unity only exists if n|(q-1)
287    assert!(q % n == 1, "{n}^th root of unity in F_{q} does not exists");
288
289    let t = (q - 1) / n;
290
291    for _ in 0..100 {
292        let mut omega = rng.gen::<u64>() % q;
293
294        // \omega = \omega^t. \omega is now n^th root of unity
295        omega = mod_exponent(omega, t, q);
296
297        // We restrict n to be power of 2. Thus checking whether \omega is primitive
298        // n^th root of unity is as simple as checking: \omega^{n/2} != 1
299        if mod_exponent(omega, n >> 1, q) == 1 {
300            continue;
301        } else {
302            return Some(omega);
303        }
304    }
305
306    None
307}
308
309#[derive(Debug)]
310pub struct NttBackendU64 {
311    q: u64,
312    q_twice: u64,
313    _n: u64,
314    n_inv: u64,
315    n_inv_shoup: u64,
316    psi_powers_bo: Box<[u64]>,
317    psi_inv_powers_bo: Box<[u64]>,
318    psi_powers_bo_shoup: Box<[u64]>,
319    psi_inv_powers_bo_shoup: Box<[u64]>,
320}
321
322impl NttBackendU64 {
323    fn _new(q: u64, n: usize) -> Self {
324        // \psi = 2n^{th} primitive root of unity in F_q
325        let mut rng = ChaCha8Rng::from_seed([0u8; 32]);
326        let psi = find_primitive_root(q, (n * 2) as u64, &mut rng)
327            .expect("Unable to find 2n^th root of unity");
328        let psi_inv = mod_inverse(psi, q);
329
330        // assert!(
331        //     ((psi_inv as u128 * psi as u128) % q as u128) == 1,
332        //     "psi:{psi}, psi_inv:{psi_inv}"
333        // );
334
335        let modulus = ModularOpsU64::new(q);
336
337        let mut psi_powers = Vec::with_capacity(n as usize);
338        let mut psi_inv_powers = Vec::with_capacity(n as usize);
339        let mut running_psi = 1;
340        let mut running_psi_inv = 1;
341        for _ in 0..n {
342            psi_powers.push(running_psi);
343            psi_inv_powers.push(running_psi_inv);
344
345            running_psi = modulus.mul(&running_psi, &psi);
346            running_psi_inv = modulus.mul(&running_psi_inv, &psi_inv);
347        }
348
349        // powers stored in bit reversed order
350        let mut psi_powers_bo = vec![0u64; n as usize];
351        let mut psi_inv_powers_bo = vec![0u64; n as usize];
352        let shift_by = n.leading_zeros() + 1;
353        for i in 0..n as usize {
354            // i in bit reversed order
355            let bo_index = i.reverse_bits() >> shift_by;
356
357            psi_powers_bo[bo_index] = psi_powers[i];
358            psi_inv_powers_bo[bo_index] = psi_inv_powers[i];
359        }
360
361        // shoup representation
362        let psi_powers_bo_shoup = psi_powers_bo
363            .iter()
364            .map(|v| ShoupMul::representation(*v, q))
365            .collect_vec();
366        let psi_inv_powers_bo_shoup = psi_inv_powers_bo
367            .iter()
368            .map(|v| ShoupMul::representation(*v, q))
369            .collect_vec();
370
371        // n^{-1} \mod{q}
372        let n_inv = mod_inverse(n as u64, q);
373
374        NttBackendU64 {
375            q,
376            q_twice: 2 * q,
377            _n: n as u64,
378            n_inv,
379            n_inv_shoup: ShoupMul::representation(n_inv, q),
380            psi_powers_bo: psi_powers_bo.into_boxed_slice(),
381            psi_inv_powers_bo: psi_inv_powers_bo.into_boxed_slice(),
382            psi_powers_bo_shoup: psi_powers_bo_shoup.into_boxed_slice(),
383            psi_inv_powers_bo_shoup: psi_inv_powers_bo_shoup.into_boxed_slice(),
384        }
385    }
386}
387
388impl<M: Modulus<Element = u64>> NttInit<M> for NttBackendU64 {
389    fn new(q: &M, n: usize) -> Self {
390        // This NTT does not support native modulus
391        assert!(!q.is_native());
392        NttBackendU64::_new(q.q().unwrap(), n)
393    }
394}
395
396impl Ntt for NttBackendU64 {
397    type Element = u64;
398
399    fn forward_lazy(&self, v: &mut [Self::Element]) {
400        ntt_lazy(
401            v,
402            &self.psi_powers_bo,
403            &self.psi_powers_bo_shoup,
404            self.q,
405            self.q_twice,
406        )
407    }
408
409    fn forward(&self, v: &mut [Self::Element]) {
410        ntt(
411            v,
412            &self.psi_powers_bo,
413            &self.psi_powers_bo_shoup,
414            self.q,
415            self.q_twice,
416        );
417    }
418
419    fn backward_lazy(&self, v: &mut [Self::Element]) {
420        ntt_inv_lazy(
421            v,
422            &self.psi_inv_powers_bo,
423            &self.psi_inv_powers_bo_shoup,
424            self.n_inv,
425            self.n_inv_shoup,
426            self.q,
427            self.q_twice,
428        )
429    }
430
431    fn backward(&self, v: &mut [Self::Element]) {
432        ntt_inv(
433            v,
434            &self.psi_inv_powers_bo,
435            &self.psi_inv_powers_bo_shoup,
436            self.n_inv,
437            self.n_inv_shoup,
438            self.q,
439            self.q_twice,
440        );
441    }
442}
443
444#[cfg(test)]
445mod tests {
446    use itertools::Itertools;
447    use rand::{thread_rng, Rng};
448    use rand_distr::Uniform;
449
450    use super::NttBackendU64;
451    use crate::{
452        backend::{ModInit, ModularOpsU64, VectorOps},
453        ntt::Ntt,
454        utils::{generate_prime, negacyclic_mul},
455    };
456
457    const Q_60_BITS: u64 = 1152921504606748673;
458    const N: usize = 1 << 4;
459
460    const K: usize = 128;
461
462    fn random_vec_in_fq(size: usize, q: u64) -> Vec<u64> {
463        thread_rng()
464            .sample_iter(Uniform::new(0, q))
465            .take(size)
466            .collect_vec()
467    }
468
469    fn assert_output_range(a: &[u64], max_val: u64) {
470        a.iter()
471            .for_each(|v| assert!(v <= &max_val, "{v} > {max_val}"));
472    }
473
474    #[test]
475    fn native_ntt_backend_works() {
476        // TODO(Jay): Improve tests. Add tests for different primes and ring size.
477        let ntt_backend = NttBackendU64::_new(Q_60_BITS, N);
478        for _ in 0..K {
479            let mut a = random_vec_in_fq(N, Q_60_BITS);
480            let a_clone = a.clone();
481
482            ntt_backend.forward(&mut a);
483            assert_output_range(a.as_ref(), Q_60_BITS - 1);
484            assert_ne!(a, a_clone);
485            ntt_backend.backward(&mut a);
486            assert_output_range(a.as_ref(), Q_60_BITS - 1);
487            assert_eq!(a, a_clone);
488
489            ntt_backend.forward_lazy(&mut a);
490            assert_output_range(a.as_ref(), (2 * Q_60_BITS) - 1);
491            assert_ne!(a, a_clone);
492            ntt_backend.backward(&mut a);
493            assert_output_range(a.as_ref(), Q_60_BITS - 1);
494            assert_eq!(a, a_clone);
495
496            ntt_backend.forward(&mut a);
497            assert_output_range(a.as_ref(), Q_60_BITS - 1);
498            ntt_backend.backward_lazy(&mut a);
499            assert_output_range(a.as_ref(), (2 * Q_60_BITS) - 1);
500            // reduce
501            a.iter_mut().for_each(|a0| {
502                if *a0 >= Q_60_BITS {
503                    *a0 -= *a0 - Q_60_BITS;
504                }
505            });
506            assert_eq!(a, a_clone);
507        }
508    }
509
510    #[test]
511    fn native_ntt_negacylic_mul() {
512        let primes = [25, 40, 50, 60]
513            .iter()
514            .map(|bits| generate_prime(*bits, (2 * N) as u64, 1u64 << bits).unwrap())
515            .collect_vec();
516
517        for p in primes.into_iter() {
518            let ntt_backend = NttBackendU64::_new(p, N);
519            let modulus_backend = ModularOpsU64::new(p);
520            for _ in 0..K {
521                let a = random_vec_in_fq(N, p);
522                let b = random_vec_in_fq(N, p);
523
524                let mut a_clone = a.clone();
525                let mut b_clone = b.clone();
526                ntt_backend.forward_lazy(&mut a_clone);
527                ntt_backend.forward_lazy(&mut b_clone);
528                modulus_backend.elwise_mul_mut(&mut a_clone, &b_clone);
529                ntt_backend.backward(&mut a_clone);
530
531                let mul = |a: &u64, b: &u64| {
532                    let tmp = *a as u128 * *b as u128;
533                    (tmp % p as u128) as u64
534                };
535                let expected_out = negacyclic_mul(&a, &b, mul, p);
536
537                assert_eq!(a_clone, expected_out);
538            }
539        }
540    }
541}