1use itertools::{izip, Itertools};
2use rand::{Rng, RngCore, SeedableRng};
3use rand_chacha::ChaCha8Rng;
4
5use crate::{
6 backend::{ArithmeticOps, ModInit, ModularOpsU64, Modulus},
7 utils::{mod_exponent, mod_inverse, ShoupMul},
8};
9
10pub trait NttInit<M> {
11 fn new(q: &M, n: usize) -> Self;
14}
15
16pub trait Ntt {
17 type Element;
18 fn forward_lazy(&self, v: &mut [Self::Element]);
19 fn forward(&self, v: &mut [Self::Element]);
20 fn backward_lazy(&self, v: &mut [Self::Element]);
21 fn backward(&self, v: &mut [Self::Element]);
22}
23
24pub fn forward_butterly_0_to_4q(
32 mut x: u64,
33 y: u64,
34 w: u64,
35 w_shoup: u64,
36 q: u64,
37 q_twice: u64,
38) -> (u64, u64) {
39 debug_assert!(x < q * 4, "{} >= (4q){}", x, 4 * q);
40 debug_assert!(y < q * 4, "{} >= (4q){}", y, 4 * q);
41
42 if x >= q_twice {
43 x = x - q_twice;
44 }
45
46 let t = ShoupMul::mul(y, w, w_shoup, q);
47
48 (x + t, x + q_twice - t)
49}
50
51pub fn forward_butterly_0_to_2q(
52 mut x: u64,
53 y: u64,
54 w: u64,
55 w_shoup: u64,
56 q: u64,
57 q_twice: u64,
58) -> (u64, u64) {
59 debug_assert!(x < q * 4, "{} >= (4q){}", x, 4 * q);
60 debug_assert!(y < q * 4, "{} >= (4q){}", y, 4 * q);
61
62 if x >= q_twice {
63 x = x - q_twice;
64 }
65
66 let t = ShoupMul::mul(y, w, w_shoup, q);
67
68 let ox = x.wrapping_add(t);
69 let oy = x.wrapping_sub(t);
70
71 (
72 (ox).min(ox.wrapping_sub(q_twice)),
73 oy.min(oy.wrapping_add(q_twice)),
74 )
75}
76
77pub fn inverse_butterfly_0_to_2q(
85 x: u64,
86 y: u64,
87 w_inv: u64,
88 w_inv_shoup: u64,
89 q: u64,
90 q_twice: u64,
91) -> (u64, u64) {
92 debug_assert!(x < q_twice, "{} >= (2q){q_twice}", x);
93 debug_assert!(y < q_twice, "{} >= (2q){q_twice}", y);
94
95 let mut x_dash = x + y;
96 if x_dash >= q_twice {
97 x_dash -= q_twice
98 }
99
100 let t = x + q_twice - y;
101 let y = ShoupMul::mul(t, w_inv, w_inv_shoup, q);
102
103 (x_dash, y)
104}
105
106pub fn ntt_lazy(a: &mut [u64], psi: &[u64], psi_shoup: &[u64], q: u64, q_twice: u64) {
111 assert!(a.len() == psi.len());
112
113 let n = a.len();
114 let mut t = n;
115
116 let mut m = 1;
117 while m < n {
118 t >>= 1;
119 let w = &psi[m..];
120 let w_shoup = &psi_shoup[m..];
121
122 if t == 1 {
123 for (a, w, w_shoup) in izip!(a.chunks_mut(2), w.iter(), w_shoup.iter()) {
124 let (ox, oy) = forward_butterly_0_to_2q(a[0], a[1], *w, *w_shoup, q, q_twice);
125 a[0] = ox;
126 a[1] = oy;
127 }
128 } else {
129 for i in 0..m {
130 let a = &mut a[2 * i * t..(2 * (i + 1) * t)];
131 let (left, right) = a.split_at_mut(t);
132
133 for (x, y) in izip!(left.iter_mut(), right.iter_mut()) {
134 let (ox, oy) = forward_butterly_0_to_4q(*x, *y, w[i], w_shoup[i], q, q_twice);
135 *x = ox;
136 *y = oy;
137 }
138 }
139 }
140
141 m <<= 1;
142 }
143}
144
145pub fn ntt(a: &mut [u64], psi: &[u64], psi_shoup: &[u64], q: u64, q_twice: u64) {
147 assert!(a.len() == psi.len());
148
149 let n = a.len();
150 let mut t = n;
151
152 let mut m = 1;
153 while m < n {
154 t >>= 1;
155 let w = &psi[m..];
156 let w_shoup = &psi_shoup[m..];
157
158 if t == 1 {
159 for (a, w, w_shoup) in izip!(a.chunks_mut(2), w.iter(), w_shoup.iter()) {
160 let (ox, oy) = forward_butterly_0_to_2q(a[0], a[1], *w, *w_shoup, q, q_twice);
161 a[0] = ox.min(ox.wrapping_sub(q));
163 a[1] = oy.min(oy.wrapping_sub(q));
164 }
165 } else {
166 for i in 0..m {
167 let a = &mut a[2 * i * t..(2 * (i + 1) * t)];
168 let (left, right) = a.split_at_mut(t);
169
170 for (x, y) in izip!(left.iter_mut(), right.iter_mut()) {
171 let (ox, oy) = forward_butterly_0_to_4q(*x, *y, w[i], w_shoup[i], q, q_twice);
172 *x = ox;
173 *y = oy;
174 }
175 }
176 }
177
178 m <<= 1;
179 }
180}
181
182pub fn ntt_inv_lazy(
188 a: &mut [u64],
189 psi_inv: &[u64],
190 psi_inv_shoup: &[u64],
191 n_inv: u64,
192 n_inv_shoup: u64,
193 q: u64,
194 q_twice: u64,
195) {
196 assert!(a.len() == psi_inv.len());
197
198 let mut m = a.len() >> 1;
199 let mut t = 1;
200
201 while m > 0 {
202 if m == 1 {
203 let (left, right) = a.split_at_mut(t);
204
205 for (x, y) in izip!(left.iter_mut(), right.iter_mut()) {
206 let (ox, oy) =
207 inverse_butterfly_0_to_2q(*x, *y, psi_inv[1], psi_inv_shoup[1], q, q_twice);
208 *x = ShoupMul::mul(ox, n_inv, n_inv_shoup, q);
209 *y = ShoupMul::mul(oy, n_inv, n_inv_shoup, q);
210 }
211 } else {
212 let w_inv = &psi_inv[m..];
213 let w_inv_shoup = &psi_inv_shoup[m..];
214 for i in 0..m {
215 let a = &mut a[2 * i * t..2 * (i + 1) * t];
216 let (left, right) = a.split_at_mut(t);
217
218 for (x, y) in izip!(left.iter_mut(), right.iter_mut()) {
219 let (ox, oy) =
220 inverse_butterfly_0_to_2q(*x, *y, w_inv[i], w_inv_shoup[i], q, q_twice);
221 *x = ox;
222 *y = oy;
223 }
224 }
225 }
226
227 t *= 2;
228 m >>= 1;
229 }
230}
231
232pub fn ntt_inv(
234 a: &mut [u64],
235 psi_inv: &[u64],
236 psi_inv_shoup: &[u64],
237 n_inv: u64,
238 n_inv_shoup: u64,
239 q: u64,
240 q_twice: u64,
241) {
242 assert!(a.len() == psi_inv.len());
243
244 let mut m = a.len() >> 1;
245 let mut t = 1;
246
247 while m > 0 {
248 if m == 1 {
249 let (left, right) = a.split_at_mut(t);
250
251 for (x, y) in izip!(left.iter_mut(), right.iter_mut()) {
252 let (ox, oy) =
253 inverse_butterfly_0_to_2q(*x, *y, psi_inv[1], psi_inv_shoup[1], q, q_twice);
254 let ox = ShoupMul::mul(ox, n_inv, n_inv_shoup, q);
255 let oy = ShoupMul::mul(oy, n_inv, n_inv_shoup, q);
256 *x = ox.min(ox.wrapping_sub(q));
257 *y = oy.min(oy.wrapping_sub(q));
258 }
259 } else {
260 let w_inv = &psi_inv[m..];
261 let w_inv_shoup = &psi_inv_shoup[m..];
262 for i in 0..m {
263 let a = &mut a[2 * i * t..2 * (i + 1) * t];
264 let (left, right) = a.split_at_mut(t);
265
266 for (x, y) in izip!(left.iter_mut(), right.iter_mut()) {
267 let (ox, oy) =
268 inverse_butterfly_0_to_2q(*x, *y, w_inv[i], w_inv_shoup[i], q, q_twice);
269 *x = ox;
270 *y = oy;
271 }
272 }
273 }
274
275 t *= 2;
276 m >>= 1;
277 }
278}
279
280pub(crate) fn find_primitive_root<R: RngCore>(q: u64, n: u64, rng: &mut R) -> Option<u64> {
284 assert!(n.is_power_of_two(), "{n} is not power of two");
285
286 assert!(q % n == 1, "{n}^th root of unity in F_{q} does not exists");
288
289 let t = (q - 1) / n;
290
291 for _ in 0..100 {
292 let mut omega = rng.gen::<u64>() % q;
293
294 omega = mod_exponent(omega, t, q);
296
297 if mod_exponent(omega, n >> 1, q) == 1 {
300 continue;
301 } else {
302 return Some(omega);
303 }
304 }
305
306 None
307}
308
309#[derive(Debug)]
310pub struct NttBackendU64 {
311 q: u64,
312 q_twice: u64,
313 _n: u64,
314 n_inv: u64,
315 n_inv_shoup: u64,
316 psi_powers_bo: Box<[u64]>,
317 psi_inv_powers_bo: Box<[u64]>,
318 psi_powers_bo_shoup: Box<[u64]>,
319 psi_inv_powers_bo_shoup: Box<[u64]>,
320}
321
322impl NttBackendU64 {
323 fn _new(q: u64, n: usize) -> Self {
324 let mut rng = ChaCha8Rng::from_seed([0u8; 32]);
326 let psi = find_primitive_root(q, (n * 2) as u64, &mut rng)
327 .expect("Unable to find 2n^th root of unity");
328 let psi_inv = mod_inverse(psi, q);
329
330 let modulus = ModularOpsU64::new(q);
336
337 let mut psi_powers = Vec::with_capacity(n as usize);
338 let mut psi_inv_powers = Vec::with_capacity(n as usize);
339 let mut running_psi = 1;
340 let mut running_psi_inv = 1;
341 for _ in 0..n {
342 psi_powers.push(running_psi);
343 psi_inv_powers.push(running_psi_inv);
344
345 running_psi = modulus.mul(&running_psi, &psi);
346 running_psi_inv = modulus.mul(&running_psi_inv, &psi_inv);
347 }
348
349 let mut psi_powers_bo = vec![0u64; n as usize];
351 let mut psi_inv_powers_bo = vec![0u64; n as usize];
352 let shift_by = n.leading_zeros() + 1;
353 for i in 0..n as usize {
354 let bo_index = i.reverse_bits() >> shift_by;
356
357 psi_powers_bo[bo_index] = psi_powers[i];
358 psi_inv_powers_bo[bo_index] = psi_inv_powers[i];
359 }
360
361 let psi_powers_bo_shoup = psi_powers_bo
363 .iter()
364 .map(|v| ShoupMul::representation(*v, q))
365 .collect_vec();
366 let psi_inv_powers_bo_shoup = psi_inv_powers_bo
367 .iter()
368 .map(|v| ShoupMul::representation(*v, q))
369 .collect_vec();
370
371 let n_inv = mod_inverse(n as u64, q);
373
374 NttBackendU64 {
375 q,
376 q_twice: 2 * q,
377 _n: n as u64,
378 n_inv,
379 n_inv_shoup: ShoupMul::representation(n_inv, q),
380 psi_powers_bo: psi_powers_bo.into_boxed_slice(),
381 psi_inv_powers_bo: psi_inv_powers_bo.into_boxed_slice(),
382 psi_powers_bo_shoup: psi_powers_bo_shoup.into_boxed_slice(),
383 psi_inv_powers_bo_shoup: psi_inv_powers_bo_shoup.into_boxed_slice(),
384 }
385 }
386}
387
388impl<M: Modulus<Element = u64>> NttInit<M> for NttBackendU64 {
389 fn new(q: &M, n: usize) -> Self {
390 assert!(!q.is_native());
392 NttBackendU64::_new(q.q().unwrap(), n)
393 }
394}
395
396impl Ntt for NttBackendU64 {
397 type Element = u64;
398
399 fn forward_lazy(&self, v: &mut [Self::Element]) {
400 ntt_lazy(
401 v,
402 &self.psi_powers_bo,
403 &self.psi_powers_bo_shoup,
404 self.q,
405 self.q_twice,
406 )
407 }
408
409 fn forward(&self, v: &mut [Self::Element]) {
410 ntt(
411 v,
412 &self.psi_powers_bo,
413 &self.psi_powers_bo_shoup,
414 self.q,
415 self.q_twice,
416 );
417 }
418
419 fn backward_lazy(&self, v: &mut [Self::Element]) {
420 ntt_inv_lazy(
421 v,
422 &self.psi_inv_powers_bo,
423 &self.psi_inv_powers_bo_shoup,
424 self.n_inv,
425 self.n_inv_shoup,
426 self.q,
427 self.q_twice,
428 )
429 }
430
431 fn backward(&self, v: &mut [Self::Element]) {
432 ntt_inv(
433 v,
434 &self.psi_inv_powers_bo,
435 &self.psi_inv_powers_bo_shoup,
436 self.n_inv,
437 self.n_inv_shoup,
438 self.q,
439 self.q_twice,
440 );
441 }
442}
443
444#[cfg(test)]
445mod tests {
446 use itertools::Itertools;
447 use rand::{thread_rng, Rng};
448 use rand_distr::Uniform;
449
450 use super::NttBackendU64;
451 use crate::{
452 backend::{ModInit, ModularOpsU64, VectorOps},
453 ntt::Ntt,
454 utils::{generate_prime, negacyclic_mul},
455 };
456
457 const Q_60_BITS: u64 = 1152921504606748673;
458 const N: usize = 1 << 4;
459
460 const K: usize = 128;
461
462 fn random_vec_in_fq(size: usize, q: u64) -> Vec<u64> {
463 thread_rng()
464 .sample_iter(Uniform::new(0, q))
465 .take(size)
466 .collect_vec()
467 }
468
469 fn assert_output_range(a: &[u64], max_val: u64) {
470 a.iter()
471 .for_each(|v| assert!(v <= &max_val, "{v} > {max_val}"));
472 }
473
474 #[test]
475 fn native_ntt_backend_works() {
476 let ntt_backend = NttBackendU64::_new(Q_60_BITS, N);
478 for _ in 0..K {
479 let mut a = random_vec_in_fq(N, Q_60_BITS);
480 let a_clone = a.clone();
481
482 ntt_backend.forward(&mut a);
483 assert_output_range(a.as_ref(), Q_60_BITS - 1);
484 assert_ne!(a, a_clone);
485 ntt_backend.backward(&mut a);
486 assert_output_range(a.as_ref(), Q_60_BITS - 1);
487 assert_eq!(a, a_clone);
488
489 ntt_backend.forward_lazy(&mut a);
490 assert_output_range(a.as_ref(), (2 * Q_60_BITS) - 1);
491 assert_ne!(a, a_clone);
492 ntt_backend.backward(&mut a);
493 assert_output_range(a.as_ref(), Q_60_BITS - 1);
494 assert_eq!(a, a_clone);
495
496 ntt_backend.forward(&mut a);
497 assert_output_range(a.as_ref(), Q_60_BITS - 1);
498 ntt_backend.backward_lazy(&mut a);
499 assert_output_range(a.as_ref(), (2 * Q_60_BITS) - 1);
500 a.iter_mut().for_each(|a0| {
502 if *a0 >= Q_60_BITS {
503 *a0 -= *a0 - Q_60_BITS;
504 }
505 });
506 assert_eq!(a, a_clone);
507 }
508 }
509
510 #[test]
511 fn native_ntt_negacylic_mul() {
512 let primes = [25, 40, 50, 60]
513 .iter()
514 .map(|bits| generate_prime(*bits, (2 * N) as u64, 1u64 << bits).unwrap())
515 .collect_vec();
516
517 for p in primes.into_iter() {
518 let ntt_backend = NttBackendU64::_new(p, N);
519 let modulus_backend = ModularOpsU64::new(p);
520 for _ in 0..K {
521 let a = random_vec_in_fq(N, p);
522 let b = random_vec_in_fq(N, p);
523
524 let mut a_clone = a.clone();
525 let mut b_clone = b.clone();
526 ntt_backend.forward_lazy(&mut a_clone);
527 ntt_backend.forward_lazy(&mut b_clone);
528 modulus_backend.elwise_mul_mut(&mut a_clone, &b_clone);
529 ntt_backend.backward(&mut a_clone);
530
531 let mul = |a: &u64, b: &u64| {
532 let tmp = *a as u128 * *b as u128;
533 (tmp % p as u128) as u64
534 };
535 let expected_out = negacyclic_mul(&a, &b, mul, p);
536
537 assert_eq!(a_clone, expected_out);
538 }
539 }
540 }
541}