Skip to main content

vaea_ntt/
pq.rs

1// Copyright (C) 2024-2026 Vaea SAS
2// SPDX-License-Identifier: AGPL-3.0-or-later
3//
4// This file is part of VaeaNTT.
5//
6// VaeaNTT is free software: you can redistribute it and/or modify it under
7// the terms of the GNU Affero General Public License as published by the
8// Free Software Foundation, either version 3 of the License, or (at your
9// option) any later version.
10//
11// VaeaNTT is distributed in the hope that it will be useful, but WITHOUT
12// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13// FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public
14// License for more details.
15//
16// You should have received a copy of the GNU Affero General Public License
17// along with VaeaNTT. If not, see <https://www.gnu.org/licenses/>.
18
19
20//! # Post-Quantum Cryptography Presets
21//!
22//! Pre-configured NTT contexts for NIST post-quantum standards.
23//! One import, one line of code — instant access to ML-DSA and custom lattice schemes.
24//!
25//! ```
26//! use vaea_ntt::pq::{PqScheme, PqNtt};
27//!
28//! // ML-DSA-65 (NIST Level 3 digital signatures)
29//! let ntt = PqNtt::new(PqScheme::MlDsa65);
30//! let mut poly = vec![0u32; ntt.n()];
31//! poly[0] = 42;
32//! ntt.forward(&mut poly);
33//! ntt.inverse(&mut poly);
34//! assert_eq!(poly[0], 42);
35//! ```
36//!
37//! ## Why this matters
38//!
39//! Other NTT libraries are single-scheme:
40//! - `mlkem-native` → ML-KEM only (q=3329, int16, incomplete NTT)
41//! - `pqcrystals-dilithium` → ML-DSA only
42//! - SEAL/OpenFHE → FHE only, no ARM NEON
43//!
44//! **VaeaNTT covers ML-DSA + custom lattice + FHE with a single engine**, NEON-optimized.
45//!
46//! ## Supported schemes
47//!
48//! | Scheme | Standard | q | N | Notes |
49//! |--------|----------|---|---|-------|
50//! | ML-DSA-44 | FIPS 204 | 8380417 | 256 | Full negacyclic NTT |
51//! | ML-DSA-65 | FIPS 204 | 8380417 | 256 | Full negacyclic NTT |
52//! | ML-DSA-87 | FIPS 204 | 8380417 | 256 | Full negacyclic NTT |
53//!
54//! ### ML-KEM Note
55//!
56//! ML-KEM uses q=3329 with N=256, but its NTT is an **incomplete NTT** (size-128
57//! NTT over coefficient pairs), not a standard size-256 negacyclic NTT. This is
58//! because q−1 = 3328 = 2⁸×13 only has a 256th root of unity, not a 512th.
59//! A dedicated ML-KEM module with incomplete NTT support is planned.
60
61use crate::ntt32::Ntt32Context;
62
63// ===========================================================================
64// PqScheme — Post-Quantum scheme selector
65// ===========================================================================
66
67/// NIST post-quantum cryptographic scheme.
68///
69/// Each variant fully specifies the NTT parameters (N, q) for a given
70/// standard, eliminating the risk of misconfiguration.
71#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
72pub enum PqScheme {
73    // ----- FIPS 204: ML-DSA (Module-Lattice Digital Signature) -----
74
75    /// ML-DSA-44 — NIST Level 2 (128-bit classical security)
76    ///
77    /// (k,l) = (4,4), N=256, q=8380417.
78    MlDsa44,
79
80    /// ML-DSA-65 — NIST Level 3 (192-bit classical security)
81    ///
82    /// (k,l) = (6,5), N=256, q=8380417.
83    MlDsa65,
84
85    /// ML-DSA-87 — NIST Level 5 (256-bit classical security)
86    ///
87    /// (k,l) = (8,7), N=256, q=8380417.
88    MlDsa87,
89}
90
91impl PqScheme {
92    /// Returns the polynomial degree N for this scheme.
93    #[inline]
94    pub const fn n(self) -> usize {
95        match self {
96            Self::MlDsa44 | Self::MlDsa65 | Self::MlDsa87 => 256,
97        }
98    }
99
100    /// Returns the prime modulus q for this scheme.
101    #[inline]
102    pub const fn q(self) -> u32 {
103        match self {
104            Self::MlDsa44 | Self::MlDsa65 | Self::MlDsa87 => 8380417,
105        }
106    }
107
108    /// Returns the module rank k (number of polynomials).
109    #[inline]
110    pub const fn k(self) -> usize {
111        match self {
112            Self::MlDsa44 => 4,
113            Self::MlDsa65 => 6,
114            Self::MlDsa87 => 8,
115        }
116    }
117
118    /// Returns the NIST security level (1–5).
119    #[inline]
120    pub const fn security_level(self) -> u8 {
121        match self {
122            Self::MlDsa44 => 2,
123            Self::MlDsa65 => 3,
124            Self::MlDsa87 => 5,
125        }
126    }
127
128    /// Returns a human-readable name for this scheme.
129    #[inline]
130    pub const fn name(self) -> &'static str {
131        match self {
132            Self::MlDsa44 => "ML-DSA-44",
133            Self::MlDsa65 => "ML-DSA-65",
134            Self::MlDsa87 => "ML-DSA-87",
135        }
136    }
137
138    /// Returns the NIST FIPS standard number.
139    #[inline]
140    pub const fn fips(self) -> &'static str {
141        match self {
142            Self::MlDsa44 | Self::MlDsa65 | Self::MlDsa87 => "FIPS 204",
143        }
144    }
145}
146
147// ===========================================================================
148// PqNtt — Post-Quantum NTT engine
149// ===========================================================================
150
151/// A ready-to-use NTT engine configured for a specific post-quantum scheme.
152///
153/// Wraps [`Ntt32Context`] with scheme metadata for safety and convenience.
154///
155/// # Example
156///
157/// ```
158/// use vaea_ntt::pq::{PqScheme, PqNtt};
159///
160/// let ntt = PqNtt::new(PqScheme::MlDsa65);
161/// assert_eq!(ntt.scheme(), PqScheme::MlDsa65);
162/// assert_eq!(ntt.n(), 256);
163/// assert_eq!(ntt.q(), 8380417);
164///
165/// let mut data = vec![0u32; 256];
166/// data[0] = 1;
167/// ntt.forward(&mut data);
168/// ntt.inverse(&mut data);
169/// assert_eq!(data[0], 1);
170/// ```
171pub struct PqNtt {
172    /// The underlying NTT context.
173    ctx: Ntt32Context,
174    /// The scheme this context was created for.
175    scheme: PqScheme,
176}
177
178impl PqNtt {
179    /// Creates a new PQ-NTT engine for the given scheme.
180    ///
181    /// This precomputes all twiddle factors and modular arithmetic
182    /// constants. The context can be reused for multiple NTT calls.
183    #[inline]
184    pub fn new(scheme: PqScheme) -> Self {
185        let ctx = Ntt32Context::new(scheme.n(), scheme.q());
186        Self { ctx, scheme }
187    }
188
189    /// Returns the scheme this engine was configured for.
190    #[inline]
191    pub fn scheme(&self) -> PqScheme {
192        self.scheme
193    }
194
195    /// Returns the polynomial degree N.
196    #[inline]
197    pub fn n(&self) -> usize {
198        self.ctx.n
199    }
200
201    /// Returns the prime modulus q.
202    #[inline]
203    pub fn q(&self) -> u32 {
204        self.ctx.q
205    }
206
207    /// Returns the NIST security level.
208    #[inline]
209    pub fn security_level(&self) -> u8 {
210        self.scheme.security_level()
211    }
212
213    /// Returns a reference to the underlying [`Ntt32Context`].
214    #[inline]
215    pub fn context(&self) -> &Ntt32Context {
216        &self.ctx
217    }
218
219    /// Applies forward NTT in-place.
220    ///
221    /// Transforms from coefficient domain to evaluation (NTT) domain.
222    /// In NTT domain, polynomial multiplication is pointwise O(N).
223    ///
224    /// # Panics
225    /// If `data.len() != self.n()`.
226    #[inline]
227    pub fn forward(&self, data: &mut [u32]) {
228        self.ctx.forward(data);
229    }
230
231    /// Applies inverse NTT in-place.
232    ///
233    /// Transforms from evaluation (NTT) domain back to coefficient domain.
234    /// Includes the N⁻¹ normalization factor.
235    ///
236    /// # Panics
237    /// If `data.len() != self.n()`.
238    #[inline]
239    pub fn inverse(&self, data: &mut [u32]) {
240        self.ctx.inverse(data);
241    }
242
243    /// Computes negacyclic polynomial multiplication: `result = a × b mod (X^N + 1, q)`.
244    ///
245    /// Both inputs must be in coefficient domain (not NTT).
246    /// Result is in coefficient domain.
247    ///
248    /// # Panics
249    /// If `a.len() != self.n()` or `b.len() != self.n()`.
250    #[inline]
251    pub fn multiply(&self, a: &[u32], b: &[u32]) -> alloc::vec::Vec<u32> {
252        self.ctx.negacyclic_mul(a, b)
253    }
254
255    /// Computes negacyclic polynomial multiplication: `result = a × b mod (X^N + 1, q)`.
256    ///
257    /// Both `a` and `b` are consumed (transformed in-place as scratch space).
258    /// Result is written to `result`.
259    ///
260    /// # Panics
261    /// If any slice length != `self.n()`.
262    #[inline]
263    pub fn multiply_into(&self, a: &mut [u32], b: &mut [u32], result: &mut [u32]) {
264        self.ctx.negacyclic_mul_into(a, b, result);
265    }
266}
267
268impl core::fmt::Debug for PqNtt {
269    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
270        f.debug_struct("PqNtt")
271            .field("scheme", &self.scheme.name())
272            .field("n", &self.n())
273            .field("q", &self.q())
274            .field("security_level", &self.security_level())
275            .field("fips", &self.scheme.fips())
276            .finish()
277    }
278}
279
280// ===========================================================================
281// Tests
282// ===========================================================================
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287
288    #[test]
289    fn test_mldsa_44_roundtrip() {
290        let ntt = PqNtt::new(PqScheme::MlDsa44);
291        assert_eq!(ntt.q(), 8380417);
292        assert_eq!(ntt.n(), 256);
293        assert_eq!(ntt.security_level(), 2);
294
295        let mut data: alloc::vec::Vec<u32> = (0..256).map(|i| i * 1000 % 8380417).collect();
296        let original = data.clone();
297        ntt.forward(&mut data);
298        assert_ne!(data, original, "NTT forward did nothing");
299        ntt.inverse(&mut data);
300        assert_eq!(data, original);
301    }
302
303    #[test]
304    fn test_mldsa_65_roundtrip() {
305        let ntt = PqNtt::new(PqScheme::MlDsa65);
306        assert_eq!(ntt.security_level(), 3);
307
308        let mut data = alloc::vec![8380416u32; 256]; // q-1
309        let original = data.clone();
310        ntt.forward(&mut data);
311        ntt.inverse(&mut data);
312        assert_eq!(data, original);
313    }
314
315    #[test]
316    fn test_mldsa_87_roundtrip() {
317        let ntt = PqNtt::new(PqScheme::MlDsa87);
318        assert_eq!(ntt.security_level(), 5);
319
320        let mut data = alloc::vec![0u32; 256];
321        data[0] = 1;
322        let original = data.clone();
323        ntt.forward(&mut data);
324        ntt.inverse(&mut data);
325        assert_eq!(data, original);
326    }
327
328    #[test]
329    fn test_multiply() {
330        let ntt = PqNtt::new(PqScheme::MlDsa44);
331        let q = ntt.q();
332
333        // Multiply (1 + x) × (1 + x) mod (x^256 + 1, q)
334        // = 1 + 2x + x^2
335        let mut a = alloc::vec![0u32; 256];
336        a[0] = 1;
337        a[1] = 1;
338        let result = ntt.multiply(&a, &a);
339        assert_eq!(result[0], 1);
340        assert_eq!(result[1], 2);
341        assert_eq!(result[2], 1);
342        for i in 3..256 {
343            assert_eq!(result[i], 0, "unexpected non-zero at index {i}");
344        }
345    }
346
347    #[test]
348    fn test_scheme_metadata() {
349        assert_eq!(PqScheme::MlDsa44.name(), "ML-DSA-44");
350        assert_eq!(PqScheme::MlDsa44.fips(), "FIPS 204");
351        assert_eq!(PqScheme::MlDsa44.k(), 4);
352
353        assert_eq!(PqScheme::MlDsa65.name(), "ML-DSA-65");
354        assert_eq!(PqScheme::MlDsa65.k(), 6);
355
356        assert_eq!(PqScheme::MlDsa87.name(), "ML-DSA-87");
357        assert_eq!(PqScheme::MlDsa87.k(), 8);
358    }
359
360    #[test]
361    fn test_output_fully_reduced() {
362        let ntt = PqNtt::new(PqScheme::MlDsa65);
363        let mut data: alloc::vec::Vec<u32> =
364            (0..ntt.n()).map(|i| (i as u32 * 7 + 13) % ntt.q()).collect();
365        ntt.forward(&mut data);
366        assert!(
367            data.iter().all(|&x| x < ntt.q()),
368            "Output not fully reduced for ML-DSA-65"
369        );
370    }
371
372    #[test]
373    fn test_debug_display() {
374        let ntt = PqNtt::new(PqScheme::MlDsa65);
375        let debug = alloc::format!("{:?}", ntt);
376        assert!(debug.contains("ML-DSA-65"));
377        assert!(debug.contains("FIPS 204"));
378    }
379}