Skip to main content

vaea_ntt/
pq.rs

1// Copyright (C) 2024-2026 Vaea SAS
2// SPDX-License-Identifier: AGPL-3.0-or-later
3//
4// This file is part of VaeaNTT.
5//
6// VaeaNTT is free software: you can redistribute it and/or modify it under
7// the terms of the GNU Affero General Public License as published by the
8// Free Software Foundation, either version 3 of the License, or (at your
9// option) any later version.
10//
11// VaeaNTT is distributed in the hope that it will be useful, but WITHOUT
12// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13// FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public
14// License for more details.
15//
16// You should have received a copy of the GNU Affero General Public License
17// along with VaeaNTT. If not, see <https://www.gnu.org/licenses/>.
18
19//! # Post-Quantum Cryptography Presets
20//!
21//! Pre-configured NTT contexts for NIST post-quantum standards.
22//! One import, one line of code — instant access to ML-DSA and custom lattice schemes.
23//!
24//! ```
25//! use vaea_ntt::pq::{PqScheme, PqNtt};
26//!
27//! // ML-DSA-65 (NIST Level 3 digital signatures)
28//! let ntt = PqNtt::new(PqScheme::MlDsa65);
29//! let mut poly = vec![0u32; ntt.n()];
30//! poly[0] = 42;
31//! ntt.forward(&mut poly);
32//! ntt.inverse(&mut poly);
33//! assert_eq!(poly[0], 42);
34//! ```
35//!
36//! ## Why this matters
37//!
38//! Other NTT libraries are single-scheme:
39//! - `mlkem-native` → ML-KEM only (q=3329, int16, incomplete NTT)
40//! - `pqcrystals-dilithium` → ML-DSA only
41//! - SEAL/OpenFHE → FHE only, no ARM NEON
42//!
43//! **VaeaNTT covers ML-DSA + custom lattice + FHE with a single engine**, NEON-optimized.
44//!
45//! ## Supported schemes
46//!
47//! | Scheme | Standard | q | N | Notes |
48//! |--------|----------|---|---|-------|
49//! | ML-DSA-44 | NIST Standard | 8380417 | 256 | Full negacyclic NTT |
50//! | ML-DSA-65 | NIST Standard | 8380417 | 256 | Full negacyclic NTT |
51//! | ML-DSA-87 | NIST Standard | 8380417 | 256 | Full negacyclic NTT |
52//!
53//! ### ML-KEM Note
54//!
55//! ML-KEM uses q=3329 with N=256, but its NTT is an **incomplete NTT** (size-128
56//! NTT over coefficient pairs), not a standard size-256 negacyclic NTT. This is
57//! because q−1 = 3328 = 2⁸×13 only has a 256th root of unity, not a 512th.
58//! A dedicated ML-KEM module with incomplete NTT support is planned.
59
60use crate::ntt32::Ntt32Context;
61
62// ===========================================================================
63// PqScheme — Post-Quantum scheme selector
64// ===========================================================================
65
66/// NIST post-quantum cryptographic scheme.
67///
68/// Each variant fully specifies the NTT parameters (N, q) for a given
69/// standard, eliminating the risk of misconfiguration.
70#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
71pub enum PqScheme {
72    // ----- ML-DSA (NIST post-quantum signature standard) -----
73    /// ML-DSA-44 — NIST Level 2 (128-bit classical security)
74    ///
75    /// (k,l) = (4,4), N=256, q=8380417.
76    MlDsa44,
77
78    /// ML-DSA-65 — NIST Level 3 (192-bit classical security)
79    ///
80    /// (k,l) = (6,5), N=256, q=8380417.
81    MlDsa65,
82
83    /// ML-DSA-87 — NIST Level 5 (256-bit classical security)
84    ///
85    /// (k,l) = (8,7), N=256, q=8380417.
86    MlDsa87,
87}
88
89impl PqScheme {
90    /// Returns the polynomial degree N for this scheme.
91    #[inline]
92    pub const fn n(self) -> usize {
93        match self {
94            Self::MlDsa44 | Self::MlDsa65 | Self::MlDsa87 => 256,
95        }
96    }
97
98    /// Returns the prime modulus q for this scheme.
99    #[inline]
100    pub const fn q(self) -> u32 {
101        match self {
102            Self::MlDsa44 | Self::MlDsa65 | Self::MlDsa87 => 8380417,
103        }
104    }
105
106    /// Returns the module rank k (number of polynomials).
107    #[inline]
108    pub const fn k(self) -> usize {
109        match self {
110            Self::MlDsa44 => 4,
111            Self::MlDsa65 => 6,
112            Self::MlDsa87 => 8,
113        }
114    }
115
116    /// Returns the NIST security level (1–5).
117    #[inline]
118    pub const fn security_level(self) -> u8 {
119        match self {
120            Self::MlDsa44 => 2,
121            Self::MlDsa65 => 3,
122            Self::MlDsa87 => 5,
123        }
124    }
125
126    /// Returns a human-readable name for this scheme.
127    #[inline]
128    pub const fn name(self) -> &'static str {
129        match self {
130            Self::MlDsa44 => "ML-DSA-44",
131            Self::MlDsa65 => "ML-DSA-65",
132            Self::MlDsa87 => "ML-DSA-87",
133        }
134    }
135
136    /// Returns the NIST standard reference (e.g. "FIPS 204").
137    /// Note: this is a parameter reference, not a conformance or certification claim.
138    #[inline]
139    pub const fn fips(self) -> &'static str {
140        match self {
141            Self::MlDsa44 | Self::MlDsa65 | Self::MlDsa87 => "FIPS 204",
142        }
143    }
144}
145
146// ===========================================================================
147// PqNtt — Post-Quantum NTT engine
148// ===========================================================================
149
150/// A ready-to-use NTT engine configured for a specific post-quantum scheme.
151///
152/// Wraps [`Ntt32Context`] with scheme metadata for safety and convenience.
153///
154/// # Example
155///
156/// ```
157/// use vaea_ntt::pq::{PqScheme, PqNtt};
158///
159/// let ntt = PqNtt::new(PqScheme::MlDsa65);
160/// assert_eq!(ntt.scheme(), PqScheme::MlDsa65);
161/// assert_eq!(ntt.n(), 256);
162/// assert_eq!(ntt.q(), 8380417);
163///
164/// let mut data = vec![0u32; 256];
165/// data[0] = 1;
166/// ntt.forward(&mut data);
167/// ntt.inverse(&mut data);
168/// assert_eq!(data[0], 1);
169/// ```
170pub struct PqNtt {
171    /// The underlying NTT context.
172    ctx: Ntt32Context,
173    /// The scheme this context was created for.
174    scheme: PqScheme,
175}
176
177impl PqNtt {
178    /// Creates a new PQ-NTT engine for the given scheme.
179    ///
180    /// This precomputes all twiddle factors and modular arithmetic
181    /// constants. The context can be reused for multiple NTT calls.
182    #[inline]
183    pub fn new(scheme: PqScheme) -> Self {
184        let ctx = Ntt32Context::new(scheme.n(), scheme.q());
185        Self { ctx, scheme }
186    }
187
188    /// Returns the scheme this engine was configured for.
189    #[inline]
190    pub fn scheme(&self) -> PqScheme {
191        self.scheme
192    }
193
194    /// Returns the polynomial degree N.
195    #[inline]
196    pub fn n(&self) -> usize {
197        self.ctx.n
198    }
199
200    /// Returns the prime modulus q.
201    #[inline]
202    pub fn q(&self) -> u32 {
203        self.ctx.q
204    }
205
206    /// Returns the NIST security level.
207    #[inline]
208    pub fn security_level(&self) -> u8 {
209        self.scheme.security_level()
210    }
211
212    /// Returns a reference to the underlying [`Ntt32Context`].
213    #[inline]
214    pub fn context(&self) -> &Ntt32Context {
215        &self.ctx
216    }
217
218    /// Applies forward NTT in-place.
219    ///
220    /// Transforms from coefficient domain to evaluation (NTT) domain.
221    /// In NTT domain, polynomial multiplication is pointwise O(N).
222    ///
223    /// # Panics
224    /// If `data.len() != self.n()`.
225    #[inline]
226    pub fn forward(&self, data: &mut [u32]) {
227        self.ctx.forward(data);
228    }
229
230    /// Applies inverse NTT in-place.
231    ///
232    /// Transforms from evaluation (NTT) domain back to coefficient domain.
233    /// Includes the N⁻¹ normalization factor.
234    ///
235    /// # Panics
236    /// If `data.len() != self.n()`.
237    #[inline]
238    pub fn inverse(&self, data: &mut [u32]) {
239        self.ctx.inverse(data);
240    }
241
242    /// Computes negacyclic polynomial multiplication: `result = a × b mod (X^N + 1, q)`.
243    ///
244    /// Both inputs must be in coefficient domain (not NTT).
245    /// Result is in coefficient domain.
246    ///
247    /// # Panics
248    /// If `a.len() != self.n()` or `b.len() != self.n()`.
249    #[inline]
250    pub fn multiply(&self, a: &[u32], b: &[u32]) -> alloc::vec::Vec<u32> {
251        self.ctx.negacyclic_mul(a, b)
252    }
253
254    /// Computes negacyclic polynomial multiplication: `result = a × b mod (X^N + 1, q)`.
255    ///
256    /// Both `a` and `b` are consumed (transformed in-place as scratch space).
257    /// Result is written to `result`.
258    ///
259    /// # Panics
260    /// If any slice length != `self.n()`.
261    #[inline]
262    pub fn multiply_into(&self, a: &mut [u32], b: &mut [u32], result: &mut [u32]) {
263        self.ctx.negacyclic_mul_into(a, b, result);
264    }
265}
266
267impl core::fmt::Debug for PqNtt {
268    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
269        f.debug_struct("PqNtt")
270            .field("scheme", &self.scheme.name())
271            .field("n", &self.n())
272            .field("q", &self.q())
273            .field("security_level", &self.security_level())
274            .field("fips", &self.scheme.fips())
275            .finish()
276    }
277}
278
279// ===========================================================================
280// Tests
281// ===========================================================================
282
283#[cfg(test)]
284#[allow(unused_variables, clippy::needless_range_loop, dead_code)]
285mod tests {
286    use super::*;
287
288    #[test]
289    fn test_mldsa_44_roundtrip() {
290        let ntt = PqNtt::new(PqScheme::MlDsa44);
291        assert_eq!(ntt.q(), 8380417);
292        assert_eq!(ntt.n(), 256);
293        assert_eq!(ntt.security_level(), 2);
294
295        let mut data: alloc::vec::Vec<u32> = (0..256).map(|i| i * 1000 % 8380417).collect();
296        let original = data.clone();
297        ntt.forward(&mut data);
298        assert_ne!(data, original, "NTT forward did nothing");
299        ntt.inverse(&mut data);
300        assert_eq!(data, original);
301    }
302
303    #[test]
304    fn test_mldsa_65_roundtrip() {
305        let ntt = PqNtt::new(PqScheme::MlDsa65);
306        assert_eq!(ntt.security_level(), 3);
307
308        let mut data = alloc::vec![8380416u32; 256]; // q-1
309        let original = data.clone();
310        ntt.forward(&mut data);
311        ntt.inverse(&mut data);
312        assert_eq!(data, original);
313    }
314
315    #[test]
316    fn test_mldsa_87_roundtrip() {
317        let ntt = PqNtt::new(PqScheme::MlDsa87);
318        assert_eq!(ntt.security_level(), 5);
319
320        let mut data = alloc::vec![0u32; 256];
321        data[0] = 1;
322        let original = data.clone();
323        ntt.forward(&mut data);
324        ntt.inverse(&mut data);
325        assert_eq!(data, original);
326    }
327
328    #[test]
329    fn test_multiply() {
330        let ntt = PqNtt::new(PqScheme::MlDsa44);
331        let q = ntt.q();
332
333        // Multiply (1 + x) × (1 + x) mod (x^256 + 1, q)
334        // = 1 + 2x + x^2
335        let mut a = alloc::vec![0u32; 256];
336        a[0] = 1;
337        a[1] = 1;
338        let result = ntt.multiply(&a, &a);
339        assert_eq!(result[0], 1);
340        assert_eq!(result[1], 2);
341        assert_eq!(result[2], 1);
342        for i in 3..256 {
343            assert_eq!(result[i], 0, "unexpected non-zero at index {i}");
344        }
345    }
346
347    #[test]
348    fn test_scheme_metadata() {
349        assert_eq!(PqScheme::MlDsa44.name(), "ML-DSA-44");
350        assert_eq!(PqScheme::MlDsa44.fips(), "FIPS 204");
351        assert_eq!(PqScheme::MlDsa44.k(), 4);
352
353        assert_eq!(PqScheme::MlDsa65.name(), "ML-DSA-65");
354        assert_eq!(PqScheme::MlDsa65.k(), 6);
355
356        assert_eq!(PqScheme::MlDsa87.name(), "ML-DSA-87");
357        assert_eq!(PqScheme::MlDsa87.k(), 8);
358    }
359
360    #[test]
361    fn test_output_fully_reduced() {
362        let ntt = PqNtt::new(PqScheme::MlDsa65);
363        let mut data: alloc::vec::Vec<u32> = (0..ntt.n())
364            .map(|i| (i as u32 * 7 + 13) % ntt.q())
365            .collect();
366        ntt.forward(&mut data);
367        assert!(
368            data.iter().all(|&x| x < ntt.q()),
369            "Output not fully reduced for ML-DSA-65"
370        );
371    }
372
373    #[test]
374    fn test_debug_display() {
375        let ntt = PqNtt::new(PqScheme::MlDsa65);
376        let debug = alloc::format!("{:?}", ntt);
377        assert!(debug.contains("ML-DSA-65"));
378        assert!(debug.contains("FIPS 204"));
379    }
380}