fhe_math/ntt/
mod.rs

1//! Number-Theoretic Transform in ZZ_q.
2
3use fhe_util::is_prime;
4
5mod native;
6
7#[cfg(any(feature = "tfhe-ntt", feature = "tfhe-ntt-nightly"))]
8mod tfhe;
9
10#[cfg(not(any(feature = "tfhe-ntt", feature = "tfhe-ntt-nightly")))]
11pub use native::NttOperator;
12#[cfg(any(feature = "tfhe-ntt", feature = "tfhe-ntt-nightly"))]
13pub use tfhe::NttOperator;
14
15/// Returns whether a modulus p is prime and supports the Number Theoretic
16/// Transform of size n.
17///
18/// Aborts if n is not a power of 2 that is >= 8.
19pub(crate) fn supports_ntt(p: u64, n: usize) -> bool {
20    assert!(n >= 8 && n.is_power_of_two());
21
22    p % ((n as u64) << 1) == 1 && is_prime(p)
23}
24
25#[cfg(test)]
26mod tests {
27    use rand::rng;
28
29    use super::{supports_ntt, NttOperator};
30    use crate::zq::Modulus;
31
32    #[test]
33    fn constructor() {
34        for size in [32, 1024] {
35            for p in [1153, 4611686018326724609] {
36                let q = Modulus::new(p).unwrap();
37                let supports_ntt = supports_ntt(p, size);
38
39                let op = NttOperator::new(&q, size);
40
41                if supports_ntt {
42                    assert!(op.is_some());
43                } else {
44                    assert!(op.is_none());
45                }
46            }
47        }
48    }
49
50    #[test]
51    fn bijection() {
52        let ntests = 100;
53        let mut rng = rng();
54
55        for size in [32, 1024] {
56            for p in [1153, 4611686018326724609] {
57                let q = Modulus::new(p).unwrap();
58
59                if supports_ntt(p, size) {
60                    let op = NttOperator::new(&q, size).unwrap();
61
62                    for _ in 0..ntests {
63                        let mut a = q.random_vec(size, &mut rng);
64                        let a_clone = a.clone();
65                        let mut b = a.clone();
66
67                        op.forward(&mut a);
68                        assert_ne!(a, a_clone);
69
70                        unsafe { op.forward_vt(b.as_mut_ptr()) }
71                        assert_eq!(a, b);
72
73                        op.backward(&mut a);
74                        assert_eq!(a, a_clone);
75
76                        unsafe { op.backward_vt(b.as_mut_ptr()) }
77                        assert_eq!(a, b);
78                    }
79                }
80            }
81        }
82    }
83
84    #[test]
85    fn forward_lazy() {
86        let ntests = 100;
87        let mut rng = rng();
88
89        for size in [32, 1024] {
90            for p in [1153, 4611686018326724609] {
91                let q = Modulus::new(p).unwrap();
92
93                if supports_ntt(p, size) {
94                    let op = NttOperator::new(&q, size).unwrap();
95
96                    for _ in 0..ntests {
97                        let mut a = q.random_vec(size, &mut rng);
98                        let mut a_lazy = a.clone();
99
100                        op.forward(&mut a);
101
102                        unsafe {
103                            op.forward_vt_lazy(a_lazy.as_mut_ptr());
104                            q.reduce_vec(&mut a_lazy);
105                        }
106
107                        assert_eq!(a, a_lazy);
108                    }
109                }
110            }
111        }
112    }
113}