ml_kem/
kem.rs

1//! Key encapsulation mechanism implementation.
2
3// Re-export traits from the `kem` crate
4pub use ::kem::{
5    Decapsulate, Decapsulator, Encapsulate, Generate, InvalidKey, Key, KeyExport, KeyInit,
6    KeySizeUser, TryKeyInit,
7};
8
9use crate::{
10    Encoded, EncodedSizeUser, KemCore, Seed,
11    crypto::{G, H, J},
12    param::{
13        DecapsulationKeySize, EncapsulationKeySize, EncodedCiphertext, ExpandedDecapsulationKey,
14        KemParams,
15    },
16    pke::{DecryptionKey, EncryptionKey},
17    util::B32,
18};
19use array::typenum::{U32, U64};
20use core::marker::PhantomData;
21use rand_core::{CryptoRng, TryCryptoRng, TryRngCore};
22use subtle::{ConditionallySelectable, ConstantTimeEq};
23
24#[cfg(feature = "deterministic")]
25use core::convert::Infallible;
26#[cfg(feature = "zeroize")]
27use zeroize::{Zeroize, ZeroizeOnDrop};
28
29/// A shared key resulting from an ML-KEM transaction
30pub(crate) type SharedKey = B32;
31
32/// A `DecapsulationKey` provides the ability to generate a new key pair, and decapsulate an
33/// encapsulated shared key.
34#[derive(Clone, Debug)]
35pub struct DecapsulationKey<P>
36where
37    P: KemParams,
38{
39    dk_pke: DecryptionKey<P>,
40    ek: EncapsulationKey<P>,
41    d: Option<B32>,
42    z: B32,
43}
44
45// Handwritten to omit `d` in the comparisons, so keys initialized from seeds compare equally to
46// keys initialized from the expanded form
47impl<P> PartialEq for DecapsulationKey<P>
48where
49    P: KemParams,
50{
51    fn eq(&self, other: &Self) -> bool {
52        self.dk_pke.ct_eq(&other.dk_pke).into() && self.ek.eq(&other.ek) && self.z.eq(&other.z)
53    }
54}
55
56#[cfg(feature = "zeroize")]
57impl<P> Drop for DecapsulationKey<P>
58where
59    P: KemParams,
60{
61    fn drop(&mut self) {
62        self.dk_pke.zeroize();
63        self.z.zeroize();
64    }
65}
66
67#[cfg(feature = "zeroize")]
68impl<P> ZeroizeOnDrop for DecapsulationKey<P> where P: KemParams {}
69
70impl<P> From<Seed> for DecapsulationKey<P>
71where
72    P: KemParams,
73{
74    fn from(seed: Seed) -> Self {
75        Self::from_seed(seed)
76    }
77}
78
79impl<P> Decapsulate for DecapsulationKey<P>
80where
81    P: KemParams,
82{
83    fn decapsulate(&self, encapsulated_key: &EncodedCiphertext<P>) -> SharedKey {
84        let mp = self.dk_pke.decrypt(encapsulated_key);
85        let (Kp, rp) = G(&[&mp, &self.ek.h]);
86        let Kbar = J(&[self.z.as_slice(), encapsulated_key.as_ref()]);
87        let cp = self.ek.ek_pke.encrypt(&mp, &rp);
88        B32::conditional_select(&Kbar, &Kp, cp.ct_eq(encapsulated_key))
89    }
90}
91
92impl<P> Decapsulator for DecapsulationKey<P>
93where
94    P: KemParams,
95{
96    type Encapsulator = EncapsulationKey<P>;
97
98    fn encapsulator(&self) -> &EncapsulationKey<P> {
99        &self.ek
100    }
101}
102
103impl<P> EncodedSizeUser for DecapsulationKey<P>
104where
105    P: KemParams,
106{
107    type EncodedSize = DecapsulationKeySize<P>;
108
109    fn from_encoded_bytes(expanded: &Encoded<Self>) -> Result<Self, InvalidKey> {
110        #[allow(deprecated)]
111        Self::from_expanded(expanded)
112    }
113
114    fn to_encoded_bytes(&self) -> Encoded<Self> {
115        let dk_pke = self.dk_pke.to_bytes();
116        let ek = self.ek.to_encoded_bytes();
117        P::concat_dk(dk_pke, ek, self.ek.h.clone(), self.z.clone())
118    }
119}
120
121impl<P> Generate for DecapsulationKey<P>
122where
123    P: KemParams,
124{
125    fn try_generate_from_rng<R>(rng: &mut R) -> Result<Self, <R as TryRngCore>::Error>
126    where
127        R: TryCryptoRng + ?Sized,
128    {
129        Self::try_generate_from_rng(rng)
130    }
131}
132
133impl<P> KeySizeUser for DecapsulationKey<P>
134where
135    P: KemParams,
136{
137    type KeySize = U64;
138}
139
140impl<P> KeyInit for DecapsulationKey<P>
141where
142    P: KemParams,
143{
144    #[inline]
145    fn new(seed: &Seed) -> Self {
146        Self::from_seed(*seed)
147    }
148}
149
150impl<P> DecapsulationKey<P>
151where
152    P: KemParams,
153{
154    /// Create a [`DecapsulationKey`] instance from a 64-byte random seed value.
155    #[inline]
156    #[must_use]
157    pub fn from_seed(seed: Seed) -> Self {
158        let (d, z) = seed.split();
159        Self::generate_deterministic(d, z)
160    }
161
162    /// Initialize a [`DecapsulationKey`] from the serialized expanded key form.
163    ///
164    /// Note that this form is deprecated in practice; prefer to use
165    /// [`DecapsulationKey::from_seed`].
166    ///
167    /// # Errors
168    /// - Returns [`InvalidKey`] in the event the expanded key failed validation
169    #[deprecated(since = "0.3.0", note = "use `DecapsulationKey::from_seed` instead")]
170    pub fn from_expanded(enc: &ExpandedDecapsulationKey<P>) -> Result<Self, InvalidKey> {
171        let (dk_pke, ek_pke, h, z) = P::split_dk(enc);
172        let ek_pke = EncryptionKey::from_bytes(ek_pke)?;
173
174        // XXX(RLB): The encoding here is redundant, since `h` can be computed from `ek_pke`.
175        // Should we verify that the provided `h` value is valid?
176
177        Ok(Self {
178            dk_pke: DecryptionKey::from_bytes(dk_pke),
179            ek: EncapsulationKey {
180                ek_pke,
181                h: h.clone(),
182            },
183            d: None,
184            z: z.clone(),
185        })
186    }
187
188    /// Serialize the [`Seed`] value: 64-bytes which can be used to reconstruct the
189    /// [`DecapsulationKey`].
190    ///
191    /// <div class="warning">
192    /// <b>Warning!</B>
193    ///
194    /// This value is key material. Please treat it with care.
195    /// </div>
196    ///
197    /// # Returns
198    /// - `Some` if the [`DecapsulationKey`] was initialized using `from_seed` or `generate`.
199    /// - `None` if the [`DecapsulationKey`] was initialized from the expanded form.
200    #[inline]
201    pub fn to_seed(&self) -> Option<Seed> {
202        self.d.map(|d| d.concat(self.z))
203    }
204
205    /// Get the [`EncapsulationKey`] which corresponds to this [`DecapsulationKey`].
206    pub fn encapsulation_key(&self) -> &EncapsulationKey<P> {
207        &self.ek
208    }
209
210    #[inline]
211    pub(crate) fn try_generate_from_rng<R>(rng: &mut R) -> Result<Self, <R as TryRngCore>::Error>
212    where
213        R: TryCryptoRng + ?Sized,
214    {
215        let d = B32::try_generate_from_rng(rng)?;
216        let z = B32::try_generate_from_rng(rng)?;
217        Ok(Self::generate_deterministic(d, z))
218    }
219
220    #[inline]
221    #[must_use]
222    #[allow(clippy::similar_names)] // allow dk_pke, ek_pke, following the spec
223    pub(crate) fn generate_deterministic(d: B32, z: B32) -> Self {
224        let (dk_pke, ek_pke) = DecryptionKey::generate(&d);
225        let ek = EncapsulationKey::new(ek_pke);
226        let d = Some(d);
227        Self { dk_pke, ek, d, z }
228    }
229}
230
231/// An `EncapsulationKey` provides the ability to encapsulate a shared key so that it can only be
232/// decapsulated by the holder of the corresponding decapsulation key.
233#[derive(Clone, Debug)]
234pub struct EncapsulationKey<P>
235where
236    P: KemParams,
237{
238    ek_pke: EncryptionKey<P>,
239    h: B32,
240}
241
242impl<P> EncapsulationKey<P>
243where
244    P: KemParams,
245{
246    pub(crate) fn new(ek_pke: EncryptionKey<P>) -> Self {
247        let h = H(ek_pke.to_bytes());
248        Self { ek_pke, h }
249    }
250
251    fn encapsulate_deterministic_inner(&self, m: &B32) -> (EncodedCiphertext<P>, SharedKey) {
252        let (K, r) = G(&[m, &self.h]);
253        let c = self.ek_pke.encrypt(m, &r);
254        (c, K)
255    }
256}
257
258impl<P> Encapsulate for EncapsulationKey<P>
259where
260    P: KemParams,
261{
262    fn encapsulate_with_rng<R: TryCryptoRng + ?Sized>(
263        &self,
264        rng: &mut R,
265    ) -> Result<(EncodedCiphertext<P>, SharedKey), R::Error> {
266        let m = B32::try_generate_from_rng(rng)?;
267        Ok(self.encapsulate_deterministic_inner(&m))
268    }
269}
270
271impl<P> EncodedSizeUser for EncapsulationKey<P>
272where
273    P: KemParams,
274{
275    type EncodedSize = EncapsulationKeySize<P>;
276
277    fn from_encoded_bytes(enc: &Encoded<Self>) -> Result<Self, InvalidKey> {
278        Ok(Self::new(EncryptionKey::from_bytes(enc)?))
279    }
280
281    fn to_encoded_bytes(&self) -> Encoded<Self> {
282        self.ek_pke.to_bytes()
283    }
284}
285
286impl<P> ::kem::KemParams for EncapsulationKey<P>
287where
288    P: KemParams,
289{
290    type CiphertextSize = P::CiphertextSize;
291    type SharedSecretSize = U32;
292}
293
294impl<P> KeyExport for EncapsulationKey<P>
295where
296    P: KemParams,
297{
298    fn to_bytes(&self) -> Key<Self> {
299        self.ek_pke.to_bytes()
300    }
301}
302
303impl<P> KeySizeUser for EncapsulationKey<P>
304where
305    P: KemParams,
306{
307    type KeySize = EncapsulationKeySize<P>;
308}
309
310impl<P> TryKeyInit for EncapsulationKey<P>
311where
312    P: KemParams,
313{
314    fn new(encapsulation_key: &Key<Self>) -> Result<Self, InvalidKey> {
315        EncryptionKey::from_bytes(encapsulation_key)
316            .map(Self::new)
317            .map_err(|_| InvalidKey)
318    }
319}
320
321impl<P> Eq for EncapsulationKey<P> where P: KemParams {}
322impl<P> PartialEq for EncapsulationKey<P>
323where
324    P: KemParams,
325{
326    fn eq(&self, other: &Self) -> bool {
327        // Handwritten to avoid derive putting `Eq` bounds on `KemParams`
328        self.ek_pke == other.ek_pke && self.h == other.h
329    }
330}
331
332#[cfg(feature = "deterministic")]
333impl<P> crate::EncapsulateDeterministic<EncodedCiphertext<P>, SharedKey> for EncapsulationKey<P>
334where
335    P: KemParams,
336{
337    type Error = Infallible;
338
339    fn encapsulate_deterministic(
340        &self,
341        m: &B32,
342    ) -> Result<(EncodedCiphertext<P>, SharedKey), Self::Error> {
343        Ok(self.encapsulate_deterministic_inner(m))
344    }
345}
346
347/// An implementation of overall ML-KEM functionality.  Generic over parameter sets, but then ties
348/// together all of the other related types and sizes.
349#[derive(Clone)]
350pub struct Kem<P>
351where
352    P: KemParams,
353{
354    _phantom: PhantomData<P>,
355}
356
357impl<P> KemCore for Kem<P>
358where
359    P: KemParams,
360{
361    type SharedKeySize = U32;
362    type CiphertextSize = P::CiphertextSize;
363    type DecapsulationKey = DecapsulationKey<P>;
364    type EncapsulationKey = EncapsulationKey<P>;
365
366    /// Generate a new (decapsulation, encapsulation) key pair
367    fn generate<R: CryptoRng + ?Sized>(
368        rng: &mut R,
369    ) -> (Self::DecapsulationKey, Self::EncapsulationKey) {
370        let dk = Self::DecapsulationKey::generate_from_rng(rng);
371        let ek = dk.encapsulation_key().clone();
372        (dk, ek)
373    }
374
375    fn from_seed(seed: Seed) -> (Self::DecapsulationKey, Self::EncapsulationKey) {
376        let dk = Self::DecapsulationKey::from_seed(seed);
377        let ek = dk.encapsulation_key().clone();
378        (dk, ek)
379    }
380}
381
382#[cfg(test)]
383mod test {
384    use super::*;
385    use crate::{MlKem512Params, MlKem768Params, MlKem1024Params};
386    use ::kem::{Decapsulate, Encapsulate, Generate};
387    use getrandom::SysRng;
388    use rand_core::TryRngCore;
389
390    fn round_trip_test<P>()
391    where
392        P: KemParams,
393    {
394        let mut rng = SysRng.unwrap_err();
395
396        let dk = DecapsulationKey::<P>::generate_from_rng(&mut rng);
397        let ek = dk.encapsulation_key();
398
399        let (ct, k_send) = ek.encapsulate_with_rng(&mut rng).unwrap();
400        let k_recv = dk.decapsulate(&ct);
401        assert_eq!(k_send, k_recv);
402    }
403
404    #[test]
405    fn round_trip() {
406        round_trip_test::<MlKem512Params>();
407        round_trip_test::<MlKem768Params>();
408        round_trip_test::<MlKem1024Params>();
409    }
410
411    fn expanded_key_test<P>()
412    where
413        P: KemParams,
414    {
415        let mut rng = SysRng.unwrap_err();
416        let dk_original = DecapsulationKey::<P>::generate_from_rng(&mut rng);
417        let ek_original = dk_original.encapsulation_key().clone();
418
419        let dk_encoded = dk_original.to_encoded_bytes();
420        let dk_decoded = DecapsulationKey::from_encoded_bytes(&dk_encoded).unwrap();
421        assert_eq!(dk_original, dk_decoded);
422
423        let ek_encoded = ek_original.to_encoded_bytes();
424        let ek_decoded = EncapsulationKey::from_encoded_bytes(&ek_encoded).unwrap();
425        assert_eq!(ek_original, ek_decoded);
426    }
427
428    #[test]
429    fn expanded_key() {
430        expanded_key_test::<MlKem512Params>();
431        expanded_key_test::<MlKem768Params>();
432        expanded_key_test::<MlKem1024Params>();
433    }
434
435    fn seed_test<P>()
436    where
437        P: KemParams,
438    {
439        let mut rng = SysRng.unwrap_err();
440        let mut seed = Seed::default();
441        rng.try_fill_bytes(&mut seed).unwrap();
442
443        let dk = DecapsulationKey::<P>::from_seed(seed.clone());
444        let seed_encoded = dk.to_seed().unwrap();
445        assert_eq!(seed, seed_encoded);
446    }
447
448    #[test]
449    fn seed() {
450        seed_test::<MlKem512Params>();
451        seed_test::<MlKem768Params>();
452        seed_test::<MlKem1024Params>();
453    }
454}