1use core::marker::PhantomData;
2use hybrid_array::typenum::U32;
3use rand_core::CryptoRngCore;
4
5use crate::crypto::{rand, G, H, J};
6use crate::param::{DecapsulationKeySize, EncapsulationKeySize, EncodedCiphertext, KemParams};
7use crate::pke::{DecryptionKey, EncryptionKey};
8use crate::util::B32;
9use crate::{Encoded, EncodedSizeUser};
10
11#[cfg(feature = "zeroize")]
12use zeroize::{Zeroize, ZeroizeOnDrop};
13
14pub use ::kem::{Decapsulate, Encapsulate};
16
17pub(crate) type SharedKey = B32;
19
20#[derive(Clone, Debug, PartialEq)]
23pub struct DecapsulationKey<P>
24where
25 P: KemParams,
26{
27 dk_pke: DecryptionKey<P>,
28 ek: EncapsulationKey<P>,
29 z: B32,
30}
31
32#[cfg(feature = "zeroize")]
33impl<P> Drop for DecapsulationKey<P>
34where
35 P: KemParams,
36{
37 fn drop(&mut self) {
38 self.dk_pke.zeroize();
39 self.z.zeroize();
40 }
41}
42
43#[cfg(feature = "zeroize")]
44impl<P> ZeroizeOnDrop for DecapsulationKey<P> where P: KemParams {}
45
46impl<P> EncodedSizeUser for DecapsulationKey<P>
47where
48 P: KemParams,
49{
50 type EncodedSize = DecapsulationKeySize<P>;
51
52 #[allow(clippy::similar_names)] fn from_bytes(enc: &Encoded<Self>) -> Self {
54 let (dk_pke, ek_pke, h, z) = P::split_dk(enc);
55 let ek_pke = EncryptionKey::from_bytes(ek_pke);
56
57 Self {
61 dk_pke: DecryptionKey::from_bytes(dk_pke),
62 ek: EncapsulationKey {
63 ek_pke,
64 h: h.clone(),
65 },
66 z: z.clone(),
67 }
68 }
69
70 fn as_bytes(&self) -> Encoded<Self> {
71 let dk_pke = self.dk_pke.as_bytes();
72 let ek = self.ek.as_bytes();
73 P::concat_dk(dk_pke, ek, self.ek.h.clone(), self.z.clone())
74 }
75}
76
77fn constant_time_eq(x: u8, y: u8) -> u8 {
79 let diff = x ^ y;
80 let is_zero = !diff & diff.wrapping_sub(1);
81 0u8.wrapping_sub(is_zero >> 7)
82}
83
84impl<P> ::kem::Decapsulate<EncodedCiphertext<P>, SharedKey> for DecapsulationKey<P>
85where
86 P: KemParams,
87{
88 type Error = ();
92
93 fn decapsulate(&self, encapsulated_key: &EncodedCiphertext<P>) -> Result<SharedKey, ()> {
94 let mp = self.dk_pke.decrypt(encapsulated_key);
95 let (Kp, rp) = G(&[&mp, &self.ek.h]);
96 let Kbar = J(&[self.z.as_slice(), encapsulated_key.as_ref()]);
97 let cp = self.ek.ek_pke.encrypt(&mp, &rp);
98
99 let equal = cp
107 .iter()
108 .zip(encapsulated_key.iter())
109 .map(|(&x, &y)| constant_time_eq(x, y))
110 .fold(0xff, |x, y| x & y);
111 Ok(Kp
112 .iter()
113 .zip(Kbar.iter())
114 .map(|(x, y)| (equal & x) | (!equal & y))
115 .collect())
116 }
117}
118
119impl<P> DecapsulationKey<P>
120where
121 P: KemParams,
122{
123 pub fn encapsulation_key(&self) -> &EncapsulationKey<P> {
125 &self.ek
126 }
127
128 pub(crate) fn generate(rng: &mut impl CryptoRngCore) -> Self {
129 let d: B32 = rand(rng);
130 let z: B32 = rand(rng);
131 Self::generate_deterministic(&d, &z)
132 }
133
134 #[must_use]
135 #[allow(clippy::similar_names)] pub(crate) fn generate_deterministic(d: &B32, z: &B32) -> Self {
137 let (dk_pke, ek_pke) = DecryptionKey::generate(d);
138 let ek = EncapsulationKey::new(ek_pke);
139 let z = z.clone();
140 Self { dk_pke, ek, z }
141 }
142}
143
144#[derive(Clone, Debug, PartialEq)]
147pub struct EncapsulationKey<P>
148where
149 P: KemParams,
150{
151 ek_pke: EncryptionKey<P>,
152 h: B32,
153}
154
155impl<P> EncapsulationKey<P>
156where
157 P: KemParams,
158{
159 fn new(ek_pke: EncryptionKey<P>) -> Self {
160 let h = H(ek_pke.as_bytes());
161 Self { ek_pke, h }
162 }
163
164 fn encapsulate_deterministic_inner(&self, m: &B32) -> (EncodedCiphertext<P>, SharedKey) {
165 let (K, r) = G(&[m, &self.h]);
166 let c = self.ek_pke.encrypt(m, &r);
167 (c, K)
168 }
169}
170
171impl<P> EncodedSizeUser for EncapsulationKey<P>
172where
173 P: KemParams,
174{
175 type EncodedSize = EncapsulationKeySize<P>;
176
177 fn from_bytes(enc: &Encoded<Self>) -> Self {
178 Self::new(EncryptionKey::from_bytes(enc))
179 }
180
181 fn as_bytes(&self) -> Encoded<Self> {
182 self.ek_pke.as_bytes()
183 }
184}
185
186impl<P> ::kem::Encapsulate<EncodedCiphertext<P>, SharedKey> for EncapsulationKey<P>
187where
188 P: KemParams,
189{
190 type Error = ();
193
194 fn encapsulate(
195 &self,
196 rng: &mut impl CryptoRngCore,
197 ) -> Result<(EncodedCiphertext<P>, SharedKey), Self::Error> {
198 let m: B32 = rand(rng);
199 Ok(self.encapsulate_deterministic_inner(&m))
200 }
201}
202
203#[cfg(feature = "deterministic")]
204impl<P> crate::EncapsulateDeterministic<EncodedCiphertext<P>, SharedKey> for EncapsulationKey<P>
205where
206 P: KemParams,
207{
208 type Error = ();
210
211 fn encapsulate_deterministic(
212 &self,
213 m: &B32,
214 ) -> Result<(EncodedCiphertext<P>, SharedKey), Self::Error> {
215 Ok(self.encapsulate_deterministic_inner(m))
216 }
217}
218
219pub struct Kem<P>
222where
223 P: KemParams,
224{
225 _phantom: PhantomData<P>,
226}
227
228impl<P> crate::KemCore for Kem<P>
229where
230 P: KemParams,
231{
232 type SharedKeySize = U32;
233 type CiphertextSize = P::CiphertextSize;
234 type DecapsulationKey = DecapsulationKey<P>;
235 type EncapsulationKey = EncapsulationKey<P>;
236
237 fn generate(rng: &mut impl CryptoRngCore) -> (Self::DecapsulationKey, Self::EncapsulationKey) {
239 let dk = Self::DecapsulationKey::generate(rng);
240 let ek = dk.encapsulation_key().clone();
241 (dk, ek)
242 }
243
244 #[cfg(feature = "deterministic")]
245 fn generate_deterministic(
246 d: &B32,
247 z: &B32,
248 ) -> (Self::DecapsulationKey, Self::EncapsulationKey) {
249 let dk = Self::DecapsulationKey::generate_deterministic(d, z);
250 let ek = dk.encapsulation_key().clone();
251 (dk, ek)
252 }
253}
254
255#[cfg(test)]
256mod test {
257 use super::*;
258 use crate::{MlKem1024Params, MlKem512Params, MlKem768Params};
259 use ::kem::{Decapsulate, Encapsulate};
260
261 fn round_trip_test<P>()
262 where
263 P: KemParams,
264 {
265 let mut rng = rand::thread_rng();
266
267 let dk = DecapsulationKey::<P>::generate(&mut rng);
268 let ek = dk.encapsulation_key();
269
270 let (ct, k_send) = ek.encapsulate(&mut rng).unwrap();
271 let k_recv = dk.decapsulate(&ct).unwrap();
272 assert_eq!(k_send, k_recv);
273 }
274
275 #[test]
276 fn round_trip() {
277 round_trip_test::<MlKem512Params>();
278 round_trip_test::<MlKem768Params>();
279 round_trip_test::<MlKem1024Params>();
280 }
281
282 fn codec_test<P>()
283 where
284 P: KemParams,
285 {
286 let mut rng = rand::thread_rng();
287 let dk_original = DecapsulationKey::<P>::generate(&mut rng);
288 let ek_original = dk_original.encapsulation_key().clone();
289
290 let dk_encoded = dk_original.as_bytes();
291 let dk_decoded = DecapsulationKey::from_bytes(&dk_encoded);
292 assert_eq!(dk_original, dk_decoded);
293
294 let ek_encoded = ek_original.as_bytes();
295 let ek_decoded = EncapsulationKey::from_bytes(&ek_encoded);
296 assert_eq!(ek_original, ek_decoded);
297 }
298
299 #[test]
300 fn codec() {
301 codec_test::<MlKem512Params>();
302 codec_test::<MlKem768Params>();
303 codec_test::<MlKem1024Params>();
304 }
305}