x_wing/
lib.rs

1#![cfg_attr(not(test), no_std)]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3#![doc = include_str!("../README.md")]
4#![doc(
5    html_logo_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg",
6    html_favicon_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg"
7)]
8#![deny(missing_docs)]
9#![warn(clippy::pedantic)]
10
11//! # Usage
12//!
13//! This crate implements the X-Wing Key Encapsulation Method (X-Wing-KEM) algorithm.
14//! X-Wing-KEM is a KEM in the sense that it creates an (decapsulation key, encapsulation key) pair,
15//! such that anyone can use the encapsulation key to establish a shared key with the holder of the
16//! decapsulation key. X-Wing-KEM is a general-purpose hybrid post-quantum KEM, combining x25519 and ML-KEM-768.
17#![cfg_attr(feature = "getrandom", doc = "```")]
18#![cfg_attr(not(feature = "getrandom"), doc = "```ignore")]
19//! // NOTE: requires the `getrandom` feature is enabled
20//! use kem::{Decapsulate, Encapsulate};
21//!
22//! let (sk, pk) = x_wing::generate_key_pair();
23//! let (ct, ss_sender) = pk.encapsulate();
24//! let ss_receiver = sk.decapsulate(&ct);
25//! assert_eq!(ss_sender, ss_receiver);
26//! ```
27
28pub use kem::{
29    self, Decapsulate, Decapsulator, Encapsulate, Generate, InvalidKey, KemParams, Key, KeyExport,
30    KeyInit, KeySizeUser, TryKeyInit,
31};
32
33use ml_kem::{
34    EncodedSizeUser, KemCore, MlKem768, MlKem768Params,
35    array::{
36        Array, ArrayN, AsArrayRef,
37        sizes::{U32, U1120, U1184, U1216},
38    },
39};
40use rand_core::{CryptoRng, TryCryptoRng, TryRng};
41use sha3::{
42    Sha3_256, Shake256, Shake256Reader,
43    digest::{ExtendableOutput, XofReader},
44};
45use x25519_dalek::{PublicKey, StaticSecret};
46
47#[cfg(feature = "zeroize")]
48use zeroize::{Zeroize, ZeroizeOnDrop};
49
50type MlKem768DecapsulationKey = ml_kem::kem::DecapsulationKey<MlKem768Params>;
51type MlKem768EncapsulationKey = ml_kem::kem::EncapsulationKey<MlKem768Params>;
52
53const X_WING_LABEL: &[u8; 6] = br"\.//^\";
54
55/// Size in bytes of the `EncapsulationKey`.
56pub const ENCAPSULATION_KEY_SIZE: usize = 1216;
57/// Size in bytes of the `DecapsulationKey`.
58pub const DECAPSULATION_KEY_SIZE: usize = 32;
59/// Size in bytes of the `Ciphertext`.
60pub const CIPHERTEXT_SIZE: usize = 1120;
61/// Number of bytes necessary to encapsulate a key
62pub const ENCAPSULATION_RANDOMNESS_SIZE: usize = 64;
63
64/// Serialized ciphertext.
65pub type Ciphertext = Array<u8, U1120>;
66/// Shared secret key.
67pub type SharedSecret = Array<u8, U32>;
68
69// The naming convention of variables matches the RFC.
70// ss -> Shared Secret
71// ct -> Cipher Text
72// ek -> Ephemeral Key
73// pk -> Public Key
74// sk -> Secret Key
75// Postfixes:
76// _m -> ML-Kem related key
77// _x -> x25519 related key
78
79/// X-Wing encapsulation or public key.
80#[derive(Clone, Debug, Eq, PartialEq)]
81pub struct EncapsulationKey {
82    pk_m: MlKem768EncapsulationKey,
83    pk_x: PublicKey,
84}
85
86impl EncapsulationKey {
87    /// Encapsulates with the given randomness. Uses the first 32 bytes for ML-KEM and the last 32
88    /// bytes for x25519. This is useful for testing against known vectors.
89    ///
90    /// # Warning
91    /// Do NOT use this function unless you know what you're doing. If you fail to use all uniform
92    /// random bytes even once, you can have catastrophic security failure.
93    #[doc(hidden)]
94    #[cfg_attr(not(feature = "hazmat"), doc(hidden))]
95    #[expect(clippy::must_use_candidate)]
96    pub fn encapsulate_deterministic(
97        &self,
98        randomness: &ArrayN<u8, ENCAPSULATION_RANDOMNESS_SIZE>,
99    ) -> (Ciphertext, SharedSecret) {
100        // Split randomness into two 32-byte arrays
101        let (rand_m, rand_x) = randomness.split::<U32>();
102
103        // Encapsulate with ML-KEM first. This is infallible
104        let (ct_m, ss_m) = self.pk_m.encapsulate_deterministic(&rand_m);
105
106        let ek_x = StaticSecret::from(rand_x.0);
107        // Equal to ct_x = x25519(ek_x, BASE_POINT)
108        let ct_x = PublicKey::from(&ek_x);
109        // Equal to ss_x = x25519(ek_x, pk_x)
110        let ss_x = ek_x.diffie_hellman(&self.pk_x);
111
112        let ss = combiner(&ss_m, &ss_x, &ct_x, &self.pk_x);
113        let ct = CiphertextMessage { ct_m, ct_x };
114
115        (ct.into(), ss)
116    }
117}
118
119impl Encapsulate for EncapsulationKey {
120    fn encapsulate_with_rng<R>(&self, rng: &mut R) -> (Ciphertext, SharedSecret)
121    where
122        R: CryptoRng + ?Sized,
123    {
124        #[allow(unused_mut)]
125        let mut randomness = Array::generate_from_rng(rng);
126        let res = self.encapsulate_deterministic(&randomness);
127
128        #[cfg(feature = "zeroize")]
129        randomness.zeroize();
130
131        res
132    }
133}
134
135impl KemParams for EncapsulationKey {
136    type CiphertextSize = U1120;
137    type SharedSecretSize = U32;
138}
139
140impl KeySizeUser for EncapsulationKey {
141    type KeySize = U1216;
142}
143
144impl KeyExport for EncapsulationKey {
145    fn to_bytes(&self) -> Key<Self> {
146        let mut key_bytes = Key::<Self>::default();
147        let (m, x) = key_bytes.split_at_mut(1184);
148        m.copy_from_slice(&self.pk_m.to_encoded_bytes());
149        x.copy_from_slice(self.pk_x.as_bytes());
150        key_bytes
151    }
152}
153
154impl TryKeyInit for EncapsulationKey {
155    fn new(key_bytes: &Key<Self>) -> Result<Self, InvalidKey> {
156        let (m_bytes, x_bytes) = key_bytes.split_ref::<U1184>();
157
158        let pk_m = MlKem768EncapsulationKey::from_encoded_bytes(m_bytes).map_err(|_| InvalidKey)?;
159        let pk_x = PublicKey::from(x_bytes.0);
160
161        Ok(EncapsulationKey { pk_m, pk_x })
162    }
163}
164
165impl TryFrom<&[u8]> for EncapsulationKey {
166    type Error = InvalidKey;
167
168    fn try_from(key_bytes: &[u8]) -> Result<Self, InvalidKey> {
169        Self::new_from_slice(key_bytes)
170    }
171}
172
173/// X-Wing decapsulation key or private key.
174#[derive(Clone)]
175pub struct DecapsulationKey {
176    sk: [u8; DECAPSULATION_KEY_SIZE],
177    ek: EncapsulationKey,
178}
179
180impl DecapsulationKey {
181    /// Private key as bytes.
182    #[must_use]
183    pub fn as_bytes(&self) -> &[u8; DECAPSULATION_KEY_SIZE] {
184        &self.sk
185    }
186}
187
188impl Decapsulate for DecapsulationKey {
189    #[allow(clippy::similar_names)] // So we can use the names as in the RFC
190    fn decapsulate(&self, ct: &Ciphertext) -> SharedSecret {
191        let ct = CiphertextMessage::from(ct);
192        let (sk_m, sk_x, _pk_m, pk_x) = expand_key(&self.sk);
193
194        let ss_m = sk_m.decapsulate(&ct.ct_m);
195
196        // equal to ss_x = x25519(sk_x, ct_x)
197        let ss_x = sk_x.diffie_hellman(&ct.ct_x);
198
199        combiner(&ss_m, &ss_x, &ct.ct_x, &pk_x)
200    }
201}
202
203impl Decapsulator for DecapsulationKey {
204    type Encapsulator = EncapsulationKey;
205
206    fn encapsulator(&self) -> &EncapsulationKey {
207        &self.ek
208    }
209}
210
211impl Drop for DecapsulationKey {
212    fn drop(&mut self) {
213        #[cfg(feature = "zeroize")]
214        self.sk.zeroize();
215    }
216}
217
218impl From<[u8; DECAPSULATION_KEY_SIZE]> for DecapsulationKey {
219    fn from(sk: [u8; DECAPSULATION_KEY_SIZE]) -> Self {
220        DecapsulationKey::new(sk.as_array_ref())
221    }
222}
223
224impl Generate for DecapsulationKey {
225    fn try_generate_from_rng<R>(rng: &mut R) -> Result<Self, <R as TryRng>::Error>
226    where
227        R: TryCryptoRng + ?Sized,
228    {
229        <[u8; DECAPSULATION_KEY_SIZE]>::try_generate_from_rng(rng).map(Into::into)
230    }
231}
232
233impl KeySizeUser for DecapsulationKey {
234    type KeySize = U32;
235}
236
237impl KeyInit for DecapsulationKey {
238    fn new(key: &ArrayN<u8, 32>) -> Self {
239        let (_sk_m, _sk_x, pk_m, pk_x) = expand_key(key.as_ref());
240        let ek = EncapsulationKey { pk_m, pk_x };
241        Self { sk: key.0, ek }
242    }
243}
244
245#[cfg(feature = "zeroize")]
246impl ZeroizeOnDrop for DecapsulationKey {}
247
248fn expand_key(
249    sk: &[u8; DECAPSULATION_KEY_SIZE],
250) -> (
251    MlKem768DecapsulationKey,
252    StaticSecret,
253    MlKem768EncapsulationKey,
254    PublicKey,
255) {
256    use sha3::digest::Update;
257    let mut shaker = Shake256::default();
258    shaker.update(sk);
259    let mut expanded: Shake256Reader = shaker.finalize_xof();
260
261    let seed = read_from(&mut expanded).into();
262    let (sk_m, pk_m) = MlKem768::from_seed(seed);
263
264    let sk_x = read_from(&mut expanded);
265    let sk_x = StaticSecret::from(sk_x);
266    let pk_x = PublicKey::from(&sk_x);
267
268    (sk_m, sk_x, pk_m, pk_x)
269}
270
271/// X-Wing ciphertext.
272#[derive(Clone, PartialEq, Eq)]
273pub struct CiphertextMessage {
274    ct_m: ArrayN<u8, 1088>,
275    ct_x: PublicKey,
276}
277
278impl CiphertextMessage {
279    /// Convert the ciphertext to the following format:
280    /// ML-KEM-768 ciphertext(1088 bytes) || X25519 ciphertext(32 bytes).
281    #[must_use]
282    pub fn to_bytes(&self) -> Ciphertext {
283        let mut buffer = Ciphertext::default();
284        buffer[0..1088].copy_from_slice(&self.ct_m);
285        buffer[1088..].copy_from_slice(self.ct_x.as_bytes());
286        buffer
287    }
288}
289
290impl From<&Ciphertext> for CiphertextMessage {
291    fn from(value: &Ciphertext) -> Self {
292        let mut ct_m = [0; 1088];
293        ct_m.copy_from_slice(&value[0..1088]);
294        let mut ct_x = [0; 32];
295        ct_x.copy_from_slice(&value[1088..]);
296
297        CiphertextMessage {
298            ct_m: ct_m.into(),
299            ct_x: ct_x.into(),
300        }
301    }
302}
303
304impl From<&CiphertextMessage> for Ciphertext {
305    #[inline]
306    fn from(msg: &CiphertextMessage) -> Self {
307        msg.to_bytes()
308    }
309}
310
311impl From<CiphertextMessage> for Ciphertext {
312    #[inline]
313    fn from(msg: CiphertextMessage) -> Self {
314        Self::from(&msg)
315    }
316}
317
318/// Generate a X-Wing key pair using `OsRng`.
319#[cfg(feature = "getrandom")]
320#[must_use]
321pub fn generate_key_pair() -> (DecapsulationKey, EncapsulationKey) {
322    let sk = DecapsulationKey::generate();
323    let pk = sk.encapsulator().clone();
324    (sk, pk)
325}
326
327/// Generate a X-Wing key pair using the provided rng.
328pub fn generate_key_pair_from_rng<R: CryptoRng + ?Sized>(
329    rng: &mut R,
330) -> (DecapsulationKey, EncapsulationKey) {
331    let sk = DecapsulationKey::generate_from_rng(rng);
332    let pk = sk.encapsulator().clone();
333    (sk, pk)
334}
335
336fn combiner(
337    ss_m: &ArrayN<u8, 32>,
338    ss_x: &x25519_dalek::SharedSecret,
339    ct_x: &PublicKey,
340    pk_x: &PublicKey,
341) -> SharedSecret {
342    use sha3::Digest;
343
344    let mut hasher = Sha3_256::new();
345    hasher.update(ss_m);
346    hasher.update(ss_x);
347    hasher.update(ct_x);
348    hasher.update(pk_x.as_bytes());
349    hasher.update(X_WING_LABEL);
350    hasher.finalize()
351}
352
353fn read_from<const N: usize>(reader: &mut Shake256Reader) -> [u8; N] {
354    let mut data = [0; N];
355    reader.read(&mut data);
356    data
357}
358
359#[cfg(test)]
360mod tests {
361    use core::convert::Infallible;
362    use getrandom::SysRng;
363    use ml_kem::array::Array;
364    use rand_core::{TryCryptoRng, TryRng, UnwrapErr, utils};
365    use serde::Deserialize;
366
367    use super::*;
368
369    pub(crate) struct SeedRng {
370        pub(crate) seed: Vec<u8>,
371    }
372
373    impl SeedRng {
374        fn new(seed: Vec<u8>) -> SeedRng {
375            SeedRng { seed }
376        }
377    }
378
379    impl TryRng for SeedRng {
380        type Error = Infallible;
381
382        fn try_next_u32(&mut self) -> Result<u32, Self::Error> {
383            utils::next_word_via_fill(self)
384        }
385
386        fn try_next_u64(&mut self) -> Result<u64, Self::Error> {
387            utils::next_word_via_fill(self)
388        }
389
390        fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Self::Error> {
391            dest.copy_from_slice(&self.seed[0..dest.len()]);
392            self.seed.drain(0..dest.len());
393            Ok(())
394        }
395    }
396
397    #[derive(Deserialize)]
398    struct TestVector {
399        #[serde(deserialize_with = "hex::serde::deserialize")]
400        seed: Vec<u8>,
401
402        #[serde(deserialize_with = "hex::serde::deserialize")]
403        eseed: Vec<u8>,
404
405        #[serde(deserialize_with = "hex::serde::deserialize")]
406        ss: [u8; 32],
407
408        #[serde(deserialize_with = "hex::serde::deserialize")]
409        sk: [u8; 32],
410
411        #[serde(deserialize_with = "hex::serde::deserialize")]
412        pk: Vec<u8>, //[u8; PUBLIC_KEY_SIZE],
413
414        #[serde(deserialize_with = "hex::serde::deserialize")]
415        ct: Vec<u8>, //[u8; 1120],
416    }
417
418    impl TryCryptoRng for SeedRng {}
419
420    /// Test with test vectors from: <https://github.com/dconnolly/draft-connolly-cfrg-xwing-kem/blob/main/spec/test-vectors.json>
421    #[test]
422    fn rfc_test_vectors() {
423        let test_vectors =
424            serde_json::from_str::<Vec<TestVector>>(include_str!("test-vectors.json")).unwrap();
425
426        for test_vector in test_vectors {
427            run_test(test_vector);
428        }
429    }
430
431    fn run_test(test_vector: TestVector) {
432        let mut seed = SeedRng::new(test_vector.seed);
433        let (sk, pk) = generate_key_pair_from_rng(&mut seed);
434
435        assert_eq!(sk.as_bytes(), &test_vector.sk);
436        assert_eq!(&*pk.to_bytes(), test_vector.pk.as_slice());
437
438        let mut eseed = SeedRng::new(test_vector.eseed);
439        let (ct, ss) = pk.encapsulate_with_rng(&mut eseed);
440
441        assert_eq!(ss, test_vector.ss);
442        assert_eq!(&*ct, test_vector.ct.as_slice());
443
444        let ss = sk.decapsulate(&ct);
445        assert_eq!(ss, test_vector.ss);
446    }
447
448    #[test]
449    fn ciphertext_serialize() {
450        let mut rng = UnwrapErr(SysRng);
451
452        let ct_a = CiphertextMessage {
453            ct_m: Array::generate_from_rng(&mut rng),
454            ct_x: <[u8; 32]>::generate_from_rng(&mut rng).into(),
455        };
456
457        let bytes = ct_a.to_bytes();
458        let ct_b = CiphertextMessage::from(&bytes);
459
460        assert!(ct_a == ct_b);
461    }
462
463    #[test]
464    fn key_serialize() {
465        let sk = DecapsulationKey::generate_from_rng(&mut UnwrapErr(SysRng));
466        let pk = sk.encapsulator().clone();
467
468        let sk_bytes = sk.as_bytes();
469        let pk_bytes = pk.to_bytes();
470
471        let sk_b = DecapsulationKey::from(*sk_bytes);
472        let pk_b = EncapsulationKey::new(&pk_bytes).unwrap();
473
474        assert_eq!(sk.sk, sk_b.sk);
475        assert!(pk == pk_b);
476    }
477}