ml_kem/
kem.rs

1use core::marker::PhantomData;
2use hybrid_array::typenum::U32;
3use rand_core::CryptoRngCore;
4
5use crate::crypto::{rand, G, H, J};
6use crate::param::{DecapsulationKeySize, EncapsulationKeySize, EncodedCiphertext, KemParams};
7use crate::pke::{DecryptionKey, EncryptionKey};
8use crate::util::B32;
9use crate::{Encoded, EncodedSizeUser};
10
11#[cfg(feature = "zeroize")]
12use zeroize::{Zeroize, ZeroizeOnDrop};
13
14// Re-export traits from the `kem` crate
15pub use ::kem::{Decapsulate, Encapsulate};
16
17/// A shared key resulting from an ML-KEM transaction
18pub(crate) type SharedKey = B32;
19
20/// A `DecapsulationKey` provides the ability to generate a new key pair, and decapsulate an
21/// encapsulated shared key.
22#[derive(Clone, Debug, PartialEq)]
23pub struct DecapsulationKey<P>
24where
25    P: KemParams,
26{
27    dk_pke: DecryptionKey<P>,
28    ek: EncapsulationKey<P>,
29    z: B32,
30}
31
32#[cfg(feature = "zeroize")]
33impl<P> Drop for DecapsulationKey<P>
34where
35    P: KemParams,
36{
37    fn drop(&mut self) {
38        self.dk_pke.zeroize();
39        self.z.zeroize();
40    }
41}
42
43#[cfg(feature = "zeroize")]
44impl<P> ZeroizeOnDrop for DecapsulationKey<P> where P: KemParams {}
45
46impl<P> EncodedSizeUser for DecapsulationKey<P>
47where
48    P: KemParams,
49{
50    type EncodedSize = DecapsulationKeySize<P>;
51
52    #[allow(clippy::similar_names)] // allow dk_pke, ek_pke, following the spec
53    fn from_bytes(enc: &Encoded<Self>) -> Self {
54        let (dk_pke, ek_pke, h, z) = P::split_dk(enc);
55        let ek_pke = EncryptionKey::from_bytes(ek_pke);
56
57        // XXX(RLB): The encoding here is redundant, since `h` can be computed from `ek_pke`.
58        // Should we verify that the provided `h` value is valid?
59
60        Self {
61            dk_pke: DecryptionKey::from_bytes(dk_pke),
62            ek: EncapsulationKey {
63                ek_pke,
64                h: h.clone(),
65            },
66            z: z.clone(),
67        }
68    }
69
70    fn as_bytes(&self) -> Encoded<Self> {
71        let dk_pke = self.dk_pke.as_bytes();
72        let ek = self.ek.as_bytes();
73        P::concat_dk(dk_pke, ek, self.ek.h.clone(), self.z.clone())
74    }
75}
76
77// 0xff if x == y, 0x00 otherwise
78fn constant_time_eq(x: u8, y: u8) -> u8 {
79    let diff = x ^ y;
80    let is_zero = !diff & diff.wrapping_sub(1);
81    0u8.wrapping_sub(is_zero >> 7)
82}
83
84impl<P> ::kem::Decapsulate<EncodedCiphertext<P>, SharedKey> for DecapsulationKey<P>
85where
86    P: KemParams,
87{
88    // Decapsulation is infallible
89    // XXX(RLB): Maybe we should reflect decryption failure as an error?
90    // TODO(RLB) Make Infallible
91    type Error = ();
92
93    fn decapsulate(&self, encapsulated_key: &EncodedCiphertext<P>) -> Result<SharedKey, ()> {
94        let mp = self.dk_pke.decrypt(encapsulated_key);
95        let (Kp, rp) = G(&[&mp, &self.ek.h]);
96        let Kbar = J(&[self.z.as_slice(), encapsulated_key.as_ref()]);
97        let cp = self.ek.ek_pke.encrypt(&mp, &rp);
98
99        // Constant-time version of:
100        //
101        // if cp == *ct {
102        //     Kp
103        // } else {
104        //     Kbar
105        // }
106        let equal = cp
107            .iter()
108            .zip(encapsulated_key.iter())
109            .map(|(&x, &y)| constant_time_eq(x, y))
110            .fold(0xff, |x, y| x & y);
111        Ok(Kp
112            .iter()
113            .zip(Kbar.iter())
114            .map(|(x, y)| (equal & x) | (!equal & y))
115            .collect())
116    }
117}
118
119impl<P> DecapsulationKey<P>
120where
121    P: KemParams,
122{
123    /// Get the [`EncapsulationKey`] which corresponds to this [`DecapsulationKey`].
124    pub fn encapsulation_key(&self) -> &EncapsulationKey<P> {
125        &self.ek
126    }
127
128    pub(crate) fn generate(rng: &mut impl CryptoRngCore) -> Self {
129        let d: B32 = rand(rng);
130        let z: B32 = rand(rng);
131        Self::generate_deterministic(&d, &z)
132    }
133
134    #[must_use]
135    #[allow(clippy::similar_names)] // allow dk_pke, ek_pke, following the spec
136    pub(crate) fn generate_deterministic(d: &B32, z: &B32) -> Self {
137        let (dk_pke, ek_pke) = DecryptionKey::generate(d);
138        let ek = EncapsulationKey::new(ek_pke);
139        let z = z.clone();
140        Self { dk_pke, ek, z }
141    }
142}
143
144/// An `EncapsulationKey` provides the ability to encapsulate a shared key so that it can only be
145/// decapsulated by the holder of the corresponding decapsulation key.
146#[derive(Clone, Debug, PartialEq)]
147pub struct EncapsulationKey<P>
148where
149    P: KemParams,
150{
151    ek_pke: EncryptionKey<P>,
152    h: B32,
153}
154
155impl<P> EncapsulationKey<P>
156where
157    P: KemParams,
158{
159    fn new(ek_pke: EncryptionKey<P>) -> Self {
160        let h = H(ek_pke.as_bytes());
161        Self { ek_pke, h }
162    }
163
164    fn encapsulate_deterministic_inner(&self, m: &B32) -> (EncodedCiphertext<P>, SharedKey) {
165        let (K, r) = G(&[m, &self.h]);
166        let c = self.ek_pke.encrypt(m, &r);
167        (c, K)
168    }
169}
170
171impl<P> EncodedSizeUser for EncapsulationKey<P>
172where
173    P: KemParams,
174{
175    type EncodedSize = EncapsulationKeySize<P>;
176
177    fn from_bytes(enc: &Encoded<Self>) -> Self {
178        Self::new(EncryptionKey::from_bytes(enc))
179    }
180
181    fn as_bytes(&self) -> Encoded<Self> {
182        self.ek_pke.as_bytes()
183    }
184}
185
186impl<P> ::kem::Encapsulate<EncodedCiphertext<P>, SharedKey> for EncapsulationKey<P>
187where
188    P: KemParams,
189{
190    // TODO(RLB) Make Infallible
191    // TODO(RLB) Swap the order of the
192    type Error = ();
193
194    fn encapsulate(
195        &self,
196        rng: &mut impl CryptoRngCore,
197    ) -> Result<(EncodedCiphertext<P>, SharedKey), Self::Error> {
198        let m: B32 = rand(rng);
199        Ok(self.encapsulate_deterministic_inner(&m))
200    }
201}
202
203#[cfg(feature = "deterministic")]
204impl<P> crate::EncapsulateDeterministic<EncodedCiphertext<P>, SharedKey> for EncapsulationKey<P>
205where
206    P: KemParams,
207{
208    // TODO(RLB) Make Infallible
209    type Error = ();
210
211    fn encapsulate_deterministic(
212        &self,
213        m: &B32,
214    ) -> Result<(EncodedCiphertext<P>, SharedKey), Self::Error> {
215        Ok(self.encapsulate_deterministic_inner(m))
216    }
217}
218
219/// An implementation of overall ML-KEM functionality.  Generic over parameter sets, but then ties
220/// together all of the other related types and sizes.
221pub struct Kem<P>
222where
223    P: KemParams,
224{
225    _phantom: PhantomData<P>,
226}
227
228impl<P> crate::KemCore for Kem<P>
229where
230    P: KemParams,
231{
232    type SharedKeySize = U32;
233    type CiphertextSize = P::CiphertextSize;
234    type DecapsulationKey = DecapsulationKey<P>;
235    type EncapsulationKey = EncapsulationKey<P>;
236
237    /// Generate a new (decapsulation, encapsulation) key pair
238    fn generate(rng: &mut impl CryptoRngCore) -> (Self::DecapsulationKey, Self::EncapsulationKey) {
239        let dk = Self::DecapsulationKey::generate(rng);
240        let ek = dk.encapsulation_key().clone();
241        (dk, ek)
242    }
243
244    #[cfg(feature = "deterministic")]
245    fn generate_deterministic(
246        d: &B32,
247        z: &B32,
248    ) -> (Self::DecapsulationKey, Self::EncapsulationKey) {
249        let dk = Self::DecapsulationKey::generate_deterministic(d, z);
250        let ek = dk.encapsulation_key().clone();
251        (dk, ek)
252    }
253}
254
255#[cfg(test)]
256mod test {
257    use super::*;
258    use crate::{MlKem1024Params, MlKem512Params, MlKem768Params};
259    use ::kem::{Decapsulate, Encapsulate};
260
261    fn round_trip_test<P>()
262    where
263        P: KemParams,
264    {
265        let mut rng = rand::thread_rng();
266
267        let dk = DecapsulationKey::<P>::generate(&mut rng);
268        let ek = dk.encapsulation_key();
269
270        let (ct, k_send) = ek.encapsulate(&mut rng).unwrap();
271        let k_recv = dk.decapsulate(&ct).unwrap();
272        assert_eq!(k_send, k_recv);
273    }
274
275    #[test]
276    fn round_trip() {
277        round_trip_test::<MlKem512Params>();
278        round_trip_test::<MlKem768Params>();
279        round_trip_test::<MlKem1024Params>();
280    }
281
282    fn codec_test<P>()
283    where
284        P: KemParams,
285    {
286        let mut rng = rand::thread_rng();
287        let dk_original = DecapsulationKey::<P>::generate(&mut rng);
288        let ek_original = dk_original.encapsulation_key().clone();
289
290        let dk_encoded = dk_original.as_bytes();
291        let dk_decoded = DecapsulationKey::from_bytes(&dk_encoded);
292        assert_eq!(dk_original, dk_decoded);
293
294        let ek_encoded = ek_original.as_bytes();
295        let ek_decoded = EncapsulationKey::from_bytes(&ek_encoded);
296        assert_eq!(ek_original, ek_decoded);
297    }
298
299    #[test]
300    fn codec() {
301        codec_test::<MlKem512Params>();
302        codec_test::<MlKem768Params>();
303        codec_test::<MlKem1024Params>();
304    }
305}