1#![cfg_attr(not(test), no_std)]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3#![doc = include_str!("../README.md")]
4#![doc(
5 html_logo_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg",
6 html_favicon_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg"
7)]
8#![deny(missing_docs)]
9#![warn(clippy::pedantic)]
10
11#![cfg_attr(feature = "getrandom", doc = "```")]
18#![cfg_attr(not(feature = "getrandom"), doc = "```ignore")]
19pub use kem::{
29 self, Decapsulate, Decapsulator, Encapsulate, Generate, InvalidKey, KemParams, Key, KeyExport,
30 KeyInit, KeySizeUser, TryKeyInit,
31};
32
33use ml_kem::{
34 EncodedSizeUser, KemCore, MlKem768, MlKem768Params,
35 array::{
36 Array, ArrayN, AsArrayRef,
37 sizes::{U32, U1120, U1184, U1216},
38 },
39};
40use rand_core::{CryptoRng, TryCryptoRng, TryRng};
41use sha3::{
42 Sha3_256, Shake256, Shake256Reader,
43 digest::{ExtendableOutput, XofReader},
44};
45use x25519_dalek::{PublicKey, StaticSecret};
46
47#[cfg(feature = "zeroize")]
48use zeroize::{Zeroize, ZeroizeOnDrop};
49
50type MlKem768DecapsulationKey = ml_kem::kem::DecapsulationKey<MlKem768Params>;
51type MlKem768EncapsulationKey = ml_kem::kem::EncapsulationKey<MlKem768Params>;
52
53const X_WING_LABEL: &[u8; 6] = br"\.//^\";
54
55pub const ENCAPSULATION_KEY_SIZE: usize = 1216;
57pub const DECAPSULATION_KEY_SIZE: usize = 32;
59pub const CIPHERTEXT_SIZE: usize = 1120;
61pub const ENCAPSULATION_RANDOMNESS_SIZE: usize = 64;
63
64pub type Ciphertext = Array<u8, U1120>;
66pub type SharedSecret = Array<u8, U32>;
68
69#[derive(Clone, Debug, Eq, PartialEq)]
81pub struct EncapsulationKey {
82 pk_m: MlKem768EncapsulationKey,
83 pk_x: PublicKey,
84}
85
86impl EncapsulationKey {
87 #[doc(hidden)]
94 #[cfg_attr(not(feature = "hazmat"), doc(hidden))]
95 #[expect(clippy::must_use_candidate)]
96 pub fn encapsulate_deterministic(
97 &self,
98 randomness: &ArrayN<u8, ENCAPSULATION_RANDOMNESS_SIZE>,
99 ) -> (Ciphertext, SharedSecret) {
100 let (rand_m, rand_x) = randomness.split::<U32>();
102
103 let (ct_m, ss_m) = self.pk_m.encapsulate_deterministic(&rand_m);
105
106 let ek_x = StaticSecret::from(rand_x.0);
107 let ct_x = PublicKey::from(&ek_x);
109 let ss_x = ek_x.diffie_hellman(&self.pk_x);
111
112 let ss = combiner(&ss_m, &ss_x, &ct_x, &self.pk_x);
113 let ct = CiphertextMessage { ct_m, ct_x };
114
115 (ct.into(), ss)
116 }
117}
118
119impl Encapsulate for EncapsulationKey {
120 fn encapsulate_with_rng<R>(&self, rng: &mut R) -> (Ciphertext, SharedSecret)
121 where
122 R: CryptoRng + ?Sized,
123 {
124 #[allow(unused_mut)]
125 let mut randomness = Array::generate_from_rng(rng);
126 let res = self.encapsulate_deterministic(&randomness);
127
128 #[cfg(feature = "zeroize")]
129 randomness.zeroize();
130
131 res
132 }
133}
134
135impl KemParams for EncapsulationKey {
136 type CiphertextSize = U1120;
137 type SharedSecretSize = U32;
138}
139
140impl KeySizeUser for EncapsulationKey {
141 type KeySize = U1216;
142}
143
144impl KeyExport for EncapsulationKey {
145 fn to_bytes(&self) -> Key<Self> {
146 let mut key_bytes = Key::<Self>::default();
147 let (m, x) = key_bytes.split_at_mut(1184);
148 m.copy_from_slice(&self.pk_m.to_encoded_bytes());
149 x.copy_from_slice(self.pk_x.as_bytes());
150 key_bytes
151 }
152}
153
154impl TryKeyInit for EncapsulationKey {
155 fn new(key_bytes: &Key<Self>) -> Result<Self, InvalidKey> {
156 let (m_bytes, x_bytes) = key_bytes.split_ref::<U1184>();
157
158 let pk_m = MlKem768EncapsulationKey::from_encoded_bytes(m_bytes).map_err(|_| InvalidKey)?;
159 let pk_x = PublicKey::from(x_bytes.0);
160
161 Ok(EncapsulationKey { pk_m, pk_x })
162 }
163}
164
165impl TryFrom<&[u8]> for EncapsulationKey {
166 type Error = InvalidKey;
167
168 fn try_from(key_bytes: &[u8]) -> Result<Self, InvalidKey> {
169 Self::new_from_slice(key_bytes)
170 }
171}
172
173#[derive(Clone)]
175pub struct DecapsulationKey {
176 sk: [u8; DECAPSULATION_KEY_SIZE],
177 ek: EncapsulationKey,
178}
179
180impl DecapsulationKey {
181 #[must_use]
183 pub fn as_bytes(&self) -> &[u8; DECAPSULATION_KEY_SIZE] {
184 &self.sk
185 }
186}
187
188impl Decapsulate for DecapsulationKey {
189 #[allow(clippy::similar_names)] fn decapsulate(&self, ct: &Ciphertext) -> SharedSecret {
191 let ct = CiphertextMessage::from(ct);
192 let (sk_m, sk_x, _pk_m, pk_x) = expand_key(&self.sk);
193
194 let ss_m = sk_m.decapsulate(&ct.ct_m);
195
196 let ss_x = sk_x.diffie_hellman(&ct.ct_x);
198
199 combiner(&ss_m, &ss_x, &ct.ct_x, &pk_x)
200 }
201}
202
203impl Decapsulator for DecapsulationKey {
204 type Encapsulator = EncapsulationKey;
205
206 fn encapsulator(&self) -> &EncapsulationKey {
207 &self.ek
208 }
209}
210
211impl Drop for DecapsulationKey {
212 fn drop(&mut self) {
213 #[cfg(feature = "zeroize")]
214 self.sk.zeroize();
215 }
216}
217
218impl From<[u8; DECAPSULATION_KEY_SIZE]> for DecapsulationKey {
219 fn from(sk: [u8; DECAPSULATION_KEY_SIZE]) -> Self {
220 DecapsulationKey::new(sk.as_array_ref())
221 }
222}
223
224impl Generate for DecapsulationKey {
225 fn try_generate_from_rng<R>(rng: &mut R) -> Result<Self, <R as TryRng>::Error>
226 where
227 R: TryCryptoRng + ?Sized,
228 {
229 <[u8; DECAPSULATION_KEY_SIZE]>::try_generate_from_rng(rng).map(Into::into)
230 }
231}
232
233impl KeySizeUser for DecapsulationKey {
234 type KeySize = U32;
235}
236
237impl KeyInit for DecapsulationKey {
238 fn new(key: &ArrayN<u8, 32>) -> Self {
239 let (_sk_m, _sk_x, pk_m, pk_x) = expand_key(key.as_ref());
240 let ek = EncapsulationKey { pk_m, pk_x };
241 Self { sk: key.0, ek }
242 }
243}
244
245#[cfg(feature = "zeroize")]
246impl ZeroizeOnDrop for DecapsulationKey {}
247
248fn expand_key(
249 sk: &[u8; DECAPSULATION_KEY_SIZE],
250) -> (
251 MlKem768DecapsulationKey,
252 StaticSecret,
253 MlKem768EncapsulationKey,
254 PublicKey,
255) {
256 use sha3::digest::Update;
257 let mut shaker = Shake256::default();
258 shaker.update(sk);
259 let mut expanded: Shake256Reader = shaker.finalize_xof();
260
261 let seed = read_from(&mut expanded).into();
262 let (sk_m, pk_m) = MlKem768::from_seed(seed);
263
264 let sk_x = read_from(&mut expanded);
265 let sk_x = StaticSecret::from(sk_x);
266 let pk_x = PublicKey::from(&sk_x);
267
268 (sk_m, sk_x, pk_m, pk_x)
269}
270
271#[derive(Clone, PartialEq, Eq)]
273pub struct CiphertextMessage {
274 ct_m: ArrayN<u8, 1088>,
275 ct_x: PublicKey,
276}
277
278impl CiphertextMessage {
279 #[must_use]
282 pub fn to_bytes(&self) -> Ciphertext {
283 let mut buffer = Ciphertext::default();
284 buffer[0..1088].copy_from_slice(&self.ct_m);
285 buffer[1088..].copy_from_slice(self.ct_x.as_bytes());
286 buffer
287 }
288}
289
290impl From<&Ciphertext> for CiphertextMessage {
291 fn from(value: &Ciphertext) -> Self {
292 let mut ct_m = [0; 1088];
293 ct_m.copy_from_slice(&value[0..1088]);
294 let mut ct_x = [0; 32];
295 ct_x.copy_from_slice(&value[1088..]);
296
297 CiphertextMessage {
298 ct_m: ct_m.into(),
299 ct_x: ct_x.into(),
300 }
301 }
302}
303
304impl From<&CiphertextMessage> for Ciphertext {
305 #[inline]
306 fn from(msg: &CiphertextMessage) -> Self {
307 msg.to_bytes()
308 }
309}
310
311impl From<CiphertextMessage> for Ciphertext {
312 #[inline]
313 fn from(msg: CiphertextMessage) -> Self {
314 Self::from(&msg)
315 }
316}
317
318#[cfg(feature = "getrandom")]
320#[must_use]
321pub fn generate_key_pair() -> (DecapsulationKey, EncapsulationKey) {
322 let sk = DecapsulationKey::generate();
323 let pk = sk.encapsulator().clone();
324 (sk, pk)
325}
326
327pub fn generate_key_pair_from_rng<R: CryptoRng + ?Sized>(
329 rng: &mut R,
330) -> (DecapsulationKey, EncapsulationKey) {
331 let sk = DecapsulationKey::generate_from_rng(rng);
332 let pk = sk.encapsulator().clone();
333 (sk, pk)
334}
335
336fn combiner(
337 ss_m: &ArrayN<u8, 32>,
338 ss_x: &x25519_dalek::SharedSecret,
339 ct_x: &PublicKey,
340 pk_x: &PublicKey,
341) -> SharedSecret {
342 use sha3::Digest;
343
344 let mut hasher = Sha3_256::new();
345 hasher.update(ss_m);
346 hasher.update(ss_x);
347 hasher.update(ct_x);
348 hasher.update(pk_x.as_bytes());
349 hasher.update(X_WING_LABEL);
350 hasher.finalize()
351}
352
353fn read_from<const N: usize>(reader: &mut Shake256Reader) -> [u8; N] {
354 let mut data = [0; N];
355 reader.read(&mut data);
356 data
357}
358
359#[cfg(test)]
360mod tests {
361 use core::convert::Infallible;
362 use getrandom::SysRng;
363 use ml_kem::array::Array;
364 use rand_core::{TryCryptoRng, TryRng, UnwrapErr, utils};
365 use serde::Deserialize;
366
367 use super::*;
368
369 pub(crate) struct SeedRng {
370 pub(crate) seed: Vec<u8>,
371 }
372
373 impl SeedRng {
374 fn new(seed: Vec<u8>) -> SeedRng {
375 SeedRng { seed }
376 }
377 }
378
379 impl TryRng for SeedRng {
380 type Error = Infallible;
381
382 fn try_next_u32(&mut self) -> Result<u32, Self::Error> {
383 utils::next_word_via_fill(self)
384 }
385
386 fn try_next_u64(&mut self) -> Result<u64, Self::Error> {
387 utils::next_word_via_fill(self)
388 }
389
390 fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Self::Error> {
391 dest.copy_from_slice(&self.seed[0..dest.len()]);
392 self.seed.drain(0..dest.len());
393 Ok(())
394 }
395 }
396
397 #[derive(Deserialize)]
398 struct TestVector {
399 #[serde(deserialize_with = "hex::serde::deserialize")]
400 seed: Vec<u8>,
401
402 #[serde(deserialize_with = "hex::serde::deserialize")]
403 eseed: Vec<u8>,
404
405 #[serde(deserialize_with = "hex::serde::deserialize")]
406 ss: [u8; 32],
407
408 #[serde(deserialize_with = "hex::serde::deserialize")]
409 sk: [u8; 32],
410
411 #[serde(deserialize_with = "hex::serde::deserialize")]
412 pk: Vec<u8>, #[serde(deserialize_with = "hex::serde::deserialize")]
415 ct: Vec<u8>, }
417
418 impl TryCryptoRng for SeedRng {}
419
420 #[test]
422 fn rfc_test_vectors() {
423 let test_vectors =
424 serde_json::from_str::<Vec<TestVector>>(include_str!("test-vectors.json")).unwrap();
425
426 for test_vector in test_vectors {
427 run_test(test_vector);
428 }
429 }
430
431 fn run_test(test_vector: TestVector) {
432 let mut seed = SeedRng::new(test_vector.seed);
433 let (sk, pk) = generate_key_pair_from_rng(&mut seed);
434
435 assert_eq!(sk.as_bytes(), &test_vector.sk);
436 assert_eq!(&*pk.to_bytes(), test_vector.pk.as_slice());
437
438 let mut eseed = SeedRng::new(test_vector.eseed);
439 let (ct, ss) = pk.encapsulate_with_rng(&mut eseed);
440
441 assert_eq!(ss, test_vector.ss);
442 assert_eq!(&*ct, test_vector.ct.as_slice());
443
444 let ss = sk.decapsulate(&ct);
445 assert_eq!(ss, test_vector.ss);
446 }
447
448 #[test]
449 fn ciphertext_serialize() {
450 let mut rng = UnwrapErr(SysRng);
451
452 let ct_a = CiphertextMessage {
453 ct_m: Array::generate_from_rng(&mut rng),
454 ct_x: <[u8; 32]>::generate_from_rng(&mut rng).into(),
455 };
456
457 let bytes = ct_a.to_bytes();
458 let ct_b = CiphertextMessage::from(&bytes);
459
460 assert!(ct_a == ct_b);
461 }
462
463 #[test]
464 fn key_serialize() {
465 let sk = DecapsulationKey::generate_from_rng(&mut UnwrapErr(SysRng));
466 let pk = sk.encapsulator().clone();
467
468 let sk_bytes = sk.as_bytes();
469 let pk_bytes = pk.to_bytes();
470
471 let sk_b = DecapsulationKey::from(*sk_bytes);
472 let pk_b = EncapsulationKey::new(&pk_bytes).unwrap();
473
474 assert_eq!(sk.sk, sk_b.sk);
475 assert!(pk == pk_b);
476 }
477}