1use crate::zq::Modulus;
2use itertools::Itertools;
3use rand::{Rng, SeedableRng};
4use rand_chacha::ChaCha8Rng;
5use std::iter::successors;
6
7#[derive(Debug, Clone, PartialEq, Eq)]
9pub struct NttOperator {
10 p: Modulus,
11 p_twice: u64,
12 size: usize,
13 omegas: Box<[u64]>,
14 omegas_shoup: Box<[u64]>,
15 zetas_inv: Box<[u64]>,
16 zetas_inv_shoup: Box<[u64]>,
17 size_inv: u64,
18 size_inv_shoup: u64,
19}
20
21impl NttOperator {
22 pub fn new(p: &Modulus, size: usize) -> Option<Self> {
28 if !super::supports_ntt(p.p, size) {
29 None
30 } else {
31 let size_inv = p.inv(size as u64)?;
32
33 let omega = Self::primitive_root(size, p);
34 let omega_inv = p.inv(omega)?;
35
36 let powers = successors(Some(1u64), |n| Some(p.mul(*n, omega)))
37 .take(size)
38 .collect_vec();
39 let powers_inv = successors(Some(omega_inv), |n| Some(p.mul(*n, omega_inv)))
40 .take(size)
41 .collect_vec();
42
43 let (omegas, zetas_inv): (Vec<u64>, Vec<u64>) = (0..size)
44 .map(|i| {
45 let j = i.reverse_bits() >> (size.leading_zeros() + 1);
46 (powers[j], powers_inv[j])
47 })
48 .unzip();
49
50 let omegas_shoup = p.shoup_vec(&omegas);
51 let zetas_inv_shoup = p.shoup_vec(&zetas_inv);
52
53 Some(Self {
54 p: p.clone(),
55 p_twice: p.p * 2,
56 size,
57 omegas: omegas.into_boxed_slice(),
58 omegas_shoup: omegas_shoup.into_boxed_slice(),
59 zetas_inv: zetas_inv.into_boxed_slice(),
60 zetas_inv_shoup: zetas_inv_shoup.into_boxed_slice(),
61 size_inv,
62 size_inv_shoup: p.shoup(size_inv),
63 })
64 }
65 }
66
67 pub fn forward(&self, a: &mut [u64]) {
70 debug_assert_eq!(a.len(), self.size);
71
72 let mut l = self.size >> 1;
73 let mut k = 1;
74 while l > 0 {
75 for chunk in a.chunks_exact_mut(2 * l) {
76 let omega = self.omegas[k];
77 let omega_shoup = self.omegas_shoup[k];
78 k += 1;
79
80 let (left, right) = chunk.split_at_mut(l);
81 if l == 1 {
82 self.butterfly(&mut left[0], &mut right[0], omega, omega_shoup);
84 left[0] = self.reduce3(left[0]);
85 right[0] = self.reduce3(right[0]);
86 } else {
87 for (x, y) in left.iter_mut().zip(right.iter_mut()) {
88 self.butterfly(x, y, omega, omega_shoup);
89 }
90 }
91 }
92 l >>= 1;
93 }
94 }
95
96 pub fn backward(&self, a: &mut [u64]) {
99 debug_assert_eq!(a.len(), self.size);
100
101 let mut k = 0;
102 let mut l = 1;
103
104 while l < self.size {
105 for chunk in a.chunks_exact_mut(2 * l) {
106 let zeta_inv = self.zetas_inv[k];
107 let zeta_inv_shoup = self.zetas_inv_shoup[k];
108 k += 1;
109
110 let (left, right) = chunk.split_at_mut(l);
111 if l == 1 {
112 self.inv_butterfly(&mut left[0], &mut right[0], zeta_inv, zeta_inv_shoup);
113 } else {
114 for (x, y) in left.iter_mut().zip(right.iter_mut()) {
115 self.inv_butterfly(x, y, zeta_inv, zeta_inv_shoup);
116 }
117 }
118 }
119 l <<= 1;
120 }
121
122 a.iter_mut()
123 .for_each(|ai| *ai = self.p.mul_shoup(*ai, self.size_inv, self.size_inv_shoup));
124 }
125
126 pub(crate) unsafe fn forward_vt_lazy(&self, a_ptr: *mut u64) {
135 let mut l = self.size >> 1;
136 let mut m = 1;
137 let mut k = 1;
138 while l > 0 {
139 for i in 0..m {
140 let omega = *self.omegas.get_unchecked(k);
141 let omega_shoup = *self.omegas_shoup.get_unchecked(k);
142 k += 1;
143
144 let s = 2 * i * l;
145 match l {
146 1 => {
147 self.butterfly_vt(
148 &mut *a_ptr.add(s),
149 &mut *a_ptr.add(s + l),
150 omega,
151 omega_shoup,
152 );
153 }
154 _ => {
155 for j in s..(s + l) {
156 self.butterfly_vt(
157 &mut *a_ptr.add(j),
158 &mut *a_ptr.add(j + l),
159 omega,
160 omega_shoup,
161 );
162 }
163 }
164 }
165 }
166 l >>= 1;
167 m <<= 1;
168 }
169 }
170
171 pub unsafe fn forward_vt(&self, a_ptr: *mut u64) {
178 self.forward_vt_lazy(a_ptr);
179 for i in 0..self.size {
180 *a_ptr.add(i) = self.reduce3_vt(*a_ptr.add(i))
181 }
182 }
183
184 pub unsafe fn backward_vt(&self, a_ptr: *mut u64) {
191 let mut k = 0;
192 let mut m = self.size >> 1;
193 let mut l = 1;
194 while m > 0 {
195 for i in 0..m {
196 let s = 2 * i * l;
197 let zeta_inv = *self.zetas_inv.get_unchecked(k);
198 let zeta_inv_shoup = *self.zetas_inv_shoup.get_unchecked(k);
199 k += 1;
200 match l {
201 1 => {
202 self.inv_butterfly_vt(
203 &mut *a_ptr.add(s),
204 &mut *a_ptr.add(s + l),
205 zeta_inv,
206 zeta_inv_shoup,
207 );
208 }
209 _ => {
210 for j in s..(s + l) {
211 self.inv_butterfly_vt(
212 &mut *a_ptr.add(j),
213 &mut *a_ptr.add(j + l),
214 zeta_inv,
215 zeta_inv_shoup,
216 );
217 }
218 }
219 }
220 }
221 l <<= 1;
222 m >>= 1;
223 }
224
225 for i in 0..self.size as isize {
226 *a_ptr.offset(i) =
227 self.p
228 .mul_shoup(*a_ptr.offset(i), self.size_inv, self.size_inv_shoup)
229 }
230 }
231
232 const fn reduce3(&self, a: u64) -> u64 {
236 debug_assert!(a < 4 * self.p.p);
237
238 let y = Modulus::reduce1(a, self.p_twice);
239 Modulus::reduce1(y, self.p.p)
240 }
241
242 const unsafe fn reduce3_vt(&self, a: u64) -> u64 {
246 debug_assert!(a < 4 * self.p.p);
247
248 let y = Modulus::reduce1_vt(a, self.p_twice);
249 Modulus::reduce1_vt(y, self.p.p)
250 }
251
252 fn butterfly(&self, x: &mut u64, y: &mut u64, w: u64, w_shoup: u64) {
254 debug_assert!(*x < 4 * self.p.p);
255 debug_assert!(*y < 4 * self.p.p);
256 debug_assert!(w < self.p.p);
257 debug_assert_eq!(self.p.shoup(w), w_shoup);
258
259 *x = Modulus::reduce1(*x, self.p_twice);
260 let t = self.p.lazy_mul_shoup(*y, w, w_shoup);
261 *y = *x + self.p_twice - t;
262 *x += t;
263
264 debug_assert!(*x < 4 * self.p.p);
265 debug_assert!(*y < 4 * self.p.p);
266 }
267
268 unsafe fn butterfly_vt(&self, x: &mut u64, y: &mut u64, w: u64, w_shoup: u64) {
270 debug_assert!(*x < 4 * self.p.p);
271 debug_assert!(*y < 4 * self.p.p);
272 debug_assert!(w < self.p.p);
273 debug_assert_eq!(self.p.shoup(w), w_shoup);
274
275 *x = Modulus::reduce1_vt(*x, self.p_twice);
276 let t = self.p.lazy_mul_shoup(*y, w, w_shoup);
277 *y = *x + self.p_twice - t;
278 *x += t;
279
280 debug_assert!(*x < 4 * self.p.p);
281 debug_assert!(*y < 4 * self.p.p);
282 }
283
284 fn inv_butterfly(&self, x: &mut u64, y: &mut u64, z: u64, z_shoup: u64) {
286 debug_assert!(*x < self.p_twice);
287 debug_assert!(*y < self.p_twice);
288 debug_assert!(z < self.p.p);
289 debug_assert_eq!(self.p.shoup(z), z_shoup);
290
291 let t = *x;
292 *x = Modulus::reduce1(*y + t, self.p_twice);
293 *y = self.p.lazy_mul_shoup(self.p_twice + t - *y, z, z_shoup);
294
295 debug_assert!(*x < self.p_twice);
296 debug_assert!(*y < self.p_twice);
297 }
298
299 unsafe fn inv_butterfly_vt(&self, x: &mut u64, y: &mut u64, z: u64, z_shoup: u64) {
301 debug_assert!(*x < self.p_twice);
302 debug_assert!(*y < self.p_twice);
303 debug_assert!(z < self.p.p);
304 debug_assert_eq!(self.p.shoup(z), z_shoup);
305
306 let t = *x;
307 *x = Modulus::reduce1_vt(*y + t, self.p_twice);
308 *y = self.p.lazy_mul_shoup(self.p_twice + t - *y, z, z_shoup);
309
310 debug_assert!(*x < self.p_twice);
311 debug_assert!(*y < self.p_twice);
312 }
313
314 fn primitive_root(n: usize, p: &Modulus) -> u64 {
318 debug_assert!(super::supports_ntt(p.p, n));
319
320 let lambda = (p.p - 1) / (2 * n as u64);
321
322 let mut rng: ChaCha8Rng = SeedableRng::seed_from_u64(0);
323 for _ in 0..100 {
324 let mut root = rng.random_range(0..p.p);
325 root = p.pow(root, lambda);
326 if Self::is_primitive_root(root, 2 * n, p) {
327 return root;
328 }
329 }
330
331 debug_assert!(false, "Couldn't find primitive root");
332 0
333 }
334
335 fn is_primitive_root(a: u64, n: usize, p: &Modulus) -> bool {
339 debug_assert!(a < p.p);
340 debug_assert!(super::supports_ntt(p.p, n >> 1)); (p.pow(a, n as u64) == 1) && (p.pow(a, (n / 2) as u64) != 1)
345 }
346}