1#![no_std]
38
39extern crate alloc;
40
41use alloc::{vec, vec::Vec};
42
43use rand::{CryptoRng, Rng, TryRngCore};
44
45use libcrux_ecdh::{p256_derive, x25519_derive};
46use libcrux_ecdh::{
47    P256PrivateKey, P256PublicKey, P256SharedSecret, X25519PrivateKey, X25519PublicKey,
48    X25519SharedSecret,
49};
50use libcrux_sha3 as sha3;
51
52use libcrux_ml_kem::{mlkem1024, mlkem512, mlkem768};
53
54#[cfg(feature = "kyber")]
55use libcrux_ml_kem::kyber768;
56
57#[cfg(feature = "tests")]
63pub mod deterministic {
64    pub use libcrux_ml_kem::mlkem1024::decapsulate as mlkem1024_decapsulate_derand;
65    pub use libcrux_ml_kem::mlkem1024::encapsulate as mlkem1024_encapsulate_derand;
66    pub use libcrux_ml_kem::mlkem1024::generate_key_pair as mlkem1024_generate_keypair_derand;
67    pub use libcrux_ml_kem::mlkem512::decapsulate as mlkem512_decapsulate_derand;
68    pub use libcrux_ml_kem::mlkem512::encapsulate as mlkem512_encapsulate_derand;
69    pub use libcrux_ml_kem::mlkem512::generate_key_pair as mlkem512_generate_keypair_derand;
70    pub use libcrux_ml_kem::mlkem768::decapsulate as mlkem768_decapsulate_derand;
71    pub use libcrux_ml_kem::mlkem768::encapsulate as mlkem768_encapsulate_derand;
72    pub use libcrux_ml_kem::mlkem768::generate_key_pair as mlkem768_generate_keypair_derand;
73}
74
75use libcrux_ml_kem::MlKemSharedSecret;
76pub use libcrux_ml_kem::{
77    mlkem1024::{MlKem1024Ciphertext, MlKem1024PrivateKey, MlKem1024PublicKey},
78    mlkem512::{MlKem512Ciphertext, MlKem512PrivateKey, MlKem512PublicKey},
79    mlkem768::{MlKem768Ciphertext, MlKem768PrivateKey, MlKem768PublicKey},
80    MlKemCiphertext, MlKemKeyPair,
81};
82
83#[cfg(feature = "tests")]
84pub use libcrux_ml_kem::{
85    mlkem1024::validate_public_key as ml_kem1024_validate_public_key,
86    mlkem512::validate_public_key as ml_kem512_validate_public_key,
87    mlkem768::validate_public_key as ml_kem768_validate_public_key,
88};
89
90#[derive(Clone, Copy, PartialEq, Debug)]
94pub enum Algorithm {
95    X25519,
96    X448,
97    Secp256r1,
98    Secp384r1,
99    Secp521r1,
100    MlKem512,
101    MlKem768,
102    X25519MlKem768Draft00,
103    XWingKemDraft02,
104    #[cfg(feature = "kyber")]
105    X25519Kyber768Draft00,
106    #[cfg(feature = "kyber")]
107    XWingKyberDraft02,
108    MlKem1024,
109}
110
111#[derive(Debug, PartialEq, Eq)]
112pub enum Error {
113    EcDhError(libcrux_ecdh::Error),
114    KeyGen,
115    Encapsulate,
116    Decapsulate,
117    UnsupportedAlgorithm,
118    InvalidPrivateKey,
119    InvalidPublicKey,
120    InvalidCiphertext,
121}
122
123impl TryFrom<Algorithm> for libcrux_ecdh::Algorithm {
124    type Error = &'static str;
125
126    fn try_from(value: Algorithm) -> Result<Self, Self::Error> {
127        match value {
128            Algorithm::X25519 => Ok(libcrux_ecdh::Algorithm::X25519),
129            Algorithm::X448 => Ok(libcrux_ecdh::Algorithm::X448),
130            Algorithm::Secp256r1 => Ok(libcrux_ecdh::Algorithm::P256),
131            Algorithm::Secp384r1 => Ok(libcrux_ecdh::Algorithm::P384),
132            Algorithm::Secp521r1 => Ok(libcrux_ecdh::Algorithm::P521),
133            Algorithm::X25519MlKem768Draft00 => Ok(libcrux_ecdh::Algorithm::X25519),
134            Algorithm::XWingKemDraft02 => Ok(libcrux_ecdh::Algorithm::X25519),
135            #[cfg(feature = "kyber")]
136            Algorithm::XWingKyberDraft02 | Algorithm::X25519Kyber768Draft00 => {
137                Ok(libcrux_ecdh::Algorithm::X25519)
138            }
139            _ => Err("provided algorithm is not an ECDH algorithm"),
140        }
141    }
142}
143
144impl From<libcrux_ecdh::Error> for Error {
145    fn from(value: libcrux_ecdh::Error) -> Self {
146        Error::EcDhError(value)
147    }
148}
149
150pub struct X25519MlKem768Draft00PrivateKey {
152    pub mlkem: MlKem768PrivateKey,
153    pub x25519: X25519PrivateKey,
154}
155
156impl X25519MlKem768Draft00PrivateKey {
157    pub fn decode(bytes: &[u8]) -> Result<Self, Error> {
158        Ok(Self {
159            mlkem: bytes[..2400]
160                .try_into()
161                .map_err(|_| Error::InvalidPrivateKey)?,
162            x25519: bytes[2400..]
163                .try_into()
164                .map_err(|_| Error::InvalidPrivateKey)?,
165        })
166    }
167
168    pub fn encode(&self) -> Vec<u8> {
169        let mut out = self.mlkem.as_ref().to_vec();
170        out.extend_from_slice(&self.x25519.0);
171        out
172    }
173}
174
175pub struct XWingKemDraft02PrivateKey {
177    pub sk_m: MlKem768PrivateKey,
178    pub sk_x: X25519PrivateKey,
179    pub pk_x: X25519PublicKey,
180}
181
182impl XWingKemDraft02PrivateKey {
183    pub fn decode(bytes: &[u8]) -> Result<Self, Error> {
184        Ok(Self {
185            sk_m: bytes[..2400]
186                .try_into()
187                .map_err(|_| Error::InvalidPrivateKey)?,
188            sk_x: bytes[2400..2432]
189                .try_into()
190                .map_err(|_| Error::InvalidPrivateKey)?,
191            pk_x: bytes[2432..2464]
192                .try_into()
193                .map_err(|_| Error::InvalidPrivateKey)?,
194        })
195    }
196
197    pub fn encode(&self) -> Vec<u8> {
198        let mut out = self.sk_m.as_ref().to_vec();
199        out.extend_from_slice(self.sk_x.0.to_vec().as_ref());
200        out.extend_from_slice(self.pk_x.0.to_vec().as_ref());
201        out
202    }
203}
204
205pub enum PrivateKey {
207    X25519(X25519PrivateKey),
208    P256(P256PrivateKey),
209    MlKem512(MlKem512PrivateKey),
210    MlKem768(MlKem768PrivateKey),
211    X25519MlKem768Draft00(X25519MlKem768Draft00PrivateKey),
212    XWingKemDraft02(XWingKemDraft02PrivateKey),
213    #[cfg(feature = "kyber")]
214    X25519Kyber768Draft00(X25519MlKem768Draft00PrivateKey),
215    #[cfg(feature = "kyber")]
216    XWingKyberDraft02(XWingKemDraft02PrivateKey),
217    MlKem1024(MlKem1024PrivateKey),
218}
219
220pub struct X25519MlKem768Draft00PublicKey {
222    pub mlkem: MlKem768PublicKey,
223    pub x25519: X25519PublicKey,
224}
225
226impl X25519MlKem768Draft00PublicKey {
227    pub fn decode(bytes: &[u8]) -> Result<Self, Error> {
228        Ok(Self {
229            mlkem: {
230                let key = MlKem768PublicKey::try_from(&bytes[..1184])
231                    .map_err(|_| Error::InvalidPublicKey)?;
232                if !mlkem768::validate_public_key(&key) {
233                    return Err(Error::InvalidPublicKey);
234                }
235                key
236            },
237            x25519: bytes[1184..]
238                .try_into()
239                .map_err(|_| Error::InvalidPublicKey)?,
240        })
241    }
242
243    pub fn encode(&self) -> Vec<u8> {
244        let mut out = self.mlkem.as_ref().to_vec();
245        out.extend_from_slice(&self.x25519.0);
246        out
247    }
248}
249
250pub struct XWingKemDraft02PublicKey {
252    pub pk_m: MlKem768PublicKey,
253    pub pk_x: X25519PublicKey,
254}
255
256impl XWingKemDraft02PublicKey {
257    pub fn decode(bytes: &[u8]) -> Result<Self, Error> {
258        Ok(Self {
259            pk_m: {
260                let key = MlKem768PublicKey::try_from(&bytes[0..1184])
261                    .map_err(|_| Error::InvalidPublicKey)?;
262                if !mlkem768::validate_public_key(&key) {
263                    return Err(Error::InvalidPublicKey);
264                }
265                key
266            },
267            pk_x: bytes[1184..]
268                .try_into()
269                .map_err(|_| Error::InvalidPublicKey)?,
270        })
271    }
272
273    pub fn encode(&self) -> Vec<u8> {
274        let mut out = self.pk_m.as_ref().to_vec();
275        out.extend_from_slice(self.pk_x.0.to_vec().as_ref());
276        out
277    }
278}
279
280pub enum PublicKey {
282    X25519(X25519PublicKey),
283    P256(P256PublicKey),
284    MlKem512(MlKem512PublicKey),
285    MlKem768(MlKem768PublicKey),
286    X25519MlKem768Draft00(X25519MlKem768Draft00PublicKey),
287    XWingKemDraft02(XWingKemDraft02PublicKey),
288    #[cfg(feature = "kyber")]
289    X25519Kyber768Draft00(X25519MlKem768Draft00PublicKey),
290    #[cfg(feature = "kyber")]
291    XWingKyberDraft02(XWingKemDraft02PublicKey),
292    MlKem1024(MlKem1024PublicKey),
293}
294
295pub enum Ct {
297    X25519(X25519PublicKey),
298    P256(P256PublicKey),
299    MlKem512(MlKem512Ciphertext),
300    MlKem768(MlKem768Ciphertext),
301    X25519MlKem768Draft00(MlKem768Ciphertext, X25519PublicKey),
302    XWingKemDraft02(MlKem768Ciphertext, X25519PublicKey),
303    #[cfg(feature = "kyber")]
304    X25519Kyber768Draft00(MlKem768Ciphertext, X25519PublicKey),
305    #[cfg(feature = "kyber")]
306    XWingKyberDraft02(MlKem768Ciphertext, X25519PublicKey),
307    MlKem1024(MlKem1024Ciphertext),
308}
309
310impl Ct {
311    pub fn decapsulate(&self, sk: &PrivateKey) -> Result<Ss, Error> {
313        match self {
314            Ct::X25519(ct) => {
315                let sk = if let PrivateKey::X25519(k) = sk {
316                    k
317                } else {
318                    return Err(Error::InvalidPrivateKey);
319                };
320                x25519_derive(ct, sk).map_err(|e| e.into()).map(Ss::X25519)
321            }
322            Ct::P256(ct) => {
323                let sk = if let PrivateKey::P256(k) = sk {
324                    k
325                } else {
326                    return Err(Error::InvalidPrivateKey);
327                };
328                p256_derive(ct, sk).map_err(|e| e.into()).map(Ss::P256)
329            }
330            Ct::MlKem512(ct) => {
331                let sk = if let PrivateKey::MlKem512(k) = sk {
332                    k
333                } else {
334                    return Err(Error::InvalidPrivateKey);
335                };
336                let ss = libcrux_ml_kem::mlkem512::decapsulate(sk, ct);
337
338                Ok(Ss::MlKem768(ss))
339            }
340            Ct::MlKem768(ct) => {
341                let sk = if let PrivateKey::MlKem768(k) = sk {
342                    k
343                } else {
344                    return Err(Error::InvalidPrivateKey);
345                };
346                let ss = mlkem768::decapsulate(sk, ct);
347
348                Ok(Ss::MlKem768(ss))
349            }
350            Ct::X25519MlKem768Draft00(kct, xct) => {
351                let (ksk, xsk) =
352                    if let PrivateKey::X25519MlKem768Draft00(X25519MlKem768Draft00PrivateKey {
353                        mlkem: kk,
354                        x25519: xk,
355                    }) = sk
356                    {
357                        (kk, xk)
358                    } else {
359                        return Err(Error::InvalidPrivateKey);
360                    };
361                let kss = mlkem768::decapsulate(ksk, kct);
362                let xss = x25519_derive(xct, xsk)?;
363
364                Ok(Ss::X25519MlKem768Draft00(kss, xss))
365            }
366            Ct::XWingKemDraft02(ct_m, ct_x) => {
367                let (sk_m, sk_x, pk_x) =
368                    if let PrivateKey::XWingKemDraft02(XWingKemDraft02PrivateKey {
369                        sk_m,
370                        sk_x,
371                        pk_x,
372                    }) = sk
373                    {
374                        (sk_m, sk_x, pk_x)
375                    } else {
376                        return Err(Error::InvalidPrivateKey);
377                    };
378                let ss_m = mlkem768::decapsulate(sk_m, ct_m);
379                let ss_x = x25519_derive(ct_x, sk_x)?;
380
381                Ok(Ss::XWingKemDraft02(
382                    ss_m,
383                    ss_x,
384                    X25519PublicKey(ct_x.0),
385                    X25519PublicKey(pk_x.0),
386                ))
387            }
388            Ct::MlKem1024(ct) => {
389                let sk = if let PrivateKey::MlKem1024(k) = sk {
390                    k
391                } else {
392                    return Err(Error::InvalidPrivateKey);
393                };
394                let ss = libcrux_ml_kem::mlkem1024::decapsulate(sk, ct);
395
396                Ok(Ss::MlKem1024(ss))
397            }
398            #[cfg(feature = "kyber")]
399            Ct::X25519Kyber768Draft00(kct, xct) => {
400                let (ksk, xsk) =
401                    if let PrivateKey::X25519Kyber768Draft00(X25519MlKem768Draft00PrivateKey {
402                        mlkem: kk,
403                        x25519: xk,
404                    }) = sk
405                    {
406                        (kk, xk)
407                    } else {
408                        return Err(Error::InvalidPrivateKey);
409                    };
410                let kss = kyber768::decapsulate(ksk, kct);
411                let xss = x25519_derive(xct, xsk)?;
412
413                Ok(Ss::X25519Kyber768Draft00(kss, xss))
414            }
415            #[cfg(feature = "kyber")]
416            Ct::XWingKyberDraft02(ct_m, ct_x) => {
417                let (sk_m, sk_x, pk_x) =
418                    if let PrivateKey::XWingKyberDraft02(XWingKemDraft02PrivateKey {
419                        sk_m,
420                        sk_x,
421                        pk_x,
422                    }) = sk
423                    {
424                        (sk_m, sk_x, pk_x)
425                    } else {
426                        return Err(Error::InvalidPrivateKey);
427                    };
428                let ss_m = kyber768::decapsulate(sk_m, ct_m);
429                let ss_x = x25519_derive(ct_x, sk_x)?;
430
431                Ok(Ss::XWingKyberDraft02(
432                    ss_m,
433                    ss_x,
434                    X25519PublicKey(ct_x.0.clone()),
435                    X25519PublicKey(pk_x.0.clone()),
436                ))
437            }
438        }
439    }
440}
441
442pub enum Ss {
444    X25519(X25519SharedSecret),
445    P256(P256SharedSecret),
446    MlKem512(MlKemSharedSecret),
447    MlKem768(MlKemSharedSecret),
448    X25519MlKem768Draft00(MlKemSharedSecret, X25519SharedSecret),
449    XWingKemDraft02(
450        MlKemSharedSecret,  X25519SharedSecret, X25519PublicKey,    X25519PublicKey,    ),
455    #[cfg(feature = "kyber")]
456    X25519Kyber768Draft00(MlKemSharedSecret, X25519SharedSecret),
457    #[cfg(feature = "kyber")]
458    XWingKyberDraft02(
459        MlKemSharedSecret,  X25519SharedSecret, X25519PublicKey,    X25519PublicKey,    ),
464    MlKem1024(MlKemSharedSecret),
465}
466
467impl PrivateKey {
468    pub fn encode(&self) -> Vec<u8> {
470        match self {
471            PrivateKey::X25519(k) => k.0.to_vec(),
472            PrivateKey::P256(k) => k.0.to_vec(),
473            PrivateKey::MlKem512(k) => k.as_slice().to_vec(),
474            PrivateKey::MlKem768(k) => k.as_slice().to_vec(),
475            PrivateKey::X25519MlKem768Draft00(k) => k.encode(),
476            PrivateKey::XWingKemDraft02(k) => k.encode(),
477            PrivateKey::MlKem1024(k) => k.as_slice().to_vec(),
478            #[cfg(feature = "kyber")]
479            PrivateKey::X25519Kyber768Draft00(k) => k.encode(),
480            #[cfg(feature = "kyber")]
481            PrivateKey::XWingKyberDraft02(k) => k.encode(),
482        }
483    }
484
485    pub fn decode(alg: Algorithm, bytes: &[u8]) -> Result<Self, Error> {
487        match alg {
488            Algorithm::X25519 => bytes
489                .try_into()
490                .map_err(|_| Error::InvalidPrivateKey)
491                .map(Self::X25519),
492            Algorithm::Secp256r1 => bytes
493                .try_into()
494                .map_err(|_| Error::InvalidPrivateKey)
495                .map(Self::P256),
496            Algorithm::MlKem512 => bytes
497                .try_into()
498                .map_err(|_| Error::InvalidPrivateKey)
499                .map(Self::MlKem512),
500            Algorithm::MlKem768 => bytes
501                .try_into()
502                .map_err(|_| Error::InvalidPrivateKey)
503                .map(Self::MlKem768),
504            Algorithm::X25519MlKem768Draft00 => X25519MlKem768Draft00PrivateKey::decode(bytes)
505                .map_err(|_| Error::InvalidPrivateKey)
506                .map(Self::X25519MlKem768Draft00),
507            Algorithm::XWingKemDraft02 => {
508                let pk = XWingKemDraft02PrivateKey::decode(bytes)
509                    .map_err(|_| Error::InvalidPrivateKey)?;
510                Ok(Self::XWingKemDraft02(pk))
511            }
512            #[cfg(feature = "kyber")]
513            Algorithm::X25519Kyber768Draft00 => X25519MlKem768Draft00PrivateKey::decode(bytes)
514                .map_err(|_| Error::InvalidPrivateKey)
515                .map(Self::X25519Kyber768Draft00),
516            #[cfg(feature = "kyber")]
517            Algorithm::XWingKyberDraft02 => {
518                let pk = XWingKemDraft02PrivateKey::decode(bytes)
519                    .map_err(|_| Error::InvalidPrivateKey)?;
520                Ok(Self::XWingKyberDraft02(pk))
521            }
522            Algorithm::MlKem1024 => bytes
523                .try_into()
524                .map_err(|_| Error::InvalidPrivateKey)
525                .map(Self::MlKem1024),
526            _ => Err(Error::UnsupportedAlgorithm),
527        }
528    }
529}
530
531impl PublicKey {
532    pub fn encapsulate(&self, rng: &mut (impl CryptoRng + Rng)) -> Result<(Ss, Ct), Error> {
534        match self {
535            PublicKey::X25519(pk) => {
536                let (new_sk, new_pk) = libcrux_ecdh::x25519_key_gen(rng)?;
537                let gxy = x25519_derive(pk, &new_sk)?;
538                Ok((Ss::X25519(gxy), Ct::X25519(new_pk)))
539            }
540            PublicKey::P256(pk) => {
541                let (new_sk, new_pk) = libcrux_ecdh::p256_key_gen(rng)?;
542                let gxy = p256_derive(pk, &new_sk)?;
543                Ok((Ss::P256(gxy), Ct::P256(new_pk)))
544            }
545
546            PublicKey::MlKem512(pk) => {
547                let seed = mlkem_rand(rng)?;
548                let (ct, ss) = libcrux_ml_kem::mlkem512::encapsulate(pk, seed);
549                Ok((Ss::MlKem512(ss), Ct::MlKem512(ct)))
550            }
551
552            PublicKey::MlKem768(pk) => {
553                let seed = mlkem_rand(rng)?;
554                let (ct, ss) = mlkem768::encapsulate(pk, seed);
555                Ok((Ss::MlKem768(ss), Ct::MlKem768(ct)))
556            }
557
558            PublicKey::MlKem1024(pk) => {
559                let seed = mlkem_rand(rng)?;
560                let (ct, ss) = mlkem1024::encapsulate(pk, seed);
561                Ok((Ss::MlKem1024(ss), Ct::MlKem1024(ct)))
562            }
563
564            PublicKey::X25519MlKem768Draft00(X25519MlKem768Draft00PublicKey {
565                mlkem: kpk,
566                x25519: xpk,
567            }) => {
568                let seed = mlkem_rand(rng)?;
569                let (mlkem_ct, mlkem_ss) = mlkem768::encapsulate(kpk, seed);
570                let (x_sk, x_pk) = libcrux_ecdh::x25519_key_gen(rng)?;
571                let x_ss = x25519_derive(xpk, &x_sk)?;
572
573                Ok((
574                    Ss::X25519MlKem768Draft00(mlkem_ss, x_ss),
575                    Ct::X25519MlKem768Draft00(mlkem_ct, x_pk),
576                ))
577            }
578
579            PublicKey::XWingKemDraft02(XWingKemDraft02PublicKey { pk_m, pk_x }) => {
580                let seed = mlkem_rand(rng)?;
581                let (ct_m, ss_m) = mlkem768::encapsulate(pk_m, seed);
582                let (ek_x, ct_x) = libcrux_ecdh::x25519_key_gen(rng)?;
583                let ss_x = x25519_derive(pk_x, &ek_x)?;
584
585                Ok((
586                    Ss::XWingKemDraft02(
587                        ss_m,
588                        ss_x,
589                        X25519PublicKey(ct_x.0.clone()),
590                        X25519PublicKey(pk_x.0.clone()),
591                    ),
592                    Ct::XWingKemDraft02(ct_m, X25519PublicKey(ct_x.0.clone())),
593                ))
594            }
595
596            #[cfg(feature = "kyber")]
597            PublicKey::X25519Kyber768Draft00(X25519MlKem768Draft00PublicKey {
598                mlkem: kpk,
599                x25519: xpk,
600            }) => {
601                let seed = mlkem_rand(rng)?;
602                let (mlkem_ct, mlkem_ss) = kyber768::encapsulate(kpk, seed);
603                let (x_sk, x_pk) = libcrux_ecdh::x25519_key_gen(rng)?;
604                let x_ss = x25519_derive(xpk, &x_sk)?;
605
606                Ok((
607                    Ss::X25519Kyber768Draft00(mlkem_ss, x_ss),
608                    Ct::X25519Kyber768Draft00(mlkem_ct, x_pk),
609                ))
610            }
611
612            #[cfg(feature = "kyber")]
613            PublicKey::XWingKyberDraft02(XWingKemDraft02PublicKey { pk_m, pk_x }) => {
614                let seed = mlkem_rand(rng)?;
615                let (ct_m, ss_m) = kyber768::encapsulate(pk_m, seed);
616                let (ek_x, ct_x) = libcrux_ecdh::x25519_key_gen(rng)?;
617                let ss_x = x25519_derive(pk_x, &ek_x)?;
618
619                Ok((
620                    Ss::XWingKyberDraft02(
621                        ss_m,
622                        ss_x,
623                        X25519PublicKey(ct_x.0.clone()),
624                        X25519PublicKey(pk_x.0.clone()),
625                    ),
626                    Ct::XWingKyberDraft02(ct_m, X25519PublicKey(ct_x.0.clone())),
627                ))
628            }
629        }
630    }
631
632    pub fn encode(&self) -> Vec<u8> {
634        match self {
635            PublicKey::X25519(k) => k.0.to_vec(),
636            PublicKey::P256(k) => k.0.to_vec(),
637            PublicKey::MlKem512(k) => k.as_ref().to_vec(),
638            PublicKey::MlKem768(k) => k.as_ref().to_vec(),
639            PublicKey::X25519MlKem768Draft00(k) => k.encode(),
640            PublicKey::XWingKemDraft02(k) => k.encode(),
641            PublicKey::MlKem1024(k) => k.as_ref().to_vec(),
642            #[cfg(feature = "kyber")]
643            PublicKey::X25519Kyber768Draft00(k) => k.encode(),
644            #[cfg(feature = "kyber")]
645            PublicKey::XWingKyberDraft02(k) => k.encode(),
646        }
647    }
648
649    pub fn decode(alg: Algorithm, bytes: &[u8]) -> Result<Self, Error> {
651        match alg {
652            Algorithm::X25519 => bytes
653                .try_into()
654                .map(Self::X25519)
655                .map_err(|_| Error::InvalidPublicKey),
656            Algorithm::Secp256r1 => bytes
657                .try_into()
658                .map(Self::P256)
659                .map_err(|_| Error::InvalidPublicKey),
660            Algorithm::MlKem512 => {
661                let key =
662                    MlKem512PublicKey::try_from(bytes).map_err(|_| Error::InvalidPublicKey)?;
663                if !mlkem512::validate_public_key(&key) {
664                    return Err(Error::InvalidPublicKey);
665                }
666                Ok(Self::MlKem512(key))
667            }
668            Algorithm::MlKem768 => {
669                let key =
670                    MlKem768PublicKey::try_from(bytes).map_err(|_| Error::InvalidPublicKey)?;
671                if !mlkem768::validate_public_key(&key) {
672                    return Err(Error::InvalidPublicKey);
673                }
674                Ok(Self::MlKem768(key))
675            }
676            Algorithm::X25519MlKem768Draft00 => {
677                X25519MlKem768Draft00PublicKey::decode(bytes).map(Self::X25519MlKem768Draft00)
678            }
679            Algorithm::XWingKemDraft02 => {
680                XWingKemDraft02PublicKey::decode(bytes).map(Self::XWingKemDraft02)
681            }
682            #[cfg(feature = "kyber")]
683            Algorithm::X25519Kyber768Draft00 => {
684                X25519MlKem768Draft00PublicKey::decode(bytes).map(Self::X25519Kyber768Draft00)
685            }
686            #[cfg(feature = "kyber")]
687            Algorithm::XWingKyberDraft02 => {
688                XWingKemDraft02PublicKey::decode(bytes).map(Self::XWingKyberDraft02)
689            }
690            Algorithm::MlKem1024 => {
691                let key =
692                    MlKem1024PublicKey::try_from(bytes).map_err(|_| Error::InvalidPublicKey)?;
693                if !mlkem1024::validate_public_key(&key) {
694                    return Err(Error::InvalidPublicKey);
695                }
696                Ok(Self::MlKem1024(key))
697            }
698            _ => Err(Error::UnsupportedAlgorithm),
699        }
700    }
701}
702
703impl Ss {
704    pub fn encode(&self) -> Vec<u8> {
706        match self {
707            Ss::X25519(k) => k.0.to_vec(),
708            Ss::P256(k) => k.0.to_vec(),
709            Ss::MlKem512(k) => k.as_ref().to_vec(),
710            Ss::MlKem768(k) => k.as_ref().to_vec(),
711            Ss::X25519MlKem768Draft00(kk, xk) => {
712                let mut out = kk.to_vec();
713                out.extend_from_slice(xk.0.as_ref());
714                out
715            }
716            Ss::XWingKemDraft02(ss_m, ss_x, ct_x, pk_x) => {
717                let mut input = vec![0x5c, 0x2e, 0x2f, 0x2f, 0x5e, 0x5c];
721                input.extend_from_slice(ss_m.as_ref());
722                input.extend_from_slice(ss_x.as_ref());
723                input.extend_from_slice(ct_x.0.as_ref());
724                input.extend_from_slice(pk_x.0.as_ref());
725                sha3::sha256(&input).to_vec()
726            }
727            #[cfg(feature = "kyber")]
728            Ss::X25519Kyber768Draft00(kk, xk) => {
729                let mut out = xk.0.to_vec();
730                out.extend_from_slice(kk.as_ref());
731                out
732            }
733            #[cfg(feature = "kyber")]
734            Ss::XWingKyberDraft02(ss_m, ss_x, ct_x, pk_x) => {
735                let mut input = vec![0x5c, 0x2e, 0x2f, 0x2f, 0x5e, 0x5c];
739                input.extend_from_slice(ss_m.as_ref());
740                input.extend_from_slice(ss_x.as_ref());
741                input.extend_from_slice(ct_x.0.as_ref());
742                input.extend_from_slice(pk_x.0.as_ref());
743                sha3::sha256(&input).to_vec()
744            }
745            Ss::MlKem1024(k) => k.as_ref().to_vec(),
746        }
747    }
748}
749
750impl Ct {
751    pub fn encode(&self) -> Vec<u8> {
753        match self {
754            Ct::X25519(k) => k.0.to_vec(),
755            Ct::P256(k) => k.0.to_vec(),
756            Ct::MlKem512(k) => k.as_ref().to_vec(),
757            Ct::MlKem768(k) => k.as_ref().to_vec(),
758            Ct::X25519MlKem768Draft00(kk, xk) => {
759                let mut out = kk.as_ref().to_vec();
760                out.extend_from_slice(xk.0.as_ref());
761                out
762            }
763            Ct::XWingKemDraft02(ct_m, ct_x) => {
764                let mut out = ct_m.as_ref().to_vec();
765                out.extend_from_slice(ct_x.as_ref());
766                out
767            }
768            #[cfg(feature = "kyber")]
769            Ct::X25519Kyber768Draft00(kk, xk) => {
770                let mut out = xk.0.to_vec();
771                out.extend_from_slice(kk.as_ref());
772                out
773            }
774            #[cfg(feature = "kyber")]
775            Ct::XWingKyberDraft02(ct_m, ct_x) => {
776                let mut out = ct_m.as_ref().to_vec();
777                out.extend_from_slice(ct_x.as_ref());
778                out
779            }
780            Ct::MlKem1024(k) => k.as_ref().to_vec(),
781        }
782    }
783
784    pub fn decode(alg: Algorithm, bytes: &[u8]) -> Result<Self, Error> {
786        match alg {
787            Algorithm::X25519 => bytes
788                .try_into()
789                .map_err(|_| Error::InvalidCiphertext)
790                .map(Self::X25519),
791            Algorithm::Secp256r1 => bytes
792                .try_into()
793                .map_err(|_| Error::InvalidCiphertext)
794                .map(Self::P256),
795            Algorithm::MlKem512 => bytes
796                .try_into()
797                .map_err(|_| Error::InvalidCiphertext)
798                .map(Self::MlKem512),
799            Algorithm::MlKem768 => bytes
800                .try_into()
801                .map_err(|_| Error::InvalidCiphertext)
802                .map(Self::MlKem768),
803            Algorithm::X25519MlKem768Draft00 => {
804                let key: [u8; MlKem768Ciphertext::len() + 32] =
805                    bytes.try_into().map_err(|_| Error::InvalidCiphertext)?;
806                let (kct, xct) = key.split_at(1088);
807                Ok(Self::X25519MlKem768Draft00(
808                    kct.try_into().map_err(|_| Error::InvalidCiphertext)?,
809                    xct.try_into().map_err(|_| Error::InvalidCiphertext)?,
810                ))
811            }
812            Algorithm::XWingKemDraft02 => {
813                let key: [u8; MlKem768Ciphertext::len() + 32] =
814                    bytes.try_into().map_err(|_| Error::InvalidCiphertext)?;
815                let (ct_m, ct_x) = key.split_at(MlKem768Ciphertext::len());
816                Ok(Self::XWingKemDraft02(
817                    ct_m.try_into().map_err(|_| Error::InvalidCiphertext)?,
818                    ct_x.try_into().map_err(|_| Error::InvalidCiphertext)?,
819                ))
820            }
821            #[cfg(feature = "kyber")]
822            Algorithm::X25519Kyber768Draft00 => {
823                let key: [u8; MlKem768Ciphertext::len() + 32] =
824                    bytes.try_into().map_err(|_| Error::InvalidCiphertext)?;
825                let (xct, kct) = key.split_at(32);
826                Ok(Self::X25519Kyber768Draft00(
827                    kct.try_into().map_err(|_| Error::InvalidCiphertext)?,
828                    xct.try_into().map_err(|_| Error::InvalidCiphertext)?,
829                ))
830            }
831            #[cfg(feature = "kyber")]
832            Algorithm::XWingKyberDraft02 => {
833                let key: [u8; MlKem768Ciphertext::len() + 32] =
834                    bytes.try_into().map_err(|_| Error::InvalidCiphertext)?;
835                let (ct_m, ct_x) = key.split_at(MlKem768Ciphertext::len());
836                Ok(Self::XWingKyberDraft02(
837                    ct_m.try_into().map_err(|_| Error::InvalidCiphertext)?,
838                    ct_x.try_into().map_err(|_| Error::InvalidCiphertext)?,
839                ))
840            }
841            Algorithm::MlKem1024 => bytes
842                .try_into()
843                .map_err(|_| Error::InvalidCiphertext)
844                .map(Self::MlKem1024),
845            _ => Err(Error::UnsupportedAlgorithm),
846        }
847    }
848}
849
850pub fn secret_to_public(alg: Algorithm, sk: impl AsRef<[u8]>) -> Result<Vec<u8>, Error> {
853    match alg {
854        Algorithm::X25519 | Algorithm::Secp256r1 => {
855            libcrux_ecdh::secret_to_public(alg.try_into().unwrap(), sk.as_ref())
856                .map_err(|e| e.into())
857        }
858        _ => Err(Error::UnsupportedAlgorithm),
859    }
860}
861
862fn gen_mlkem768(
863    rng: &mut (impl CryptoRng + Rng),
864) -> Result<(MlKem768PrivateKey, MlKem768PublicKey), Error> {
865    Ok(mlkem768::generate_key_pair(random_array(rng)?).into_parts())
866}
867
868fn random_array<const L: usize>(rng: &mut (impl CryptoRng + Rng)) -> Result<[u8; L], Error> {
869    let mut seed = [0; L];
870    rng.try_fill_bytes(&mut seed).map_err(|_| Error::KeyGen)?;
871    Ok(seed)
872}
873
874pub fn key_gen(
880    alg: Algorithm,
881    rng: &mut (impl CryptoRng + Rng),
882) -> Result<(PrivateKey, PublicKey), Error> {
883    match alg {
884        Algorithm::X25519 => libcrux_ecdh::x25519_key_gen(rng)
885            .map_err(|e| e.into())
886            .map(|(private, public)| (PrivateKey::X25519(private), PublicKey::X25519(public))),
887        Algorithm::Secp256r1 => libcrux_ecdh::p256_key_gen(rng)
888            .map_err(|e| e.into())
889            .map(|(private, public)| (PrivateKey::P256(private), PublicKey::P256(public))),
890        Algorithm::MlKem512 => {
891            let (sk, pk) = mlkem512::generate_key_pair(random_array(rng)?).into_parts();
892            Ok((PrivateKey::MlKem512(sk), PublicKey::MlKem512(pk)))
893        }
894        Algorithm::MlKem768 => {
895            let (sk, pk) = mlkem768::generate_key_pair(random_array(rng)?).into_parts();
896            Ok((PrivateKey::MlKem768(sk), PublicKey::MlKem768(pk)))
897        }
898        Algorithm::MlKem1024 => {
899            let (sk, pk) = mlkem1024::generate_key_pair(random_array(rng)?).into_parts();
900            Ok((PrivateKey::MlKem1024(sk), PublicKey::MlKem1024(pk)))
901        }
902        Algorithm::X25519MlKem768Draft00 => {
903            let (mlkem_private, mlkem_public) = gen_mlkem768(rng)?;
904            let (x25519_private, x25519_public) = libcrux_ecdh::x25519_key_gen(rng)?;
905
906            Ok((
907                PrivateKey::X25519MlKem768Draft00(X25519MlKem768Draft00PrivateKey {
908                    mlkem: mlkem_private,
909                    x25519: x25519_private,
910                }),
911                PublicKey::X25519MlKem768Draft00(X25519MlKem768Draft00PublicKey {
912                    mlkem: mlkem_public,
913                    x25519: x25519_public,
914                }),
915            ))
916        }
917        Algorithm::XWingKemDraft02 => {
918            let (sk_m, pk_m) = gen_mlkem768(rng)?;
919            let (sk_x, pk_x) = libcrux_ecdh::x25519_key_gen(rng)?;
920            Ok((
921                PrivateKey::XWingKemDraft02(XWingKemDraft02PrivateKey {
922                    sk_m,
923                    sk_x,
924                    pk_x: X25519PublicKey(pk_x.0.clone()),
925                }),
926                PublicKey::XWingKemDraft02(XWingKemDraft02PublicKey { pk_m, pk_x }),
927            ))
928        }
929        #[cfg(feature = "kyber")]
930        Algorithm::X25519Kyber768Draft00 => {
931            let (mlkem_private, mlkem_public) = gen_mlkem768(rng)?;
932            let (x25519_private, x25519_public) = libcrux_ecdh::x25519_key_gen(rng)?;
933            Ok((
934                PrivateKey::X25519Kyber768Draft00(X25519MlKem768Draft00PrivateKey {
935                    mlkem: mlkem_private,
936                    x25519: x25519_private,
937                }),
938                PublicKey::X25519Kyber768Draft00(X25519MlKem768Draft00PublicKey {
939                    mlkem: mlkem_public,
940                    x25519: x25519_public,
941                }),
942            ))
943        }
944        #[cfg(feature = "kyber")]
945        Algorithm::XWingKyberDraft02 => {
946            let (sk_m, pk_m) = gen_mlkem768(rng)?;
947            let (sk_x, pk_x) = libcrux_ecdh::x25519_key_gen(rng)?;
948            Ok((
949                PrivateKey::XWingKyberDraft02(XWingKemDraft02PrivateKey {
950                    sk_m,
951                    sk_x,
952                    pk_x: X25519PublicKey(pk_x.0.clone()),
953                }),
954                PublicKey::XWingKyberDraft02(XWingKemDraft02PublicKey { pk_m, pk_x }),
955            ))
956        }
957        _ => Err(Error::UnsupportedAlgorithm),
958    }
959}
960
961fn mlkem_rand(
962    rng: &mut (impl CryptoRng + Rng),
963) -> Result<[u8; libcrux_ml_kem::SHARED_SECRET_SIZE], Error> {
964    let mut seed = [0; libcrux_ml_kem::SHARED_SECRET_SIZE];
965    rng.try_fill_bytes(&mut seed).map_err(|_| Error::KeyGen)?;
966    Ok(seed)
967}
968
969impl TryInto<libcrux_ecdh::X25519PublicKey> for PublicKey {
970    type Error = libcrux_ecdh::Error;
971
972    fn try_into(self) -> Result<libcrux_ecdh::X25519PublicKey, libcrux_ecdh::Error> {
973        if let PublicKey::X25519(k) = self {
974            Ok(k)
975        } else {
976            Err(libcrux_ecdh::Error::InvalidPoint)
977        }
978    }
979}
980
981impl TryInto<libcrux_ecdh::X25519PrivateKey> for PrivateKey {
982    type Error = libcrux_ecdh::Error;
983
984    fn try_into(self) -> Result<libcrux_ecdh::X25519PrivateKey, libcrux_ecdh::Error> {
985        if let PrivateKey::X25519(k) = self {
986            Ok(k)
987        } else {
988            Err(libcrux_ecdh::Error::InvalidPoint)
989        }
990    }
991}