ml_kem/
kem.rs

1//! Key encapsulation mechanism implementation.
2
3// Re-export traits from the `kem` crate
4pub use ::kem::{
5    Decapsulate, Decapsulator, Encapsulate, Generate, InvalidKey, Key, KeyExport, KeyInit,
6    KeySizeUser, TryKeyInit,
7};
8
9use crate::{
10    Encoded, EncodedSizeUser, KemCore, Seed,
11    crypto::{G, H, J},
12    param::{
13        DecapsulationKeySize, EncapsulationKeySize, EncodedCiphertext, ExpandedDecapsulationKey,
14        KemParams,
15    },
16    pke::{DecryptionKey, EncryptionKey},
17    util::B32,
18};
19use array::typenum::{U32, U64};
20use core::marker::PhantomData;
21use rand_core::{CryptoRng, TryCryptoRng, TryRng};
22use subtle::{ConditionallySelectable, ConstantTimeEq};
23
24#[cfg(feature = "zeroize")]
25use zeroize::{Zeroize, ZeroizeOnDrop};
26
27/// A shared key resulting from an ML-KEM transaction
28pub(crate) type SharedKey = B32;
29
30/// A `DecapsulationKey` provides the ability to generate a new key pair, and decapsulate an
31/// encapsulated shared key.
32#[derive(Clone, Debug)]
33pub struct DecapsulationKey<P>
34where
35    P: KemParams,
36{
37    dk_pke: DecryptionKey<P>,
38    ek: EncapsulationKey<P>,
39    d: Option<B32>,
40    z: B32,
41}
42
43// Handwritten to omit `d` in the comparisons, so keys initialized from seeds compare equally to
44// keys initialized from the expanded form
45impl<P> PartialEq for DecapsulationKey<P>
46where
47    P: KemParams,
48{
49    fn eq(&self, other: &Self) -> bool {
50        self.dk_pke.ct_eq(&other.dk_pke).into() && self.ek.eq(&other.ek) && self.z.eq(&other.z)
51    }
52}
53
54#[cfg(feature = "zeroize")]
55impl<P> Drop for DecapsulationKey<P>
56where
57    P: KemParams,
58{
59    fn drop(&mut self) {
60        self.dk_pke.zeroize();
61        self.z.zeroize();
62    }
63}
64
65#[cfg(feature = "zeroize")]
66impl<P> ZeroizeOnDrop for DecapsulationKey<P> where P: KemParams {}
67
68impl<P> From<Seed> for DecapsulationKey<P>
69where
70    P: KemParams,
71{
72    fn from(seed: Seed) -> Self {
73        Self::from_seed(seed)
74    }
75}
76
77impl<P> Decapsulate for DecapsulationKey<P>
78where
79    P: KemParams,
80{
81    fn decapsulate(&self, encapsulated_key: &EncodedCiphertext<P>) -> SharedKey {
82        let mp = self.dk_pke.decrypt(encapsulated_key);
83        let (Kp, rp) = G(&[&mp, &self.ek.h]);
84        let Kbar = J(&[self.z.as_slice(), encapsulated_key.as_ref()]);
85        let cp = self.ek.ek_pke.encrypt(&mp, &rp);
86        B32::conditional_select(&Kbar, &Kp, cp.ct_eq(encapsulated_key))
87    }
88}
89
90impl<P> Decapsulator for DecapsulationKey<P>
91where
92    P: KemParams,
93{
94    type Encapsulator = EncapsulationKey<P>;
95
96    fn encapsulator(&self) -> &EncapsulationKey<P> {
97        &self.ek
98    }
99}
100
101impl<P> EncodedSizeUser for DecapsulationKey<P>
102where
103    P: KemParams,
104{
105    type EncodedSize = DecapsulationKeySize<P>;
106
107    fn from_encoded_bytes(expanded: &Encoded<Self>) -> Result<Self, InvalidKey> {
108        #[allow(deprecated)]
109        Self::from_expanded(expanded)
110    }
111
112    fn to_encoded_bytes(&self) -> Encoded<Self> {
113        let dk_pke = self.dk_pke.to_bytes();
114        let ek = self.ek.to_encoded_bytes();
115        P::concat_dk(dk_pke, ek, self.ek.h.clone(), self.z.clone())
116    }
117}
118
119impl<P> Generate for DecapsulationKey<P>
120where
121    P: KemParams,
122{
123    fn try_generate_from_rng<R>(rng: &mut R) -> Result<Self, <R as TryRng>::Error>
124    where
125        R: TryCryptoRng + ?Sized,
126    {
127        Self::try_generate_from_rng(rng)
128    }
129}
130
131impl<P> KeySizeUser for DecapsulationKey<P>
132where
133    P: KemParams,
134{
135    type KeySize = U64;
136}
137
138impl<P> KeyInit for DecapsulationKey<P>
139where
140    P: KemParams,
141{
142    #[inline]
143    fn new(seed: &Seed) -> Self {
144        Self::from_seed(*seed)
145    }
146}
147
148impl<P> DecapsulationKey<P>
149where
150    P: KemParams,
151{
152    /// Create a [`DecapsulationKey`] instance from a 64-byte random seed value.
153    #[inline]
154    #[must_use]
155    pub fn from_seed(seed: Seed) -> Self {
156        let (d, z) = seed.split();
157        Self::generate_deterministic(d, z)
158    }
159
160    /// Initialize a [`DecapsulationKey`] from the serialized expanded key form.
161    ///
162    /// Note that this form is deprecated in practice; prefer to use
163    /// [`DecapsulationKey::from_seed`].
164    ///
165    /// # Errors
166    /// - Returns [`InvalidKey`] in the event the expanded key failed validation
167    #[deprecated(since = "0.3.0", note = "use `DecapsulationKey::from_seed` instead")]
168    pub fn from_expanded(enc: &ExpandedDecapsulationKey<P>) -> Result<Self, InvalidKey> {
169        let (dk_pke, ek_pke, h, z) = P::split_dk(enc);
170        let ek_pke = EncryptionKey::from_bytes(ek_pke)?;
171
172        // XXX(RLB): The encoding here is redundant, since `h` can be computed from `ek_pke`.
173        // Should we verify that the provided `h` value is valid?
174
175        Ok(Self {
176            dk_pke: DecryptionKey::from_bytes(dk_pke),
177            ek: EncapsulationKey {
178                ek_pke,
179                h: h.clone(),
180            },
181            d: None,
182            z: z.clone(),
183        })
184    }
185
186    /// Serialize the [`Seed`] value: 64-bytes which can be used to reconstruct the
187    /// [`DecapsulationKey`].
188    ///
189    /// <div class="warning">
190    /// <b>Warning!</B>
191    ///
192    /// This value is key material. Please treat it with care.
193    /// </div>
194    ///
195    /// # Returns
196    /// - `Some` if the [`DecapsulationKey`] was initialized using `from_seed` or `generate`.
197    /// - `None` if the [`DecapsulationKey`] was initialized from the expanded form.
198    #[inline]
199    pub fn to_seed(&self) -> Option<Seed> {
200        self.d.map(|d| d.concat(self.z))
201    }
202
203    /// Get the [`EncapsulationKey`] which corresponds to this [`DecapsulationKey`].
204    pub fn encapsulation_key(&self) -> &EncapsulationKey<P> {
205        &self.ek
206    }
207
208    #[inline]
209    pub(crate) fn try_generate_from_rng<R>(rng: &mut R) -> Result<Self, <R as TryRng>::Error>
210    where
211        R: TryCryptoRng + ?Sized,
212    {
213        let d = B32::try_generate_from_rng(rng)?;
214        let z = B32::try_generate_from_rng(rng)?;
215        Ok(Self::generate_deterministic(d, z))
216    }
217
218    #[inline]
219    #[must_use]
220    #[allow(clippy::similar_names)] // allow dk_pke, ek_pke, following the spec
221    pub(crate) fn generate_deterministic(d: B32, z: B32) -> Self {
222        let (dk_pke, ek_pke) = DecryptionKey::generate(&d);
223        let ek = EncapsulationKey::new(ek_pke);
224        let d = Some(d);
225        Self { dk_pke, ek, d, z }
226    }
227}
228
229/// An `EncapsulationKey` provides the ability to encapsulate a shared key so that it can only be
230/// decapsulated by the holder of the corresponding decapsulation key.
231#[derive(Clone, Debug)]
232pub struct EncapsulationKey<P>
233where
234    P: KemParams,
235{
236    ek_pke: EncryptionKey<P>,
237    h: B32,
238}
239
240impl<P> EncapsulationKey<P>
241where
242    P: KemParams,
243{
244    pub(crate) fn new(ek_pke: EncryptionKey<P>) -> Self {
245        let h = H(ek_pke.to_bytes());
246        Self { ek_pke, h }
247    }
248
249    /// Encapsulates with the given randomness. This is useful for testing against known vectors.
250    ///
251    /// # Warning
252    /// Do NOT use this function unless you know what you're doing. If you fail to use all uniform
253    /// random bytes even once, you can have catastrophic security failure.
254    #[cfg_attr(not(feature = "hazmat"), doc(hidden))]
255    pub fn encapsulate_deterministic(&self, m: &B32) -> (EncodedCiphertext<P>, SharedKey) {
256        let (K, r) = G(&[m, &self.h]);
257        let c = self.ek_pke.encrypt(m, &r);
258        (c, K)
259    }
260}
261
262impl<P> Encapsulate for EncapsulationKey<P>
263where
264    P: KemParams,
265{
266    fn encapsulate_with_rng<R>(&self, rng: &mut R) -> (EncodedCiphertext<P>, SharedKey)
267    where
268        R: CryptoRng + ?Sized,
269    {
270        let m = B32::generate_from_rng(rng);
271        self.encapsulate_deterministic(&m)
272    }
273}
274
275impl<P> EncodedSizeUser for EncapsulationKey<P>
276where
277    P: KemParams,
278{
279    type EncodedSize = EncapsulationKeySize<P>;
280
281    fn from_encoded_bytes(enc: &Encoded<Self>) -> Result<Self, InvalidKey> {
282        Ok(Self::new(EncryptionKey::from_bytes(enc)?))
283    }
284
285    fn to_encoded_bytes(&self) -> Encoded<Self> {
286        self.ek_pke.to_bytes()
287    }
288}
289
290impl<P> ::kem::KemParams for EncapsulationKey<P>
291where
292    P: KemParams,
293{
294    type CiphertextSize = P::CiphertextSize;
295    type SharedSecretSize = U32;
296}
297
298impl<P> KeyExport for EncapsulationKey<P>
299where
300    P: KemParams,
301{
302    fn to_bytes(&self) -> Key<Self> {
303        self.ek_pke.to_bytes()
304    }
305}
306
307impl<P> KeySizeUser for EncapsulationKey<P>
308where
309    P: KemParams,
310{
311    type KeySize = EncapsulationKeySize<P>;
312}
313
314impl<P> TryKeyInit for EncapsulationKey<P>
315where
316    P: KemParams,
317{
318    fn new(encapsulation_key: &Key<Self>) -> Result<Self, InvalidKey> {
319        EncryptionKey::from_bytes(encapsulation_key)
320            .map(Self::new)
321            .map_err(|_| InvalidKey)
322    }
323}
324
325impl<P> Eq for EncapsulationKey<P> where P: KemParams {}
326impl<P> PartialEq for EncapsulationKey<P>
327where
328    P: KemParams,
329{
330    fn eq(&self, other: &Self) -> bool {
331        // Handwritten to avoid derive putting `Eq` bounds on `KemParams`
332        self.ek_pke == other.ek_pke && self.h == other.h
333    }
334}
335
336/// An implementation of overall ML-KEM functionality.  Generic over parameter sets, but then ties
337/// together all of the other related types and sizes.
338#[derive(Clone)]
339pub struct Kem<P>
340where
341    P: KemParams,
342{
343    _phantom: PhantomData<P>,
344}
345
346impl<P> KemCore for Kem<P>
347where
348    P: KemParams,
349{
350    type SharedKeySize = U32;
351    type CiphertextSize = P::CiphertextSize;
352    type DecapsulationKey = DecapsulationKey<P>;
353    type EncapsulationKey = EncapsulationKey<P>;
354
355    /// Generate a new (decapsulation, encapsulation) key pair
356    fn generate<R: CryptoRng + ?Sized>(
357        rng: &mut R,
358    ) -> (Self::DecapsulationKey, Self::EncapsulationKey) {
359        let dk = Self::DecapsulationKey::generate_from_rng(rng);
360        let ek = dk.encapsulation_key().clone();
361        (dk, ek)
362    }
363
364    fn from_seed(seed: Seed) -> (Self::DecapsulationKey, Self::EncapsulationKey) {
365        let dk = Self::DecapsulationKey::from_seed(seed);
366        let ek = dk.encapsulation_key().clone();
367        (dk, ek)
368    }
369}
370
371#[cfg(test)]
372mod test {
373    use super::*;
374    use crate::{MlKem512Params, MlKem768Params, MlKem1024Params};
375    use ::kem::{Decapsulate, Encapsulate, Generate};
376    use getrandom::SysRng;
377    use rand_core::UnwrapErr;
378
379    fn round_trip_test<P>()
380    where
381        P: KemParams,
382    {
383        let mut rng = UnwrapErr(SysRng);
384
385        let dk = DecapsulationKey::<P>::generate_from_rng(&mut rng);
386        let ek = dk.encapsulation_key();
387
388        let (ct, k_send) = ek.encapsulate_with_rng(&mut rng);
389        let k_recv = dk.decapsulate(&ct);
390        assert_eq!(k_send, k_recv);
391    }
392
393    #[test]
394    fn round_trip() {
395        round_trip_test::<MlKem512Params>();
396        round_trip_test::<MlKem768Params>();
397        round_trip_test::<MlKem1024Params>();
398    }
399
400    fn expanded_key_test<P>()
401    where
402        P: KemParams,
403    {
404        let mut rng = UnwrapErr(SysRng);
405        let dk_original = DecapsulationKey::<P>::generate_from_rng(&mut rng);
406        let ek_original = dk_original.encapsulation_key().clone();
407
408        let dk_encoded = dk_original.to_encoded_bytes();
409        let dk_decoded = DecapsulationKey::from_encoded_bytes(&dk_encoded).unwrap();
410        assert_eq!(dk_original, dk_decoded);
411
412        let ek_encoded = ek_original.to_encoded_bytes();
413        let ek_decoded = EncapsulationKey::from_encoded_bytes(&ek_encoded).unwrap();
414        assert_eq!(ek_original, ek_decoded);
415    }
416
417    #[test]
418    fn expanded_key() {
419        expanded_key_test::<MlKem512Params>();
420        expanded_key_test::<MlKem768Params>();
421        expanded_key_test::<MlKem1024Params>();
422    }
423
424    fn seed_test<P>()
425    where
426        P: KemParams,
427    {
428        let mut rng = UnwrapErr(SysRng);
429        let mut seed = Seed::default();
430        rng.try_fill_bytes(&mut seed).unwrap();
431
432        let dk = DecapsulationKey::<P>::from_seed(seed.clone());
433        let seed_encoded = dk.to_seed().unwrap();
434        assert_eq!(seed, seed_encoded);
435    }
436
437    #[test]
438    fn seed() {
439        seed_test::<MlKem512Params>();
440        seed_test::<MlKem768Params>();
441        seed_test::<MlKem1024Params>();
442    }
443}