Skip to main content

vaea_ntt/ntt32/
context.rs

1// Copyright (C) 2024-2026 Vaea SAS
2// SPDX-License-Identifier: AGPL-3.0-or-later
3//
4// This file is part of VaeaNTT.
5//
6// VaeaNTT is free software: you can redistribute it and/or modify it under
7// the terms of the GNU Affero General Public License as published by the
8// Free Software Foundation, either version 3 of the License, or (at your
9// option) any later version.
10//
11// VaeaNTT is distributed in the hope that it will be useful, but WITHOUT
12// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13// FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public
14// License for more details.
15//
16// You should have received a copy of the GNU Affero General Public License
17// along with VaeaNTT. If not, see <https://www.gnu.org/licenses/>.
18
19//! # Ntt32Context — Unified NTT Context for 28-bit Primes
20//!
21//! Combines the root table from `NttSmallCtx` with the Shoup precomputed
22//! quotients from `ShoupCtx` into a single unified context struct.
23//!
24//! The `forward()` and `inverse()` methods automatically dispatch to
25//! NEON on `aarch64` targets, falling back to scalar code otherwise.
26
27use super::prime::NttRootTable;
28use super::scalar::compute_shoup;
29use alloc::vec;
30use alloc::vec::Vec;
31
32// ===========================================================================
33// Ntt32Context — the unified context
34// ===========================================================================
35
36/// Pre-computed NTT context for a single 28-bit prime.
37///
38/// Stores twiddle factors (root powers) in Longa-Naehrig ordering along
39/// with their Shoup precomputed quotients for division-free multiplication.
40///
41/// # Usage
42/// ```
43/// use vaea_ntt::ntt32::{Ntt32Context, generate_primes_28};
44///
45/// let primes = generate_primes_28(1024, 1);
46/// let ctx = Ntt32Context::new(1024, primes[0]);
47///
48/// let mut data = vec![0u32; 1024];
49/// data[0] = 42;
50/// ctx.forward(&mut data);   // NTT forward
51/// ctx.inverse(&mut data);   // NTT inverse (data restored)
52/// assert_eq!(data[0], 42);
53/// ```
54#[derive(Debug, Clone)]
55pub struct Ntt32Context {
56    /// Polynomial size (power of 2)
57    pub n: usize,
58
59    /// log2(n)
60    pub log_n: u32,
61
62    /// Prime < 2^28
63    pub q: u32,
64
65    /// 2 · q — precomputed for Harvey lazy butterfly
66    pub two_q: u32,
67
68    /// Forward root powers (Longa-Naehrig ordering)
69    pub root_powers: Vec<u32>,
70
71    /// Shoup quotients for forward root powers: `floor(root_powers[i] · 2^32 / q)`
72    pub root_powers_shoup: Vec<u32>,
73
74    /// Signed doubling-multiply-high quotients for forward root powers (aarch64 NEON).
75    /// `root_powers_qmulh[i] = floor(root_powers[i] · 2^31 / q)` as i32.
76    #[cfg(target_arch = "aarch64")]
77    pub root_powers_qmulh: Vec<i32>,
78
79    /// Inverse root powers for INTT
80    pub inv_root_powers: Vec<u32>,
81
82    /// Shoup quotients for inverse root powers
83    pub inv_root_powers_shoup: Vec<u32>,
84
85    /// Signed doubling-multiply-high quotients for inverse root powers (aarch64 NEON).
86    #[cfg(target_arch = "aarch64")]
87    pub inv_root_powers_qmulh: Vec<i32>,
88
89    /// N^{-1} mod q — normalization factor for INTT
90    pub n_inv: u32,
91
92    /// Shoup quotient for n_inv
93    pub n_inv_shoup: u32,
94}
95
96impl Ntt32Context {
97    /// Fallible constructor for an NTT context for a 28-bit prime.
98    ///
99    /// Validates all preconditions and returns an error instead of panicking.
100    ///
101    /// # Arguments
102    /// - `n` — polynomial size, must be a power of 2 ≥ 2
103    /// - `q` — prime < 2^28, must satisfy `q ≡ 1 (mod 2N)`
104    ///
105    /// # Errors
106    /// - [`crate::NttError::InvalidSize`] if `n` is not a power of 2 ≥ 2
107    /// - [`crate::NttError::PrimeTooLarge`] if `q ≥ 2^28`
108    /// - [`crate::NttError::NotPrime`] if `q` is not prime
109    /// - [`crate::NttError::NotNttFriendly`] if `(q - 1)` is not divisible by `2N`
110    pub fn try_new(n: usize, q: u32) -> Result<Self, crate::NttError> {
111        if n < 2 || !n.is_power_of_two() {
112            return Err(crate::NttError::InvalidSize(n));
113        }
114        if q >= (1u32 << 28) {
115            return Err(crate::NttError::PrimeTooLarge(q as u64));
116        }
117        if !super::prime::is_prime_32(q) {
118            return Err(crate::NttError::NotPrime(q as u64));
119        }
120        if !((q - 1) as usize).is_multiple_of(2 * n) {
121            return Err(crate::NttError::NotNttFriendly { q: q as u64, n });
122        }
123
124        // All preconditions verified — build the root table
125        let base = NttRootTable::new(n, q);
126
127        let root_powers_shoup: Vec<u32> = base
128            .root_powers
129            .iter()
130            .map(|&w| compute_shoup(w, q))
131            .collect();
132
133        let inv_root_powers_shoup: Vec<u32> = base
134            .inv_root_powers
135            .iter()
136            .map(|&w| compute_shoup(w, q))
137            .collect();
138
139        let n_inv_shoup = compute_shoup(base.n_inv, q);
140
141        #[cfg(target_arch = "aarch64")]
142        let root_powers_qmulh: Vec<i32> = base
143            .root_powers
144            .iter()
145            .map(|&w| ((w as u64 * (1u64 << 31)) / q as u64) as i32)
146            .collect();
147
148        #[cfg(target_arch = "aarch64")]
149        let inv_root_powers_qmulh: Vec<i32> = base
150            .inv_root_powers
151            .iter()
152            .map(|&w| ((w as u64 * (1u64 << 31)) / q as u64) as i32)
153            .collect();
154
155        Ok(Self {
156            n,
157            log_n: base.log_n,
158            q,
159            two_q: 2 * q,
160            root_powers: base.root_powers,
161            root_powers_shoup,
162            #[cfg(target_arch = "aarch64")]
163            root_powers_qmulh,
164            inv_root_powers: base.inv_root_powers,
165            inv_root_powers_shoup,
166            #[cfg(target_arch = "aarch64")]
167            inv_root_powers_qmulh,
168            n_inv: base.n_inv,
169            n_inv_shoup,
170        })
171    }
172
173    /// Creates a new NTT context for a 28-bit prime.
174    ///
175    /// Computes primitive roots, twiddle factors (Longa-Naehrig ordering),
176    /// and precomputes all Shoup quotients.
177    ///
178    /// # Arguments
179    /// - `n` — polynomial size, must be a power of 2 ≥ 2
180    /// - `q` — prime < 2^28, must satisfy `q ≡ 1 (mod 2N)`
181    ///
182    /// # Panics
183    /// - If `n` is not a power of 2 ≥ 2
184    /// - If `q ≥ 2^28`
185    /// - If `q` is not prime
186    /// - If `(q - 1)` is not divisible by `2N`
187    pub fn new(n: usize, q: u32) -> Self {
188        Self::try_new(n, q).expect("Invalid NTT parameters")
189    }
190
191    /// Applies the NTT forward transform in-place.
192    ///
193    /// On `aarch64`, dispatches to the fully-vectorized NEON implementation.
194    /// On other architectures, uses the scalar Shoup NTT.
195    #[inline]
196    pub fn forward(&self, data: &mut [u32]) {
197        #[cfg(target_arch = "aarch64")]
198        {
199            super::neon::ntt_fwd_neon(data, self);
200        }
201        #[cfg(not(target_arch = "aarch64"))]
202        {
203            super::scalar::ntt_forward_scalar(data, self);
204        }
205    }
206
207    /// Applies the NTT inverse transform in-place (with N⁻¹ normalization).
208    ///
209    /// Output coefficients are fully normalized to `[0, q)`.
210    /// On `aarch64`, dispatches to the NEON implementation.
211    /// On other architectures, uses the scalar Shoup NTT.
212    #[inline]
213    pub fn inverse(&self, data: &mut [u32]) {
214        #[cfg(target_arch = "aarch64")]
215        {
216            super::neon::ntt_inv_neon(data, self);
217        }
218        #[cfg(not(target_arch = "aarch64"))]
219        {
220            super::scalar::ntt_inverse_scalar(data, self);
221        }
222    }
223
224    /// Applies the NTT inverse transform **without** N⁻¹ normalization.
225    ///
226    /// Output coefficients are scaled by N relative to the true INTT.
227    /// Use this when chaining operations where normalization can be deferred,
228    /// or when matching libraries that don't normalize (e.g., concrete-ntt).
229    #[inline]
230    pub fn inverse_lazy(&self, data: &mut [u32]) {
231        #[cfg(target_arch = "aarch64")]
232        {
233            super::neon::ntt_inv_neon_lazy(data, self);
234        }
235        #[cfg(not(target_arch = "aarch64"))]
236        {
237            super::scalar::ntt_inverse_scalar_lazy(data, self);
238        }
239    }
240
241    /// Returns N⁻¹ mod q — useful for manual normalization after `inverse_lazy()`.
242    #[inline]
243    pub fn n_inv(&self) -> u32 {
244        self.n_inv
245    }
246
247    /// Returns the Shoup quotient for N⁻¹ — for manual Shoup normalization.
248    #[inline]
249    pub fn n_inv_shoup(&self) -> u32 {
250        self.n_inv_shoup
251    }
252
253    /// Pointwise multiplication of two vectors in the NTT domain.
254    ///
255    /// Computes `result[i] = a[i] · b[i] mod q` for each coefficient.
256    pub fn pointwise_mul(&self, a: &[u32], b: &[u32], result: &mut [u32]) {
257        super::scalar::ntt_pointwise_mul_scalar(a, b, result, self.q, self.n);
258    }
259
260    /// Negacyclic polynomial multiplication in Z_q\[X\]/(X^N + 1).
261    ///
262    /// Computes `result = a · b mod (X^N + 1)` using forward NTT,
263    /// pointwise multiplication, and inverse NTT.
264    ///
265    /// # Returns
266    /// A new vector of length N containing the product.
267    pub fn negacyclic_mul(&self, a: &[u32], b: &[u32]) -> Vec<u32> {
268        let n = self.n;
269        assert_eq!(a.len(), n, "negacyclic_mul: a.len() must be N");
270        assert_eq!(b.len(), n, "negacyclic_mul: b.len() must be N");
271        let mut a_buf = a.to_vec();
272        let mut b_buf = b.to_vec();
273        let mut result = vec![0u32; n];
274        self.negacyclic_mul_into(&mut a_buf, &mut b_buf, &mut result);
275        result
276    }
277
278    /// Zero-allocation negacyclic multiplication.
279    ///
280    /// The caller provides pre-allocated buffers:
281    /// - `a_buf` / `b_buf`: input polynomials (overwritten with NTT-domain values)
282    /// - `result`: output buffer for the product
283    ///
284    /// All buffers must have length N. After the call, `a_buf` and `b_buf`
285    /// contain NTT-domain data (destroyed); `result` contains the product
286    /// in coefficient domain.
287    ///
288    /// # Example
289    /// ```
290    /// use vaea_ntt::ntt32::{Ntt32Context, generate_primes_28};
291    ///
292    /// let primes = generate_primes_28(256, 1);
293    /// let ctx = Ntt32Context::new(256, primes[0]);
294    ///
295    /// let mut a = vec![1u32; 256];
296    /// let mut b = vec![2u32; 256];
297    /// let mut result = vec![0u32; 256];
298    ///
299    /// ctx.negacyclic_mul_into(&mut a, &mut b, &mut result);
300    /// // result now contains a·b mod (X^256 + 1)
301    /// // a and b are now in NTT domain (overwritten)
302    /// ```
303    pub fn negacyclic_mul_into(&self, a_buf: &mut [u32], b_buf: &mut [u32], result: &mut [u32]) {
304        let n = self.n;
305        assert_eq!(a_buf.len(), n, "a_buf.len()={} != N={n}", a_buf.len());
306        assert_eq!(b_buf.len(), n, "b_buf.len()={} != N={n}", b_buf.len());
307        assert_eq!(result.len(), n, "result.len()={} != N={n}", result.len());
308
309        self.forward(a_buf);
310        self.forward(b_buf);
311        self.pointwise_mul(a_buf, b_buf, result);
312        self.inverse(result);
313    }
314}
315
316// ===========================================================================
317// Tests
318// ===========================================================================
319
320#[cfg(test)]
321#[allow(unused_variables, clippy::needless_range_loop, dead_code)]
322mod tests {
323    use super::*;
324    use crate::ntt32::prime::generate_primes_28;
325
326    fn test_prime(n: usize) -> u32 {
327        generate_primes_28(n, 1)[0]
328    }
329
330    fn make_test_data(n: usize, q: u32) -> Vec<u32> {
331        (0..n)
332            .map(|i| ((i as u64 * 314_159_265 + 271_828_182) % q as u64) as u32)
333            .collect()
334    }
335
336    #[test]
337    fn test_roundtrip_n2() {
338        // N=2: edge case, must fall back to scalar on NEON
339        let q = 5u32; // smallest NTT-friendly prime for N=2: q ≡ 1 (mod 4), q=5 works
340        let ctx = Ntt32Context::new(2, q);
341        let original = vec![1u32, 3];
342        let mut data = original.clone();
343        ctx.forward(&mut data);
344        assert_ne!(data, original, "NTT forward did nothing for N=2");
345        ctx.inverse(&mut data);
346        assert_eq!(data, original, "NTT roundtrip failed for N=2");
347    }
348
349    #[test]
350    fn test_roundtrip_n4() {
351        // N=4: edge case, must fall back to scalar on NEON
352        let q = 17u32; // q ≡ 1 (mod 8): 17 works
353        let ctx = Ntt32Context::new(4, q);
354        let original = vec![1u32, 5, 9, 13];
355        let mut data = original.clone();
356        ctx.forward(&mut data);
357        assert_ne!(data, original, "NTT forward did nothing for N=4");
358        ctx.inverse(&mut data);
359        assert_eq!(data, original, "NTT roundtrip failed for N=4");
360    }
361
362    #[test]
363    fn test_roundtrip_n16() {
364        let n = 16;
365        let q = test_prime(n);
366        let ctx = Ntt32Context::new(n, q);
367        let original = make_test_data(n, q);
368        let mut data = original.clone();
369
370        ctx.forward(&mut data);
371        assert_ne!(data, original, "NTT forward did nothing for N={n}");
372        ctx.inverse(&mut data);
373        assert_eq!(data, original, "NTT roundtrip failed for N={n}");
374    }
375
376    #[test]
377    fn test_roundtrip_n64() {
378        let n = 64;
379        let q = test_prime(n);
380        let ctx = Ntt32Context::new(n, q);
381        let original = make_test_data(n, q);
382        let mut data = original.clone();
383
384        ctx.forward(&mut data);
385        ctx.inverse(&mut data);
386        assert_eq!(data, original, "NTT roundtrip failed for N={n}");
387    }
388
389    #[test]
390    fn test_roundtrip_n1024() {
391        let n = 1024;
392        let q = test_prime(n);
393        let ctx = Ntt32Context::new(n, q);
394        let original = make_test_data(n, q);
395        let mut data = original.clone();
396
397        ctx.forward(&mut data);
398        ctx.inverse(&mut data);
399        assert_eq!(data, original, "NTT roundtrip failed for N={n}");
400    }
401
402    #[test]
403    fn test_roundtrip_n32768() {
404        let n = 32768;
405        let q = test_prime(n);
406        let ctx = Ntt32Context::new(n, q);
407        let original = make_test_data(n, q);
408        let mut data = original.clone();
409
410        ctx.forward(&mut data);
411        ctx.inverse(&mut data);
412        assert_eq!(data, original, "NTT roundtrip failed for N=32768");
413    }
414
415    #[test]
416    fn test_roundtrip_zeros() {
417        let n = 64;
418        let q = test_prime(n);
419        let ctx = Ntt32Context::new(n, q);
420        let mut data = vec![0u32; n];
421        ctx.forward(&mut data);
422        ctx.inverse(&mut data);
423        assert_eq!(data, vec![0u32; n]);
424    }
425
426    #[test]
427    fn test_constant_polynomial() {
428        // NTT of [c, 0, 0, ...] should give [c, c, c, ...]
429        let n = 64;
430        let q = test_prime(n);
431        let ctx = Ntt32Context::new(n, q);
432        let c = 42u32;
433        let mut data = vec![0u32; n];
434        data[0] = c;
435
436        ctx.forward(&mut data);
437        for (i, &v) in data.iter().enumerate() {
438            assert_eq!(v, c, "NTT of constant: data[{i}]={v}, expected {c}");
439        }
440    }
441
442    #[test]
443    fn test_negacyclic_mul_identity() {
444        // Multiply by [1, 0, 0, ...] should be identity
445        let n = 64;
446        let q = test_prime(n);
447        let ctx = Ntt32Context::new(n, q);
448
449        let a: Vec<u32> = (0..n)
450            .map(|i| ((i as u64 * 17 + 5) % q as u64) as u32)
451            .collect();
452        let mut one = vec![0u32; n];
453        one[0] = 1;
454
455        let result = ctx.negacyclic_mul(&a, &one);
456        assert_eq!(result, a, "Multiply by 1 is not identity");
457    }
458
459    #[test]
460    fn test_negacyclic_mul_n16() {
461        let n = 16;
462        let q = test_prime(n);
463        let ctx = Ntt32Context::new(n, q);
464
465        let a: Vec<u32> = (0..n).map(|i| (i as u32 + 1) % q).collect();
466        let b: Vec<u32> = vec![1u32; n];
467
468        // Naive reference
469        let mut expected = vec![0u32; n];
470        for i in 0..n {
471            for j in 0..n {
472                let prod = (a[i] as u64 * b[j] as u64) % q as u64;
473                if i + j < n {
474                    expected[i + j] = ((expected[i + j] as u64 + prod) % q as u64) as u32;
475                } else {
476                    let idx = i + j - n;
477                    expected[idx] = ((expected[idx] as u64 + q as u64 - prod) % q as u64) as u32;
478                }
479            }
480        }
481
482        let result = ctx.negacyclic_mul(&a, &b);
483        assert_eq!(result, expected, "Negacyclic multiplication mismatch");
484    }
485
486    #[test]
487    fn test_inverse_lazy_no_normalization() {
488        let n = 256;
489        let q = test_prime(n);
490        let ctx = Ntt32Context::new(n, q);
491        let original = make_test_data(n, q);
492
493        // inverse_lazy should NOT equal original (missing N^{-1})
494        let mut data = original.clone();
495        ctx.forward(&mut data);
496        ctx.inverse_lazy(&mut data);
497        assert_ne!(
498            data, original,
499            "inverse_lazy should not match original (no N^{{-1}})"
500        );
501
502        // But after manual N^{-1} normalization, it should match
503        let n_inv = ctx.n_inv();
504        for x in data.iter_mut() {
505            *x = ((*x as u64 * n_inv as u64) % q as u64) as u32;
506        }
507        assert_eq!(
508            data, original,
509            "inverse_lazy + manual N^{{-1}} should match original"
510        );
511    }
512
513    #[test]
514    fn test_inverse_lazy_matches_concrete_ntt_style() {
515        // Verify that inverse_lazy() is exactly inverse() without N^{-1}
516        let n = 1024;
517        let q = test_prime(n);
518        let ctx = Ntt32Context::new(n, q);
519        let original = make_test_data(n, q);
520
521        let mut data_full = original.clone();
522        let mut data_lazy = original.clone();
523
524        ctx.forward(&mut data_full);
525        ctx.forward(&mut data_lazy);
526
527        ctx.inverse(&mut data_full);
528        ctx.inverse_lazy(&mut data_lazy);
529
530        // data_lazy * N^{-1} should equal data_full
531        let n_inv = ctx.n_inv();
532        let data_lazy_normalized: Vec<u32> = data_lazy
533            .iter()
534            .map(|&x| ((x as u64 * n_inv as u64) % q as u64) as u32)
535            .collect();
536        assert_eq!(data_full, data_lazy_normalized);
537    }
538
539    #[test]
540    fn test_negacyclic_mul_into_matches_negacyclic_mul() {
541        let n = 256;
542        let q = test_prime(n);
543        let ctx = Ntt32Context::new(n, q);
544
545        let a: Vec<u32> = (0..n)
546            .map(|i| ((i as u64 * 17 + 3) % q as u64) as u32)
547            .collect();
548        let b: Vec<u32> = (0..n)
549            .map(|i| ((i as u64 * 31 + 7) % q as u64) as u32)
550            .collect();
551
552        // Allocating version
553        let result_alloc = ctx.negacyclic_mul(&a, &b);
554
555        // Zero-alloc version
556        let mut a_buf = a.clone();
557        let mut b_buf = b.clone();
558        let mut result_inplace = vec![0u32; n];
559        ctx.negacyclic_mul_into(&mut a_buf, &mut b_buf, &mut result_inplace);
560
561        assert_eq!(
562            result_alloc, result_inplace,
563            "negacyclic_mul_into must match negacyclic_mul"
564        );
565    }
566
567    #[test]
568    fn test_negacyclic_mul_into_reusable_buffers() {
569        // Verify that buffers can be reused across calls
570        let n = 64;
571        let q = test_prime(n);
572        let ctx = Ntt32Context::new(n, q);
573
574        let mut a_buf = vec![0u32; n];
575        let mut b_buf = vec![0u32; n];
576        let mut result = vec![0u32; n];
577
578        for round in 0..3u32 {
579            // Fill buffers with different data each round
580            for i in 0..n {
581                a_buf[i] = ((i as u64 * (round as u64 + 17) + 3) % q as u64) as u32;
582                b_buf[i] = ((i as u64 * (round as u64 + 31) + 7) % q as u64) as u32;
583            }
584            let expected = ctx.negacyclic_mul(&a_buf, &b_buf);
585
586            ctx.negacyclic_mul_into(&mut a_buf, &mut b_buf, &mut result);
587            assert_eq!(
588                result, expected,
589                "Reusable buffer mismatch at round {round}"
590            );
591
592            // Re-fill for next round (a_buf/b_buf were destroyed)
593        }
594    }
595
596    // ===================================================================
597    // NIST Post-Quantum Standard Primes
598    // ===================================================================
599
600    #[test]
601    fn test_pq_mldsa_roundtrip() {
602        // ML-DSA (FIPS 204): q = 8380417 = 2^23 - 2^13 + 1, N = 256
603        let q: u32 = 8_380_417;
604        let n = 256;
605        assert_eq!((q - 1) % (2 * n as u32), 0, "q-1 must be divisible by 2N");
606
607        let ctx = Ntt32Context::new(n, q);
608        let original = make_test_data(n, q);
609        let mut data = original.clone();
610
611        ctx.forward(&mut data);
612        assert_ne!(data, original, "Forward NTT should change data");
613        ctx.inverse(&mut data);
614        assert_eq!(data, original, "ML-DSA roundtrip failed");
615    }
616
617    #[test]
618    fn test_pq_mldsa_negacyclic_mul() {
619        let q: u32 = 8_380_417;
620        let n = 256;
621        let ctx = Ntt32Context::new(n, q);
622
623        // Multiply by [1, 0, 0, ...] should be identity
624        let a: Vec<u32> = (0..n)
625            .map(|i| ((i as u64 * 17 + 5) % q as u64) as u32)
626            .collect();
627        let mut one = vec![0u32; n];
628        one[0] = 1;
629
630        let result = ctx.negacyclic_mul(&a, &one);
631        assert_eq!(result, a, "ML-DSA: multiply by 1 is not identity");
632    }
633
634    #[test]
635    fn test_pq_falcon512_roundtrip() {
636        // Falcon-512: q = 12289, N = 512
637        let q: u32 = 12_289;
638        let n = 512;
639        assert_eq!((q - 1) % (2 * n as u32), 0, "q-1 must be divisible by 2N");
640
641        let ctx = Ntt32Context::new(n, q);
642        let original = make_test_data(n, q);
643        let mut data = original.clone();
644
645        ctx.forward(&mut data);
646        ctx.inverse(&mut data);
647        assert_eq!(data, original, "Falcon-512 roundtrip failed");
648    }
649
650    #[test]
651    fn test_pq_falcon1024_roundtrip() {
652        // Falcon-1024: q = 12289, N = 1024
653        let q: u32 = 12_289;
654        let n = 1024;
655        assert_eq!((q - 1) % (2 * n as u32), 0, "q-1 must be divisible by 2N");
656
657        let ctx = Ntt32Context::new(n, q);
658        let original = make_test_data(n, q);
659        let mut data = original.clone();
660
661        ctx.forward(&mut data);
662        ctx.inverse(&mut data);
663        assert_eq!(data, original, "Falcon-1024 roundtrip failed");
664    }
665
666    #[test]
667    fn test_pq_falcon_negacyclic_mul() {
668        let q: u32 = 12_289;
669        let n = 512;
670        let ctx = Ntt32Context::new(n, q);
671
672        let a: Vec<u32> = (0..n)
673            .map(|i| ((i as u64 * 17 + 5) % q as u64) as u32)
674            .collect();
675        let mut one = vec![0u32; n];
676        one[0] = 1;
677
678        let result = ctx.negacyclic_mul(&a, &one);
679        assert_eq!(result, a, "Falcon: multiply by 1 is not identity");
680    }
681
682    #[test]
683    fn test_pq_mlkem_proxy_roundtrip() {
684        // ML-KEM proxy: q = 3329, N = 128 (Kyber uses incomplete 128-point NTT)
685        // 3329 - 1 = 3328 = 2^8 × 13, and 2×128 = 256 | 3328 ✓
686        let q: u32 = 3_329;
687        let n = 128;
688        assert_eq!((q - 1) % (2 * n as u32), 0, "q-1 must be divisible by 2N");
689
690        let ctx = Ntt32Context::new(n, q);
691        let original = make_test_data(n, q);
692        let mut data = original.clone();
693
694        ctx.forward(&mut data);
695        ctx.inverse(&mut data);
696        assert_eq!(data, original, "ML-KEM proxy roundtrip failed");
697    }
698
699    #[test]
700    fn test_pq_mlkem_negacyclic_mul() {
701        let q: u32 = 3_329;
702        let n = 128;
703        let ctx = Ntt32Context::new(n, q);
704
705        // Verify against naive O(N²) multiplication
706        let a: Vec<u32> = (0..n).map(|i| (i as u32 + 1) % q).collect();
707        let b: Vec<u32> = vec![1u32; n];
708
709        let mut expected = vec![0u32; n];
710        for i in 0..n {
711            for j in 0..n {
712                let prod = (a[i] as u64 * b[j] as u64) % q as u64;
713                if i + j < n {
714                    expected[i + j] = ((expected[i + j] as u64 + prod) % q as u64) as u32;
715                } else {
716                    let idx = i + j - n;
717                    expected[idx] = ((expected[idx] as u64 + q as u64 - prod) % q as u64) as u32;
718                }
719            }
720        }
721
722        let result = ctx.negacyclic_mul(&a, &b);
723        assert_eq!(
724            result, expected,
725            "ML-KEM negacyclic multiplication mismatch"
726        );
727    }
728
729    /// Exhaustive NEON-vs-scalar regression test.
730    ///
731    /// Validates that the NEON NTT path produces identical results to the
732    /// scalar path for all sizes and representative primes. This is the
733    /// permanent canary for the vqdmulhq_s32 overflow bug (N≥16384, q~2^28).
734    #[test]
735    fn test_neon_vs_scalar_exhaustive() {
736        use super::super::prime::generate_primes_28;
737
738        let sizes = [256, 1024, 4096, 8192, 16384, 32768];
739        let num_primes = 10;
740
741        for &n in &sizes {
742            let primes = generate_primes_28(n, num_primes);
743            for &q in &primes {
744                let ctx = super::Ntt32Context::new(n, q);
745
746                // Test data: deterministic pseudo-random values in [0, q)
747                let data: Vec<u32> = (0..n)
748                    .map(|i| ((i as u64 * 7 + 13) % q as u64) as u32)
749                    .collect();
750
751                // Forward NTT
752                let mut neon_fwd = data.clone();
753                let mut scalar_fwd = data.clone();
754                ctx.forward(&mut neon_fwd);
755                super::super::scalar::ntt_forward_scalar(&mut scalar_fwd, &ctx);
756                assert_eq!(
757                    neon_fwd, scalar_fwd,
758                    "NEON vs scalar FORWARD mismatch: N={n}, q={q}"
759                );
760
761                // Inverse NTT
762                let mut neon_inv = neon_fwd.clone();
763                let mut scalar_inv = scalar_fwd.clone();
764                ctx.inverse(&mut neon_inv);
765                super::super::scalar::ntt_inverse_scalar(&mut scalar_inv, &ctx);
766                assert_eq!(
767                    neon_inv, scalar_inv,
768                    "NEON vs scalar INVERSE mismatch: N={n}, q={q}"
769                );
770
771                // Roundtrip: should recover original data
772                for i in 0..n {
773                    assert_eq!(
774                        neon_inv[i] % q, data[i] % q,
775                        "Roundtrip mismatch at index {i}: N={n}, q={q}"
776                    );
777                }
778            }
779        }
780    }
781
782    // Compile-time check: Ntt32Context must be Send + Sync
783    // (required for safe sharing across threads in crypto applications)
784    const _: () = {
785        fn assert_send<T: Send>() {}
786        fn assert_sync<T: Sync>() {}
787        fn check() {
788            assert_send::<super::Ntt32Context>();
789            assert_sync::<super::Ntt32Context>();
790        }
791    };
792}