1use core::convert::Infallible;
2use core::marker::PhantomData;
3use hybrid_array::typenum::{U32, U64};
4use rand_core::{CryptoRng, TryCryptoRng};
5use subtle::{ConditionallySelectable, ConstantTimeEq};
6
7use crate::crypto::{G, H, J, rand};
8use crate::param::{DecapsulationKeySize, EncapsulationKeySize, EncodedCiphertext, KemParams};
9use crate::pke::{DecryptionKey, EncryptionKey};
10use crate::util::B32;
11use crate::{Encoded, EncodedSizeUser, Seed};
12
13#[cfg(feature = "zeroize")]
14use zeroize::{Zeroize, ZeroizeOnDrop};
15
16pub use ::kem::{Decapsulate, Encapsulate};
18
19pub(crate) type SharedKey = B32;
21
22#[derive(Clone, Debug)]
25pub struct DecapsulationKey<P>
26where
27 P: KemParams,
28{
29 dk_pke: DecryptionKey<P>,
30 ek: EncapsulationKey<P>,
31 d: Option<B32>,
32 z: B32,
33}
34
35impl<P> PartialEq for DecapsulationKey<P>
38where
39 P: KemParams,
40{
41 fn eq(&self, other: &Self) -> bool {
42 self.dk_pke.ct_eq(&other.dk_pke).into() && self.ek.eq(&other.ek) && self.z.eq(&other.z)
43 }
44}
45
46#[cfg(feature = "zeroize")]
47impl<P> Drop for DecapsulationKey<P>
48where
49 P: KemParams,
50{
51 fn drop(&mut self) {
52 self.dk_pke.zeroize();
53 self.z.zeroize();
54 }
55}
56
57#[cfg(feature = "zeroize")]
58impl<P> ZeroizeOnDrop for DecapsulationKey<P> where P: KemParams {}
59
60impl<P> From<Seed> for DecapsulationKey<P>
61where
62 P: KemParams,
63{
64 fn from(seed: Seed) -> Self {
65 Self::from_seed(seed)
66 }
67}
68
69impl<P> EncodedSizeUser for DecapsulationKey<P>
70where
71 P: KemParams,
72{
73 type EncodedSize = DecapsulationKeySize<P>;
74
75 #[allow(clippy::similar_names)] fn from_bytes(enc: &Encoded<Self>) -> Self {
77 let (dk_pke, ek_pke, h, z) = P::split_dk(enc);
78 let ek_pke = EncryptionKey::from_bytes(ek_pke);
79
80 Self {
84 dk_pke: DecryptionKey::from_bytes(dk_pke),
85 ek: EncapsulationKey {
86 ek_pke,
87 h: h.clone(),
88 },
89 d: None,
90 z: z.clone(),
91 }
92 }
93
94 fn as_bytes(&self) -> Encoded<Self> {
95 let dk_pke = self.dk_pke.as_bytes();
96 let ek = self.ek.as_bytes();
97 P::concat_dk(dk_pke, ek, self.ek.h.clone(), self.z.clone())
98 }
99}
100
101impl<P> ::kem::KeySizeUser for DecapsulationKey<P>
102where
103 P: KemParams,
104{
105 type KeySize = U64;
106}
107
108impl<P> ::kem::KeyInit for DecapsulationKey<P>
109where
110 P: KemParams,
111{
112 #[inline]
113 fn new(seed: &Seed) -> Self {
114 Self::from_seed(*seed)
115 }
116}
117
118impl<P> ::kem::Decapsulate<EncodedCiphertext<P>, SharedKey> for DecapsulationKey<P>
119where
120 P: KemParams,
121{
122 type Encapsulator = EncapsulationKey<P>;
123 type Error = Infallible;
124
125 fn decapsulate(
126 &self,
127 encapsulated_key: &EncodedCiphertext<P>,
128 ) -> Result<SharedKey, Self::Error> {
129 let mp = self.dk_pke.decrypt(encapsulated_key);
130 let (Kp, rp) = G(&[&mp, &self.ek.h]);
131 let Kbar = J(&[self.z.as_slice(), encapsulated_key.as_ref()]);
132 let cp = self.ek.ek_pke.encrypt(&mp, &rp);
133 Ok(B32::conditional_select(
134 &Kbar,
135 &Kp,
136 cp.ct_eq(encapsulated_key),
137 ))
138 }
139
140 fn encapsulator(&self) -> EncapsulationKey<P> {
141 self.ek.clone()
142 }
143}
144
145impl<P> DecapsulationKey<P>
146where
147 P: KemParams,
148{
149 #[inline]
151 #[must_use]
152 pub fn from_seed(seed: Seed) -> Self {
153 let (d, z) = seed.split();
154 Self::generate_deterministic(d, z)
155 }
156
157 #[inline]
168 pub fn to_seed(&self) -> Option<Seed> {
169 self.d.map(|d| d.concat(self.z))
170 }
171
172 pub fn encapsulation_key(&self) -> &EncapsulationKey<P> {
174 &self.ek
175 }
176
177 #[inline]
178 pub(crate) fn generate<R: CryptoRng + ?Sized>(rng: &mut R) -> Self {
179 let d: B32 = rand(rng);
180 let z: B32 = rand(rng);
181 Self::generate_deterministic(d, z)
182 }
183
184 #[inline]
185 #[must_use]
186 #[allow(clippy::similar_names)] pub(crate) fn generate_deterministic(d: B32, z: B32) -> Self {
188 let (dk_pke, ek_pke) = DecryptionKey::generate(&d);
189 let ek = EncapsulationKey::new(ek_pke);
190 let d = Some(d);
191 Self { dk_pke, ek, d, z }
192 }
193}
194
195#[derive(Clone, Debug, PartialEq)]
198pub struct EncapsulationKey<P>
199where
200 P: KemParams,
201{
202 ek_pke: EncryptionKey<P>,
203 h: B32,
204}
205
206impl<P> EncapsulationKey<P>
207where
208 P: KemParams,
209{
210 pub(crate) fn new(ek_pke: EncryptionKey<P>) -> Self {
211 let h = H(ek_pke.as_bytes());
212 Self { ek_pke, h }
213 }
214
215 fn encapsulate_deterministic_inner(&self, m: &B32) -> (EncodedCiphertext<P>, SharedKey) {
216 let (K, r) = G(&[m, &self.h]);
217 let c = self.ek_pke.encrypt(m, &r);
218 (c, K)
219 }
220}
221
222impl<P> EncodedSizeUser for EncapsulationKey<P>
223where
224 P: KemParams,
225{
226 type EncodedSize = EncapsulationKeySize<P>;
227
228 fn from_bytes(enc: &Encoded<Self>) -> Self {
229 Self::new(EncryptionKey::from_bytes(enc))
230 }
231
232 fn as_bytes(&self) -> Encoded<Self> {
233 self.ek_pke.as_bytes()
234 }
235}
236
237impl<P> ::kem::Encapsulate<EncodedCiphertext<P>, SharedKey> for EncapsulationKey<P>
238where
239 P: KemParams,
240{
241 type Error = Infallible;
242
243 fn encapsulate<R: TryCryptoRng + ?Sized>(
244 &self,
245 rng: &mut R,
246 ) -> Result<(EncodedCiphertext<P>, SharedKey), Self::Error> {
247 let m: B32 = rand(&mut rng.unwrap_mut());
248 Ok(self.encapsulate_deterministic_inner(&m))
249 }
250}
251
252#[cfg(feature = "deterministic")]
253impl<P> crate::EncapsulateDeterministic<EncodedCiphertext<P>, SharedKey> for EncapsulationKey<P>
254where
255 P: KemParams,
256{
257 type Error = Infallible;
258
259 fn encapsulate_deterministic(
260 &self,
261 m: &B32,
262 ) -> Result<(EncodedCiphertext<P>, SharedKey), Self::Error> {
263 Ok(self.encapsulate_deterministic_inner(m))
264 }
265}
266
267#[derive(Clone)]
270pub struct Kem<P>
271where
272 P: KemParams,
273{
274 _phantom: PhantomData<P>,
275}
276
277impl<P> crate::KemCore for Kem<P>
278where
279 P: KemParams,
280{
281 type SharedKeySize = U32;
282 type CiphertextSize = P::CiphertextSize;
283 type DecapsulationKey = DecapsulationKey<P>;
284 type EncapsulationKey = EncapsulationKey<P>;
285
286 fn generate<R: CryptoRng + ?Sized>(
288 rng: &mut R,
289 ) -> (Self::DecapsulationKey, Self::EncapsulationKey) {
290 let dk = Self::DecapsulationKey::generate(rng);
291 let ek = dk.encapsulation_key().clone();
292 (dk, ek)
293 }
294
295 fn from_seed(seed: Seed) -> (Self::DecapsulationKey, Self::EncapsulationKey) {
296 let dk = Self::DecapsulationKey::from_seed(seed);
297 let ek = dk.encapsulation_key().clone();
298 (dk, ek)
299 }
300}
301
302#[cfg(test)]
303mod test {
304 use super::*;
305 use crate::{MlKem512Params, MlKem768Params, MlKem1024Params};
306 use ::kem::{Decapsulate, Encapsulate};
307 use rand_core::TryRngCore;
308
309 fn round_trip_test<P>()
310 where
311 P: KemParams,
312 {
313 let mut rng = rand::rng();
314
315 let dk = DecapsulationKey::<P>::generate(&mut rng);
316 let ek = dk.encapsulation_key();
317
318 let (ct, k_send) = ek.encapsulate(&mut rng).unwrap();
319 let k_recv = dk.decapsulate(&ct).unwrap();
320 assert_eq!(k_send, k_recv);
321 }
322
323 #[test]
324 fn round_trip() {
325 round_trip_test::<MlKem512Params>();
326 round_trip_test::<MlKem768Params>();
327 round_trip_test::<MlKem1024Params>();
328 }
329
330 fn expanded_key_test<P>()
331 where
332 P: KemParams,
333 {
334 let mut rng = rand::rng();
335 let dk_original = DecapsulationKey::<P>::generate(&mut rng);
336 let ek_original = dk_original.encapsulation_key().clone();
337
338 let dk_encoded = dk_original.as_bytes();
339 let dk_decoded = DecapsulationKey::from_bytes(&dk_encoded);
340 assert_eq!(dk_original, dk_decoded);
341
342 let ek_encoded = ek_original.as_bytes();
343 let ek_decoded = EncapsulationKey::from_bytes(&ek_encoded);
344 assert_eq!(ek_original, ek_decoded);
345 }
346
347 #[test]
348 fn expanded_key() {
349 expanded_key_test::<MlKem512Params>();
350 expanded_key_test::<MlKem768Params>();
351 expanded_key_test::<MlKem1024Params>();
352 }
353
354 fn seed_test<P>()
355 where
356 P: KemParams,
357 {
358 let mut rng = rand::rng();
359 let mut seed = Seed::default();
360 rng.try_fill_bytes(&mut seed).unwrap();
361
362 let dk = DecapsulationKey::<P>::from_seed(seed.clone());
363 let seed_encoded = dk.to_seed().unwrap();
364 assert_eq!(seed, seed_encoded);
365 }
366
367 #[test]
368 fn seed() {
369 seed_test::<MlKem512Params>();
370 seed_test::<MlKem768Params>();
371 seed_test::<MlKem1024Params>();
372 }
373}