libcrux_kem/
kem.rs

1//! # Key Encapsulation Mechanism
2//!
3//! A KEM interface.
4//!
5//! For ECDH structs, check the [`libcrux_ecdh`] crate.
6//!
7//! Available algorithms:
8//! * [`Algorithm::X25519`]\: x25519 ECDH KEM. Also see [`libcrux::ecdh#x25519`].
9//! * [`Algorithm::Secp256r1`]\: NIST P256 ECDH KEM. Also see [`libcrux::ecdh#P256`].
10//! * [`Algorithm::MlKem512`]\: ML-KEM 512 from [FIPS 203].
11//! * [`Algorithm::MlKem768`]\: ML-KEM 768 from [FIPS 203].
12//! * [`Algorithm::MlKem1024`]\: ML-KEM 1024 from [FIPS 203].
13//! * [`Algorithm::X25519MlKem768Draft00`]\: Hybrid x25519 - ML-KEM 768 [draft kem for hpke](https://www.ietf.org/archive/id/draft-westerbaan-cfrg-hpke-xyber768d00-00.html).
14//! * [`Algorithm::XWingKemDraft06`]\: Hybrid x25519 - ML-KEM 768 [draft xwing kem for hpke](https://www.ietf.org/archive/id/draft-connolly-cfrg-xwing-kem-06.html).
15//!
16//! ```
17//! use libcrux_kem::*;
18//! use rand::TryRngCore;
19//! use rand::rngs::OsRng;
20//!
21//! let mut os_rng = OsRng;
22//! let mut rng = os_rng.unwrap_mut();
23//!
24//! let (sk_a, pk_a) = key_gen(Algorithm::MlKem768, &mut rng).unwrap();
25//! let received_pk = pk_a.encode();
26//!
27//! let pk = PublicKey::decode(Algorithm::MlKem768, &received_pk).unwrap();
28//! let (ss_b, ct_b) = pk.encapsulate(&mut rng).unwrap();
29//! let received_ct = ct_b.encode();
30//!
31//! let ct_a = Ct::decode(Algorithm::MlKem768, &received_ct).unwrap();
32//! let ss_a = ct_a.decapsulate(&sk_a).unwrap();
33//! assert_eq!(ss_b.encode(), ss_a.encode());
34//! ```
35//!
36//! [FIPS 203]: https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.203.ipd.pdf
37#![no_std]
38
39extern crate alloc;
40
41use alloc::vec::Vec;
42
43use rand::{CryptoRng, TryRngCore};
44
45use libcrux_ecdh::{p256_derive, p256_secret_to_public, x25519_derive, x25519_secret_to_public};
46use libcrux_ecdh::{
47    P256PrivateKey, P256PublicKey, P256SharedSecret, X25519PrivateKey, X25519PublicKey,
48    X25519SharedSecret,
49};
50use libcrux_sha3 as sha3;
51
52use libcrux_ml_kem::{mlkem1024, mlkem512, mlkem768};
53
54// TODO: These functions are currently exposed simply in order to make NIST KAT
55// testing possible without an implementation of the NIST AES-CTR DRBG. Remove them
56// (and change the visibility of the exported functions to pub(crate)) the
57// moment we have an implementation of one. This is tracked by:
58// https://github.com/cryspen/libcrux/issues/36
59#[cfg(feature = "tests")]
60pub mod deterministic {
61    pub use libcrux_ml_kem::mlkem1024::decapsulate as mlkem1024_decapsulate_derand;
62    pub use libcrux_ml_kem::mlkem1024::encapsulate as mlkem1024_encapsulate_derand;
63    pub use libcrux_ml_kem::mlkem1024::generate_key_pair as mlkem1024_generate_keypair_derand;
64    pub use libcrux_ml_kem::mlkem512::decapsulate as mlkem512_decapsulate_derand;
65    pub use libcrux_ml_kem::mlkem512::encapsulate as mlkem512_encapsulate_derand;
66    pub use libcrux_ml_kem::mlkem512::generate_key_pair as mlkem512_generate_keypair_derand;
67    pub use libcrux_ml_kem::mlkem768::decapsulate as mlkem768_decapsulate_derand;
68    pub use libcrux_ml_kem::mlkem768::encapsulate as mlkem768_encapsulate_derand;
69    pub use libcrux_ml_kem::mlkem768::generate_key_pair as mlkem768_generate_keypair_derand;
70}
71
72use libcrux_ml_kem::MlKemSharedSecret;
73pub use libcrux_ml_kem::{
74    mlkem1024::{MlKem1024Ciphertext, MlKem1024PrivateKey, MlKem1024PublicKey},
75    mlkem512::{MlKem512Ciphertext, MlKem512PrivateKey, MlKem512PublicKey},
76    mlkem768::{MlKem768Ciphertext, MlKem768PrivateKey, MlKem768PublicKey},
77    MlKemCiphertext, MlKemKeyPair,
78};
79
80#[cfg(feature = "tests")]
81pub use libcrux_ml_kem::{
82    mlkem1024::validate_public_key as ml_kem1024_validate_public_key,
83    mlkem512::validate_public_key as ml_kem512_validate_public_key,
84    mlkem768::validate_public_key as ml_kem768_validate_public_key,
85};
86use xwing::XWingSharedSecret;
87
88/// KEM Algorithms
89///
90/// This includes named elliptic curves or dedicated KEM algorithms like ML-KEM.
91#[derive(Clone, Copy, PartialEq, Debug)]
92pub enum Algorithm {
93    X25519,
94    X448,
95    Secp256r1,
96    Secp384r1,
97    Secp521r1,
98    MlKem512,
99    MlKem768,
100    X25519MlKem768Draft00,
101    XWingKemDraft06,
102    MlKem1024,
103}
104
105#[derive(Debug, PartialEq, Eq)]
106pub enum Error {
107    EcDhError(libcrux_ecdh::Error),
108    KeyGen,
109    Encapsulate,
110    Decapsulate,
111    UnsupportedAlgorithm,
112    InvalidPrivateKey,
113    InvalidPublicKey,
114    InvalidCiphertext,
115}
116
117impl TryFrom<Algorithm> for libcrux_ecdh::Algorithm {
118    type Error = &'static str;
119
120    fn try_from(value: Algorithm) -> Result<Self, Self::Error> {
121        match value {
122            Algorithm::X25519 => Ok(libcrux_ecdh::Algorithm::X25519),
123            Algorithm::X448 => Ok(libcrux_ecdh::Algorithm::X448),
124            Algorithm::Secp256r1 => Ok(libcrux_ecdh::Algorithm::P256),
125            Algorithm::Secp384r1 => Ok(libcrux_ecdh::Algorithm::P384),
126            Algorithm::Secp521r1 => Ok(libcrux_ecdh::Algorithm::P521),
127            Algorithm::X25519MlKem768Draft00 => Ok(libcrux_ecdh::Algorithm::X25519),
128            Algorithm::XWingKemDraft06 => Ok(libcrux_ecdh::Algorithm::X25519),
129            _ => Err("provided algorithm is not an ECDH algorithm"),
130        }
131    }
132}
133
134impl From<libcrux_ecdh::Error> for Error {
135    fn from(value: libcrux_ecdh::Error) -> Self {
136        Error::EcDhError(value)
137    }
138}
139
140/// An ML-KEM768-x25519 private key.
141pub struct X25519MlKem768Draft00PrivateKey {
142    pub mlkem: MlKem768PrivateKey,
143    pub x25519: X25519PrivateKey,
144}
145
146impl X25519MlKem768Draft00PrivateKey {
147    pub fn decode(bytes: &[u8]) -> Result<Self, Error> {
148        Ok(Self {
149            mlkem: bytes[..2400]
150                .try_into()
151                .map_err(|_| Error::InvalidPrivateKey)?,
152            x25519: bytes[2400..]
153                .try_into()
154                .map_err(|_| Error::InvalidPrivateKey)?,
155        })
156    }
157
158    pub fn encode(&self) -> Vec<u8> {
159        let mut out = self.mlkem.as_ref().to_vec();
160        out.extend_from_slice(&self.x25519.0);
161        out
162    }
163}
164
165/// An X-Wing private key.
166pub struct XWingKemDraft06PrivateKey {
167    pub seed: [u8; 32],
168}
169
170impl XWingKemDraft06PrivateKey {
171    pub fn decode(bytes: &[u8]) -> Result<Self, Error> {
172        Ok(Self {
173            seed: bytes.try_into().map_err(|_| Error::InvalidPrivateKey)?,
174        })
175    }
176
177    pub fn encode(&self) -> Vec<u8> {
178        self.seed.as_ref().to_vec()
179    }
180}
181
182/// A KEM private key.
183pub enum PrivateKey {
184    X25519(X25519PrivateKey),
185    P256(P256PrivateKey),
186    MlKem512(MlKem512PrivateKey),
187    MlKem768(MlKem768PrivateKey),
188    X25519MlKem768Draft00(X25519MlKem768Draft00PrivateKey),
189    XWingKemDraft06(XWingKemDraft06PrivateKey),
190    MlKem1024(MlKem1024PrivateKey),
191}
192
193/// An ML-KEM768-x25519 public key.
194pub struct X25519MlKem768Draft00PublicKey {
195    pub mlkem: MlKem768PublicKey,
196    pub x25519: X25519PublicKey,
197}
198
199impl X25519MlKem768Draft00PublicKey {
200    pub fn decode(bytes: &[u8]) -> Result<Self, Error> {
201        Ok(Self {
202            mlkem: {
203                let key = MlKem768PublicKey::try_from(&bytes[..1184])
204                    .map_err(|_| Error::InvalidPublicKey)?;
205                if !mlkem768::validate_public_key(&key) {
206                    return Err(Error::InvalidPublicKey);
207                }
208                key
209            },
210            x25519: bytes[1184..]
211                .try_into()
212                .map_err(|_| Error::InvalidPublicKey)?,
213        })
214    }
215
216    pub fn encode(&self) -> Vec<u8> {
217        let mut out = self.mlkem.as_ref().to_vec();
218        out.extend_from_slice(&self.x25519.0);
219        out
220    }
221}
222
223/// An X-Wing public key.
224pub struct XWingKemDraft06PublicKey {
225    pub pk_m: MlKem768PublicKey,
226    pub pk_x: X25519PublicKey,
227}
228
229impl XWingKemDraft06PublicKey {
230    pub fn decode(bytes: &[u8]) -> Result<Self, Error> {
231        Ok(Self {
232            pk_m: {
233                let key = MlKem768PublicKey::try_from(&bytes[0..1184])
234                    .map_err(|_| Error::InvalidPublicKey)?;
235                if !mlkem768::validate_public_key(&key) {
236                    return Err(Error::InvalidPublicKey);
237                }
238                key
239            },
240            pk_x: bytes[1184..]
241                .try_into()
242                .map_err(|_| Error::InvalidPublicKey)?,
243        })
244    }
245
246    pub fn encode(&self) -> Vec<u8> {
247        let mut out = self.pk_m.as_ref().to_vec();
248        out.extend_from_slice(self.pk_x.0.to_vec().as_ref());
249        out
250    }
251}
252
253/// A KEM public key.
254pub enum PublicKey {
255    X25519(X25519PublicKey),
256    P256(P256PublicKey),
257    MlKem512(MlKem512PublicKey),
258    MlKem768(MlKem768PublicKey),
259    X25519MlKem768Draft00(X25519MlKem768Draft00PublicKey),
260    XWingKemDraft06(XWingKemDraft06PublicKey),
261    MlKem1024(MlKem1024PublicKey),
262}
263
264/// A KEM ciphertext
265pub enum Ct {
266    X25519(X25519PublicKey),
267    P256(P256PublicKey),
268    MlKem512(MlKem512Ciphertext),
269    MlKem768(MlKem768Ciphertext),
270    X25519MlKem768Draft00(MlKem768Ciphertext, X25519PublicKey),
271    XWingKemDraft06(MlKem768Ciphertext, X25519PublicKey),
272    MlKem1024(MlKem1024Ciphertext),
273}
274
275impl Ct {
276    /// Decapsulate the shared secret in `ct` using the private key `sk`.
277    pub fn decapsulate(&self, sk: &PrivateKey) -> Result<Ss, Error> {
278        match self {
279            Ct::X25519(ct) => {
280                let sk = if let PrivateKey::X25519(k) = sk {
281                    k
282                } else {
283                    return Err(Error::InvalidPrivateKey);
284                };
285                x25519_derive(ct, sk).map_err(|e| e.into()).map(Ss::X25519)
286            }
287            Ct::P256(ct) => {
288                let sk = if let PrivateKey::P256(k) = sk {
289                    k
290                } else {
291                    return Err(Error::InvalidPrivateKey);
292                };
293                p256_derive(ct, sk).map_err(|e| e.into()).map(Ss::P256)
294            }
295            Ct::MlKem512(ct) => {
296                let sk = if let PrivateKey::MlKem512(k) = sk {
297                    k
298                } else {
299                    return Err(Error::InvalidPrivateKey);
300                };
301                let ss = libcrux_ml_kem::mlkem512::decapsulate(sk, ct);
302
303                Ok(Ss::MlKem768(ss))
304            }
305            Ct::MlKem768(ct) => {
306                let sk = if let PrivateKey::MlKem768(k) = sk {
307                    k
308                } else {
309                    return Err(Error::InvalidPrivateKey);
310                };
311                let ss = mlkem768::decapsulate(sk, ct);
312
313                Ok(Ss::MlKem768(ss))
314            }
315            Ct::X25519MlKem768Draft00(kct, xct) => {
316                let (ksk, xsk) =
317                    if let PrivateKey::X25519MlKem768Draft00(X25519MlKem768Draft00PrivateKey {
318                        mlkem: kk,
319                        x25519: xk,
320                    }) = sk
321                    {
322                        (kk, xk)
323                    } else {
324                        return Err(Error::InvalidPrivateKey);
325                    };
326                let kss = mlkem768::decapsulate(ksk, kct);
327                let xss = x25519_derive(xct, xsk)?;
328
329                Ok(Ss::X25519MlKem768Draft00(kss, xss))
330            }
331
332            Ct::XWingKemDraft06(ct_m, ct_x) => {
333                let seed =
334                    if let PrivateKey::XWingKemDraft06(XWingKemDraft06PrivateKey { seed }) = sk {
335                        seed
336                    } else {
337                        return Err(Error::InvalidPrivateKey);
338                    };
339
340                let (kp_m, pk_x, sk_x) = xwing::expand_decap_key(seed)?;
341
342                let ss_m = mlkem768::decapsulate(kp_m.private_key(), ct_m);
343                let ss_x = x25519_derive(ct_x, &sk_x)?;
344
345                Ok(Ss::XWingKemDraft06(xwing::combiner(
346                    &ss_m,
347                    ss_x.as_ref(),
348                    ct_x.as_ref(),
349                    pk_x.as_ref(),
350                )))
351            }
352
353            Ct::MlKem1024(ct) => {
354                let sk = if let PrivateKey::MlKem1024(k) = sk {
355                    k
356                } else {
357                    return Err(Error::InvalidPrivateKey);
358                };
359                let ss = libcrux_ml_kem::mlkem1024::decapsulate(sk, ct);
360
361                Ok(Ss::MlKem1024(ss))
362            }
363        }
364    }
365}
366
367/// A KEM shared secret
368pub enum Ss {
369    X25519(X25519SharedSecret),
370    P256(P256SharedSecret),
371    MlKem512(MlKemSharedSecret),
372    MlKem768(MlKemSharedSecret),
373    X25519MlKem768Draft00(MlKemSharedSecret, X25519SharedSecret),
374    XWingKemDraft06(XWingSharedSecret),
375    MlKem1024(MlKemSharedSecret),
376}
377
378impl PrivateKey {
379    /// Encode a private key.
380    pub fn encode(&self) -> Vec<u8> {
381        match self {
382            PrivateKey::X25519(k) => k.0.to_vec(),
383            PrivateKey::P256(k) => k.0.to_vec(),
384            PrivateKey::MlKem512(k) => k.as_slice().to_vec(),
385            PrivateKey::MlKem768(k) => k.as_slice().to_vec(),
386            PrivateKey::X25519MlKem768Draft00(k) => k.encode(),
387            PrivateKey::XWingKemDraft06(k) => k.encode(),
388            PrivateKey::MlKem1024(k) => k.as_slice().to_vec(),
389        }
390    }
391
392    /// Decode a private key.
393    pub fn decode(alg: Algorithm, bytes: &[u8]) -> Result<Self, Error> {
394        match alg {
395            Algorithm::X25519 => bytes
396                .try_into()
397                .map_err(|_| Error::InvalidPrivateKey)
398                .map(Self::X25519),
399            Algorithm::Secp256r1 => bytes
400                .try_into()
401                .map_err(|_| Error::InvalidPrivateKey)
402                .map(Self::P256),
403            Algorithm::MlKem512 => bytes
404                .try_into()
405                .map_err(|_| Error::InvalidPrivateKey)
406                .map(Self::MlKem512),
407            Algorithm::MlKem768 => bytes
408                .try_into()
409                .map_err(|_| Error::InvalidPrivateKey)
410                .map(Self::MlKem768),
411            Algorithm::X25519MlKem768Draft00 => X25519MlKem768Draft00PrivateKey::decode(bytes)
412                .map_err(|_| Error::InvalidPrivateKey)
413                .map(Self::X25519MlKem768Draft00),
414            Algorithm::XWingKemDraft06 => {
415                let pk = XWingKemDraft06PrivateKey::decode(bytes)
416                    .map_err(|_| Error::InvalidPrivateKey)?;
417                Ok(Self::XWingKemDraft06(pk))
418            }
419
420            Algorithm::MlKem1024 => bytes
421                .try_into()
422                .map_err(|_| Error::InvalidPrivateKey)
423                .map(Self::MlKem1024),
424            _ => Err(Error::UnsupportedAlgorithm),
425        }
426    }
427}
428
429impl PublicKey {
430    /// Encapsulate a shared secret to the provided `pk` and return the `(Key, Enc)` tuple.
431    pub fn encapsulate(&self, rng: &mut impl CryptoRng) -> Result<(Ss, Ct), Error> {
432        match self {
433            PublicKey::X25519(pk) => {
434                let (new_sk, new_pk) = libcrux_ecdh::x25519_key_gen(rng)?;
435                let gxy = x25519_derive(pk, &new_sk)?;
436                Ok((Ss::X25519(gxy), Ct::X25519(new_pk)))
437            }
438            PublicKey::P256(pk) => {
439                let (new_sk, new_pk) = libcrux_ecdh::p256_key_gen(rng)?;
440                let gxy = p256_derive(pk, &new_sk)?;
441                Ok((Ss::P256(gxy), Ct::P256(new_pk)))
442            }
443
444            PublicKey::MlKem512(pk) => {
445                let seed = mlkem_rand(rng)?;
446                let (ct, ss) = libcrux_ml_kem::mlkem512::encapsulate(pk, seed);
447                Ok((Ss::MlKem512(ss), Ct::MlKem512(ct)))
448            }
449
450            PublicKey::MlKem768(pk) => {
451                let seed = mlkem_rand(rng)?;
452                let (ct, ss) = mlkem768::encapsulate(pk, seed);
453                Ok((Ss::MlKem768(ss), Ct::MlKem768(ct)))
454            }
455
456            PublicKey::MlKem1024(pk) => {
457                let seed = mlkem_rand(rng)?;
458                let (ct, ss) = mlkem1024::encapsulate(pk, seed);
459                Ok((Ss::MlKem1024(ss), Ct::MlKem1024(ct)))
460            }
461
462            PublicKey::X25519MlKem768Draft00(X25519MlKem768Draft00PublicKey {
463                mlkem: kpk,
464                x25519: xpk,
465            }) => {
466                let seed = mlkem_rand(rng)?;
467                let (mlkem_ct, mlkem_ss) = mlkem768::encapsulate(kpk, seed);
468                let (x_sk, x_pk) = libcrux_ecdh::x25519_key_gen(rng)?;
469                let x_ss = x25519_derive(xpk, &x_sk)?;
470
471                Ok((
472                    Ss::X25519MlKem768Draft00(mlkem_ss, x_ss),
473                    Ct::X25519MlKem768Draft00(mlkem_ct, x_pk),
474                ))
475            }
476
477            PublicKey::XWingKemDraft06(XWingKemDraft06PublicKey { pk_m, pk_x }) => {
478                let seed = mlkem_rand(rng)?;
479                let (ct_m, ss_m) = mlkem768::encapsulate(pk_m, seed);
480                let (ek_x, ct_x) = libcrux_ecdh::x25519_key_gen(rng)?;
481                let ss_x = x25519_derive(pk_x, &ek_x)?;
482
483                Ok((
484                    Ss::XWingKemDraft06(xwing::combiner(
485                        &ss_m,
486                        ss_x.as_ref(),
487                        ct_x.as_ref(),
488                        pk_x.as_ref(),
489                    )),
490                    Ct::XWingKemDraft06(ct_m, X25519PublicKey(ct_x.0)),
491                ))
492            }
493        }
494    }
495
496    /// Encapsulate a shared secret to the provided `pk` and return the `(Key, Enc)` tuple.
497    pub fn encapsulate_derand(&self, seed: &[u8]) -> Result<(Ss, Ct), Error> {
498        match self {
499            PublicKey::X25519(pk) => {
500                let new_sk = X25519PrivateKey::try_from(seed)?; // clamps
501                let new_pk = x25519_secret_to_public(&new_sk)?;
502                let gxy = x25519_derive(pk, &new_sk)?;
503                Ok((Ss::X25519(gxy), Ct::X25519(new_pk.try_into().unwrap())))
504            }
505
506            PublicKey::P256(pk) => {
507                let new_sk = P256PrivateKey::try_from(seed)?;
508                let new_pk = p256_secret_to_public(&new_sk)?;
509
510                let gxy = p256_derive(pk, &new_sk)?;
511                Ok((Ss::P256(gxy), Ct::P256(new_pk)))
512            }
513
514            PublicKey::MlKem512(pk) => {
515                let (ct, ss) = libcrux_ml_kem::mlkem512::encapsulate(
516                    pk,
517                    seed.try_into().map_err(|_| Error::KeyGen)?,
518                );
519                Ok((Ss::MlKem512(ss), Ct::MlKem512(ct)))
520            }
521
522            PublicKey::MlKem768(pk) => {
523                let (ct, ss) =
524                    mlkem768::encapsulate(pk, seed.try_into().map_err(|_| Error::KeyGen)?);
525                Ok((Ss::MlKem768(ss), Ct::MlKem768(ct)))
526            }
527
528            PublicKey::MlKem1024(pk) => {
529                let (ct, ss) =
530                    mlkem1024::encapsulate(pk, seed.try_into().map_err(|_| Error::KeyGen)?);
531                Ok((Ss::MlKem1024(ss), Ct::MlKem1024(ct)))
532            }
533
534            PublicKey::X25519MlKem768Draft00(X25519MlKem768Draft00PublicKey {
535                mlkem: kpk,
536                x25519: xpk,
537            }) => {
538                // seed = mlkem_seed || x_sk
539                let (mlkem_ct, mlkem_ss) =
540                    mlkem768::encapsulate(kpk, seed[0..32].try_into().map_err(|_| Error::KeyGen)?);
541
542                let x_sk = X25519PrivateKey::try_from(&seed[32..])?; // clamps
543                let x_pk = x25519_secret_to_public(&x_sk)?;
544
545                let x_ss = x25519_derive(xpk, &x_sk)?;
546
547                Ok((
548                    Ss::X25519MlKem768Draft00(mlkem_ss, x_ss),
549                    Ct::X25519MlKem768Draft00(mlkem_ct, x_pk),
550                ))
551            }
552
553            PublicKey::XWingKemDraft06(XWingKemDraft06PublicKey { pk_m, pk_x }) => {
554                let (ct_m, ss_m) =
555                    mlkem768::encapsulate(pk_m, seed[0..32].try_into().map_err(|_| Error::KeyGen)?);
556
557                let ek_x = X25519PrivateKey::try_from(&seed[32..])?; // clamps
558                let ct_x = x25519_secret_to_public(&ek_x)?;
559
560                let ss_x = x25519_derive(pk_x, &ek_x)?;
561
562                Ok((
563                    Ss::XWingKemDraft06(xwing::combiner(
564                        &ss_m,
565                        ss_x.as_ref(),
566                        ct_x.as_ref(),
567                        pk_x.as_ref(),
568                    )),
569                    Ct::XWingKemDraft06(ct_m, X25519PublicKey(ct_x.0)),
570                ))
571            }
572        }
573    }
574
575    /// Encode public key.
576    pub fn encode(&self) -> Vec<u8> {
577        match self {
578            PublicKey::X25519(k) => k.0.to_vec(),
579            PublicKey::P256(k) => k.0.to_vec(),
580            PublicKey::MlKem512(k) => k.as_ref().to_vec(),
581            PublicKey::MlKem768(k) => k.as_ref().to_vec(),
582            PublicKey::X25519MlKem768Draft00(k) => k.encode(),
583            PublicKey::XWingKemDraft06(k) => k.encode(),
584            PublicKey::MlKem1024(k) => k.as_ref().to_vec(),
585        }
586    }
587
588    /// Decode a public key.
589    pub fn decode(alg: Algorithm, bytes: &[u8]) -> Result<Self, Error> {
590        match alg {
591            Algorithm::X25519 => bytes
592                .try_into()
593                .map(Self::X25519)
594                .map_err(|_| Error::InvalidPublicKey),
595            Algorithm::Secp256r1 => bytes
596                .try_into()
597                .map(Self::P256)
598                .map_err(|_| Error::InvalidPublicKey),
599            Algorithm::MlKem512 => {
600                let key =
601                    MlKem512PublicKey::try_from(bytes).map_err(|_| Error::InvalidPublicKey)?;
602                if !mlkem512::validate_public_key(&key) {
603                    return Err(Error::InvalidPublicKey);
604                }
605                Ok(Self::MlKem512(key))
606            }
607            Algorithm::MlKem768 => {
608                let key =
609                    MlKem768PublicKey::try_from(bytes).map_err(|_| Error::InvalidPublicKey)?;
610                if !mlkem768::validate_public_key(&key) {
611                    return Err(Error::InvalidPublicKey);
612                }
613                Ok(Self::MlKem768(key))
614            }
615            Algorithm::X25519MlKem768Draft00 => {
616                X25519MlKem768Draft00PublicKey::decode(bytes).map(Self::X25519MlKem768Draft00)
617            }
618            Algorithm::XWingKemDraft06 => {
619                XWingKemDraft06PublicKey::decode(bytes).map(Self::XWingKemDraft06)
620            }
621            Algorithm::MlKem1024 => {
622                let key =
623                    MlKem1024PublicKey::try_from(bytes).map_err(|_| Error::InvalidPublicKey)?;
624                if !mlkem1024::validate_public_key(&key) {
625                    return Err(Error::InvalidPublicKey);
626                }
627                Ok(Self::MlKem1024(key))
628            }
629            _ => Err(Error::UnsupportedAlgorithm),
630        }
631    }
632}
633
634impl Ss {
635    /// Encode a shared secret.
636    pub fn encode(&self) -> Vec<u8> {
637        match self {
638            Ss::X25519(k) => k.0.to_vec(),
639            Ss::P256(k) => k.0.to_vec(),
640            Ss::MlKem512(k) => k.as_ref().to_vec(),
641            Ss::MlKem768(k) => k.as_ref().to_vec(),
642            Ss::X25519MlKem768Draft00(kk, xk) => {
643                let mut out = kk.to_vec();
644                out.extend_from_slice(xk.0.as_ref());
645                out
646            }
647            Ss::XWingKemDraft06(ss) => ss.value.into(),
648            Ss::MlKem1024(k) => k.as_ref().to_vec(),
649        }
650    }
651}
652
653impl Ct {
654    /// Encode a ciphertext.
655    pub fn encode(&self) -> Vec<u8> {
656        match self {
657            Ct::X25519(k) => k.0.to_vec(),
658            Ct::P256(k) => k.0.to_vec(),
659            Ct::MlKem512(k) => k.as_ref().to_vec(),
660            Ct::MlKem768(k) => k.as_ref().to_vec(),
661            Ct::X25519MlKem768Draft00(kk, xk) => {
662                let mut out = kk.as_ref().to_vec();
663                out.extend_from_slice(xk.0.as_ref());
664                out
665            }
666            Ct::XWingKemDraft06(ct_m, ct_x) => {
667                let mut out = ct_m.as_ref().to_vec();
668                out.extend_from_slice(ct_x.as_ref());
669                out
670            }
671            Ct::MlKem1024(k) => k.as_ref().to_vec(),
672        }
673    }
674
675    /// Decode a ciphertext.
676    pub fn decode(alg: Algorithm, bytes: &[u8]) -> Result<Self, Error> {
677        match alg {
678            Algorithm::X25519 => bytes
679                .try_into()
680                .map_err(|_| Error::InvalidCiphertext)
681                .map(Self::X25519),
682            Algorithm::Secp256r1 => bytes
683                .try_into()
684                .map_err(|_| Error::InvalidCiphertext)
685                .map(Self::P256),
686            Algorithm::MlKem512 => bytes
687                .try_into()
688                .map_err(|_| Error::InvalidCiphertext)
689                .map(Self::MlKem512),
690            Algorithm::MlKem768 => bytes
691                .try_into()
692                .map_err(|_| Error::InvalidCiphertext)
693                .map(Self::MlKem768),
694            Algorithm::X25519MlKem768Draft00 => {
695                let key: [u8; MlKem768Ciphertext::len() + 32] =
696                    bytes.try_into().map_err(|_| Error::InvalidCiphertext)?;
697                let (kct, xct) = key.split_at(1088);
698                Ok(Self::X25519MlKem768Draft00(
699                    kct.try_into().map_err(|_| Error::InvalidCiphertext)?,
700                    xct.try_into().map_err(|_| Error::InvalidCiphertext)?,
701                ))
702            }
703            Algorithm::XWingKemDraft06 => {
704                let key: [u8; MlKem768Ciphertext::len() + 32] =
705                    bytes.try_into().map_err(|_| Error::InvalidCiphertext)?;
706                let (ct_m, ct_x) = key.split_at(MlKem768Ciphertext::len());
707                Ok(Self::XWingKemDraft06(
708                    ct_m.try_into().map_err(|_| Error::InvalidCiphertext)?,
709                    ct_x.try_into().map_err(|_| Error::InvalidCiphertext)?,
710                ))
711            }
712            Algorithm::MlKem1024 => bytes
713                .try_into()
714                .map_err(|_| Error::InvalidCiphertext)
715                .map(Self::MlKem1024),
716            _ => Err(Error::UnsupportedAlgorithm),
717        }
718    }
719}
720
721/// Compute the public key for a private key of the given [`Algorithm`].
722/// Applicable only to X25519 and secp256r1.
723pub fn secret_to_public(alg: Algorithm, sk: impl AsRef<[u8]>) -> Result<Vec<u8>, Error> {
724    match alg {
725        Algorithm::X25519 | Algorithm::Secp256r1 => {
726            libcrux_ecdh::secret_to_public(alg.try_into().unwrap(), sk.as_ref())
727                .map_err(|e| e.into())
728        }
729        _ => Err(Error::UnsupportedAlgorithm),
730    }
731}
732
733fn gen_mlkem768(
734    rng: &mut impl CryptoRng,
735) -> Result<(MlKem768PrivateKey, MlKem768PublicKey), Error> {
736    Ok(mlkem768::generate_key_pair(random_array(rng)?).into_parts())
737}
738
739fn random_array<const L: usize>(rng: &mut impl CryptoRng) -> Result<[u8; L], Error> {
740    let mut seed = [0; L];
741    rng.try_fill_bytes(&mut seed).map_err(|_| Error::KeyGen)?;
742    Ok(seed)
743}
744
745/// Generate a key pair for the [`Algorithm`] using the provided rng.
746///
747/// The function returns a fresh key or a [`Error::KeyGen`] error if
748/// * not enough entropy was available
749/// * it was not possible to generate a valid key within a reasonable amount of iterations.
750pub fn key_gen(alg: Algorithm, rng: &mut impl CryptoRng) -> Result<(PrivateKey, PublicKey), Error> {
751    match alg {
752        Algorithm::X25519 => libcrux_ecdh::x25519_key_gen(rng)
753            .map_err(|e| e.into())
754            .map(|(private, public)| (PrivateKey::X25519(private), PublicKey::X25519(public))),
755        Algorithm::Secp256r1 => libcrux_ecdh::p256_key_gen(rng)
756            .map_err(|e| e.into())
757            .map(|(private, public)| (PrivateKey::P256(private), PublicKey::P256(public))),
758        Algorithm::MlKem512 => {
759            let (sk, pk) = mlkem512::generate_key_pair(random_array(rng)?).into_parts();
760            Ok((PrivateKey::MlKem512(sk), PublicKey::MlKem512(pk)))
761        }
762        Algorithm::MlKem768 => {
763            let (sk, pk) = mlkem768::generate_key_pair(random_array(rng)?).into_parts();
764            Ok((PrivateKey::MlKem768(sk), PublicKey::MlKem768(pk)))
765        }
766        Algorithm::MlKem1024 => {
767            let (sk, pk) = mlkem1024::generate_key_pair(random_array(rng)?).into_parts();
768            Ok((PrivateKey::MlKem1024(sk), PublicKey::MlKem1024(pk)))
769        }
770        Algorithm::X25519MlKem768Draft00 => {
771            let (mlkem_private, mlkem_public) = gen_mlkem768(rng)?;
772            let (x25519_private, x25519_public) = libcrux_ecdh::x25519_key_gen(rng)?;
773
774            Ok((
775                PrivateKey::X25519MlKem768Draft00(X25519MlKem768Draft00PrivateKey {
776                    mlkem: mlkem_private,
777                    x25519: x25519_private,
778                }),
779                PublicKey::X25519MlKem768Draft00(X25519MlKem768Draft00PublicKey {
780                    mlkem: mlkem_public,
781                    x25519: x25519_public,
782                }),
783            ))
784        }
785
786        Algorithm::XWingKemDraft06 => {
787            let mut seed = [0u8; 32];
788            rng.fill_bytes(&mut seed);
789
790            let (kp_m, pk_x, _) = xwing::expand_decap_key(&seed)?;
791
792            Ok((
793                PrivateKey::XWingKemDraft06(XWingKemDraft06PrivateKey { seed }),
794                PublicKey::XWingKemDraft06(XWingKemDraft06PublicKey {
795                    pk_m: kp_m.pk().into(),
796                    // unwrap is ok here because it comes from the secret to pub above
797                    pk_x,
798                }),
799            ))
800        }
801        _ => Err(Error::UnsupportedAlgorithm),
802    }
803}
804
805/// Generate a key pair for the [`Algorithm`] using the provided rng.
806///
807/// The function returns a fresh key or a [`Error::KeyGen`] error if
808/// * the `seed` wasn't long enough
809/// * it was not possible to generate a valid key within a reasonable amount of iterations.
810pub fn key_gen_derand(alg: Algorithm, seed: &[u8]) -> Result<(PrivateKey, PublicKey), Error> {
811    match alg {
812        Algorithm::X25519 => {
813            let sk = X25519PrivateKey::try_from(seed)?;
814            let pk = x25519_secret_to_public(&sk)?;
815            Ok((PrivateKey::X25519(sk), PublicKey::X25519(pk)))
816        }
817
818        Algorithm::Secp256r1 => {
819            let sk = P256PrivateKey::try_from(seed)?;
820            let pk = p256_secret_to_public(&sk)?;
821            Ok((PrivateKey::P256(sk), PublicKey::P256(pk)))
822        }
823
824        Algorithm::MlKem512 => {
825            let (sk, pk) = mlkem512::generate_key_pair(seed.try_into().map_err(|_| Error::KeyGen)?)
826                .into_parts();
827            Ok((PrivateKey::MlKem512(sk), PublicKey::MlKem512(pk)))
828        }
829
830        Algorithm::MlKem768 => {
831            let (sk, pk) = mlkem768::generate_key_pair(seed.try_into().map_err(|_| Error::KeyGen)?)
832                .into_parts();
833            Ok((PrivateKey::MlKem768(sk), PublicKey::MlKem768(pk)))
834        }
835
836        Algorithm::MlKem1024 => {
837            let (sk, pk) =
838                mlkem1024::generate_key_pair(seed.try_into().map_err(|_| Error::KeyGen)?)
839                    .into_parts();
840            Ok((PrivateKey::MlKem1024(sk), PublicKey::MlKem1024(pk)))
841        }
842
843        Algorithm::X25519MlKem768Draft00 => {
844            let (mlkem_private, mlkem_public) =
845                mlkem768::generate_key_pair(seed.try_into().map_err(|_| Error::KeyGen)?)
846                    .into_parts();
847            let x25519_private = X25519PrivateKey::try_from(seed)?;
848            let x25519_public = x25519_secret_to_public(&x25519_private)?;
849
850            Ok((
851                PrivateKey::X25519MlKem768Draft00(X25519MlKem768Draft00PrivateKey {
852                    mlkem: mlkem_private,
853                    x25519: x25519_private,
854                }),
855                PublicKey::X25519MlKem768Draft00(X25519MlKem768Draft00PublicKey {
856                    mlkem: mlkem_public,
857                    x25519: x25519_public,
858                }),
859            ))
860        }
861
862        Algorithm::XWingKemDraft06 => {
863            let seed: [u8; 32] = seed.try_into().map_err(|_| Error::KeyGen)?;
864            let (kp_m, pk_x, _) = xwing::expand_decap_key(&seed)?;
865
866            Ok((
867                PrivateKey::XWingKemDraft06(XWingKemDraft06PrivateKey { seed }),
868                PublicKey::XWingKemDraft06(XWingKemDraft06PublicKey {
869                    pk_m: kp_m.pk().into(),
870                    // unwrap is ok here because it comes from the secret to pub above
871                    pk_x,
872                }),
873            ))
874        }
875
876        _ => Err(Error::UnsupportedAlgorithm),
877    }
878}
879
880/// The XWing KEM combiner with ML-KEM 768 and x25519
881mod xwing {
882    use libcrux_ecdh::X25519PrivateKey;
883
884    use super::*;
885
886    pub struct XWingSharedSecret {
887        pub(super) value: [u8; 32],
888    }
889
890    /// Expand the `seed` to the ML-KEM and x25519 key pairs.
891    pub(super) fn expand_decap_key(
892        seed: &[u8; 32],
893    ) -> Result<(MlKemKeyPair<2400, 1184>, X25519PublicKey, X25519PrivateKey), Error> {
894        let expanded: [u8; 96] = libcrux_sha3::shake256(seed);
895        let kp_m =
896            mlkem768::generate_key_pair(expanded[..64].try_into().map_err(|_| Error::KeyGen)?);
897        let mut sk_x = [0u8; 32];
898        sk_x.copy_from_slice(&expanded[64..]);
899        let sk_x = X25519PrivateKey::from(&sk_x); // clamps
900        let pk_x = x25519_secret_to_public(&sk_x)?;
901        Ok((kp_m, pk_x, sk_x))
902    }
903
904    pub(super) fn combiner(
905        ss_m: &[u8],
906        ss_x: &[u8],
907        ct_x: &[u8],
908        pk_x: &[u8],
909    ) -> XWingSharedSecret {
910        // label:
911        // \./
912        // /^\
913        // 5c2e2f2f5e5c
914        let mut input = ss_m.to_vec();
915        input.extend_from_slice(ss_x);
916        input.extend_from_slice(ct_x);
917        input.extend_from_slice(pk_x);
918        input.extend_from_slice(&[0x5c, 0x2e, 0x2f, 0x2f, 0x5e, 0x5c]);
919        XWingSharedSecret {
920            value: sha3::sha256(&input),
921        }
922    }
923
924    impl<'a> TryFrom<&'a [u8]> for XWingSharedSecret {
925        type Error = <[u8; 32] as TryFrom<&'a [u8]>>::Error;
926
927        fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> {
928            Ok(Self {
929                value: value.try_into()?,
930            })
931        }
932    }
933}
934
935fn mlkem_rand(rng: &mut impl CryptoRng) -> Result<[u8; libcrux_ml_kem::SHARED_SECRET_SIZE], Error> {
936    let mut seed = [0; libcrux_ml_kem::SHARED_SECRET_SIZE];
937    rng.try_fill_bytes(&mut seed).map_err(|_| Error::KeyGen)?;
938    Ok(seed)
939}
940
941impl TryInto<libcrux_ecdh::X25519PublicKey> for PublicKey {
942    type Error = libcrux_ecdh::Error;
943
944    fn try_into(self) -> Result<libcrux_ecdh::X25519PublicKey, libcrux_ecdh::Error> {
945        if let PublicKey::X25519(k) = self {
946            Ok(k)
947        } else {
948            Err(libcrux_ecdh::Error::InvalidPoint)
949        }
950    }
951}
952
953impl TryInto<libcrux_ecdh::X25519PrivateKey> for PrivateKey {
954    type Error = libcrux_ecdh::Error;
955
956    fn try_into(self) -> Result<libcrux_ecdh::X25519PrivateKey, libcrux_ecdh::Error> {
957        if let PrivateKey::X25519(k) = self {
958            Ok(k)
959        } else {
960            Err(libcrux_ecdh::Error::InvalidPoint)
961        }
962    }
963}