vaea_ntt/pq.rs
1// Copyright (C) 2024-2026 Vaea SAS
2// SPDX-License-Identifier: AGPL-3.0-or-later
3//
4// This file is part of VaeaNTT.
5//
6// VaeaNTT is free software: you can redistribute it and/or modify it under
7// the terms of the GNU Affero General Public License as published by the
8// Free Software Foundation, either version 3 of the License, or (at your
9// option) any later version.
10//
11// VaeaNTT is distributed in the hope that it will be useful, but WITHOUT
12// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13// FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public
14// License for more details.
15//
16// You should have received a copy of the GNU Affero General Public License
17// along with VaeaNTT. If not, see <https://www.gnu.org/licenses/>.
18
19//! # Post-Quantum Cryptography Presets
20//!
21//! Pre-configured NTT contexts for NIST post-quantum standards.
22//! One import, one line of code — instant access to ML-DSA and custom lattice schemes.
23//!
24//! ```
25//! use vaea_ntt::pq::{PqScheme, PqNtt};
26//!
27//! // ML-DSA-65 (NIST Level 3 digital signatures)
28//! let ntt = PqNtt::new(PqScheme::MlDsa65);
29//! let mut poly = vec![0u32; ntt.n()];
30//! poly[0] = 42;
31//! ntt.forward(&mut poly);
32//! ntt.inverse(&mut poly);
33//! assert_eq!(poly[0], 42);
34//! ```
35//!
36//! ## Why this matters
37//!
38//! Other NTT libraries are single-scheme:
39//! - `mlkem-native` → ML-KEM only (q=3329, int16, incomplete NTT)
40//! - `pqcrystals-dilithium` → ML-DSA only
41//! - SEAL/OpenFHE → FHE only, no ARM NEON
42//!
43//! **VaeaNTT covers ML-DSA + custom lattice + FHE with a single engine**, NEON-optimized.
44//!
45//! ## Supported schemes
46//!
47//! | Scheme | Standard | q | N | Notes |
48//! |--------|----------|---|---|-------|
49//! | ML-DSA-44 | NIST Standard | 8380417 | 256 | Full negacyclic NTT |
50//! | ML-DSA-65 | NIST Standard | 8380417 | 256 | Full negacyclic NTT |
51//! | ML-DSA-87 | NIST Standard | 8380417 | 256 | Full negacyclic NTT |
52//!
53//! ### ML-KEM Note
54//!
55//! ML-KEM uses q=3329 with N=256, but its NTT is an **incomplete NTT** (size-128
56//! NTT over coefficient pairs), not a standard size-256 negacyclic NTT. This is
57//! because q−1 = 3328 = 2⁸×13 only has a 256th root of unity, not a 512th.
58//! A dedicated ML-KEM module with incomplete NTT support is planned.
59
60use crate::ntt32::Ntt32Context;
61
62// ===========================================================================
63// PqScheme — Post-Quantum scheme selector
64// ===========================================================================
65
66/// NIST post-quantum cryptographic scheme.
67///
68/// Each variant fully specifies the NTT parameters (N, q) for a given
69/// standard, eliminating the risk of misconfiguration.
70#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
71pub enum PqScheme {
72 // ----- ML-DSA (NIST post-quantum signature standard) -----
73 /// ML-DSA-44 — NIST Level 2 (128-bit classical security)
74 ///
75 /// (k,l) = (4,4), N=256, q=8380417.
76 MlDsa44,
77
78 /// ML-DSA-65 — NIST Level 3 (192-bit classical security)
79 ///
80 /// (k,l) = (6,5), N=256, q=8380417.
81 MlDsa65,
82
83 /// ML-DSA-87 — NIST Level 5 (256-bit classical security)
84 ///
85 /// (k,l) = (8,7), N=256, q=8380417.
86 MlDsa87,
87}
88
89impl PqScheme {
90 /// Returns the polynomial degree N for this scheme.
91 #[inline]
92 pub const fn n(self) -> usize {
93 match self {
94 Self::MlDsa44 | Self::MlDsa65 | Self::MlDsa87 => 256,
95 }
96 }
97
98 /// Returns the prime modulus q for this scheme.
99 #[inline]
100 pub const fn q(self) -> u32 {
101 match self {
102 Self::MlDsa44 | Self::MlDsa65 | Self::MlDsa87 => 8380417,
103 }
104 }
105
106 /// Returns the module rank k (number of polynomials).
107 #[inline]
108 pub const fn k(self) -> usize {
109 match self {
110 Self::MlDsa44 => 4,
111 Self::MlDsa65 => 6,
112 Self::MlDsa87 => 8,
113 }
114 }
115
116 /// Returns the NIST security level (1–5).
117 #[inline]
118 pub const fn security_level(self) -> u8 {
119 match self {
120 Self::MlDsa44 => 2,
121 Self::MlDsa65 => 3,
122 Self::MlDsa87 => 5,
123 }
124 }
125
126 /// Returns a human-readable name for this scheme.
127 #[inline]
128 pub const fn name(self) -> &'static str {
129 match self {
130 Self::MlDsa44 => "ML-DSA-44",
131 Self::MlDsa65 => "ML-DSA-65",
132 Self::MlDsa87 => "ML-DSA-87",
133 }
134 }
135
136 /// Returns the NIST standard reference (e.g. "FIPS 204").
137 /// Note: this is a parameter reference, not a conformance or certification claim.
138 #[inline]
139 pub const fn fips(self) -> &'static str {
140 match self {
141 Self::MlDsa44 | Self::MlDsa65 | Self::MlDsa87 => "FIPS 204",
142 }
143 }
144}
145
146// ===========================================================================
147// PqNtt — Post-Quantum NTT engine
148// ===========================================================================
149
150/// A ready-to-use NTT engine configured for a specific post-quantum scheme.
151///
152/// Wraps [`Ntt32Context`] with scheme metadata for safety and convenience.
153///
154/// # Example
155///
156/// ```
157/// use vaea_ntt::pq::{PqScheme, PqNtt};
158///
159/// let ntt = PqNtt::new(PqScheme::MlDsa65);
160/// assert_eq!(ntt.scheme(), PqScheme::MlDsa65);
161/// assert_eq!(ntt.n(), 256);
162/// assert_eq!(ntt.q(), 8380417);
163///
164/// let mut data = vec![0u32; 256];
165/// data[0] = 1;
166/// ntt.forward(&mut data);
167/// ntt.inverse(&mut data);
168/// assert_eq!(data[0], 1);
169/// ```
170pub struct PqNtt {
171 /// The underlying NTT context.
172 ctx: Ntt32Context,
173 /// The scheme this context was created for.
174 scheme: PqScheme,
175}
176
177impl PqNtt {
178 /// Creates a new PQ-NTT engine for the given scheme.
179 ///
180 /// This precomputes all twiddle factors and modular arithmetic
181 /// constants. The context can be reused for multiple NTT calls.
182 #[inline]
183 pub fn new(scheme: PqScheme) -> Self {
184 let ctx = Ntt32Context::new(scheme.n(), scheme.q());
185 Self { ctx, scheme }
186 }
187
188 /// Returns the scheme this engine was configured for.
189 #[inline]
190 pub fn scheme(&self) -> PqScheme {
191 self.scheme
192 }
193
194 /// Returns the polynomial degree N.
195 #[inline]
196 pub fn n(&self) -> usize {
197 self.ctx.n
198 }
199
200 /// Returns the prime modulus q.
201 #[inline]
202 pub fn q(&self) -> u32 {
203 self.ctx.q
204 }
205
206 /// Returns the NIST security level.
207 #[inline]
208 pub fn security_level(&self) -> u8 {
209 self.scheme.security_level()
210 }
211
212 /// Returns a reference to the underlying [`Ntt32Context`].
213 #[inline]
214 pub fn context(&self) -> &Ntt32Context {
215 &self.ctx
216 }
217
218 /// Applies forward NTT in-place.
219 ///
220 /// Transforms from coefficient domain to evaluation (NTT) domain.
221 /// In NTT domain, polynomial multiplication is pointwise O(N).
222 ///
223 /// # Panics
224 /// If `data.len() != self.n()`.
225 #[inline]
226 pub fn forward(&self, data: &mut [u32]) {
227 self.ctx.forward(data);
228 }
229
230 /// Applies inverse NTT in-place.
231 ///
232 /// Transforms from evaluation (NTT) domain back to coefficient domain.
233 /// Includes the N⁻¹ normalization factor.
234 ///
235 /// # Panics
236 /// If `data.len() != self.n()`.
237 #[inline]
238 pub fn inverse(&self, data: &mut [u32]) {
239 self.ctx.inverse(data);
240 }
241
242 /// Computes negacyclic polynomial multiplication: `result = a × b mod (X^N + 1, q)`.
243 ///
244 /// Both inputs must be in coefficient domain (not NTT).
245 /// Result is in coefficient domain.
246 ///
247 /// # Panics
248 /// If `a.len() != self.n()` or `b.len() != self.n()`.
249 #[inline]
250 pub fn multiply(&self, a: &[u32], b: &[u32]) -> alloc::vec::Vec<u32> {
251 self.ctx.negacyclic_mul(a, b)
252 }
253
254 /// Computes negacyclic polynomial multiplication: `result = a × b mod (X^N + 1, q)`.
255 ///
256 /// Both `a` and `b` are consumed (transformed in-place as scratch space).
257 /// Result is written to `result`.
258 ///
259 /// # Panics
260 /// If any slice length != `self.n()`.
261 #[inline]
262 pub fn multiply_into(&self, a: &mut [u32], b: &mut [u32], result: &mut [u32]) {
263 self.ctx.negacyclic_mul_into(a, b, result);
264 }
265}
266
267impl core::fmt::Debug for PqNtt {
268 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
269 f.debug_struct("PqNtt")
270 .field("scheme", &self.scheme.name())
271 .field("n", &self.n())
272 .field("q", &self.q())
273 .field("security_level", &self.security_level())
274 .field("fips", &self.scheme.fips())
275 .finish()
276 }
277}
278
279// ===========================================================================
280// Tests
281// ===========================================================================
282
283#[cfg(test)]
284#[allow(unused_variables, clippy::needless_range_loop, dead_code)]
285mod tests {
286 use super::*;
287
288 #[test]
289 fn test_mldsa_44_roundtrip() {
290 let ntt = PqNtt::new(PqScheme::MlDsa44);
291 assert_eq!(ntt.q(), 8380417);
292 assert_eq!(ntt.n(), 256);
293 assert_eq!(ntt.security_level(), 2);
294
295 let mut data: alloc::vec::Vec<u32> = (0..256).map(|i| i * 1000 % 8380417).collect();
296 let original = data.clone();
297 ntt.forward(&mut data);
298 assert_ne!(data, original, "NTT forward did nothing");
299 ntt.inverse(&mut data);
300 assert_eq!(data, original);
301 }
302
303 #[test]
304 fn test_mldsa_65_roundtrip() {
305 let ntt = PqNtt::new(PqScheme::MlDsa65);
306 assert_eq!(ntt.security_level(), 3);
307
308 let mut data = alloc::vec![8380416u32; 256]; // q-1
309 let original = data.clone();
310 ntt.forward(&mut data);
311 ntt.inverse(&mut data);
312 assert_eq!(data, original);
313 }
314
315 #[test]
316 fn test_mldsa_87_roundtrip() {
317 let ntt = PqNtt::new(PqScheme::MlDsa87);
318 assert_eq!(ntt.security_level(), 5);
319
320 let mut data = alloc::vec![0u32; 256];
321 data[0] = 1;
322 let original = data.clone();
323 ntt.forward(&mut data);
324 ntt.inverse(&mut data);
325 assert_eq!(data, original);
326 }
327
328 #[test]
329 fn test_multiply() {
330 let ntt = PqNtt::new(PqScheme::MlDsa44);
331 let q = ntt.q();
332
333 // Multiply (1 + x) × (1 + x) mod (x^256 + 1, q)
334 // = 1 + 2x + x^2
335 let mut a = alloc::vec![0u32; 256];
336 a[0] = 1;
337 a[1] = 1;
338 let result = ntt.multiply(&a, &a);
339 assert_eq!(result[0], 1);
340 assert_eq!(result[1], 2);
341 assert_eq!(result[2], 1);
342 for i in 3..256 {
343 assert_eq!(result[i], 0, "unexpected non-zero at index {i}");
344 }
345 }
346
347 #[test]
348 fn test_scheme_metadata() {
349 assert_eq!(PqScheme::MlDsa44.name(), "ML-DSA-44");
350 assert_eq!(PqScheme::MlDsa44.fips(), "FIPS 204");
351 assert_eq!(PqScheme::MlDsa44.k(), 4);
352
353 assert_eq!(PqScheme::MlDsa65.name(), "ML-DSA-65");
354 assert_eq!(PqScheme::MlDsa65.k(), 6);
355
356 assert_eq!(PqScheme::MlDsa87.name(), "ML-DSA-87");
357 assert_eq!(PqScheme::MlDsa87.k(), 8);
358 }
359
360 #[test]
361 fn test_output_fully_reduced() {
362 let ntt = PqNtt::new(PqScheme::MlDsa65);
363 let mut data: alloc::vec::Vec<u32> = (0..ntt.n())
364 .map(|i| (i as u32 * 7 + 13) % ntt.q())
365 .collect();
366 ntt.forward(&mut data);
367 assert!(
368 data.iter().all(|&x| x < ntt.q()),
369 "Output not fully reduced for ML-DSA-65"
370 );
371 }
372
373 #[test]
374 fn test_debug_display() {
375 let ntt = PqNtt::new(PqScheme::MlDsa65);
376 let debug = alloc::format!("{:?}", ntt);
377 assert!(debug.contains("ML-DSA-65"));
378 assert!(debug.contains("FIPS 204"));
379 }
380}