Skip to main content

mlkem/
lib.rs

1//! # mlkem
2//!
3//! post-quantum key encapsulation per [fips 203][1] (ml-kem, formerly kyber),
4//! in pure rust. all three security levels.
5//!
6//! ## quick start
7//!
8//! ```
9//! use mlkem::MlKem768;
10//! use rand::thread_rng;
11//!
12//! let mut rng = thread_rng();
13//! let (pk, sk) = MlKem768::keygen(&mut rng);
14//! let (ct, ss_alice) = MlKem768::encapsulate(&pk, &mut rng);
15//! let ss_bob = MlKem768::decapsulate(&sk, &ct);
16//! assert_eq!(ss_alice.as_bytes(), ss_bob.as_bytes());
17//! ```
18//!
19//! ## variants
20//!
21//! - [`MlKem512`]: nist security category 1 (~ aes-128). pk 800, sk 1632, ct 768.
22//! - [`MlKem768`]: nist security category 3 (~ aes-192). pk 1184, sk 2400, ct 1088.
23//! - [`MlKem1024`]: nist security category 5 (~ aes-256). pk 1568, sk 3168, ct 1568.
24//!
25//! all three implement the [`Kem`] trait, so callers can be generic.
26//!
27//! ## features
28//!
29//! - `std` (default): enables `std::error::Error` impl on [`LengthError`] and uses
30//!   the std versions of the crypto deps.
31//! - `serde`: implements `Serialize` + `Deserialize` on every key, ciphertext,
32//!   and shared-secret newtype across all three parameter sets.
33//!
34//! ## correctness
35//!
36//! - all 180 official nist acvp test vectors pass byte-for-byte (75 keygen,
37//!   75 encapsulation, 30 decapsulation, distributed evenly across the three
38//!   parameter sets).
39//! - 3000-seed cross-check against the audited [`ml-kem`][2] crate.
40//! - 24000 stable-rust stress iterations on every `cargo test`.
41//! - cargo-fuzz harness in `fuzz/`.
42//!
43//! ## security and stability
44//!
45//! this crate is **not audited**. for production cryptography, use the
46//! audited [`ml-kem`][2] crate from rustcrypto. this implementation exists
47//! to be readable end-to-end, suitable for study, tooling, and tests.
48//!
49//! [1]: https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.203.pdf
50//! [2]: https://crates.io/crates/ml-kem
51
52#![cfg_attr(not(feature = "std"), no_std)]
53#![warn(clippy::all, clippy::pedantic)]
54#![warn(missing_debug_implementations)]
55// pedantic categories that fire repeatedly inside the algebraic / ntt code
56// where the cast / shadowing patterns are intentional. these are reviewed
57// case-by-case rather than chased project-wide.
58#![allow(clippy::needless_range_loop)]
59#![allow(clippy::cast_lossless)]
60#![allow(clippy::cast_possible_truncation)]
61#![allow(clippy::cast_sign_loss)]
62#![allow(clippy::cast_possible_wrap)]
63#![allow(clippy::unreadable_literal)]
64#![allow(clippy::similar_names)]
65#![allow(clippy::many_single_char_names)]
66#![allow(clippy::missing_errors_doc)]
67#![allow(clippy::missing_panics_doc)]
68#![allow(clippy::module_name_repetitions)]
69#![allow(clippy::must_use_candidate)]
70#![allow(clippy::doc_markdown)]
71#![allow(clippy::inline_always)]
72#![allow(clippy::items_after_statements)]
73
74extern crate alloc;
75
76mod compress;
77mod field;
78mod hash;
79mod kpke;
80mod mlkem;
81mod ntt;
82mod params;
83mod poly;
84mod sample;
85mod serialize;
86
87use rand_core::{CryptoRng, RngCore};
88use subtle::ConstantTimeEq;
89use zeroize::{Zeroize, ZeroizeOnDrop};
90
91pub use params::{Params, Params1024, Params512, Params768};
92pub use poly::MAX_K;
93
94/// returned when a slice handed to a `TryFrom` impl on a key, ciphertext, or
95/// shared secret newtype has the wrong length.
96#[derive(Clone, Copy, Debug, PartialEq, Eq)]
97pub struct LengthError {
98    pub expected: usize,
99    pub got: usize,
100}
101
102impl core::fmt::Display for LengthError {
103    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
104        write!(
105            f,
106            "wrong byte length: expected {}, got {}",
107            self.expected, self.got
108        )
109    }
110}
111
112#[cfg(feature = "std")]
113impl std::error::Error for LengthError {}
114
115/// generic interface implemented by `MlKem512`, `MlKem768` and `MlKem1024`.
116/// lets you write code that picks a parameter set at instantiation time.
117///
118/// ```
119/// use mlkem::{Kem, MlKem768};
120/// use rand::thread_rng;
121///
122/// fn handshake<K: Kem>() -> bool {
123///     let mut rng = thread_rng();
124///     let (pk, sk) = K::keygen(&mut rng);
125///     let (ct, ss_a) = K::encapsulate(&pk, &mut rng);
126///     let ss_b = K::decapsulate(&sk, &ct);
127///     ss_a.as_ref() == ss_b.as_ref()
128/// }
129///
130/// assert!(handshake::<MlKem768>());
131/// ```
132pub trait Kem {
133    type PublicKey: Clone + AsRef<[u8]>;
134    type SecretKey: Clone;
135    type Ciphertext: Clone + AsRef<[u8]>;
136    type SharedSecret: Clone + AsRef<[u8]>;
137
138    const PUBLIC_KEY_SIZE: usize;
139    const SECRET_KEY_SIZE: usize;
140    const CIPHERTEXT_SIZE: usize;
141    const SHARED_SECRET_SIZE: usize = 32;
142
143    fn keygen<R: RngCore + CryptoRng>(rng: &mut R) -> (Self::PublicKey, Self::SecretKey);
144    fn encapsulate<R: RngCore + CryptoRng>(
145        pk: &Self::PublicKey,
146        rng: &mut R,
147    ) -> (Self::Ciphertext, Self::SharedSecret);
148    fn decapsulate(sk: &Self::SecretKey, ct: &Self::Ciphertext) -> Self::SharedSecret;
149}
150
151// macro that defines a public api type for one parameter set.
152// `$name` is the entry point (MlKem512 etc), `$pk/$sk/$ct` are the byte sizes.
153macro_rules! mlkem_api {
154    ($name:ident, $params:ty, $pkty:ident, $skty:ident, $ctty:ident, $ssty:ident,
155     $pk:expr, $sk:expr, $ct:expr) => {
156        #[derive(Debug)]
157        pub struct $name;
158
159        impl $name {
160            pub const PUBLIC_KEY_SIZE: usize = $pk;
161            pub const SECRET_KEY_SIZE: usize = $sk;
162            pub const CIPHERTEXT_SIZE: usize = $ct;
163            pub const SHARED_SECRET_SIZE: usize = 32;
164
165            /// deterministic keygen from a 64-byte seed (d || z).
166            pub fn keygen_deterministic(seed: &[u8; 64]) -> ($pkty, $skty) {
167                let mut d = [0u8; 32];
168                let mut z = [0u8; 32];
169                d.copy_from_slice(&seed[..32]);
170                z.copy_from_slice(&seed[32..]);
171                let (pk, sk) = mlkem::MlKem::<$params>::keygen(&d, &z);
172                let mut pk_arr = [0u8; $pk];
173                let mut sk_arr = [0u8; $sk];
174                pk_arr.copy_from_slice(&pk);
175                sk_arr.copy_from_slice(&sk);
176                ($pkty(pk_arr), $skty(sk_arr))
177            }
178
179            pub fn keygen<R: RngCore + CryptoRng>(rng: &mut R) -> ($pkty, $skty) {
180                let mut seed = [0u8; 64];
181                rng.fill_bytes(&mut seed);
182                Self::keygen_deterministic(&seed)
183            }
184
185            pub fn encapsulate_deterministic(pk: &$pkty, m: &[u8; 32]) -> ($ctty, $ssty) {
186                let (ct, ss) = mlkem::MlKem::<$params>::encapsulate(&pk.0, m);
187                let mut ct_arr = [0u8; $ct];
188                ct_arr.copy_from_slice(&ct);
189                ($ctty(ct_arr), $ssty(ss))
190            }
191
192            pub fn encapsulate<R: RngCore + CryptoRng>(pk: &$pkty, rng: &mut R) -> ($ctty, $ssty) {
193                let mut m = [0u8; 32];
194                rng.fill_bytes(&mut m);
195                Self::encapsulate_deterministic(pk, &m)
196            }
197
198            pub fn decapsulate(sk: &$skty, ct: &$ctty) -> $ssty {
199                $ssty(mlkem::MlKem::<$params>::decapsulate(&sk.0, &ct.0))
200            }
201        }
202
203        #[derive(Clone)]
204        pub struct $pkty(pub(crate) [u8; $pk]);
205
206        #[derive(Clone, ZeroizeOnDrop)]
207        pub struct $skty(pub(crate) [u8; $sk]);
208
209        #[derive(Clone)]
210        pub struct $ctty(pub(crate) [u8; $ct]);
211
212        #[derive(Clone, ZeroizeOnDrop)]
213        pub struct $ssty(pub(crate) [u8; 32]);
214
215        #[cfg(feature = "serde")]
216        const _: () = {
217            use serde::de::{Error as DeError, SeqAccess, Visitor};
218            use serde::{Deserialize, Deserializer, Serialize, Serializer};
219
220            macro_rules! serde_byte_array {
221                ($t:ident, $n:expr) => {
222                    impl Serialize for $t {
223                        fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
224                            s.serialize_bytes(&self.0)
225                        }
226                    }
227                    impl<'de> Deserialize<'de> for $t {
228                        fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
229                            struct BytesVisitor;
230                            impl<'de> Visitor<'de> for BytesVisitor {
231                                type Value = [u8; $n];
232                                fn expecting(
233                                    &self,
234                                    f: &mut core::fmt::Formatter,
235                                ) -> core::fmt::Result {
236                                    write!(f, concat!("a byte sequence of length ", stringify!($n)))
237                                }
238                                fn visit_bytes<E: DeError>(self, v: &[u8]) -> Result<[u8; $n], E> {
239                                    if v.len() != $n {
240                                        return Err(E::invalid_length(v.len(), &self));
241                                    }
242                                    let mut a = [0u8; $n];
243                                    a.copy_from_slice(v);
244                                    Ok(a)
245                                }
246                                fn visit_seq<A: SeqAccess<'de>>(
247                                    self,
248                                    mut seq: A,
249                                ) -> Result<[u8; $n], A::Error> {
250                                    let mut a = [0u8; $n];
251                                    for i in 0..$n {
252                                        a[i] = seq
253                                            .next_element()?
254                                            .ok_or_else(|| A::Error::invalid_length(i, &self))?;
255                                    }
256                                    Ok(a)
257                                }
258                            }
259                            d.deserialize_bytes(BytesVisitor).map($t)
260                        }
261                    }
262                };
263            }
264
265            serde_byte_array!($pkty, $pk);
266            serde_byte_array!($skty, $sk);
267            serde_byte_array!($ctty, $ct);
268            serde_byte_array!($ssty, 32);
269        };
270
271        impl $pkty {
272            pub fn as_bytes(&self) -> &[u8; $pk] {
273                &self.0
274            }
275            pub fn from_bytes(b: &[u8; $pk]) -> Self {
276                Self(*b)
277            }
278        }
279        impl $skty {
280            pub fn as_bytes(&self) -> &[u8; $sk] {
281                &self.0
282            }
283            pub fn from_bytes(b: &[u8; $sk]) -> Self {
284                Self(*b)
285            }
286        }
287        impl $ctty {
288            pub fn as_bytes(&self) -> &[u8; $ct] {
289                &self.0
290            }
291            pub fn from_bytes(b: &[u8; $ct]) -> Self {
292                Self(*b)
293            }
294        }
295        impl $ssty {
296            pub fn as_bytes(&self) -> &[u8; 32] {
297                &self.0
298            }
299        }
300
301        impl PartialEq for $pkty {
302            fn eq(&self, other: &Self) -> bool {
303                self.0.ct_eq(&other.0).into()
304            }
305        }
306        impl Eq for $pkty {}
307        impl PartialEq for $skty {
308            fn eq(&self, other: &Self) -> bool {
309                self.0.as_slice().ct_eq(other.0.as_slice()).into()
310            }
311        }
312        impl Eq for $skty {}
313        impl PartialEq for $ctty {
314            fn eq(&self, other: &Self) -> bool {
315                self.0.ct_eq(&other.0).into()
316            }
317        }
318        impl Eq for $ctty {}
319        impl PartialEq for $ssty {
320            fn eq(&self, other: &Self) -> bool {
321                self.0.ct_eq(&other.0).into()
322            }
323        }
324        impl Eq for $ssty {}
325
326        impl core::fmt::Debug for $pkty {
327            fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
328                write!(
329                    f,
330                    concat!(stringify!($pkty), "(..{} bytes..)"),
331                    self.0.len()
332                )
333            }
334        }
335        impl core::fmt::Debug for $skty {
336            fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
337                write!(f, concat!(stringify!($skty), "(..REDACTED..)"))
338            }
339        }
340        impl core::fmt::Debug for $ctty {
341            fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
342                write!(
343                    f,
344                    concat!(stringify!($ctty), "(..{} bytes..)"),
345                    self.0.len()
346                )
347            }
348        }
349        impl core::fmt::Debug for $ssty {
350            fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
351                write!(f, concat!(stringify!($ssty), "(..REDACTED..)"))
352            }
353        }
354
355        impl Zeroize for $skty {
356            fn zeroize(&mut self) {
357                self.0.zeroize();
358            }
359        }
360        impl Zeroize for $ssty {
361            fn zeroize(&mut self) {
362                self.0.zeroize();
363            }
364        }
365
366        impl TryFrom<&[u8]> for $pkty {
367            type Error = LengthError;
368            fn try_from(b: &[u8]) -> Result<Self, LengthError> {
369                if b.len() != $pk {
370                    return Err(LengthError {
371                        expected: $pk,
372                        got: b.len(),
373                    });
374                }
375                let mut a = [0u8; $pk];
376                a.copy_from_slice(b);
377                Ok(Self(a))
378            }
379        }
380        impl TryFrom<&[u8]> for $skty {
381            type Error = LengthError;
382            fn try_from(b: &[u8]) -> Result<Self, LengthError> {
383                if b.len() != $sk {
384                    return Err(LengthError {
385                        expected: $sk,
386                        got: b.len(),
387                    });
388                }
389                let mut a = [0u8; $sk];
390                a.copy_from_slice(b);
391                Ok(Self(a))
392            }
393        }
394        impl TryFrom<&[u8]> for $ctty {
395            type Error = LengthError;
396            fn try_from(b: &[u8]) -> Result<Self, LengthError> {
397                if b.len() != $ct {
398                    return Err(LengthError {
399                        expected: $ct,
400                        got: b.len(),
401                    });
402                }
403                let mut a = [0u8; $ct];
404                a.copy_from_slice(b);
405                Ok(Self(a))
406            }
407        }
408
409        impl AsRef<[u8]> for $pkty {
410            fn as_ref(&self) -> &[u8] {
411                &self.0
412            }
413        }
414        impl AsRef<[u8]> for $ctty {
415            fn as_ref(&self) -> &[u8] {
416                &self.0
417            }
418        }
419        impl AsRef<[u8]> for $skty {
420            fn as_ref(&self) -> &[u8] {
421                &self.0
422            }
423        }
424        impl AsRef<[u8]> for $ssty {
425            fn as_ref(&self) -> &[u8] {
426                &self.0
427            }
428        }
429
430        impl Kem for $name {
431            type PublicKey = $pkty;
432            type SecretKey = $skty;
433            type Ciphertext = $ctty;
434            type SharedSecret = $ssty;
435            const PUBLIC_KEY_SIZE: usize = $pk;
436            const SECRET_KEY_SIZE: usize = $sk;
437            const CIPHERTEXT_SIZE: usize = $ct;
438
439            fn keygen<R: RngCore + CryptoRng>(rng: &mut R) -> ($pkty, $skty) {
440                <$name>::keygen(rng)
441            }
442            fn encapsulate<R: RngCore + CryptoRng>(pk: &$pkty, rng: &mut R) -> ($ctty, $ssty) {
443                <$name>::encapsulate(pk, rng)
444            }
445            fn decapsulate(sk: &$skty, ct: &$ctty) -> $ssty {
446                <$name>::decapsulate(sk, ct)
447            }
448        }
449    };
450}
451
452// ml-kem-512: pk 800, sk 1632, ct 768. fips 203 table 3, security category 1.
453mlkem_api!(
454    MlKem512,
455    Params512,
456    PublicKey512,
457    SecretKey512,
458    Ciphertext512,
459    SharedSecret512,
460    800,
461    1632,
462    768
463);
464
465// ml-kem-768: pk 1184, sk 2400, ct 1088. security category 3 (default if you must pick one).
466mlkem_api!(
467    MlKem768,
468    Params768,
469    PublicKey768,
470    SecretKey768,
471    Ciphertext768,
472    SharedSecret768,
473    1184,
474    2400,
475    1088
476);
477
478// ml-kem-1024: pk 1568, sk 3168, ct 1568. security category 5.
479mlkem_api!(
480    MlKem1024,
481    Params1024,
482    PublicKey1024,
483    SecretKey1024,
484    Ciphertext1024,
485    SharedSecret1024,
486    1568,
487    3168,
488    1568
489);
490
491// back-compat aliases for the old 0.1 api.
492pub type PublicKey = PublicKey768;
493pub type SecretKey = SecretKey768;
494pub type Ciphertext = Ciphertext768;
495pub type SharedSecret = SharedSecret768;