1use fhe_util::is_prime;
4
5mod native;
6
7#[cfg(any(feature = "tfhe-ntt", feature = "tfhe-ntt-nightly"))]
8mod tfhe;
9
10#[cfg(not(any(feature = "tfhe-ntt", feature = "tfhe-ntt-nightly")))]
11pub use native::NttOperator;
12#[cfg(any(feature = "tfhe-ntt", feature = "tfhe-ntt-nightly"))]
13pub use tfhe::NttOperator;
14
15pub(crate) fn supports_ntt(p: u64, n: usize) -> bool {
20 assert!(n >= 8 && n.is_power_of_two());
21
22 p % ((n as u64) << 1) == 1 && is_prime(p)
23}
24
25#[cfg(test)]
26mod tests {
27 use rand::rng;
28
29 use super::{supports_ntt, NttOperator};
30 use crate::zq::Modulus;
31
32 #[test]
33 fn constructor() {
34 for size in [32, 1024] {
35 for p in [1153, 4611686018326724609] {
36 let q = Modulus::new(p).unwrap();
37 let supports_ntt = supports_ntt(p, size);
38
39 let op = NttOperator::new(&q, size);
40
41 if supports_ntt {
42 assert!(op.is_some());
43 } else {
44 assert!(op.is_none());
45 }
46 }
47 }
48 }
49
50 #[test]
51 fn bijection() {
52 let ntests = 100;
53 let mut rng = rng();
54
55 for size in [32, 1024] {
56 for p in [1153, 4611686018326724609] {
57 let q = Modulus::new(p).unwrap();
58
59 if supports_ntt(p, size) {
60 let op = NttOperator::new(&q, size).unwrap();
61
62 for _ in 0..ntests {
63 let mut a = q.random_vec(size, &mut rng);
64 let a_clone = a.clone();
65 let mut b = a.clone();
66
67 op.forward(&mut a);
68 assert_ne!(a, a_clone);
69
70 unsafe { op.forward_vt(b.as_mut_ptr()) }
71 assert_eq!(a, b);
72
73 op.backward(&mut a);
74 assert_eq!(a, a_clone);
75
76 unsafe { op.backward_vt(b.as_mut_ptr()) }
77 assert_eq!(a, b);
78 }
79 }
80 }
81 }
82 }
83
84 #[test]
85 fn forward_lazy() {
86 let ntests = 100;
87 let mut rng = rng();
88
89 for size in [32, 1024] {
90 for p in [1153, 4611686018326724609] {
91 let q = Modulus::new(p).unwrap();
92
93 if supports_ntt(p, size) {
94 let op = NttOperator::new(&q, size).unwrap();
95
96 for _ in 0..ntests {
97 let mut a = q.random_vec(size, &mut rng);
98 let mut a_lazy = a.clone();
99
100 op.forward(&mut a);
101
102 unsafe {
103 op.forward_vt_lazy(a_lazy.as_mut_ptr());
104 q.reduce_vec(&mut a_lazy);
105 }
106
107 assert_eq!(a, a_lazy);
108 }
109 }
110 }
111 }
112 }
113}