vaea_ntt/pq.rs
1// Copyright (C) 2024-2026 Vaea SAS
2// SPDX-License-Identifier: AGPL-3.0-or-later
3//
4// This file is part of VaeaNTT.
5//
6// VaeaNTT is free software: you can redistribute it and/or modify it under
7// the terms of the GNU Affero General Public License as published by the
8// Free Software Foundation, either version 3 of the License, or (at your
9// option) any later version.
10//
11// VaeaNTT is distributed in the hope that it will be useful, but WITHOUT
12// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13// FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public
14// License for more details.
15//
16// You should have received a copy of the GNU Affero General Public License
17// along with VaeaNTT. If not, see <https://www.gnu.org/licenses/>.
18
19
20//! # Post-Quantum Cryptography Presets
21//!
22//! Pre-configured NTT contexts for NIST post-quantum standards.
23//! One import, one line of code — instant access to ML-DSA and custom lattice schemes.
24//!
25//! ```
26//! use vaea_ntt::pq::{PqScheme, PqNtt};
27//!
28//! // ML-DSA-65 (NIST Level 3 digital signatures)
29//! let ntt = PqNtt::new(PqScheme::MlDsa65);
30//! let mut poly = vec![0u32; ntt.n()];
31//! poly[0] = 42;
32//! ntt.forward(&mut poly);
33//! ntt.inverse(&mut poly);
34//! assert_eq!(poly[0], 42);
35//! ```
36//!
37//! ## Why this matters
38//!
39//! Other NTT libraries are single-scheme:
40//! - `mlkem-native` → ML-KEM only (q=3329, int16, incomplete NTT)
41//! - `pqcrystals-dilithium` → ML-DSA only
42//! - SEAL/OpenFHE → FHE only, no ARM NEON
43//!
44//! **VaeaNTT covers ML-DSA + custom lattice + FHE with a single engine**, NEON-optimized.
45//!
46//! ## Supported schemes
47//!
48//! | Scheme | Standard | q | N | Notes |
49//! |--------|----------|---|---|-------|
50//! | ML-DSA-44 | FIPS 204 | 8380417 | 256 | Full negacyclic NTT |
51//! | ML-DSA-65 | FIPS 204 | 8380417 | 256 | Full negacyclic NTT |
52//! | ML-DSA-87 | FIPS 204 | 8380417 | 256 | Full negacyclic NTT |
53//!
54//! ### ML-KEM Note
55//!
56//! ML-KEM uses q=3329 with N=256, but its NTT is an **incomplete NTT** (size-128
57//! NTT over coefficient pairs), not a standard size-256 negacyclic NTT. This is
58//! because q−1 = 3328 = 2⁸×13 only has a 256th root of unity, not a 512th.
59//! A dedicated ML-KEM module with incomplete NTT support is planned.
60
61use crate::ntt32::Ntt32Context;
62
63// ===========================================================================
64// PqScheme — Post-Quantum scheme selector
65// ===========================================================================
66
67/// NIST post-quantum cryptographic scheme.
68///
69/// Each variant fully specifies the NTT parameters (N, q) for a given
70/// standard, eliminating the risk of misconfiguration.
71#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
72pub enum PqScheme {
73 // ----- FIPS 204: ML-DSA (Module-Lattice Digital Signature) -----
74
75 /// ML-DSA-44 — NIST Level 2 (128-bit classical security)
76 ///
77 /// (k,l) = (4,4), N=256, q=8380417.
78 MlDsa44,
79
80 /// ML-DSA-65 — NIST Level 3 (192-bit classical security)
81 ///
82 /// (k,l) = (6,5), N=256, q=8380417.
83 MlDsa65,
84
85 /// ML-DSA-87 — NIST Level 5 (256-bit classical security)
86 ///
87 /// (k,l) = (8,7), N=256, q=8380417.
88 MlDsa87,
89}
90
91impl PqScheme {
92 /// Returns the polynomial degree N for this scheme.
93 #[inline]
94 pub const fn n(self) -> usize {
95 match self {
96 Self::MlDsa44 | Self::MlDsa65 | Self::MlDsa87 => 256,
97 }
98 }
99
100 /// Returns the prime modulus q for this scheme.
101 #[inline]
102 pub const fn q(self) -> u32 {
103 match self {
104 Self::MlDsa44 | Self::MlDsa65 | Self::MlDsa87 => 8380417,
105 }
106 }
107
108 /// Returns the module rank k (number of polynomials).
109 #[inline]
110 pub const fn k(self) -> usize {
111 match self {
112 Self::MlDsa44 => 4,
113 Self::MlDsa65 => 6,
114 Self::MlDsa87 => 8,
115 }
116 }
117
118 /// Returns the NIST security level (1–5).
119 #[inline]
120 pub const fn security_level(self) -> u8 {
121 match self {
122 Self::MlDsa44 => 2,
123 Self::MlDsa65 => 3,
124 Self::MlDsa87 => 5,
125 }
126 }
127
128 /// Returns a human-readable name for this scheme.
129 #[inline]
130 pub const fn name(self) -> &'static str {
131 match self {
132 Self::MlDsa44 => "ML-DSA-44",
133 Self::MlDsa65 => "ML-DSA-65",
134 Self::MlDsa87 => "ML-DSA-87",
135 }
136 }
137
138 /// Returns the NIST FIPS standard number.
139 #[inline]
140 pub const fn fips(self) -> &'static str {
141 match self {
142 Self::MlDsa44 | Self::MlDsa65 | Self::MlDsa87 => "FIPS 204",
143 }
144 }
145}
146
147// ===========================================================================
148// PqNtt — Post-Quantum NTT engine
149// ===========================================================================
150
151/// A ready-to-use NTT engine configured for a specific post-quantum scheme.
152///
153/// Wraps [`Ntt32Context`] with scheme metadata for safety and convenience.
154///
155/// # Example
156///
157/// ```
158/// use vaea_ntt::pq::{PqScheme, PqNtt};
159///
160/// let ntt = PqNtt::new(PqScheme::MlDsa65);
161/// assert_eq!(ntt.scheme(), PqScheme::MlDsa65);
162/// assert_eq!(ntt.n(), 256);
163/// assert_eq!(ntt.q(), 8380417);
164///
165/// let mut data = vec![0u32; 256];
166/// data[0] = 1;
167/// ntt.forward(&mut data);
168/// ntt.inverse(&mut data);
169/// assert_eq!(data[0], 1);
170/// ```
171pub struct PqNtt {
172 /// The underlying NTT context.
173 ctx: Ntt32Context,
174 /// The scheme this context was created for.
175 scheme: PqScheme,
176}
177
178impl PqNtt {
179 /// Creates a new PQ-NTT engine for the given scheme.
180 ///
181 /// This precomputes all twiddle factors and modular arithmetic
182 /// constants. The context can be reused for multiple NTT calls.
183 #[inline]
184 pub fn new(scheme: PqScheme) -> Self {
185 let ctx = Ntt32Context::new(scheme.n(), scheme.q());
186 Self { ctx, scheme }
187 }
188
189 /// Returns the scheme this engine was configured for.
190 #[inline]
191 pub fn scheme(&self) -> PqScheme {
192 self.scheme
193 }
194
195 /// Returns the polynomial degree N.
196 #[inline]
197 pub fn n(&self) -> usize {
198 self.ctx.n
199 }
200
201 /// Returns the prime modulus q.
202 #[inline]
203 pub fn q(&self) -> u32 {
204 self.ctx.q
205 }
206
207 /// Returns the NIST security level.
208 #[inline]
209 pub fn security_level(&self) -> u8 {
210 self.scheme.security_level()
211 }
212
213 /// Returns a reference to the underlying [`Ntt32Context`].
214 #[inline]
215 pub fn context(&self) -> &Ntt32Context {
216 &self.ctx
217 }
218
219 /// Applies forward NTT in-place.
220 ///
221 /// Transforms from coefficient domain to evaluation (NTT) domain.
222 /// In NTT domain, polynomial multiplication is pointwise O(N).
223 ///
224 /// # Panics
225 /// If `data.len() != self.n()`.
226 #[inline]
227 pub fn forward(&self, data: &mut [u32]) {
228 self.ctx.forward(data);
229 }
230
231 /// Applies inverse NTT in-place.
232 ///
233 /// Transforms from evaluation (NTT) domain back to coefficient domain.
234 /// Includes the N⁻¹ normalization factor.
235 ///
236 /// # Panics
237 /// If `data.len() != self.n()`.
238 #[inline]
239 pub fn inverse(&self, data: &mut [u32]) {
240 self.ctx.inverse(data);
241 }
242
243 /// Computes negacyclic polynomial multiplication: `result = a × b mod (X^N + 1, q)`.
244 ///
245 /// Both inputs must be in coefficient domain (not NTT).
246 /// Result is in coefficient domain.
247 ///
248 /// # Panics
249 /// If `a.len() != self.n()` or `b.len() != self.n()`.
250 #[inline]
251 pub fn multiply(&self, a: &[u32], b: &[u32]) -> alloc::vec::Vec<u32> {
252 self.ctx.negacyclic_mul(a, b)
253 }
254
255 /// Computes negacyclic polynomial multiplication: `result = a × b mod (X^N + 1, q)`.
256 ///
257 /// Both `a` and `b` are consumed (transformed in-place as scratch space).
258 /// Result is written to `result`.
259 ///
260 /// # Panics
261 /// If any slice length != `self.n()`.
262 #[inline]
263 pub fn multiply_into(&self, a: &mut [u32], b: &mut [u32], result: &mut [u32]) {
264 self.ctx.negacyclic_mul_into(a, b, result);
265 }
266}
267
268impl core::fmt::Debug for PqNtt {
269 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
270 f.debug_struct("PqNtt")
271 .field("scheme", &self.scheme.name())
272 .field("n", &self.n())
273 .field("q", &self.q())
274 .field("security_level", &self.security_level())
275 .field("fips", &self.scheme.fips())
276 .finish()
277 }
278}
279
280// ===========================================================================
281// Tests
282// ===========================================================================
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287
288 #[test]
289 fn test_mldsa_44_roundtrip() {
290 let ntt = PqNtt::new(PqScheme::MlDsa44);
291 assert_eq!(ntt.q(), 8380417);
292 assert_eq!(ntt.n(), 256);
293 assert_eq!(ntt.security_level(), 2);
294
295 let mut data: alloc::vec::Vec<u32> = (0..256).map(|i| i * 1000 % 8380417).collect();
296 let original = data.clone();
297 ntt.forward(&mut data);
298 assert_ne!(data, original, "NTT forward did nothing");
299 ntt.inverse(&mut data);
300 assert_eq!(data, original);
301 }
302
303 #[test]
304 fn test_mldsa_65_roundtrip() {
305 let ntt = PqNtt::new(PqScheme::MlDsa65);
306 assert_eq!(ntt.security_level(), 3);
307
308 let mut data = alloc::vec![8380416u32; 256]; // q-1
309 let original = data.clone();
310 ntt.forward(&mut data);
311 ntt.inverse(&mut data);
312 assert_eq!(data, original);
313 }
314
315 #[test]
316 fn test_mldsa_87_roundtrip() {
317 let ntt = PqNtt::new(PqScheme::MlDsa87);
318 assert_eq!(ntt.security_level(), 5);
319
320 let mut data = alloc::vec![0u32; 256];
321 data[0] = 1;
322 let original = data.clone();
323 ntt.forward(&mut data);
324 ntt.inverse(&mut data);
325 assert_eq!(data, original);
326 }
327
328 #[test]
329 fn test_multiply() {
330 let ntt = PqNtt::new(PqScheme::MlDsa44);
331 let q = ntt.q();
332
333 // Multiply (1 + x) × (1 + x) mod (x^256 + 1, q)
334 // = 1 + 2x + x^2
335 let mut a = alloc::vec![0u32; 256];
336 a[0] = 1;
337 a[1] = 1;
338 let result = ntt.multiply(&a, &a);
339 assert_eq!(result[0], 1);
340 assert_eq!(result[1], 2);
341 assert_eq!(result[2], 1);
342 for i in 3..256 {
343 assert_eq!(result[i], 0, "unexpected non-zero at index {i}");
344 }
345 }
346
347 #[test]
348 fn test_scheme_metadata() {
349 assert_eq!(PqScheme::MlDsa44.name(), "ML-DSA-44");
350 assert_eq!(PqScheme::MlDsa44.fips(), "FIPS 204");
351 assert_eq!(PqScheme::MlDsa44.k(), 4);
352
353 assert_eq!(PqScheme::MlDsa65.name(), "ML-DSA-65");
354 assert_eq!(PqScheme::MlDsa65.k(), 6);
355
356 assert_eq!(PqScheme::MlDsa87.name(), "ML-DSA-87");
357 assert_eq!(PqScheme::MlDsa87.k(), 8);
358 }
359
360 #[test]
361 fn test_output_fully_reduced() {
362 let ntt = PqNtt::new(PqScheme::MlDsa65);
363 let mut data: alloc::vec::Vec<u32> =
364 (0..ntt.n()).map(|i| (i as u32 * 7 + 13) % ntt.q()).collect();
365 ntt.forward(&mut data);
366 assert!(
367 data.iter().all(|&x| x < ntt.q()),
368 "Output not fully reduced for ML-DSA-65"
369 );
370 }
371
372 #[test]
373 fn test_debug_display() {
374 let ntt = PqNtt::new(PqScheme::MlDsa65);
375 let debug = alloc::format!("{:?}", ntt);
376 assert!(debug.contains("ML-DSA-65"));
377 assert!(debug.contains("FIPS 204"));
378 }
379}