ml_kem/
kem.rs

1use core::convert::Infallible;
2use core::marker::PhantomData;
3use hybrid_array::typenum::{U32, U64};
4use rand_core::{CryptoRng, TryCryptoRng};
5use subtle::{ConditionallySelectable, ConstantTimeEq};
6
7use crate::crypto::{G, H, J, rand};
8use crate::param::{DecapsulationKeySize, EncapsulationKeySize, EncodedCiphertext, KemParams};
9use crate::pke::{DecryptionKey, EncryptionKey};
10use crate::util::B32;
11use crate::{Encoded, EncodedSizeUser, Seed};
12
13#[cfg(feature = "zeroize")]
14use zeroize::{Zeroize, ZeroizeOnDrop};
15
16// Re-export traits from the `kem` crate
17pub use ::kem::{Decapsulate, Encapsulate};
18
19/// A shared key resulting from an ML-KEM transaction
20pub(crate) type SharedKey = B32;
21
22/// A `DecapsulationKey` provides the ability to generate a new key pair, and decapsulate an
23/// encapsulated shared key.
24#[derive(Clone, Debug)]
25pub struct DecapsulationKey<P>
26where
27    P: KemParams,
28{
29    dk_pke: DecryptionKey<P>,
30    ek: EncapsulationKey<P>,
31    d: Option<B32>,
32    z: B32,
33}
34
35// Handwritten to omit `d` in the comparisons, so keys initialized from seeds compare equally to
36// keys initialized from the expanded form
37impl<P> PartialEq for DecapsulationKey<P>
38where
39    P: KemParams,
40{
41    fn eq(&self, other: &Self) -> bool {
42        self.dk_pke.ct_eq(&other.dk_pke).into() && self.ek.eq(&other.ek) && self.z.eq(&other.z)
43    }
44}
45
46#[cfg(feature = "zeroize")]
47impl<P> Drop for DecapsulationKey<P>
48where
49    P: KemParams,
50{
51    fn drop(&mut self) {
52        self.dk_pke.zeroize();
53        self.z.zeroize();
54    }
55}
56
57#[cfg(feature = "zeroize")]
58impl<P> ZeroizeOnDrop for DecapsulationKey<P> where P: KemParams {}
59
60impl<P> From<Seed> for DecapsulationKey<P>
61where
62    P: KemParams,
63{
64    fn from(seed: Seed) -> Self {
65        Self::from_seed(seed)
66    }
67}
68
69impl<P> EncodedSizeUser for DecapsulationKey<P>
70where
71    P: KemParams,
72{
73    type EncodedSize = DecapsulationKeySize<P>;
74
75    #[allow(clippy::similar_names)] // allow dk_pke, ek_pke, following the spec
76    fn from_bytes(enc: &Encoded<Self>) -> Self {
77        let (dk_pke, ek_pke, h, z) = P::split_dk(enc);
78        let ek_pke = EncryptionKey::from_bytes(ek_pke);
79
80        // XXX(RLB): The encoding here is redundant, since `h` can be computed from `ek_pke`.
81        // Should we verify that the provided `h` value is valid?
82
83        Self {
84            dk_pke: DecryptionKey::from_bytes(dk_pke),
85            ek: EncapsulationKey {
86                ek_pke,
87                h: h.clone(),
88            },
89            d: None,
90            z: z.clone(),
91        }
92    }
93
94    fn as_bytes(&self) -> Encoded<Self> {
95        let dk_pke = self.dk_pke.as_bytes();
96        let ek = self.ek.as_bytes();
97        P::concat_dk(dk_pke, ek, self.ek.h.clone(), self.z.clone())
98    }
99}
100
101impl<P> ::kem::KeySizeUser for DecapsulationKey<P>
102where
103    P: KemParams,
104{
105    type KeySize = U64;
106}
107
108impl<P> ::kem::KeyInit for DecapsulationKey<P>
109where
110    P: KemParams,
111{
112    #[inline]
113    fn new(seed: &Seed) -> Self {
114        Self::from_seed(*seed)
115    }
116}
117
118impl<P> ::kem::Decapsulate<EncodedCiphertext<P>, SharedKey> for DecapsulationKey<P>
119where
120    P: KemParams,
121{
122    type Encapsulator = EncapsulationKey<P>;
123    type Error = Infallible;
124
125    fn decapsulate(
126        &self,
127        encapsulated_key: &EncodedCiphertext<P>,
128    ) -> Result<SharedKey, Self::Error> {
129        let mp = self.dk_pke.decrypt(encapsulated_key);
130        let (Kp, rp) = G(&[&mp, &self.ek.h]);
131        let Kbar = J(&[self.z.as_slice(), encapsulated_key.as_ref()]);
132        let cp = self.ek.ek_pke.encrypt(&mp, &rp);
133        Ok(B32::conditional_select(
134            &Kbar,
135            &Kp,
136            cp.ct_eq(encapsulated_key),
137        ))
138    }
139
140    fn encapsulator(&self) -> EncapsulationKey<P> {
141        self.ek.clone()
142    }
143}
144
145impl<P> DecapsulationKey<P>
146where
147    P: KemParams,
148{
149    /// Create a [`DecapsulationKey`] instance from a 64-byte random seed value.
150    #[inline]
151    #[must_use]
152    pub fn from_seed(seed: Seed) -> Self {
153        let (d, z) = seed.split();
154        Self::generate_deterministic(d, z)
155    }
156
157    /// Serialize the [`Seed`] value: 64-bytes which can be used to reconstruct the
158    /// [`DecapsulationKey`].
159    ///
160    /// # ⚠️Warning!
161    ///
162    /// This value is key material. Please treat it with care.
163    ///
164    /// # Returns
165    /// - `Some` if the [`DecapsulationKey`] was initialized using `from_seed` or `generate`.
166    /// - `None` if the [`DecapsulationKey`] was initialized from the expanded form.
167    #[inline]
168    pub fn to_seed(&self) -> Option<Seed> {
169        self.d.map(|d| d.concat(self.z))
170    }
171
172    /// Get the [`EncapsulationKey`] which corresponds to this [`DecapsulationKey`].
173    pub fn encapsulation_key(&self) -> &EncapsulationKey<P> {
174        &self.ek
175    }
176
177    #[inline]
178    pub(crate) fn generate<R: CryptoRng + ?Sized>(rng: &mut R) -> Self {
179        let d: B32 = rand(rng);
180        let z: B32 = rand(rng);
181        Self::generate_deterministic(d, z)
182    }
183
184    #[inline]
185    #[must_use]
186    #[allow(clippy::similar_names)] // allow dk_pke, ek_pke, following the spec
187    pub(crate) fn generate_deterministic(d: B32, z: B32) -> Self {
188        let (dk_pke, ek_pke) = DecryptionKey::generate(&d);
189        let ek = EncapsulationKey::new(ek_pke);
190        let d = Some(d);
191        Self { dk_pke, ek, d, z }
192    }
193}
194
195/// An `EncapsulationKey` provides the ability to encapsulate a shared key so that it can only be
196/// decapsulated by the holder of the corresponding decapsulation key.
197#[derive(Clone, Debug, PartialEq)]
198pub struct EncapsulationKey<P>
199where
200    P: KemParams,
201{
202    ek_pke: EncryptionKey<P>,
203    h: B32,
204}
205
206impl<P> EncapsulationKey<P>
207where
208    P: KemParams,
209{
210    pub(crate) fn new(ek_pke: EncryptionKey<P>) -> Self {
211        let h = H(ek_pke.as_bytes());
212        Self { ek_pke, h }
213    }
214
215    fn encapsulate_deterministic_inner(&self, m: &B32) -> (EncodedCiphertext<P>, SharedKey) {
216        let (K, r) = G(&[m, &self.h]);
217        let c = self.ek_pke.encrypt(m, &r);
218        (c, K)
219    }
220}
221
222impl<P> EncodedSizeUser for EncapsulationKey<P>
223where
224    P: KemParams,
225{
226    type EncodedSize = EncapsulationKeySize<P>;
227
228    fn from_bytes(enc: &Encoded<Self>) -> Self {
229        Self::new(EncryptionKey::from_bytes(enc))
230    }
231
232    fn as_bytes(&self) -> Encoded<Self> {
233        self.ek_pke.as_bytes()
234    }
235}
236
237impl<P> ::kem::Encapsulate<EncodedCiphertext<P>, SharedKey> for EncapsulationKey<P>
238where
239    P: KemParams,
240{
241    type Error = Infallible;
242
243    fn encapsulate<R: TryCryptoRng + ?Sized>(
244        &self,
245        rng: &mut R,
246    ) -> Result<(EncodedCiphertext<P>, SharedKey), Self::Error> {
247        let m: B32 = rand(&mut rng.unwrap_mut());
248        Ok(self.encapsulate_deterministic_inner(&m))
249    }
250}
251
252#[cfg(feature = "deterministic")]
253impl<P> crate::EncapsulateDeterministic<EncodedCiphertext<P>, SharedKey> for EncapsulationKey<P>
254where
255    P: KemParams,
256{
257    type Error = Infallible;
258
259    fn encapsulate_deterministic(
260        &self,
261        m: &B32,
262    ) -> Result<(EncodedCiphertext<P>, SharedKey), Self::Error> {
263        Ok(self.encapsulate_deterministic_inner(m))
264    }
265}
266
267/// An implementation of overall ML-KEM functionality.  Generic over parameter sets, but then ties
268/// together all of the other related types and sizes.
269#[derive(Clone)]
270pub struct Kem<P>
271where
272    P: KemParams,
273{
274    _phantom: PhantomData<P>,
275}
276
277impl<P> crate::KemCore for Kem<P>
278where
279    P: KemParams,
280{
281    type SharedKeySize = U32;
282    type CiphertextSize = P::CiphertextSize;
283    type DecapsulationKey = DecapsulationKey<P>;
284    type EncapsulationKey = EncapsulationKey<P>;
285
286    /// Generate a new (decapsulation, encapsulation) key pair
287    fn generate<R: CryptoRng + ?Sized>(
288        rng: &mut R,
289    ) -> (Self::DecapsulationKey, Self::EncapsulationKey) {
290        let dk = Self::DecapsulationKey::generate(rng);
291        let ek = dk.encapsulation_key().clone();
292        (dk, ek)
293    }
294
295    fn from_seed(seed: Seed) -> (Self::DecapsulationKey, Self::EncapsulationKey) {
296        let dk = Self::DecapsulationKey::from_seed(seed);
297        let ek = dk.encapsulation_key().clone();
298        (dk, ek)
299    }
300}
301
302#[cfg(test)]
303mod test {
304    use super::*;
305    use crate::{MlKem512Params, MlKem768Params, MlKem1024Params};
306    use ::kem::{Decapsulate, Encapsulate};
307    use rand_core::TryRngCore;
308
309    fn round_trip_test<P>()
310    where
311        P: KemParams,
312    {
313        let mut rng = rand::rng();
314
315        let dk = DecapsulationKey::<P>::generate(&mut rng);
316        let ek = dk.encapsulation_key();
317
318        let (ct, k_send) = ek.encapsulate(&mut rng).unwrap();
319        let k_recv = dk.decapsulate(&ct).unwrap();
320        assert_eq!(k_send, k_recv);
321    }
322
323    #[test]
324    fn round_trip() {
325        round_trip_test::<MlKem512Params>();
326        round_trip_test::<MlKem768Params>();
327        round_trip_test::<MlKem1024Params>();
328    }
329
330    fn expanded_key_test<P>()
331    where
332        P: KemParams,
333    {
334        let mut rng = rand::rng();
335        let dk_original = DecapsulationKey::<P>::generate(&mut rng);
336        let ek_original = dk_original.encapsulation_key().clone();
337
338        let dk_encoded = dk_original.as_bytes();
339        let dk_decoded = DecapsulationKey::from_bytes(&dk_encoded);
340        assert_eq!(dk_original, dk_decoded);
341
342        let ek_encoded = ek_original.as_bytes();
343        let ek_decoded = EncapsulationKey::from_bytes(&ek_encoded);
344        assert_eq!(ek_original, ek_decoded);
345    }
346
347    #[test]
348    fn expanded_key() {
349        expanded_key_test::<MlKem512Params>();
350        expanded_key_test::<MlKem768Params>();
351        expanded_key_test::<MlKem1024Params>();
352    }
353
354    fn seed_test<P>()
355    where
356        P: KemParams,
357    {
358        let mut rng = rand::rng();
359        let mut seed = Seed::default();
360        rng.try_fill_bytes(&mut seed).unwrap();
361
362        let dk = DecapsulationKey::<P>::from_seed(seed.clone());
363        let seed_encoded = dk.to_seed().unwrap();
364        assert_eq!(seed, seed_encoded);
365    }
366
367    #[test]
368    fn seed() {
369        seed_test::<MlKem512Params>();
370        seed_test::<MlKem768Params>();
371        seed_test::<MlKem1024Params>();
372    }
373}