Skip to main content

x_wing/
lib.rs

1#![cfg_attr(not(test), no_std)]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3#![doc = include_str!("../README.md")]
4#![doc(
5    html_logo_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg",
6    html_favicon_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg"
7)]
8
9//! # Usage
10//!
11//! This crate implements the X-Wing Key Encapsulation Method (X-Wing-KEM) algorithm.
12//! X-Wing-KEM is a KEM in the sense that it creates an (decapsulation key, encapsulation key) pair,
13//! such that anyone can use the encapsulation key to establish a shared key with the holder of the
14//! decapsulation key. X-Wing-KEM is a general-purpose hybrid post-quantum KEM, combining x25519 and ML-KEM-768.
15#![cfg_attr(feature = "getrandom", doc = "```")]
16#![cfg_attr(not(feature = "getrandom"), doc = "```ignore")]
17//! // NOTE: requires the `getrandom` feature is enabled
18//! use x_wing::{
19//!     XWingKem,
20//!     kem::{Decapsulate, Encapsulate, Kem}
21//! };
22//!
23//! let (sk, pk) = XWingKem::generate_keypair();
24//! let (ct, sk_sender) = pk.encapsulate();
25//! let sk_receiver = sk.decapsulate(&ct);
26//! assert_eq!(sk_sender, sk_receiver);
27//! ```
28
29pub use kem::{
30    self, Decapsulate, Decapsulator, Encapsulate, Generate, InvalidKey, Kem, Key, KeyExport,
31    KeyInit, KeySizeUser, TryKeyInit,
32};
33
34use core::fmt::{self, Debug};
35use ml_kem::{
36    FromSeed, MlKem768,
37    array::{
38        Array, ArrayN, AsArrayRef,
39        sizes::{U32, U1120, U1184, U1216},
40    },
41    ml_kem_768,
42};
43use rand_core::{CryptoRng, TryCryptoRng, TryRng};
44use sha3::{
45    Sha3_256, Shake256, Shake256Reader,
46    digest::{ExtendableOutput, XofReader},
47};
48use x25519_dalek::{PublicKey, StaticSecret};
49
50#[cfg(feature = "zeroize")]
51use zeroize::{Zeroize, ZeroizeOnDrop};
52
53type MlKem768DecapsulationKey = ml_kem_768::DecapsulationKey;
54type MlKem768EncapsulationKey = ml_kem_768::EncapsulationKey;
55
56const X_WING_LABEL: &[u8; 6] = br"\.//^\";
57
58/// Size in bytes of the `EncapsulationKey`.
59pub const ENCAPSULATION_KEY_SIZE: usize = 1216;
60/// Size in bytes of the `DecapsulationKey`.
61pub const DECAPSULATION_KEY_SIZE: usize = 32;
62/// Size in bytes of the `Ciphertext`.
63pub const CIPHERTEXT_SIZE: usize = 1120;
64/// Number of bytes necessary to encapsulate a key
65pub const ENCAPSULATION_RANDOMNESS_SIZE: usize = 64;
66
67/// Serialized ciphertext.
68pub type Ciphertext = kem::Ciphertext<XWingKem>;
69/// Shared secret key.
70pub type SharedKey = Array<u8, U32>;
71
72/// X-Wing Key Encapsulation Method (X-Wing-KEM).
73#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, PartialOrd, Ord)]
74pub struct XWingKem;
75
76impl Kem for XWingKem {
77    type DecapsulationKey = DecapsulationKey;
78    type EncapsulationKey = EncapsulationKey;
79    type CiphertextSize = U1120;
80    type SharedKeySize = U32;
81}
82
83// The naming convention of variables matches the RFC.
84// ss -> Shared Secret
85// ct -> Cipher Text
86// ek -> Ephemeral Key
87// pk -> Public Key
88// sk -> Secret Key
89// Postfixes:
90// _m -> ML-Kem related key
91// _x -> x25519 related key
92
93/// X-Wing encapsulation or public key.
94#[derive(Clone, Debug, Eq, PartialEq)]
95pub struct EncapsulationKey {
96    pk_m: MlKem768EncapsulationKey,
97    pk_x: PublicKey,
98}
99
100impl EncapsulationKey {
101    /// Encapsulates with the given randomness. Uses the first 32 bytes for ML-KEM and the last 32
102    /// bytes for x25519. This is useful for testing against known vectors.
103    ///
104    /// # Warning
105    /// Do NOT use this function unless you know what you're doing. If you fail to use all uniform
106    /// random bytes even once, you can have catastrophic security failure.
107    #[doc(hidden)]
108    #[cfg_attr(not(feature = "hazmat"), doc(hidden))]
109    #[expect(clippy::must_use_candidate)]
110    pub fn encapsulate_deterministic(
111        &self,
112        randomness: &ArrayN<u8, ENCAPSULATION_RANDOMNESS_SIZE>,
113    ) -> (Ciphertext, SharedKey) {
114        // Split randomness into two 32-byte arrays
115        let (rand_m, rand_x) = randomness.split::<U32>();
116
117        // Encapsulate with ML-KEM first. This is infallible
118        let (ct_m, ss_m) = self.pk_m.encapsulate_deterministic(&rand_m);
119
120        let ek_x = StaticSecret::from(rand_x.0);
121        // Equal to ct_x = x25519(ek_x, BASE_POINT)
122        let ct_x = PublicKey::from(&ek_x);
123        // Equal to ss_x = x25519(ek_x, pk_x)
124        let ss_x = ek_x.diffie_hellman(&self.pk_x);
125
126        let ss = combiner(&ss_m, &ss_x, &ct_x, &self.pk_x);
127        let ct = CiphertextMessage { ct_m, ct_x };
128
129        (ct.into(), ss)
130    }
131}
132
133impl Encapsulate for EncapsulationKey {
134    type Kem = XWingKem;
135
136    fn encapsulate_with_rng<R>(&self, rng: &mut R) -> (Ciphertext, SharedKey)
137    where
138        R: CryptoRng + ?Sized,
139    {
140        #[allow(unused_mut)]
141        let mut randomness = Array::generate_from_rng(rng);
142        let res = self.encapsulate_deterministic(&randomness);
143
144        #[cfg(feature = "zeroize")]
145        randomness.zeroize();
146
147        res
148    }
149}
150
151impl KeySizeUser for EncapsulationKey {
152    type KeySize = U1216;
153}
154
155impl KeyExport for EncapsulationKey {
156    fn to_bytes(&self) -> Key<Self> {
157        let mut key_bytes = Key::<Self>::default();
158        let (m, x) = key_bytes.split_at_mut(1184);
159        m.copy_from_slice(&self.pk_m.to_bytes());
160        x.copy_from_slice(self.pk_x.as_bytes());
161        key_bytes
162    }
163}
164
165impl TryKeyInit for EncapsulationKey {
166    fn new(key_bytes: &Key<Self>) -> Result<Self, InvalidKey> {
167        let (m_bytes, x_bytes) = key_bytes.split_ref::<U1184>();
168
169        let pk_m = MlKem768EncapsulationKey::new(m_bytes)?;
170        let pk_x = PublicKey::from(x_bytes.0);
171
172        Ok(EncapsulationKey { pk_m, pk_x })
173    }
174}
175
176impl TryFrom<&[u8]> for EncapsulationKey {
177    type Error = InvalidKey;
178
179    fn try_from(key_bytes: &[u8]) -> Result<Self, InvalidKey> {
180        Self::new_from_slice(key_bytes)
181    }
182}
183
184/// X-Wing decapsulation key or private key.
185#[derive(Clone)]
186pub struct DecapsulationKey {
187    sk: [u8; DECAPSULATION_KEY_SIZE],
188    ek: EncapsulationKey,
189}
190
191impl DecapsulationKey {
192    /// Private key as bytes.
193    #[must_use]
194    pub fn as_bytes(&self) -> &[u8; DECAPSULATION_KEY_SIZE] {
195        &self.sk
196    }
197}
198
199impl Debug for DecapsulationKey {
200    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201        f.debug_struct("DecapsulationKey")
202            .field("ek", &self.ek)
203            .finish_non_exhaustive()
204    }
205}
206
207impl Decapsulate for DecapsulationKey {
208    #[allow(clippy::similar_names)] // So we can use the names as in the RFC
209    fn decapsulate(&self, ct: &Ciphertext) -> SharedKey {
210        let ct = CiphertextMessage::from(ct);
211        let (sk_m, sk_x, _pk_m, pk_x) = expand_key(&self.sk);
212
213        let ss_m = sk_m.decapsulate(&ct.ct_m);
214
215        // equal to ss_x = x25519(sk_x, ct_x)
216        let ss_x = sk_x.diffie_hellman(&ct.ct_x);
217
218        combiner(&ss_m, &ss_x, &ct.ct_x, &pk_x)
219    }
220}
221
222impl Decapsulator for DecapsulationKey {
223    type Kem = XWingKem;
224
225    fn encapsulation_key(&self) -> &EncapsulationKey {
226        &self.ek
227    }
228}
229
230impl Drop for DecapsulationKey {
231    fn drop(&mut self) {
232        #[cfg(feature = "zeroize")]
233        self.sk.zeroize();
234    }
235}
236
237impl From<[u8; DECAPSULATION_KEY_SIZE]> for DecapsulationKey {
238    fn from(sk: [u8; DECAPSULATION_KEY_SIZE]) -> Self {
239        DecapsulationKey::new(sk.as_array_ref())
240    }
241}
242
243impl Generate for DecapsulationKey {
244    fn try_generate_from_rng<R>(rng: &mut R) -> Result<Self, <R as TryRng>::Error>
245    where
246        R: TryCryptoRng + ?Sized,
247    {
248        <[u8; DECAPSULATION_KEY_SIZE]>::try_generate_from_rng(rng).map(Into::into)
249    }
250}
251
252impl KeySizeUser for DecapsulationKey {
253    type KeySize = U32;
254}
255
256impl KeyInit for DecapsulationKey {
257    fn new(key: &Key<Self>) -> Self {
258        let (_sk_m, _sk_x, pk_m, pk_x) = expand_key(key.as_ref());
259        let ek = EncapsulationKey { pk_m, pk_x };
260        Self { sk: key.0, ek }
261    }
262}
263
264impl KeyExport for DecapsulationKey {
265    fn to_bytes(&self) -> Key<Self> {
266        self.sk.into()
267    }
268}
269
270#[cfg(feature = "zeroize")]
271impl ZeroizeOnDrop for DecapsulationKey {}
272
273fn expand_key(
274    sk: &[u8; DECAPSULATION_KEY_SIZE],
275) -> (
276    MlKem768DecapsulationKey,
277    StaticSecret,
278    MlKem768EncapsulationKey,
279    PublicKey,
280) {
281    use sha3::digest::Update;
282    let mut shaker = Shake256::default();
283    shaker.update(sk);
284    let mut expanded: Shake256Reader = shaker.finalize_xof();
285
286    let seed = read_from(&mut expanded).into();
287    let (sk_m, pk_m) = MlKem768::from_seed(&seed);
288
289    let sk_x = read_from(&mut expanded);
290    let sk_x = StaticSecret::from(sk_x);
291    let pk_x = PublicKey::from(&sk_x);
292
293    (sk_m, sk_x, pk_m, pk_x)
294}
295
296/// X-Wing ciphertext.
297#[derive(Clone, Debug, PartialEq, Eq)]
298pub struct CiphertextMessage {
299    ct_m: ArrayN<u8, 1088>,
300    ct_x: PublicKey,
301}
302
303impl CiphertextMessage {
304    /// Convert the ciphertext to the following format:
305    /// ML-KEM-768 ciphertext(1088 bytes) || X25519 ciphertext(32 bytes).
306    #[must_use]
307    pub fn to_bytes(&self) -> Ciphertext {
308        let mut buffer = Ciphertext::default();
309        buffer[0..1088].copy_from_slice(&self.ct_m);
310        buffer[1088..].copy_from_slice(self.ct_x.as_bytes());
311        buffer
312    }
313}
314
315impl From<&Ciphertext> for CiphertextMessage {
316    fn from(value: &Ciphertext) -> Self {
317        let mut ct_m = [0; 1088];
318        ct_m.copy_from_slice(&value[0..1088]);
319        let mut ct_x = [0; 32];
320        ct_x.copy_from_slice(&value[1088..]);
321
322        CiphertextMessage {
323            ct_m: ct_m.into(),
324            ct_x: ct_x.into(),
325        }
326    }
327}
328
329impl From<&CiphertextMessage> for Ciphertext {
330    #[inline]
331    fn from(msg: &CiphertextMessage) -> Self {
332        msg.to_bytes()
333    }
334}
335
336impl From<CiphertextMessage> for Ciphertext {
337    #[inline]
338    fn from(msg: CiphertextMessage) -> Self {
339        Self::from(&msg)
340    }
341}
342
343fn combiner(
344    ss_m: &ArrayN<u8, 32>,
345    ss_x: &x25519_dalek::SharedSecret,
346    ct_x: &PublicKey,
347    pk_x: &PublicKey,
348) -> SharedKey {
349    use sha3::Digest;
350
351    let mut hasher = Sha3_256::new();
352    hasher.update(ss_m);
353    hasher.update(ss_x);
354    hasher.update(ct_x);
355    hasher.update(pk_x.as_bytes());
356    hasher.update(X_WING_LABEL);
357    hasher.finalize()
358}
359
360fn read_from<const N: usize>(reader: &mut Shake256Reader) -> [u8; N] {
361    let mut data = [0; N];
362    reader.read(&mut data);
363    data
364}
365
366#[cfg(test)]
367mod tests {
368    use crate::{Kem, XWingKem};
369    use core::convert::Infallible;
370    use getrandom::SysRng;
371    use ml_kem::array::Array;
372    use rand_core::{TryCryptoRng, TryRng, UnwrapErr, utils};
373    use serde::Deserialize;
374
375    use super::*;
376
377    pub(crate) struct SeedRng {
378        pub(crate) seed: Vec<u8>,
379    }
380
381    impl SeedRng {
382        fn new(seed: Vec<u8>) -> SeedRng {
383            SeedRng { seed }
384        }
385    }
386
387    impl TryRng for SeedRng {
388        type Error = Infallible;
389
390        fn try_next_u32(&mut self) -> Result<u32, Self::Error> {
391            utils::next_word_via_fill(self)
392        }
393
394        fn try_next_u64(&mut self) -> Result<u64, Self::Error> {
395            utils::next_word_via_fill(self)
396        }
397
398        fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Self::Error> {
399            dest.copy_from_slice(&self.seed[0..dest.len()]);
400            self.seed.drain(0..dest.len());
401            Ok(())
402        }
403    }
404
405    #[derive(Deserialize)]
406    struct TestVector {
407        #[serde(deserialize_with = "hex::serde::deserialize")]
408        seed: Vec<u8>,
409
410        #[serde(deserialize_with = "hex::serde::deserialize")]
411        eseed: Vec<u8>,
412
413        #[serde(deserialize_with = "hex::serde::deserialize")]
414        ss: [u8; 32],
415
416        #[serde(deserialize_with = "hex::serde::deserialize")]
417        sk: [u8; 32],
418
419        #[serde(deserialize_with = "hex::serde::deserialize")]
420        pk: Vec<u8>, //[u8; PUBLIC_KEY_SIZE],
421
422        #[serde(deserialize_with = "hex::serde::deserialize")]
423        ct: Vec<u8>, //[u8; 1120],
424    }
425
426    impl TryCryptoRng for SeedRng {}
427
428    /// Test with test vectors from: <https://github.com/dconnolly/draft-connolly-cfrg-xwing-kem/blob/main/spec/test-vectors.json>
429    #[test]
430    fn rfc_test_vectors() {
431        let test_vectors =
432            serde_json::from_str::<Vec<TestVector>>(include_str!("test-vectors.json")).unwrap();
433
434        for test_vector in test_vectors {
435            run_test(test_vector);
436        }
437    }
438
439    fn run_test(test_vector: TestVector) {
440        let mut seed = SeedRng::new(test_vector.seed);
441        let (sk, pk) = XWingKem::generate_keypair_from_rng(&mut seed);
442
443        assert_eq!(sk.as_bytes(), &test_vector.sk);
444        assert_eq!(&*pk.to_bytes(), test_vector.pk.as_slice());
445
446        let mut eseed = SeedRng::new(test_vector.eseed);
447        let (ct, ss) = pk.encapsulate_with_rng(&mut eseed);
448
449        assert_eq!(ss, test_vector.ss);
450        assert_eq!(&*ct, test_vector.ct.as_slice());
451
452        let ss = sk.decapsulate(&ct);
453        assert_eq!(ss, test_vector.ss);
454    }
455
456    #[test]
457    fn ciphertext_serialize() {
458        let mut rng = UnwrapErr(SysRng);
459
460        let ct_a = CiphertextMessage {
461            ct_m: Array::generate_from_rng(&mut rng),
462            ct_x: <[u8; 32]>::generate_from_rng(&mut rng).into(),
463        };
464
465        let bytes = ct_a.to_bytes();
466        let ct_b = CiphertextMessage::from(&bytes);
467
468        assert!(ct_a == ct_b);
469    }
470
471    #[test]
472    #[cfg(feature = "getrandom")]
473    fn key_serialize() {
474        let (sk, pk) = XWingKem::generate_keypair();
475
476        let sk_bytes = sk.as_bytes();
477        let pk_bytes = pk.to_bytes();
478
479        let sk_b = DecapsulationKey::from(*sk_bytes);
480        let pk_b = EncapsulationKey::new(&pk_bytes).unwrap();
481
482        assert_eq!(sk.sk, sk_b.sk);
483        assert!(pk == pk_b);
484    }
485}