orion/hazardous/kem/ml_kem/
mlkem768.rs

1// MIT License
2
3// Copyright (c) 2025 The orion Developers
4
5// Permission is hereby granted, free of charge, to any person obtaining a copy
6// of this software and associated documentation files (the "Software"), to deal
7// in the Software without restriction, including without limitation the rights
8// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9// copies of the Software, and to permit persons to whom the Software is
10// furnished to do so, subject to the following conditions:
11
12// The above copyright notice and this permission notice shall be included in
13// all copies or substantial portions of the Software.
14
15// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21// SOFTWARE.
22
23//! ### ML-KEM key usage recommendations
24//!
25//! In general, it is highly recommended to use the [`KeyPair`] type to deal with decapsulating operations, or decapsulation keys in general.
26//!
27//! A [`KeyPair`] requires, or automatically generates, a [`Seed`]. It cannot be made solely from encoded/serialized decapsulation key in bytes, unless a [`Seed`] is also provided.
28//!
29//! A seed is only 64 bytes, is fully FIPS compliant, and hardens against attacks described [here](https://eprint.iacr.org/2024/523).
30//!
31//! #### Serialized decapsulation keys
32//! It is possible to instantiate a [`DecapsulationKey`] directly, if strictly required, using [`DecapsulationKey::unchecked_from_slice()`].
33//!
34//! # Parameters:
35//! - `ek`: The public encapsulation key, for which a shared secret and ciphertext is generated.
36//! - `dk`: The secret decapsulation key, for which a ciphertext is used to derive a shared secret.
37//! - `c`: The public ciphertext, sent to the decapsulating party.
38//! - `m`: Explicit randomness used for encapsulation.
39//!
40//! # Errors:
41//! An error will be returned if:
42//! - [`getrandom::fill()`] fails during encapsulation.
43//! - `m` is not 32 bytes.
44//!
45//! # Panics:
46//! A panic will occur if:
47//! - [`getrandom::fill()`] fails during [`KeyPair::generate()`].
48//!
49//! # Security:
50//! - It is critical that both the seed and explicit randomness `m`, used for key generation and encapsulation
51//! are generated using a strong CSPRNG.
52//! - Users should always prefer encapsulation without specifying explicit randomness, if possible. `encap_deterministic()`
53//! exists mainly for `no_std` usage.
54//! - Prefer using [`KeyPair`] to create and use ML-KEM keys, which is MAL-BIND-K-CT secure.
55//!
56//! # Example:
57//! ```rust
58//! # #[cfg(feature = "safe_api")] {
59//! use orion::hazardous::kem::mlkem768::*;
60//!
61//! let keypair = KeyPair::generate()?;
62//!
63//! let (sender_shared_secret, sender_ciphertext) = MlKem768::encap(keypair.public())?;
64//! let recipient_shared_secret = MlKem768::decap(keypair.private(), &sender_ciphertext)?;
65//!
66//! assert_eq!(sender_shared_secret, recipient_shared_secret);
67//! # }
68//! # Ok::<(), orion::errors::UnknownCryptoError>(())
69//! ```
70//! [`getrandom::fill()`]: getrandom::fill
71//! [`encap()`]: mlkem768::MlKem768::encap
72//! [`decap()`]: mlkem768::MlKem768::decap
73//! [`KeyPair::generate()`]: mlkem768::KeyPair::generate
74//! [`KeyPair`]: mlkem768::KeyPair
75//! [`Seed`]: mlkem768::Seed
76//! [`DecapsulationKey`]: mlkem768::DecapsulationKey
77//! [`DecapsulationKey::unchecked_from_slice()`]:  mlkem768::DecapsulationKey::unchecked_from_slice
78
79use crate::errors::UnknownCryptoError;
80use crate::hazardous::kem::ml_kem::internal::*;
81pub use crate::hazardous::kem::ml_kem::Seed;
82use zeroize::Zeroize;
83
84construct_secret_key! {
85    /// A type to represent the `SharedSecret` that ML-KEM-768 produces.
86    ///
87    /// This type simply holds bytes. Creating an instance from slices or similar,
88    /// performs no checks whatsoever.
89    ///
90    /// # Errors:
91    /// An error will be returned if:
92    /// - `slice` is not 32 bytes.
93    (SharedSecret, test_shared_key, MlKem768Internal::SHARED_SECRET_SIZE, MlKem768Internal::SHARED_SECRET_SIZE)
94}
95
96impl_from_trait!(SharedSecret, MlKem768Internal::SHARED_SECRET_SIZE);
97
98construct_public! {
99    /// A type to represent the KEM `Ciphertext` that ML-KEM-768 returns.
100    ///
101    /// # Errors:
102    /// An error will be returned if:
103    /// - `slice` is not 1088 bytes.
104    (Ciphertext, test_kem_ciphertext, MlKem768Internal::CIPHERTEXT_SIZE, MlKem768Internal::CIPHERTEXT_SIZE)
105}
106
107impl_from_trait!(Ciphertext, MlKem768Internal::CIPHERTEXT_SIZE);
108
109#[derive(Debug, PartialEq)]
110/// A keypair of ML-KEM-768 keys, that are derived from a given seed.
111pub struct KeyPair {
112    seed: Seed,
113    dk: DecapsulationKey,
114}
115
116impl KeyPair {
117    #[cfg(feature = "safe_api")]
118    #[cfg_attr(docsrs, doc(cfg(feature = "safe_api")))]
119    /// Generate a fresh [KeyPair].
120    pub fn generate() -> Result<Self, UnknownCryptoError> {
121        let seed = Seed::generate();
122        let (ek, dk) = KeyPairInternal::<MlKem768Internal>::from_seed::<3, 1184, 2400>(&seed)?;
123
124        Ok(Self {
125            seed,
126            dk: DecapsulationKey {
127                value: dk,
128                cached_ek: EncapsulationKey { value: ek },
129            },
130        })
131    }
132
133    #[cfg(feature = "safe_api")]
134    #[cfg_attr(docsrs, doc(cfg(feature = "safe_api")))]
135    /// Instantiate a [KeyPair] with all key validation checks, described
136    /// in FIPS-203, Section 7.1, 7.2 and 7.3.
137    ///
138    /// The output keypair is the equivalent of using `KeyPair::try_from(seed: &Seed)`, but this
139    /// can be used, in order to check whether a decapsulation key
140    /// is valid in relation to the `seed` provided.
141    pub fn from_keys(seed: &Seed, dk: &DecapsulationKey) -> Result<Self, UnknownCryptoError> {
142        let unchecked_ek = EncapsulationKey::try_from(dk)?;
143        let (ek, dk) = KeyPairInternal::<MlKem768Internal>::from_keys::<3, 1184, 2400, 1088>(
144            seed,
145            &unchecked_ek.value,
146            &dk.value,
147        )?;
148
149        Ok(Self {
150            seed: Seed::from_slice(seed.unprotected_as_bytes()).unwrap(),
151            dk: DecapsulationKey {
152                value: dk,
153                cached_ek: EncapsulationKey { value: ek },
154            },
155        })
156    }
157
158    /// Get the [Seed] used to generate this keypair. Use this function in order to store
159    /// the private part of the keypair and regenerate it, when needed.
160    pub fn seed(&self) -> &Seed {
161        &self.seed
162    }
163
164    /// Get the public [EncapsulationKey] corresponding to this keypair.
165    pub fn public(&self) -> &EncapsulationKey {
166        &self.dk.cached_ek
167    }
168
169    /// Get the private [DecapsulationKey] used to generate this keypair. In order to store the private
170    /// part of this [KeyPair], use [KeyPair::seed()] instead.
171    pub fn private(&self) -> &DecapsulationKey {
172        &self.dk
173    }
174}
175
176impl TryFrom<&Seed> for KeyPair {
177    type Error = UnknownCryptoError;
178
179    fn try_from(value: &Seed) -> Result<Self, Self::Error> {
180        let (ek, dk) = KeyPairInternal::<MlKem768Internal>::from_seed::<3, 1184, 2400>(value)?;
181
182        Ok(Self {
183            seed: Seed::from_slice(value.unprotected_as_bytes()).unwrap(),
184            dk: DecapsulationKey {
185                value: dk,
186                cached_ek: EncapsulationKey { value: ek },
187            },
188        })
189    }
190}
191
192#[derive(Debug, PartialEq)]
193/// A type to represent the `DecapsulationKey` that ML-KEM-768 produces.
194pub struct DecapsulationKey {
195    pub(crate) value: DecapKey<3, 1184, 2400, MlKem768Internal>,
196    // NOTE(brycx): This is simply a cache of the encapsulation key, so we avoid recomputing it
197    // on decap() operations. This is not a part of PartialEq, AsRef<> implementations or other logic
198    // pertaining to the `DecapsulationKey`, serving a purely internal purpose.
199    pub(crate) cached_ek: EncapsulationKey,
200}
201
202impl PartialEq<&[u8]> for DecapsulationKey {
203    fn eq(&self, other: &&[u8]) -> bool {
204        // Defer to DecapKey<> impl ct-eq
205        self.value == *other
206    }
207}
208
209impl DecapsulationKey {
210    /// Instantiate a [DecapsulationKey] with only key-checks from FIPS-203, section 7.3. Not MAL-BIND-K-CT secure.
211    pub fn unchecked_from_slice(slice: &[u8]) -> Result<Self, UnknownCryptoError> {
212        let dk_unchecked =
213            DecapKey::<3, 1184, 2400, MlKem768Internal>::unchecked_from_slice(slice)?;
214        let ek_unchecked =
215            EncapsulationKey::from_slice(dk_unchecked.get_encapsulation_key_bytes())?;
216
217        Ok(Self {
218            value: dk_unchecked,
219            cached_ek: ek_unchecked,
220        })
221    }
222
223    /// Perform decapsulation of a [Ciphertext].
224    pub fn decap(&self, c: &Ciphertext) -> Result<SharedSecret, UnknownCryptoError> {
225        let mut c_prime_buf = [0u8; MlKem768Internal::CIPHERTEXT_SIZE];
226        let mut k_internal = self.value.mlkem_decap_internal_with_ek(
227            c.as_ref(),
228            &mut c_prime_buf,
229            &self.cached_ek.value,
230        )?;
231        let k = SharedSecret::from_slice(&k_internal)?;
232        k_internal.zeroize();
233
234        Ok(k)
235    }
236}
237
238#[derive(Debug, PartialEq, Clone)]
239/// A type to represent the `EncapsulationKey` that ML-KEM-768 returns.
240pub struct EncapsulationKey {
241    pub(crate) value: EncapKey<3, 1184, MlKem768Internal>,
242}
243
244impl PartialEq<&[u8]> for EncapsulationKey {
245    fn eq(&self, other: &&[u8]) -> bool {
246        self.value == *other
247    }
248}
249
250impl TryFrom<&DecapsulationKey> for EncapsulationKey {
251    type Error = UnknownCryptoError;
252
253    fn try_from(value: &DecapsulationKey) -> Result<Self, Self::Error> {
254        Ok(Self {
255            value: EncapKey::<3, 1184, MlKem768Internal>::from_slice(
256                value.value.get_encapsulation_key_bytes(),
257            )?,
258        })
259    }
260}
261
262impl AsRef<[u8]> for EncapsulationKey {
263    fn as_ref(&self) -> &[u8] {
264        self.value.as_ref()
265    }
266}
267
268impl EncapsulationKey {
269    /// Instantiate a [EncapsulationKey] with key-checks from FIPS-203, section 7.2.
270    pub fn from_slice(slice: &[u8]) -> Result<Self, UnknownCryptoError> {
271        Ok(Self {
272            value: EncapKey::<3, 1184, MlKem768Internal>::from_slice(slice)?,
273        })
274    }
275
276    #[cfg(feature = "safe_api")]
277    #[cfg_attr(docsrs, doc(cfg(feature = "safe_api")))]
278    /// Given the [EncapsulationKey], generate a [SharedSecret] and associated [Ciphertext].
279    pub fn encap(&self) -> Result<(SharedSecret, Ciphertext), UnknownCryptoError> {
280        use zeroize::Zeroizing;
281
282        let mut m = Zeroizing::new([0u8; 32]);
283        getrandom::fill(m.as_mut())?;
284
285        self.encap_deterministic(m.as_ref())
286    }
287
288    /// Given the [EncapsulationKey] and randomness `m`, generate a [SharedSecret] and associated [Ciphertext].
289    pub fn encap_deterministic(
290        &self,
291        m: &[u8],
292    ) -> Result<(SharedSecret, Ciphertext), UnknownCryptoError> {
293        if m.len() != 32 {
294            return Err(UnknownCryptoError);
295        }
296
297        let mut c = Ciphertext::from_slice(&[0u8; MlKem768Internal::CIPHERTEXT_SIZE])?;
298        let mut k_internal = self.value.mlkem_encap_internal(m.as_ref(), &mut c.value)?;
299        let k = SharedSecret::from_slice(k_internal.as_slice())?;
300        k_internal.zeroize();
301
302        Ok((k, c))
303    }
304}
305
306#[derive(PartialEq, Debug)]
307/// ML-KEM-768.
308pub struct MlKem768;
309
310impl MlKem768 {
311    /// Encapsulation key size (bytes).
312    pub const EK_SIZE: usize = MlKem768Internal::EK_SIZE;
313    /// Decapsulation key size (bytes).
314    pub const DK_SIZE: usize = MlKem768Internal::DK_SIZE;
315    /// Ciphertext size (bytes).
316    pub const CIPHERTEXT_SIZE: usize = MlKem768Internal::CIPHERTEXT_SIZE;
317    /// Shared Secret size (bytes).
318    pub const SHARED_SECRET_SIZE: usize = MlKem768Internal::SHARED_SECRET_SIZE;
319
320    #[cfg(feature = "safe_api")]
321    #[cfg_attr(docsrs, doc(cfg(feature = "safe_api")))]
322    /// Given the [EncapsulationKey], generate a [SharedSecret] and associated [Ciphertext].
323    pub fn encap(ek: &EncapsulationKey) -> Result<(SharedSecret, Ciphertext), UnknownCryptoError> {
324        ek.encap()
325    }
326
327    /// Given the [DecapsulationKey], produce a [SharedSecret] using the [Ciphertext].
328    pub fn decap(
329        dk: &DecapsulationKey,
330        c: &Ciphertext,
331    ) -> Result<SharedSecret, UnknownCryptoError> {
332        dk.decap(c)
333    }
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339
340    #[cfg(feature = "safe_api")]
341    use crate::test_framework::kem_interface::{KemTester, TestableKem};
342
343    #[cfg(feature = "safe_api")]
344    impl TestableKem<SharedSecret, Ciphertext> for MlKem768 {
345        fn keygen(seed: &[u8]) -> Result<(Vec<u8>, Vec<u8>), UnknownCryptoError> {
346            let kp = KeyPair::try_from(&Seed::from_slice(seed).unwrap()).unwrap();
347
348            Ok((
349                kp.dk.cached_ek.as_ref().to_vec(),
350                kp.dk.value.unprotected_as_bytes().to_vec(),
351            ))
352        }
353
354        fn ciphertext_from_bytes(b: &[u8]) -> Result<Ciphertext, UnknownCryptoError> {
355            Ciphertext::from_slice(b)
356        }
357
358        fn encap(ek: &[u8]) -> Result<(SharedSecret, Ciphertext), UnknownCryptoError> {
359            let ek = EncapsulationKey::from_slice(ek).unwrap();
360            ek.encap()
361        }
362
363        fn decap(dk: &[u8], c: &Ciphertext) -> Result<SharedSecret, UnknownCryptoError> {
364            let dk = DecapsulationKey::unchecked_from_slice(dk).unwrap();
365            dk.decap(c)
366        }
367    }
368
369    #[test]
370    fn test_keypair_dk_ek_match_internal() {
371        let seed = Seed::from_slice(&[128u8; 64]).unwrap();
372        let kp = KeyPair::try_from(&seed).unwrap();
373        assert_eq!(kp.public(), &kp.private().cached_ek);
374    }
375
376    #[test]
377    #[cfg(feature = "safe_api")]
378    fn test_dk_cached_ek() {
379        let seed = Seed::from_slice(&[128u8; 64]).unwrap();
380        let kp = KeyPair::try_from(&seed).unwrap();
381        let (ss_pubapi, ct_pubapi) = kp.public().encap_deterministic(&[125u8; 32]).unwrap();
382        let mut c_prime = [0u8; MlKem768Internal::CIPHERTEXT_SIZE];
383        // This call re-computes encap key internally from the bytes a decapkey would store.
384        let ss_privapi = kp
385            .private()
386            .value
387            .mlkem_decap_internal(ct_pubapi.as_ref(), &mut c_prime)
388            .unwrap();
389        assert_eq!(ss_privapi.as_ref(), ss_pubapi.unprotected_as_bytes());
390        assert_eq!(
391            MlKem768::decap(kp.private(), &ct_pubapi).unwrap(),
392            ss_pubapi
393        );
394    }
395
396    #[cfg(feature = "safe_api")]
397    #[test]
398    fn test_dk_to_ek_conversions() {
399        let kp = KeyPair::generate().unwrap();
400        assert_eq!(
401            kp.dk.cached_ek,
402            EncapsulationKey::try_from(kp.private()).unwrap()
403        );
404    }
405
406    #[cfg(feature = "safe_api")]
407    #[test]
408    fn test_bad_m_length() {
409        let kp = KeyPair::generate().unwrap();
410        let mut m = [0u8; 32];
411        getrandom::fill(m.as_mut()).unwrap();
412
413        // Using the same deterministic seed is in fact deterministic,
414        // also using correct length.
415        assert_eq!(
416            kp.public().encap_deterministic(&m).unwrap(),
417            kp.public().encap_deterministic(&m).unwrap()
418        );
419        assert!(kp.public().encap_deterministic(&[0u8; 31]).is_err());
420        assert!(kp.public().encap_deterministic(&[0u8; 33]).is_err());
421    }
422
423    #[cfg(feature = "safe_api")]
424    #[test]
425    fn test_dk_ek_partialeq() {
426        let s0 = Seed::generate();
427        let kp = KeyPair::try_from(&s0).unwrap();
428
429        let dk_bytes = kp.private().value.bytes;
430        let ek_bytes = kp.public().value.bytes;
431
432        assert_eq!(
433            KeyPair::try_from(&s0).unwrap().private(),
434            &dk_bytes.as_ref()
435        );
436        assert_eq!(KeyPair::try_from(&s0).unwrap().public(), &ek_bytes.as_ref());
437    }
438
439    #[cfg(feature = "safe_api")]
440    #[test]
441    fn test_keypair_from_keys() {
442        let s0 = Seed::generate();
443        let s1 = Seed::generate();
444
445        let kp0 = KeyPair::try_from(&s0).unwrap();
446        let kp1 = KeyPair::try_from(&s1).unwrap();
447        assert_eq!(kp0.seed(), &s0);
448        assert_eq!(kp1.seed(), &s1);
449
450        assert!(KeyPair::from_keys(&s0, kp0.private()).is_ok());
451        assert!(KeyPair::from_keys(&s1, kp1.private()).is_ok());
452        assert!(KeyPair::from_keys(&s1, kp0.private()).is_err());
453        assert!(KeyPair::from_keys(&s0, kp1.private()).is_err());
454
455        let kp0_keys = KeyPair::from_keys(&s0, kp0.private()).unwrap();
456        let kp1_keys = KeyPair::from_keys(&s1, kp1.private()).unwrap();
457        assert_eq!(kp0.seed(), kp0_keys.seed());
458        assert_eq!(kp1.seed(), kp1_keys.seed());
459
460        assert_eq!(kp0.private(), kp0_keys.private());
461        assert_eq!(kp0.public(), kp0_keys.public());
462    }
463
464    #[cfg(feature = "safe_api")]
465    #[test]
466    fn run_basic_kem_tests() {
467        let seed = Seed::generate();
468        KemTester::<MlKem768, SharedSecret, Ciphertext>::run_all_tests(seed.unprotected_as_bytes());
469    }
470
471    #[test]
472    /// Basic no_std-compatible test.
473    fn basic_roundtrip() {
474        let seed = Seed::from_slice(&[127u8; 64]).unwrap();
475        let kp = KeyPair::try_from(&seed).unwrap();
476
477        let (k, c) = kp.public().encap_deterministic(&[255u8; 32]).unwrap();
478        let k_prime = kp.private().decap(&c).unwrap();
479
480        assert_eq!(k, k_prime);
481    }
482}