1pub use ::kem::{
5 Decapsulate, Decapsulator, Encapsulate, Generate, InvalidKey, Key, KeyExport, KeyInit,
6 KeySizeUser, TryKeyInit,
7};
8
9use crate::{
10 Encoded, EncodedSizeUser, KemCore, Seed,
11 crypto::{G, H, J},
12 param::{
13 DecapsulationKeySize, EncapsulationKeySize, EncodedCiphertext, ExpandedDecapsulationKey,
14 KemParams,
15 },
16 pke::{DecryptionKey, EncryptionKey},
17 util::B32,
18};
19use array::typenum::{U32, U64};
20use core::marker::PhantomData;
21use rand_core::{CryptoRng, TryCryptoRng, TryRng};
22use subtle::{ConditionallySelectable, ConstantTimeEq};
23
24#[cfg(feature = "zeroize")]
25use zeroize::{Zeroize, ZeroizeOnDrop};
26
27pub(crate) type SharedKey = B32;
29
30#[derive(Clone, Debug)]
33pub struct DecapsulationKey<P>
34where
35 P: KemParams,
36{
37 dk_pke: DecryptionKey<P>,
38 ek: EncapsulationKey<P>,
39 d: Option<B32>,
40 z: B32,
41}
42
43impl<P> PartialEq for DecapsulationKey<P>
46where
47 P: KemParams,
48{
49 fn eq(&self, other: &Self) -> bool {
50 self.dk_pke.ct_eq(&other.dk_pke).into() && self.ek.eq(&other.ek) && self.z.eq(&other.z)
51 }
52}
53
54#[cfg(feature = "zeroize")]
55impl<P> Drop for DecapsulationKey<P>
56where
57 P: KemParams,
58{
59 fn drop(&mut self) {
60 self.dk_pke.zeroize();
61 self.z.zeroize();
62 }
63}
64
65#[cfg(feature = "zeroize")]
66impl<P> ZeroizeOnDrop for DecapsulationKey<P> where P: KemParams {}
67
68impl<P> From<Seed> for DecapsulationKey<P>
69where
70 P: KemParams,
71{
72 fn from(seed: Seed) -> Self {
73 Self::from_seed(seed)
74 }
75}
76
77impl<P> Decapsulate for DecapsulationKey<P>
78where
79 P: KemParams,
80{
81 fn decapsulate(&self, encapsulated_key: &EncodedCiphertext<P>) -> SharedKey {
82 let mp = self.dk_pke.decrypt(encapsulated_key);
83 let (Kp, rp) = G(&[&mp, &self.ek.h]);
84 let Kbar = J(&[self.z.as_slice(), encapsulated_key.as_ref()]);
85 let cp = self.ek.ek_pke.encrypt(&mp, &rp);
86 B32::conditional_select(&Kbar, &Kp, cp.ct_eq(encapsulated_key))
87 }
88}
89
90impl<P> Decapsulator for DecapsulationKey<P>
91where
92 P: KemParams,
93{
94 type Encapsulator = EncapsulationKey<P>;
95
96 fn encapsulator(&self) -> &EncapsulationKey<P> {
97 &self.ek
98 }
99}
100
101impl<P> EncodedSizeUser for DecapsulationKey<P>
102where
103 P: KemParams,
104{
105 type EncodedSize = DecapsulationKeySize<P>;
106
107 fn from_encoded_bytes(expanded: &Encoded<Self>) -> Result<Self, InvalidKey> {
108 #[allow(deprecated)]
109 Self::from_expanded(expanded)
110 }
111
112 fn to_encoded_bytes(&self) -> Encoded<Self> {
113 let dk_pke = self.dk_pke.to_bytes();
114 let ek = self.ek.to_encoded_bytes();
115 P::concat_dk(dk_pke, ek, self.ek.h.clone(), self.z.clone())
116 }
117}
118
119impl<P> Generate for DecapsulationKey<P>
120where
121 P: KemParams,
122{
123 fn try_generate_from_rng<R>(rng: &mut R) -> Result<Self, <R as TryRng>::Error>
124 where
125 R: TryCryptoRng + ?Sized,
126 {
127 Self::try_generate_from_rng(rng)
128 }
129}
130
131impl<P> KeySizeUser for DecapsulationKey<P>
132where
133 P: KemParams,
134{
135 type KeySize = U64;
136}
137
138impl<P> KeyInit for DecapsulationKey<P>
139where
140 P: KemParams,
141{
142 #[inline]
143 fn new(seed: &Seed) -> Self {
144 Self::from_seed(*seed)
145 }
146}
147
148impl<P> DecapsulationKey<P>
149where
150 P: KemParams,
151{
152 #[inline]
154 #[must_use]
155 pub fn from_seed(seed: Seed) -> Self {
156 let (d, z) = seed.split();
157 Self::generate_deterministic(d, z)
158 }
159
160 #[deprecated(since = "0.3.0", note = "use `DecapsulationKey::from_seed` instead")]
168 pub fn from_expanded(enc: &ExpandedDecapsulationKey<P>) -> Result<Self, InvalidKey> {
169 let (dk_pke, ek_pke, h, z) = P::split_dk(enc);
170 let ek_pke = EncryptionKey::from_bytes(ek_pke)?;
171
172 Ok(Self {
176 dk_pke: DecryptionKey::from_bytes(dk_pke),
177 ek: EncapsulationKey {
178 ek_pke,
179 h: h.clone(),
180 },
181 d: None,
182 z: z.clone(),
183 })
184 }
185
186 #[inline]
199 pub fn to_seed(&self) -> Option<Seed> {
200 self.d.map(|d| d.concat(self.z))
201 }
202
203 pub fn encapsulation_key(&self) -> &EncapsulationKey<P> {
205 &self.ek
206 }
207
208 #[inline]
209 pub(crate) fn try_generate_from_rng<R>(rng: &mut R) -> Result<Self, <R as TryRng>::Error>
210 where
211 R: TryCryptoRng + ?Sized,
212 {
213 let d = B32::try_generate_from_rng(rng)?;
214 let z = B32::try_generate_from_rng(rng)?;
215 Ok(Self::generate_deterministic(d, z))
216 }
217
218 #[inline]
219 #[must_use]
220 #[allow(clippy::similar_names)] pub(crate) fn generate_deterministic(d: B32, z: B32) -> Self {
222 let (dk_pke, ek_pke) = DecryptionKey::generate(&d);
223 let ek = EncapsulationKey::new(ek_pke);
224 let d = Some(d);
225 Self { dk_pke, ek, d, z }
226 }
227}
228
229#[derive(Clone, Debug)]
232pub struct EncapsulationKey<P>
233where
234 P: KemParams,
235{
236 ek_pke: EncryptionKey<P>,
237 h: B32,
238}
239
240impl<P> EncapsulationKey<P>
241where
242 P: KemParams,
243{
244 pub(crate) fn new(ek_pke: EncryptionKey<P>) -> Self {
245 let h = H(ek_pke.to_bytes());
246 Self { ek_pke, h }
247 }
248
249 #[cfg_attr(not(feature = "hazmat"), doc(hidden))]
255 pub fn encapsulate_deterministic(&self, m: &B32) -> (EncodedCiphertext<P>, SharedKey) {
256 let (K, r) = G(&[m, &self.h]);
257 let c = self.ek_pke.encrypt(m, &r);
258 (c, K)
259 }
260}
261
262impl<P> Encapsulate for EncapsulationKey<P>
263where
264 P: KemParams,
265{
266 fn encapsulate_with_rng<R>(&self, rng: &mut R) -> (EncodedCiphertext<P>, SharedKey)
267 where
268 R: CryptoRng + ?Sized,
269 {
270 let m = B32::generate_from_rng(rng);
271 self.encapsulate_deterministic(&m)
272 }
273}
274
275impl<P> EncodedSizeUser for EncapsulationKey<P>
276where
277 P: KemParams,
278{
279 type EncodedSize = EncapsulationKeySize<P>;
280
281 fn from_encoded_bytes(enc: &Encoded<Self>) -> Result<Self, InvalidKey> {
282 Ok(Self::new(EncryptionKey::from_bytes(enc)?))
283 }
284
285 fn to_encoded_bytes(&self) -> Encoded<Self> {
286 self.ek_pke.to_bytes()
287 }
288}
289
290impl<P> ::kem::KemParams for EncapsulationKey<P>
291where
292 P: KemParams,
293{
294 type CiphertextSize = P::CiphertextSize;
295 type SharedSecretSize = U32;
296}
297
298impl<P> KeyExport for EncapsulationKey<P>
299where
300 P: KemParams,
301{
302 fn to_bytes(&self) -> Key<Self> {
303 self.ek_pke.to_bytes()
304 }
305}
306
307impl<P> KeySizeUser for EncapsulationKey<P>
308where
309 P: KemParams,
310{
311 type KeySize = EncapsulationKeySize<P>;
312}
313
314impl<P> TryKeyInit for EncapsulationKey<P>
315where
316 P: KemParams,
317{
318 fn new(encapsulation_key: &Key<Self>) -> Result<Self, InvalidKey> {
319 EncryptionKey::from_bytes(encapsulation_key)
320 .map(Self::new)
321 .map_err(|_| InvalidKey)
322 }
323}
324
325impl<P> Eq for EncapsulationKey<P> where P: KemParams {}
326impl<P> PartialEq for EncapsulationKey<P>
327where
328 P: KemParams,
329{
330 fn eq(&self, other: &Self) -> bool {
331 self.ek_pke == other.ek_pke && self.h == other.h
333 }
334}
335
336#[derive(Clone)]
339pub struct Kem<P>
340where
341 P: KemParams,
342{
343 _phantom: PhantomData<P>,
344}
345
346impl<P> KemCore for Kem<P>
347where
348 P: KemParams,
349{
350 type SharedKeySize = U32;
351 type CiphertextSize = P::CiphertextSize;
352 type DecapsulationKey = DecapsulationKey<P>;
353 type EncapsulationKey = EncapsulationKey<P>;
354
355 fn generate<R: CryptoRng + ?Sized>(
357 rng: &mut R,
358 ) -> (Self::DecapsulationKey, Self::EncapsulationKey) {
359 let dk = Self::DecapsulationKey::generate_from_rng(rng);
360 let ek = dk.encapsulation_key().clone();
361 (dk, ek)
362 }
363
364 fn from_seed(seed: Seed) -> (Self::DecapsulationKey, Self::EncapsulationKey) {
365 let dk = Self::DecapsulationKey::from_seed(seed);
366 let ek = dk.encapsulation_key().clone();
367 (dk, ek)
368 }
369}
370
371#[cfg(test)]
372mod test {
373 use super::*;
374 use crate::{MlKem512Params, MlKem768Params, MlKem1024Params};
375 use ::kem::{Decapsulate, Encapsulate, Generate};
376 use getrandom::SysRng;
377 use rand_core::UnwrapErr;
378
379 fn round_trip_test<P>()
380 where
381 P: KemParams,
382 {
383 let mut rng = UnwrapErr(SysRng);
384
385 let dk = DecapsulationKey::<P>::generate_from_rng(&mut rng);
386 let ek = dk.encapsulation_key();
387
388 let (ct, k_send) = ek.encapsulate_with_rng(&mut rng);
389 let k_recv = dk.decapsulate(&ct);
390 assert_eq!(k_send, k_recv);
391 }
392
393 #[test]
394 fn round_trip() {
395 round_trip_test::<MlKem512Params>();
396 round_trip_test::<MlKem768Params>();
397 round_trip_test::<MlKem1024Params>();
398 }
399
400 fn expanded_key_test<P>()
401 where
402 P: KemParams,
403 {
404 let mut rng = UnwrapErr(SysRng);
405 let dk_original = DecapsulationKey::<P>::generate_from_rng(&mut rng);
406 let ek_original = dk_original.encapsulation_key().clone();
407
408 let dk_encoded = dk_original.to_encoded_bytes();
409 let dk_decoded = DecapsulationKey::from_encoded_bytes(&dk_encoded).unwrap();
410 assert_eq!(dk_original, dk_decoded);
411
412 let ek_encoded = ek_original.to_encoded_bytes();
413 let ek_decoded = EncapsulationKey::from_encoded_bytes(&ek_encoded).unwrap();
414 assert_eq!(ek_original, ek_decoded);
415 }
416
417 #[test]
418 fn expanded_key() {
419 expanded_key_test::<MlKem512Params>();
420 expanded_key_test::<MlKem768Params>();
421 expanded_key_test::<MlKem1024Params>();
422 }
423
424 fn seed_test<P>()
425 where
426 P: KemParams,
427 {
428 let mut rng = UnwrapErr(SysRng);
429 let mut seed = Seed::default();
430 rng.try_fill_bytes(&mut seed).unwrap();
431
432 let dk = DecapsulationKey::<P>::from_seed(seed.clone());
433 let seed_encoded = dk.to_seed().unwrap();
434 assert_eq!(seed, seed_encoded);
435 }
436
437 #[test]
438 fn seed() {
439 seed_test::<MlKem512Params>();
440 seed_test::<MlKem768Params>();
441 seed_test::<MlKem1024Params>();
442 }
443}