x_wing/
lib.rs

1#![cfg_attr(not(test), no_std)]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3#![doc = include_str!("../README.md")]
4#![doc(
5    html_logo_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg",
6    html_favicon_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg"
7)]
8#![deny(missing_docs)]
9#![warn(clippy::pedantic)]
10
11//! # Usage
12//!
13//! This crate implements the X-Wing Key Encapsulation Method (X-Wing-KEM) algorithm.
14//! X-Wing-KEM is a KEM in the sense that it creates an (decapsulation key, encapsulation key) pair,
15//! such that anyone can use the encapsulation key to establish a shared key with the holder of the
16//! decapsulation key. X-Wing-KEM is a general-purpose hybrid post-quantum KEM, combining x25519 and ML-KEM-768.
17//!
18//! ```
19//! use kem::{Decapsulate, Encapsulate};
20//! use rand_core::TryRngCore;
21//!
22//! let mut rng = &mut rand::rngs::OsRng.unwrap_err();
23//! let (sk, pk) = x_wing::generate_key_pair(rng);
24//! let (ct, ss_sender) = pk.encapsulate(rng).unwrap();
25//! let ss_receiver = sk.decapsulate(&ct).unwrap();
26//! assert_eq!(ss_sender, ss_receiver);
27//! ```
28
29pub use kem::{self, Decapsulate, Encapsulate};
30
31use core::convert::Infallible;
32use ml_kem::array::{ArrayN, typenum::consts::U32};
33use ml_kem::{B32, EncodedSizeUser, KemCore, MlKem768, MlKem768Params};
34use rand_core::{CryptoRng, TryCryptoRng};
35#[cfg(feature = "os_rng")]
36use rand_core::{OsRng, TryRngCore};
37use sha3::digest::{ExtendableOutput, XofReader};
38use sha3::{Sha3_256, Shake256, Shake256Reader};
39use x25519_dalek::{EphemeralSecret, PublicKey, StaticSecret};
40
41#[cfg(feature = "zeroize")]
42use zeroize::{Zeroize, ZeroizeOnDrop};
43
44type MlKem768DecapsulationKey = ml_kem::kem::DecapsulationKey<MlKem768Params>;
45type MlKem768EncapsulationKey = ml_kem::kem::EncapsulationKey<MlKem768Params>;
46
47const X_WING_LABEL: &[u8; 6] = br"\.//^\";
48
49/// Size in bytes of the `EncapsulationKey`.
50pub const ENCAPSULATION_KEY_SIZE: usize = 1216;
51/// Size in bytes of the `DecapsulationKey`.
52pub const DECAPSULATION_KEY_SIZE: usize = 32;
53/// Size in bytes of the `Ciphertext`.
54pub const CIPHERTEXT_SIZE: usize = 1120;
55
56/// Shared secret key.
57pub type SharedSecret = [u8; 32];
58
59// The naming convention of variables matches the RFC.
60// ss -> Shared Secret
61// ct -> Cipher Text
62// ek -> Ephemeral Key
63// pk -> Public Key
64// sk -> Secret Key
65// Postfixes:
66// _m -> ML-Kem related key
67// _x -> x25519 related key
68
69/// X-Wing encapsulation or public key.
70#[derive(Clone, PartialEq)]
71pub struct EncapsulationKey {
72    pk_m: MlKem768EncapsulationKey,
73    pk_x: PublicKey,
74}
75
76impl Encapsulate<Ciphertext, SharedSecret> for EncapsulationKey {
77    type Error = Infallible;
78
79    fn encapsulate<R: TryCryptoRng + ?Sized>(
80        &self,
81        rng: &mut R,
82    ) -> Result<(Ciphertext, SharedSecret), Self::Error> {
83        // Swapped order of operations compared to RFC, so that usage of the rng matches the RFC
84        let (ct_m, ss_m) = self.pk_m.encapsulate(rng)?;
85
86        let ek_x = EphemeralSecret::random_from_rng(&mut rng.unwrap_mut());
87        // Equal to ct_x = x25519(ek_x, BASE_POINT)
88        let ct_x = PublicKey::from(&ek_x);
89        // Equal to ss_x = x25519(ek_x, pk_x)
90        let ss_x = ek_x.diffie_hellman(&self.pk_x);
91
92        let ss = combiner(&ss_m, &ss_x, &ct_x, &self.pk_x);
93        let ct = Ciphertext { ct_m, ct_x };
94        Ok((ct, ss))
95    }
96}
97
98impl EncapsulationKey {
99    /// Convert the key to the following format:
100    /// ML-KEM-768 public key(1184 bytes) || X25519 public key(32 bytes).
101    #[must_use]
102    pub fn to_bytes(&self) -> [u8; ENCAPSULATION_KEY_SIZE] {
103        let mut buffer = [0u8; ENCAPSULATION_KEY_SIZE];
104        buffer[0..1184].copy_from_slice(&self.pk_m.as_bytes());
105        buffer[1184..1216].copy_from_slice(self.pk_x.as_bytes());
106        buffer
107    }
108}
109
110impl From<&[u8; ENCAPSULATION_KEY_SIZE]> for EncapsulationKey {
111    fn from(value: &[u8; ENCAPSULATION_KEY_SIZE]) -> Self {
112        let mut pk_m = [0; 1184];
113        pk_m.copy_from_slice(&value[0..1184]);
114        let pk_m = MlKem768EncapsulationKey::from_bytes(&pk_m.into());
115
116        let mut pk_x = [0; 32];
117        pk_x.copy_from_slice(&value[1184..]);
118        let pk_x = PublicKey::from(pk_x);
119        EncapsulationKey { pk_m, pk_x }
120    }
121}
122
123/// X-Wing decapsulation key or private key.
124#[derive(Clone)]
125#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
126#[cfg_attr(test, derive(PartialEq, Eq))]
127pub struct DecapsulationKey {
128    sk: [u8; DECAPSULATION_KEY_SIZE],
129}
130
131impl Decapsulate<Ciphertext, SharedSecret> for DecapsulationKey {
132    type Encapsulator = EncapsulationKey;
133    type Error = Infallible;
134
135    #[allow(clippy::similar_names)] // So we can use the names as in the RFC
136    fn decapsulate(&self, ct: &Ciphertext) -> Result<SharedSecret, Self::Error> {
137        let (sk_m, sk_x, _pk_m, pk_x) = self.expand_key();
138
139        let ss_m = sk_m.decapsulate(&ct.ct_m)?;
140
141        // equal to ss_x = x25519(sk_x, ct_x)
142        let ss_x = sk_x.diffie_hellman(&ct.ct_x);
143
144        let ss = combiner(&ss_m, &ss_x, &ct.ct_x, &pk_x);
145        Ok(ss)
146    }
147
148    fn encapsulator(&self) -> EncapsulationKey {
149        self.encapsulation_key()
150    }
151}
152
153impl ::kem::KeySizeUser for DecapsulationKey {
154    type KeySize = U32;
155}
156
157impl ::kem::KeyInit for DecapsulationKey {
158    fn new(key: &ArrayN<u8, 32>) -> Self {
159        Self { sk: key.0 }
160    }
161}
162
163impl DecapsulationKey {
164    /// Generate a new `DecapsulationKey` using `OsRng`.
165    #[cfg(feature = "os_rng")]
166    #[must_use]
167    pub fn generate_from_os_rng() -> DecapsulationKey {
168        Self::generate(&mut OsRng.unwrap_err())
169    }
170
171    /// Generate a new `DecapsulationKey` using the provided RNG.
172    pub fn generate<R: CryptoRng + ?Sized>(rng: &mut R) -> DecapsulationKey {
173        let sk = generate(rng);
174        DecapsulationKey { sk }
175    }
176
177    /// Provide the matching `EncapsulationKey`.
178    #[must_use]
179    pub fn encapsulation_key(&self) -> EncapsulationKey {
180        let (_sk_m, _sk_x, pk_m, pk_x) = self.expand_key();
181        EncapsulationKey { pk_m, pk_x }
182    }
183
184    fn expand_key(
185        &self,
186    ) -> (
187        MlKem768DecapsulationKey,
188        StaticSecret,
189        MlKem768EncapsulationKey,
190        PublicKey,
191    ) {
192        use sha3::digest::Update;
193        let mut shaker = Shake256::default();
194        shaker.update(&self.sk);
195        let mut expanded: Shake256Reader = shaker.finalize_xof();
196
197        let seed = read_from(&mut expanded).into();
198        let (sk_m, pk_m) = MlKem768::from_seed(seed);
199
200        let sk_x = read_from(&mut expanded);
201        let sk_x = StaticSecret::from(sk_x);
202        let pk_x = PublicKey::from(&sk_x);
203
204        (sk_m, sk_x, pk_m, pk_x)
205    }
206
207    /// Private key as bytes.
208    #[must_use]
209    pub fn as_bytes(&self) -> &[u8; DECAPSULATION_KEY_SIZE] {
210        &self.sk
211    }
212}
213
214impl From<[u8; DECAPSULATION_KEY_SIZE]> for DecapsulationKey {
215    fn from(sk: [u8; DECAPSULATION_KEY_SIZE]) -> Self {
216        DecapsulationKey { sk }
217    }
218}
219
220/// X-Wing ciphertext.
221#[derive(Clone, PartialEq, Eq)]
222#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
223pub struct Ciphertext {
224    ct_m: ArrayN<u8, 1088>,
225    ct_x: PublicKey,
226}
227
228impl Ciphertext {
229    /// Convert the ciphertext to the following format:
230    /// ML-KEM-768 ciphertext(1088 bytes) || X25519 ciphertext(32 bytes).
231    #[must_use]
232    pub fn to_bytes(&self) -> [u8; CIPHERTEXT_SIZE] {
233        let mut buffer = [0; CIPHERTEXT_SIZE];
234        buffer[0..1088].copy_from_slice(&self.ct_m);
235        buffer[1088..].copy_from_slice(self.ct_x.as_bytes());
236        buffer
237    }
238}
239
240impl From<&[u8; CIPHERTEXT_SIZE]> for Ciphertext {
241    fn from(value: &[u8; CIPHERTEXT_SIZE]) -> Self {
242        let mut ct_m = [0; 1088];
243        ct_m.copy_from_slice(&value[0..1088]);
244        let mut ct_x = [0; 32];
245        ct_x.copy_from_slice(&value[1088..]);
246
247        Ciphertext {
248            ct_m: ct_m.into(),
249            ct_x: ct_x.into(),
250        }
251    }
252}
253
254/// Generate a X-Wing key pair using `OsRng`.
255#[cfg(feature = "os_rng")]
256#[must_use]
257pub fn generate_key_pair_from_os_rng() -> (DecapsulationKey, EncapsulationKey) {
258    generate_key_pair(&mut OsRng.unwrap_err())
259}
260
261/// Generate a X-Wing key pair using the provided rng.
262pub fn generate_key_pair<R: CryptoRng + ?Sized>(
263    rng: &mut R,
264) -> (DecapsulationKey, EncapsulationKey) {
265    let sk = DecapsulationKey::generate(rng);
266    let pk = sk.encapsulation_key();
267    (sk, pk)
268}
269
270fn combiner(
271    ss_m: &B32,
272    ss_x: &x25519_dalek::SharedSecret,
273    ct_x: &PublicKey,
274    pk_x: &PublicKey,
275) -> SharedSecret {
276    use sha3::Digest;
277
278    let mut hasher = Sha3_256::new();
279    hasher.update(ss_m);
280    hasher.update(ss_x);
281    hasher.update(ct_x);
282    hasher.update(pk_x.as_bytes());
283    hasher.update(X_WING_LABEL);
284    hasher.finalize().into()
285}
286
287fn read_from<const N: usize>(reader: &mut Shake256Reader) -> [u8; N] {
288    let mut data = [0; N];
289    reader.read(&mut data);
290    data
291}
292
293fn generate<const N: usize, R: CryptoRng + ?Sized>(rng: &mut R) -> [u8; N] {
294    let mut random = [0; N];
295    rng.fill_bytes(&mut random);
296    random
297}
298
299#[cfg(test)]
300mod tests {
301    use rand_core::{CryptoRng, OsRng, RngCore, TryRngCore, impls};
302    use serde::Deserialize;
303
304    use super::*;
305
306    pub(crate) struct SeedRng {
307        pub(crate) seed: Vec<u8>,
308    }
309
310    impl SeedRng {
311        fn new(seed: Vec<u8>) -> SeedRng {
312            SeedRng { seed }
313        }
314    }
315
316    impl RngCore for SeedRng {
317        fn next_u32(&mut self) -> u32 {
318            impls::next_u32_via_fill(self)
319        }
320
321        fn next_u64(&mut self) -> u64 {
322            impls::next_u64_via_fill(self)
323        }
324
325        fn fill_bytes(&mut self, dest: &mut [u8]) {
326            dest.copy_from_slice(&self.seed[0..dest.len()]);
327            self.seed.drain(0..dest.len());
328        }
329    }
330
331    #[derive(Deserialize)]
332    struct TestVector {
333        #[serde(deserialize_with = "hex::serde::deserialize")]
334        seed: Vec<u8>,
335
336        #[serde(deserialize_with = "hex::serde::deserialize")]
337        eseed: Vec<u8>,
338
339        #[serde(deserialize_with = "hex::serde::deserialize")]
340        ss: [u8; 32],
341
342        #[serde(deserialize_with = "hex::serde::deserialize")]
343        sk: [u8; 32],
344
345        #[serde(deserialize_with = "hex::serde::deserialize")]
346        pk: Vec<u8>, //[u8; PUBLIC_KEY_SIZE],
347
348        #[serde(deserialize_with = "hex::serde::deserialize")]
349        ct: Vec<u8>, //[u8; 1120],
350    }
351
352    impl CryptoRng for SeedRng {}
353
354    /// Test with test vectors from: <https://github.com/dconnolly/draft-connolly-cfrg-xwing-kem/blob/main/spec/test-vectors.json>
355    #[test]
356    fn rfc_test_vectors() {
357        let test_vectors =
358            serde_json::from_str::<Vec<TestVector>>(include_str!("test-vectors.json")).unwrap();
359
360        for test_vector in test_vectors {
361            run_test(test_vector);
362        }
363    }
364
365    fn run_test(test_vector: TestVector) {
366        let mut seed = SeedRng::new(test_vector.seed);
367        let (sk, pk) = generate_key_pair(&mut seed);
368
369        assert_eq!(sk.as_bytes(), &test_vector.sk);
370        assert_eq!(&pk.to_bytes(), test_vector.pk.as_slice());
371
372        let mut eseed = SeedRng::new(test_vector.eseed);
373        let (ct, ss) = pk.encapsulate(&mut eseed).unwrap();
374
375        assert_eq!(ss, test_vector.ss);
376        assert_eq!(&ct.to_bytes(), test_vector.ct.as_slice());
377
378        let ss = sk.decapsulate(&ct).unwrap();
379        assert_eq!(ss, test_vector.ss);
380    }
381
382    #[test]
383    fn ciphertext_serialize() {
384        let mut rng = OsRng.unwrap_err();
385
386        let ct_a = Ciphertext {
387            ct_m: generate(&mut rng).into(),
388            ct_x: generate(&mut rng).into(),
389        };
390
391        let bytes = ct_a.to_bytes();
392
393        let ct_b = Ciphertext::from(&bytes);
394
395        assert!(ct_a == ct_b);
396    }
397
398    #[test]
399    fn key_serialize() {
400        let sk = DecapsulationKey::generate(&mut OsRng.unwrap_err());
401        let pk = sk.encapsulation_key();
402
403        let sk_bytes = sk.as_bytes();
404        let pk_bytes = pk.to_bytes();
405
406        let sk_b = DecapsulationKey::from(*sk_bytes);
407        let pk_b = EncapsulationKey::from(&pk_bytes);
408
409        assert!(sk == sk_b);
410        assert!(pk == pk_b);
411    }
412}