1pub use ::kem::{
5 Decapsulate, Decapsulator, Encapsulate, Generate, InvalidKey, Key, KeyExport, KeyInit,
6 KeySizeUser, TryKeyInit,
7};
8
9use crate::{
10 Encoded, EncodedSizeUser, KemCore, Seed,
11 crypto::{G, H, J},
12 param::{
13 DecapsulationKeySize, EncapsulationKeySize, EncodedCiphertext, ExpandedDecapsulationKey,
14 KemParams,
15 },
16 pke::{DecryptionKey, EncryptionKey},
17 util::B32,
18};
19use array::typenum::{U32, U64};
20use core::marker::PhantomData;
21use rand_core::{CryptoRng, TryCryptoRng, TryRngCore};
22use subtle::{ConditionallySelectable, ConstantTimeEq};
23
24#[cfg(feature = "deterministic")]
25use core::convert::Infallible;
26#[cfg(feature = "zeroize")]
27use zeroize::{Zeroize, ZeroizeOnDrop};
28
29pub(crate) type SharedKey = B32;
31
32#[derive(Clone, Debug)]
35pub struct DecapsulationKey<P>
36where
37 P: KemParams,
38{
39 dk_pke: DecryptionKey<P>,
40 ek: EncapsulationKey<P>,
41 d: Option<B32>,
42 z: B32,
43}
44
45impl<P> PartialEq for DecapsulationKey<P>
48where
49 P: KemParams,
50{
51 fn eq(&self, other: &Self) -> bool {
52 self.dk_pke.ct_eq(&other.dk_pke).into() && self.ek.eq(&other.ek) && self.z.eq(&other.z)
53 }
54}
55
56#[cfg(feature = "zeroize")]
57impl<P> Drop for DecapsulationKey<P>
58where
59 P: KemParams,
60{
61 fn drop(&mut self) {
62 self.dk_pke.zeroize();
63 self.z.zeroize();
64 }
65}
66
67#[cfg(feature = "zeroize")]
68impl<P> ZeroizeOnDrop for DecapsulationKey<P> where P: KemParams {}
69
70impl<P> From<Seed> for DecapsulationKey<P>
71where
72 P: KemParams,
73{
74 fn from(seed: Seed) -> Self {
75 Self::from_seed(seed)
76 }
77}
78
79impl<P> Decapsulate for DecapsulationKey<P>
80where
81 P: KemParams,
82{
83 fn decapsulate(&self, encapsulated_key: &EncodedCiphertext<P>) -> SharedKey {
84 let mp = self.dk_pke.decrypt(encapsulated_key);
85 let (Kp, rp) = G(&[&mp, &self.ek.h]);
86 let Kbar = J(&[self.z.as_slice(), encapsulated_key.as_ref()]);
87 let cp = self.ek.ek_pke.encrypt(&mp, &rp);
88 B32::conditional_select(&Kbar, &Kp, cp.ct_eq(encapsulated_key))
89 }
90}
91
92impl<P> Decapsulator for DecapsulationKey<P>
93where
94 P: KemParams,
95{
96 type Encapsulator = EncapsulationKey<P>;
97
98 fn encapsulator(&self) -> &EncapsulationKey<P> {
99 &self.ek
100 }
101}
102
103impl<P> EncodedSizeUser for DecapsulationKey<P>
104where
105 P: KemParams,
106{
107 type EncodedSize = DecapsulationKeySize<P>;
108
109 fn from_encoded_bytes(expanded: &Encoded<Self>) -> Result<Self, InvalidKey> {
110 #[allow(deprecated)]
111 Self::from_expanded(expanded)
112 }
113
114 fn to_encoded_bytes(&self) -> Encoded<Self> {
115 let dk_pke = self.dk_pke.to_bytes();
116 let ek = self.ek.to_encoded_bytes();
117 P::concat_dk(dk_pke, ek, self.ek.h.clone(), self.z.clone())
118 }
119}
120
121impl<P> Generate for DecapsulationKey<P>
122where
123 P: KemParams,
124{
125 fn try_generate_from_rng<R>(rng: &mut R) -> Result<Self, <R as TryRngCore>::Error>
126 where
127 R: TryCryptoRng + ?Sized,
128 {
129 Self::try_generate_from_rng(rng)
130 }
131}
132
133impl<P> KeySizeUser for DecapsulationKey<P>
134where
135 P: KemParams,
136{
137 type KeySize = U64;
138}
139
140impl<P> KeyInit for DecapsulationKey<P>
141where
142 P: KemParams,
143{
144 #[inline]
145 fn new(seed: &Seed) -> Self {
146 Self::from_seed(*seed)
147 }
148}
149
150impl<P> DecapsulationKey<P>
151where
152 P: KemParams,
153{
154 #[inline]
156 #[must_use]
157 pub fn from_seed(seed: Seed) -> Self {
158 let (d, z) = seed.split();
159 Self::generate_deterministic(d, z)
160 }
161
162 #[deprecated(since = "0.3.0", note = "use `DecapsulationKey::from_seed` instead")]
170 pub fn from_expanded(enc: &ExpandedDecapsulationKey<P>) -> Result<Self, InvalidKey> {
171 let (dk_pke, ek_pke, h, z) = P::split_dk(enc);
172 let ek_pke = EncryptionKey::from_bytes(ek_pke)?;
173
174 Ok(Self {
178 dk_pke: DecryptionKey::from_bytes(dk_pke),
179 ek: EncapsulationKey {
180 ek_pke,
181 h: h.clone(),
182 },
183 d: None,
184 z: z.clone(),
185 })
186 }
187
188 #[inline]
201 pub fn to_seed(&self) -> Option<Seed> {
202 self.d.map(|d| d.concat(self.z))
203 }
204
205 pub fn encapsulation_key(&self) -> &EncapsulationKey<P> {
207 &self.ek
208 }
209
210 #[inline]
211 pub(crate) fn try_generate_from_rng<R>(rng: &mut R) -> Result<Self, <R as TryRngCore>::Error>
212 where
213 R: TryCryptoRng + ?Sized,
214 {
215 let d = B32::try_generate_from_rng(rng)?;
216 let z = B32::try_generate_from_rng(rng)?;
217 Ok(Self::generate_deterministic(d, z))
218 }
219
220 #[inline]
221 #[must_use]
222 #[allow(clippy::similar_names)] pub(crate) fn generate_deterministic(d: B32, z: B32) -> Self {
224 let (dk_pke, ek_pke) = DecryptionKey::generate(&d);
225 let ek = EncapsulationKey::new(ek_pke);
226 let d = Some(d);
227 Self { dk_pke, ek, d, z }
228 }
229}
230
231#[derive(Clone, Debug)]
234pub struct EncapsulationKey<P>
235where
236 P: KemParams,
237{
238 ek_pke: EncryptionKey<P>,
239 h: B32,
240}
241
242impl<P> EncapsulationKey<P>
243where
244 P: KemParams,
245{
246 pub(crate) fn new(ek_pke: EncryptionKey<P>) -> Self {
247 let h = H(ek_pke.to_bytes());
248 Self { ek_pke, h }
249 }
250
251 fn encapsulate_deterministic_inner(&self, m: &B32) -> (EncodedCiphertext<P>, SharedKey) {
252 let (K, r) = G(&[m, &self.h]);
253 let c = self.ek_pke.encrypt(m, &r);
254 (c, K)
255 }
256}
257
258impl<P> Encapsulate for EncapsulationKey<P>
259where
260 P: KemParams,
261{
262 fn encapsulate_with_rng<R: TryCryptoRng + ?Sized>(
263 &self,
264 rng: &mut R,
265 ) -> Result<(EncodedCiphertext<P>, SharedKey), R::Error> {
266 let m = B32::try_generate_from_rng(rng)?;
267 Ok(self.encapsulate_deterministic_inner(&m))
268 }
269}
270
271impl<P> EncodedSizeUser for EncapsulationKey<P>
272where
273 P: KemParams,
274{
275 type EncodedSize = EncapsulationKeySize<P>;
276
277 fn from_encoded_bytes(enc: &Encoded<Self>) -> Result<Self, InvalidKey> {
278 Ok(Self::new(EncryptionKey::from_bytes(enc)?))
279 }
280
281 fn to_encoded_bytes(&self) -> Encoded<Self> {
282 self.ek_pke.to_bytes()
283 }
284}
285
286impl<P> ::kem::KemParams for EncapsulationKey<P>
287where
288 P: KemParams,
289{
290 type CiphertextSize = P::CiphertextSize;
291 type SharedSecretSize = U32;
292}
293
294impl<P> KeyExport for EncapsulationKey<P>
295where
296 P: KemParams,
297{
298 fn to_bytes(&self) -> Key<Self> {
299 self.ek_pke.to_bytes()
300 }
301}
302
303impl<P> KeySizeUser for EncapsulationKey<P>
304where
305 P: KemParams,
306{
307 type KeySize = EncapsulationKeySize<P>;
308}
309
310impl<P> TryKeyInit for EncapsulationKey<P>
311where
312 P: KemParams,
313{
314 fn new(encapsulation_key: &Key<Self>) -> Result<Self, InvalidKey> {
315 EncryptionKey::from_bytes(encapsulation_key)
316 .map(Self::new)
317 .map_err(|_| InvalidKey)
318 }
319}
320
321impl<P> Eq for EncapsulationKey<P> where P: KemParams {}
322impl<P> PartialEq for EncapsulationKey<P>
323where
324 P: KemParams,
325{
326 fn eq(&self, other: &Self) -> bool {
327 self.ek_pke == other.ek_pke && self.h == other.h
329 }
330}
331
332#[cfg(feature = "deterministic")]
333impl<P> crate::EncapsulateDeterministic<EncodedCiphertext<P>, SharedKey> for EncapsulationKey<P>
334where
335 P: KemParams,
336{
337 type Error = Infallible;
338
339 fn encapsulate_deterministic(
340 &self,
341 m: &B32,
342 ) -> Result<(EncodedCiphertext<P>, SharedKey), Self::Error> {
343 Ok(self.encapsulate_deterministic_inner(m))
344 }
345}
346
347#[derive(Clone)]
350pub struct Kem<P>
351where
352 P: KemParams,
353{
354 _phantom: PhantomData<P>,
355}
356
357impl<P> KemCore for Kem<P>
358where
359 P: KemParams,
360{
361 type SharedKeySize = U32;
362 type CiphertextSize = P::CiphertextSize;
363 type DecapsulationKey = DecapsulationKey<P>;
364 type EncapsulationKey = EncapsulationKey<P>;
365
366 fn generate<R: CryptoRng + ?Sized>(
368 rng: &mut R,
369 ) -> (Self::DecapsulationKey, Self::EncapsulationKey) {
370 let dk = Self::DecapsulationKey::generate_from_rng(rng);
371 let ek = dk.encapsulation_key().clone();
372 (dk, ek)
373 }
374
375 fn from_seed(seed: Seed) -> (Self::DecapsulationKey, Self::EncapsulationKey) {
376 let dk = Self::DecapsulationKey::from_seed(seed);
377 let ek = dk.encapsulation_key().clone();
378 (dk, ek)
379 }
380}
381
382#[cfg(test)]
383mod test {
384 use super::*;
385 use crate::{MlKem512Params, MlKem768Params, MlKem1024Params};
386 use ::kem::{Decapsulate, Encapsulate, Generate};
387 use getrandom::SysRng;
388 use rand_core::TryRngCore;
389
390 fn round_trip_test<P>()
391 where
392 P: KemParams,
393 {
394 let mut rng = SysRng.unwrap_err();
395
396 let dk = DecapsulationKey::<P>::generate_from_rng(&mut rng);
397 let ek = dk.encapsulation_key();
398
399 let (ct, k_send) = ek.encapsulate_with_rng(&mut rng).unwrap();
400 let k_recv = dk.decapsulate(&ct);
401 assert_eq!(k_send, k_recv);
402 }
403
404 #[test]
405 fn round_trip() {
406 round_trip_test::<MlKem512Params>();
407 round_trip_test::<MlKem768Params>();
408 round_trip_test::<MlKem1024Params>();
409 }
410
411 fn expanded_key_test<P>()
412 where
413 P: KemParams,
414 {
415 let mut rng = SysRng.unwrap_err();
416 let dk_original = DecapsulationKey::<P>::generate_from_rng(&mut rng);
417 let ek_original = dk_original.encapsulation_key().clone();
418
419 let dk_encoded = dk_original.to_encoded_bytes();
420 let dk_decoded = DecapsulationKey::from_encoded_bytes(&dk_encoded).unwrap();
421 assert_eq!(dk_original, dk_decoded);
422
423 let ek_encoded = ek_original.to_encoded_bytes();
424 let ek_decoded = EncapsulationKey::from_encoded_bytes(&ek_encoded).unwrap();
425 assert_eq!(ek_original, ek_decoded);
426 }
427
428 #[test]
429 fn expanded_key() {
430 expanded_key_test::<MlKem512Params>();
431 expanded_key_test::<MlKem768Params>();
432 expanded_key_test::<MlKem1024Params>();
433 }
434
435 fn seed_test<P>()
436 where
437 P: KemParams,
438 {
439 let mut rng = SysRng.unwrap_err();
440 let mut seed = Seed::default();
441 rng.try_fill_bytes(&mut seed).unwrap();
442
443 let dk = DecapsulationKey::<P>::from_seed(seed.clone());
444 let seed_encoded = dk.to_seed().unwrap();
445 assert_eq!(seed, seed_encoded);
446 }
447
448 #[test]
449 fn seed() {
450 seed_test::<MlKem512Params>();
451 seed_test::<MlKem768Params>();
452 seed_test::<MlKem1024Params>();
453 }
454}