1use super::prime::NttRootTable;
28use super::scalar::compute_shoup;
29use alloc::vec;
30use alloc::vec::Vec;
31
32#[derive(Debug, Clone)]
55pub struct Ntt32Context {
56 pub n: usize,
58
59 pub log_n: u32,
61
62 pub q: u32,
64
65 pub two_q: u32,
67
68 pub root_powers: Vec<u32>,
70
71 pub root_powers_shoup: Vec<u32>,
73
74 #[cfg(target_arch = "aarch64")]
77 pub root_powers_qmulh: Vec<i32>,
78
79 pub inv_root_powers: Vec<u32>,
81
82 pub inv_root_powers_shoup: Vec<u32>,
84
85 #[cfg(target_arch = "aarch64")]
87 pub inv_root_powers_qmulh: Vec<i32>,
88
89 pub n_inv: u32,
91
92 pub n_inv_shoup: u32,
94}
95
96impl Ntt32Context {
97 pub fn try_new(n: usize, q: u32) -> Result<Self, crate::NttError> {
111 if n < 2 || !n.is_power_of_two() {
112 return Err(crate::NttError::InvalidSize(n));
113 }
114 if q >= (1u32 << 28) {
115 return Err(crate::NttError::PrimeTooLarge(q as u64));
116 }
117 if !super::prime::is_prime_32(q) {
118 return Err(crate::NttError::NotPrime(q as u64));
119 }
120 if !((q - 1) as usize).is_multiple_of(2 * n) {
121 return Err(crate::NttError::NotNttFriendly { q: q as u64, n });
122 }
123
124 let base = NttRootTable::new(n, q);
126
127 let root_powers_shoup: Vec<u32> = base
128 .root_powers
129 .iter()
130 .map(|&w| compute_shoup(w, q))
131 .collect();
132
133 let inv_root_powers_shoup: Vec<u32> = base
134 .inv_root_powers
135 .iter()
136 .map(|&w| compute_shoup(w, q))
137 .collect();
138
139 let n_inv_shoup = compute_shoup(base.n_inv, q);
140
141 #[cfg(target_arch = "aarch64")]
142 let root_powers_qmulh: Vec<i32> = base
143 .root_powers
144 .iter()
145 .map(|&w| ((w as u64 * (1u64 << 31)) / q as u64) as i32)
146 .collect();
147
148 #[cfg(target_arch = "aarch64")]
149 let inv_root_powers_qmulh: Vec<i32> = base
150 .inv_root_powers
151 .iter()
152 .map(|&w| ((w as u64 * (1u64 << 31)) / q as u64) as i32)
153 .collect();
154
155 Ok(Self {
156 n,
157 log_n: base.log_n,
158 q,
159 two_q: 2 * q,
160 root_powers: base.root_powers,
161 root_powers_shoup,
162 #[cfg(target_arch = "aarch64")]
163 root_powers_qmulh,
164 inv_root_powers: base.inv_root_powers,
165 inv_root_powers_shoup,
166 #[cfg(target_arch = "aarch64")]
167 inv_root_powers_qmulh,
168 n_inv: base.n_inv,
169 n_inv_shoup,
170 })
171 }
172
173 pub fn new(n: usize, q: u32) -> Self {
188 Self::try_new(n, q).expect("Invalid NTT parameters")
189 }
190
191 #[inline]
196 pub fn forward(&self, data: &mut [u32]) {
197 #[cfg(target_arch = "aarch64")]
198 {
199 super::neon::ntt_fwd_neon(data, self);
200 }
201 #[cfg(not(target_arch = "aarch64"))]
202 {
203 super::scalar::ntt_forward_scalar(data, self);
204 }
205 }
206
207 #[inline]
213 pub fn inverse(&self, data: &mut [u32]) {
214 #[cfg(target_arch = "aarch64")]
215 {
216 super::neon::ntt_inv_neon(data, self);
217 }
218 #[cfg(not(target_arch = "aarch64"))]
219 {
220 super::scalar::ntt_inverse_scalar(data, self);
221 }
222 }
223
224 #[inline]
230 pub fn inverse_lazy(&self, data: &mut [u32]) {
231 #[cfg(target_arch = "aarch64")]
232 {
233 super::neon::ntt_inv_neon_lazy(data, self);
234 }
235 #[cfg(not(target_arch = "aarch64"))]
236 {
237 super::scalar::ntt_inverse_scalar_lazy(data, self);
238 }
239 }
240
241 #[inline]
243 pub fn n_inv(&self) -> u32 {
244 self.n_inv
245 }
246
247 #[inline]
249 pub fn n_inv_shoup(&self) -> u32 {
250 self.n_inv_shoup
251 }
252
253 pub fn pointwise_mul(&self, a: &[u32], b: &[u32], result: &mut [u32]) {
257 super::scalar::ntt_pointwise_mul_scalar(a, b, result, self.q, self.n);
258 }
259
260 pub fn negacyclic_mul(&self, a: &[u32], b: &[u32]) -> Vec<u32> {
268 let n = self.n;
269 assert_eq!(a.len(), n, "negacyclic_mul: a.len() must be N");
270 assert_eq!(b.len(), n, "negacyclic_mul: b.len() must be N");
271 let mut a_buf = a.to_vec();
272 let mut b_buf = b.to_vec();
273 let mut result = vec![0u32; n];
274 self.negacyclic_mul_into(&mut a_buf, &mut b_buf, &mut result);
275 result
276 }
277
278 pub fn negacyclic_mul_into(&self, a_buf: &mut [u32], b_buf: &mut [u32], result: &mut [u32]) {
304 let n = self.n;
305 assert_eq!(a_buf.len(), n, "a_buf.len()={} != N={n}", a_buf.len());
306 assert_eq!(b_buf.len(), n, "b_buf.len()={} != N={n}", b_buf.len());
307 assert_eq!(result.len(), n, "result.len()={} != N={n}", result.len());
308
309 self.forward(a_buf);
310 self.forward(b_buf);
311 self.pointwise_mul(a_buf, b_buf, result);
312 self.inverse(result);
313 }
314}
315
316#[cfg(test)]
321#[allow(unused_variables, clippy::needless_range_loop, dead_code)]
322mod tests {
323 use super::*;
324 use crate::ntt32::prime::generate_primes_28;
325
326 fn test_prime(n: usize) -> u32 {
327 generate_primes_28(n, 1)[0]
328 }
329
330 fn make_test_data(n: usize, q: u32) -> Vec<u32> {
331 (0..n)
332 .map(|i| ((i as u64 * 314_159_265 + 271_828_182) % q as u64) as u32)
333 .collect()
334 }
335
336 #[test]
337 fn test_roundtrip_n2() {
338 let q = 5u32; let ctx = Ntt32Context::new(2, q);
341 let original = vec![1u32, 3];
342 let mut data = original.clone();
343 ctx.forward(&mut data);
344 assert_ne!(data, original, "NTT forward did nothing for N=2");
345 ctx.inverse(&mut data);
346 assert_eq!(data, original, "NTT roundtrip failed for N=2");
347 }
348
349 #[test]
350 fn test_roundtrip_n4() {
351 let q = 17u32; let ctx = Ntt32Context::new(4, q);
354 let original = vec![1u32, 5, 9, 13];
355 let mut data = original.clone();
356 ctx.forward(&mut data);
357 assert_ne!(data, original, "NTT forward did nothing for N=4");
358 ctx.inverse(&mut data);
359 assert_eq!(data, original, "NTT roundtrip failed for N=4");
360 }
361
362 #[test]
363 fn test_roundtrip_n16() {
364 let n = 16;
365 let q = test_prime(n);
366 let ctx = Ntt32Context::new(n, q);
367 let original = make_test_data(n, q);
368 let mut data = original.clone();
369
370 ctx.forward(&mut data);
371 assert_ne!(data, original, "NTT forward did nothing for N={n}");
372 ctx.inverse(&mut data);
373 assert_eq!(data, original, "NTT roundtrip failed for N={n}");
374 }
375
376 #[test]
377 fn test_roundtrip_n64() {
378 let n = 64;
379 let q = test_prime(n);
380 let ctx = Ntt32Context::new(n, q);
381 let original = make_test_data(n, q);
382 let mut data = original.clone();
383
384 ctx.forward(&mut data);
385 ctx.inverse(&mut data);
386 assert_eq!(data, original, "NTT roundtrip failed for N={n}");
387 }
388
389 #[test]
390 fn test_roundtrip_n1024() {
391 let n = 1024;
392 let q = test_prime(n);
393 let ctx = Ntt32Context::new(n, q);
394 let original = make_test_data(n, q);
395 let mut data = original.clone();
396
397 ctx.forward(&mut data);
398 ctx.inverse(&mut data);
399 assert_eq!(data, original, "NTT roundtrip failed for N={n}");
400 }
401
402 #[test]
403 fn test_roundtrip_n32768() {
404 let n = 32768;
405 let q = test_prime(n);
406 let ctx = Ntt32Context::new(n, q);
407 let original = make_test_data(n, q);
408 let mut data = original.clone();
409
410 ctx.forward(&mut data);
411 ctx.inverse(&mut data);
412 assert_eq!(data, original, "NTT roundtrip failed for N=32768");
413 }
414
415 #[test]
416 fn test_roundtrip_zeros() {
417 let n = 64;
418 let q = test_prime(n);
419 let ctx = Ntt32Context::new(n, q);
420 let mut data = vec![0u32; n];
421 ctx.forward(&mut data);
422 ctx.inverse(&mut data);
423 assert_eq!(data, vec![0u32; n]);
424 }
425
426 #[test]
427 fn test_constant_polynomial() {
428 let n = 64;
430 let q = test_prime(n);
431 let ctx = Ntt32Context::new(n, q);
432 let c = 42u32;
433 let mut data = vec![0u32; n];
434 data[0] = c;
435
436 ctx.forward(&mut data);
437 for (i, &v) in data.iter().enumerate() {
438 assert_eq!(v, c, "NTT of constant: data[{i}]={v}, expected {c}");
439 }
440 }
441
442 #[test]
443 fn test_negacyclic_mul_identity() {
444 let n = 64;
446 let q = test_prime(n);
447 let ctx = Ntt32Context::new(n, q);
448
449 let a: Vec<u32> = (0..n)
450 .map(|i| ((i as u64 * 17 + 5) % q as u64) as u32)
451 .collect();
452 let mut one = vec![0u32; n];
453 one[0] = 1;
454
455 let result = ctx.negacyclic_mul(&a, &one);
456 assert_eq!(result, a, "Multiply by 1 is not identity");
457 }
458
459 #[test]
460 fn test_negacyclic_mul_n16() {
461 let n = 16;
462 let q = test_prime(n);
463 let ctx = Ntt32Context::new(n, q);
464
465 let a: Vec<u32> = (0..n).map(|i| (i as u32 + 1) % q).collect();
466 let b: Vec<u32> = vec![1u32; n];
467
468 let mut expected = vec![0u32; n];
470 for i in 0..n {
471 for j in 0..n {
472 let prod = (a[i] as u64 * b[j] as u64) % q as u64;
473 if i + j < n {
474 expected[i + j] = ((expected[i + j] as u64 + prod) % q as u64) as u32;
475 } else {
476 let idx = i + j - n;
477 expected[idx] = ((expected[idx] as u64 + q as u64 - prod) % q as u64) as u32;
478 }
479 }
480 }
481
482 let result = ctx.negacyclic_mul(&a, &b);
483 assert_eq!(result, expected, "Negacyclic multiplication mismatch");
484 }
485
486 #[test]
487 fn test_inverse_lazy_no_normalization() {
488 let n = 256;
489 let q = test_prime(n);
490 let ctx = Ntt32Context::new(n, q);
491 let original = make_test_data(n, q);
492
493 let mut data = original.clone();
495 ctx.forward(&mut data);
496 ctx.inverse_lazy(&mut data);
497 assert_ne!(
498 data, original,
499 "inverse_lazy should not match original (no N^{{-1}})"
500 );
501
502 let n_inv = ctx.n_inv();
504 for x in data.iter_mut() {
505 *x = ((*x as u64 * n_inv as u64) % q as u64) as u32;
506 }
507 assert_eq!(
508 data, original,
509 "inverse_lazy + manual N^{{-1}} should match original"
510 );
511 }
512
513 #[test]
514 fn test_inverse_lazy_matches_concrete_ntt_style() {
515 let n = 1024;
517 let q = test_prime(n);
518 let ctx = Ntt32Context::new(n, q);
519 let original = make_test_data(n, q);
520
521 let mut data_full = original.clone();
522 let mut data_lazy = original.clone();
523
524 ctx.forward(&mut data_full);
525 ctx.forward(&mut data_lazy);
526
527 ctx.inverse(&mut data_full);
528 ctx.inverse_lazy(&mut data_lazy);
529
530 let n_inv = ctx.n_inv();
532 let data_lazy_normalized: Vec<u32> = data_lazy
533 .iter()
534 .map(|&x| ((x as u64 * n_inv as u64) % q as u64) as u32)
535 .collect();
536 assert_eq!(data_full, data_lazy_normalized);
537 }
538
539 #[test]
540 fn test_negacyclic_mul_into_matches_negacyclic_mul() {
541 let n = 256;
542 let q = test_prime(n);
543 let ctx = Ntt32Context::new(n, q);
544
545 let a: Vec<u32> = (0..n)
546 .map(|i| ((i as u64 * 17 + 3) % q as u64) as u32)
547 .collect();
548 let b: Vec<u32> = (0..n)
549 .map(|i| ((i as u64 * 31 + 7) % q as u64) as u32)
550 .collect();
551
552 let result_alloc = ctx.negacyclic_mul(&a, &b);
554
555 let mut a_buf = a.clone();
557 let mut b_buf = b.clone();
558 let mut result_inplace = vec![0u32; n];
559 ctx.negacyclic_mul_into(&mut a_buf, &mut b_buf, &mut result_inplace);
560
561 assert_eq!(
562 result_alloc, result_inplace,
563 "negacyclic_mul_into must match negacyclic_mul"
564 );
565 }
566
567 #[test]
568 fn test_negacyclic_mul_into_reusable_buffers() {
569 let n = 64;
571 let q = test_prime(n);
572 let ctx = Ntt32Context::new(n, q);
573
574 let mut a_buf = vec![0u32; n];
575 let mut b_buf = vec![0u32; n];
576 let mut result = vec![0u32; n];
577
578 for round in 0..3u32 {
579 for i in 0..n {
581 a_buf[i] = ((i as u64 * (round as u64 + 17) + 3) % q as u64) as u32;
582 b_buf[i] = ((i as u64 * (round as u64 + 31) + 7) % q as u64) as u32;
583 }
584 let expected = ctx.negacyclic_mul(&a_buf, &b_buf);
585
586 ctx.negacyclic_mul_into(&mut a_buf, &mut b_buf, &mut result);
587 assert_eq!(
588 result, expected,
589 "Reusable buffer mismatch at round {round}"
590 );
591
592 }
594 }
595
596 #[test]
601 fn test_pq_mldsa_roundtrip() {
602 let q: u32 = 8_380_417;
604 let n = 256;
605 assert_eq!((q - 1) % (2 * n as u32), 0, "q-1 must be divisible by 2N");
606
607 let ctx = Ntt32Context::new(n, q);
608 let original = make_test_data(n, q);
609 let mut data = original.clone();
610
611 ctx.forward(&mut data);
612 assert_ne!(data, original, "Forward NTT should change data");
613 ctx.inverse(&mut data);
614 assert_eq!(data, original, "ML-DSA roundtrip failed");
615 }
616
617 #[test]
618 fn test_pq_mldsa_negacyclic_mul() {
619 let q: u32 = 8_380_417;
620 let n = 256;
621 let ctx = Ntt32Context::new(n, q);
622
623 let a: Vec<u32> = (0..n)
625 .map(|i| ((i as u64 * 17 + 5) % q as u64) as u32)
626 .collect();
627 let mut one = vec![0u32; n];
628 one[0] = 1;
629
630 let result = ctx.negacyclic_mul(&a, &one);
631 assert_eq!(result, a, "ML-DSA: multiply by 1 is not identity");
632 }
633
634 #[test]
635 fn test_pq_falcon512_roundtrip() {
636 let q: u32 = 12_289;
638 let n = 512;
639 assert_eq!((q - 1) % (2 * n as u32), 0, "q-1 must be divisible by 2N");
640
641 let ctx = Ntt32Context::new(n, q);
642 let original = make_test_data(n, q);
643 let mut data = original.clone();
644
645 ctx.forward(&mut data);
646 ctx.inverse(&mut data);
647 assert_eq!(data, original, "Falcon-512 roundtrip failed");
648 }
649
650 #[test]
651 fn test_pq_falcon1024_roundtrip() {
652 let q: u32 = 12_289;
654 let n = 1024;
655 assert_eq!((q - 1) % (2 * n as u32), 0, "q-1 must be divisible by 2N");
656
657 let ctx = Ntt32Context::new(n, q);
658 let original = make_test_data(n, q);
659 let mut data = original.clone();
660
661 ctx.forward(&mut data);
662 ctx.inverse(&mut data);
663 assert_eq!(data, original, "Falcon-1024 roundtrip failed");
664 }
665
666 #[test]
667 fn test_pq_falcon_negacyclic_mul() {
668 let q: u32 = 12_289;
669 let n = 512;
670 let ctx = Ntt32Context::new(n, q);
671
672 let a: Vec<u32> = (0..n)
673 .map(|i| ((i as u64 * 17 + 5) % q as u64) as u32)
674 .collect();
675 let mut one = vec![0u32; n];
676 one[0] = 1;
677
678 let result = ctx.negacyclic_mul(&a, &one);
679 assert_eq!(result, a, "Falcon: multiply by 1 is not identity");
680 }
681
682 #[test]
683 fn test_pq_mlkem_proxy_roundtrip() {
684 let q: u32 = 3_329;
687 let n = 128;
688 assert_eq!((q - 1) % (2 * n as u32), 0, "q-1 must be divisible by 2N");
689
690 let ctx = Ntt32Context::new(n, q);
691 let original = make_test_data(n, q);
692 let mut data = original.clone();
693
694 ctx.forward(&mut data);
695 ctx.inverse(&mut data);
696 assert_eq!(data, original, "ML-KEM proxy roundtrip failed");
697 }
698
699 #[test]
700 fn test_pq_mlkem_negacyclic_mul() {
701 let q: u32 = 3_329;
702 let n = 128;
703 let ctx = Ntt32Context::new(n, q);
704
705 let a: Vec<u32> = (0..n).map(|i| (i as u32 + 1) % q).collect();
707 let b: Vec<u32> = vec![1u32; n];
708
709 let mut expected = vec![0u32; n];
710 for i in 0..n {
711 for j in 0..n {
712 let prod = (a[i] as u64 * b[j] as u64) % q as u64;
713 if i + j < n {
714 expected[i + j] = ((expected[i + j] as u64 + prod) % q as u64) as u32;
715 } else {
716 let idx = i + j - n;
717 expected[idx] = ((expected[idx] as u64 + q as u64 - prod) % q as u64) as u32;
718 }
719 }
720 }
721
722 let result = ctx.negacyclic_mul(&a, &b);
723 assert_eq!(
724 result, expected,
725 "ML-KEM negacyclic multiplication mismatch"
726 );
727 }
728
729 const _: () = {
732 fn assert_send<T: Send>() {}
733 fn assert_sync<T: Sync>() {}
734 fn check() {
735 assert_send::<super::Ntt32Context>();
736 assert_sync::<super::Ntt32Context>();
737 }
738 };
739}