use alloc::vec::Vec;
use rand_core::{CryptoRng, RngCore};
use zeroize::{Zeroize, ZeroizeOnDrop, Zeroizing};
use crate::HpkeError;
use crate::kem::Kem;
use crate::sealed::Sealed;
struct RngCompat10<'a, R: RngCore + CryptoRng>(pub(crate) &'a mut R);
impl<R: RngCore + CryptoRng> rand_core_10::TryRng for RngCompat10<'_, R> {
type Error = core::convert::Infallible;
fn try_next_u32(&mut self) -> Result<u32, Self::Error> {
Ok(self.0.next_u32())
}
fn try_next_u64(&mut self) -> Result<u64, Self::Error> {
Ok(self.0.next_u64())
}
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Self::Error> {
self.0.fill_bytes(dest);
Ok(())
}
}
impl<R: RngCore + CryptoRng> rand_core_10::TryCryptoRng for RngCompat10<'_, R> {}
#[derive(Debug, Clone, Copy, Default)]
pub struct XWingDraft06;
impl Sealed for XWingDraft06 {}
#[derive(Clone, Debug)]
pub struct XWingPublicKey {
bytes: Vec<u8>,
parsed: x_wing::EncapsulationKey,
}
impl AsRef<[u8]> for XWingPublicKey {
fn as_ref(&self) -> &[u8] {
&self.bytes
}
}
pub struct XWingPrivateKey {
seed: [u8; 32],
dk: Option<x_wing::DecapsulationKey>,
}
impl Zeroize for XWingPrivateKey {
fn zeroize(&mut self) {
self.seed.zeroize();
self.dk = None;
}
}
impl ZeroizeOnDrop for XWingPrivateKey {}
impl Drop for XWingPrivateKey {
fn drop(&mut self) {
self.zeroize();
}
}
#[derive(Clone, Debug)]
pub struct XWingEncappedKey(Vec<u8>);
impl AsRef<[u8]> for XWingEncappedKey {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
pub struct XWingSharedSecret([u8; 32]);
impl AsRef<[u8]> for XWingSharedSecret {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
impl Zeroize for XWingSharedSecret {
fn zeroize(&mut self) {
self.0.zeroize();
}
}
impl Drop for XWingSharedSecret {
fn drop(&mut self) {
self.zeroize();
}
}
impl Kem for XWingDraft06 {
const ID: u16 = 0x647A;
const ENCAPPED_KEY_LEN: usize = x_wing::CIPHERTEXT_SIZE;
const PUBLIC_KEY_LEN: usize = x_wing::ENCAPSULATION_KEY_SIZE;
const PRIVATE_KEY_LEN: usize = x_wing::DECAPSULATION_KEY_SIZE;
const SHARED_SECRET_LEN: usize = 32;
type PublicKey = XWingPublicKey;
type PrivateKey = XWingPrivateKey;
type EncappedKey = XWingEncappedKey;
type SharedSecret = XWingSharedSecret;
fn generate<R: CryptoRng + RngCore>(
rng: &mut R,
) -> Result<(Self::PrivateKey, Self::PublicKey), HpkeError> {
let mut seed = [0u8; 32];
rng.fill_bytes(&mut seed);
Ok(keypair_from_seed(seed))
}
fn derive_key_pair(ikm: &[u8]) -> Result<(Self::PrivateKey, Self::PublicKey), HpkeError> {
use sha3::digest::{ExtendableOutput, Update, XofReader};
let mut hasher = sha3::Shake256::default();
hasher.update(ikm);
let mut reader = hasher.finalize_xof();
let mut seed = [0u8; 32];
reader.read(&mut seed);
Ok(keypair_from_seed(seed))
}
fn encap<R: CryptoRng + RngCore>(
rng: &mut R,
pk_r: &Self::PublicKey,
) -> Result<(Self::SharedSecret, Self::EncappedKey), HpkeError> {
use x_wing::Encapsulate;
let mut compat = RngCompat10(rng);
let (ct, ss) = pk_r.parsed.encapsulate_with_rng(&mut compat);
let mut ss_bytes = [0u8; 32];
ss_bytes.copy_from_slice(ss.as_ref());
Ok((XWingSharedSecret(ss_bytes), XWingEncappedKey(ct.to_vec())))
}
fn decap(
enc: &Self::EncappedKey,
sk_r: &Self::PrivateKey,
) -> Result<Self::SharedSecret, HpkeError> {
use x_wing::Decapsulate;
let dk = sk_r.dk.as_ref().ok_or(HpkeError::DecapError)?;
let ss = dk
.decapsulate_slice(enc.0.as_slice())
.map_err(|_| HpkeError::InvalidEncappedKey)?;
let mut ss_bytes = [0u8; 32];
ss_bytes.copy_from_slice(ss.as_ref());
Ok(XWingSharedSecret(ss_bytes))
}
fn pk_from_bytes(b: &[u8]) -> Result<Self::PublicKey, HpkeError> {
if b.len() != Self::PUBLIC_KEY_LEN {
return Err(HpkeError::InvalidPublicKey);
}
let parsed =
x_wing::EncapsulationKey::try_from(b).map_err(|_| HpkeError::InvalidPublicKey)?;
Ok(XWingPublicKey {
bytes: b.to_vec(),
parsed,
})
}
fn sk_from_bytes(b: &[u8]) -> Result<Self::PrivateKey, HpkeError> {
if b.len() != 32 {
return Err(HpkeError::InvalidPrivateKey);
}
let mut seed = [0u8; 32];
seed.copy_from_slice(b);
let dk = x_wing::DecapsulationKey::from(seed);
Ok(XWingPrivateKey { seed, dk: Some(dk) })
}
fn enc_from_bytes(b: &[u8]) -> Result<Self::EncappedKey, HpkeError> {
if b.len() != Self::ENCAPPED_KEY_LEN {
return Err(HpkeError::InvalidEncappedKey);
}
Ok(XWingEncappedKey(b.to_vec()))
}
fn pk_to_bytes(pk: &Self::PublicKey) -> Vec<u8> {
pk.bytes.clone()
}
fn sk_to_bytes(sk: &Self::PrivateKey) -> Zeroizing<Vec<u8>> {
Zeroizing::new(sk.seed.to_vec())
}
}
fn keypair_from_seed(seed: [u8; 32]) -> (XWingPrivateKey, XWingPublicKey) {
use x_wing::{DecapsulationKey, Decapsulator, KeyExport};
let dk = DecapsulationKey::from(seed);
let ek = dk.encapsulation_key().clone();
let pk_bytes = ek.to_bytes().to_vec();
(
XWingPrivateKey { seed, dk: Some(dk) },
XWingPublicKey {
bytes: pk_bytes,
parsed: ek,
},
)
}
pub struct MlKemSharedSecret(Vec<u8>);
impl AsRef<[u8]> for MlKemSharedSecret {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
impl Zeroize for MlKemSharedSecret {
fn zeroize(&mut self) {
self.0.zeroize();
}
}
impl Drop for MlKemSharedSecret {
fn drop(&mut self) {
self.zeroize();
}
}
macro_rules! ml_kem_variant {
(
$marker:ident, $variant:literal, $id:expr, $nenc:expr, $npk:expr,
$dk:ty, $ek:ty, $ct:ty,
$pk_wrap:ident, $sk_wrap:ident, $enc_wrap:ident, $from_seed:ident $(,)?
) => {
#[doc = concat!("`", $variant, "` (FIPS 203). Private keys are stored as the 64-byte (d, z) seed; the expanded decapsulation key is rebuilt from it.")]
#[derive(Debug, Clone, Copy, Default)]
pub struct $marker;
impl Sealed for $marker {}
#[doc = concat!("Public (encapsulation) key for `", $variant, "`. Parsed `EncapsulationKey` cached so `encap` skips the per-call decode of the wire bytes.")]
#[derive(Clone, Debug)]
pub struct $pk_wrap {
bytes: Vec<u8>,
parsed: $ek,
}
impl AsRef<[u8]> for $pk_wrap {
fn as_ref(&self) -> &[u8] {
&self.bytes
}
}
#[doc = concat!("Private (decapsulation) key for `", $variant, "` โ 64-byte `d || z` seed plus expanded `dk`.")]
pub struct $sk_wrap {
dk: $dk,
seed: [u8; 64],
}
impl Zeroize for $sk_wrap {
fn zeroize(&mut self) {
self.seed.zeroize();
}
}
impl ZeroizeOnDrop for $sk_wrap {}
impl Drop for $sk_wrap {
fn drop(&mut self) {
self.zeroize();
}
}
#[doc = concat!("Encapsulated key (ciphertext) for `", $variant, "`.")]
#[derive(Clone, Debug)]
pub struct $enc_wrap(Vec<u8>);
impl AsRef<[u8]> for $enc_wrap {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
impl Kem for $marker {
const ID: u16 = $id;
const ENCAPPED_KEY_LEN: usize = $nenc;
const PUBLIC_KEY_LEN: usize = $npk;
const PRIVATE_KEY_LEN: usize = 64;
const SHARED_SECRET_LEN: usize = 32;
type PublicKey = $pk_wrap;
type PrivateKey = $sk_wrap;
type EncappedKey = $enc_wrap;
type SharedSecret = MlKemSharedSecret;
fn generate<R: CryptoRng + RngCore>(
rng: &mut R,
) -> Result<(Self::PrivateKey, Self::PublicKey), HpkeError> {
let mut seed = [0u8; 64];
rng.fill_bytes(&mut seed);
Ok($from_seed(seed))
}
fn derive_key_pair(
ikm: &[u8],
) -> Result<(Self::PrivateKey, Self::PublicKey), HpkeError> {
if ikm.len() != 64 {
return Err(HpkeError::DeriveKeyPairError);
}
let mut seed = [0u8; 64];
seed.copy_from_slice(ikm);
Ok($from_seed(seed))
}
fn encap<R: CryptoRng + RngCore>(
rng: &mut R,
pk_r: &Self::PublicKey,
) -> Result<(Self::SharedSecret, Self::EncappedKey), HpkeError> {
use ml_kem::kem::Encapsulate as _;
let mut compat = RngCompat10(rng);
let (ct, ss) = pk_r.parsed.encapsulate_with_rng(&mut compat);
Ok((MlKemSharedSecret(ss.to_vec()), $enc_wrap(ct.to_vec())))
}
fn decap(
enc: &Self::EncappedKey,
sk_r: &Self::PrivateKey,
) -> Result<Self::SharedSecret, HpkeError> {
use ml_kem::kem::Decapsulate as _;
let ct: $ct = enc
.0
.as_slice()
.try_into()
.map_err(|_| HpkeError::InvalidEncappedKey)?;
let ss = sk_r.dk.decapsulate(&ct);
Ok(MlKemSharedSecret(ss.to_vec()))
}
fn pk_from_bytes(b: &[u8]) -> Result<Self::PublicKey, HpkeError> {
if b.len() != Self::PUBLIC_KEY_LEN {
return Err(HpkeError::InvalidPublicKey);
}
let ek_bytes: ml_kem::kem::Key<$ek> = b
.try_into()
.map_err(|_| HpkeError::InvalidPublicKey)?;
let parsed =
<$ek>::new(&ek_bytes).map_err(|_| HpkeError::InvalidPublicKey)?;
Ok($pk_wrap {
bytes: b.to_vec(),
parsed,
})
}
fn sk_from_bytes(b: &[u8]) -> Result<Self::PrivateKey, HpkeError> {
if b.len() != 64 {
return Err(HpkeError::InvalidPrivateKey);
}
let mut seed = [0u8; 64];
seed.copy_from_slice(b);
let (sk, _pk) = $from_seed(seed);
Ok(sk)
}
fn enc_from_bytes(b: &[u8]) -> Result<Self::EncappedKey, HpkeError> {
if b.len() != Self::ENCAPPED_KEY_LEN {
return Err(HpkeError::InvalidEncappedKey);
}
Ok($enc_wrap(b.to_vec()))
}
fn pk_to_bytes(pk: &Self::PublicKey) -> Vec<u8> {
pk.bytes.clone()
}
fn sk_to_bytes(sk: &Self::PrivateKey) -> Zeroizing<Vec<u8>> {
Zeroizing::new(sk.seed.to_vec())
}
}
fn $from_seed(seed: [u8; 64]) -> ($sk_wrap, $pk_wrap) {
use ml_kem::kem::KeyExport as _;
let ml_seed: ml_kem::Seed = seed.into();
let dk = <$dk>::from_seed(ml_seed);
let ek = dk.encapsulation_key().clone();
let ek_bytes: Vec<u8> = ek.to_bytes().to_vec();
(
$sk_wrap { dk, seed },
$pk_wrap {
bytes: ek_bytes,
parsed: ek,
},
)
}
};
}
ml_kem_variant!(
MlKem768,
"ML-KEM-768",
0x0041,
1088,
1184,
ml_kem::DecapsulationKey768,
ml_kem::EncapsulationKey768,
ml_kem::ml_kem_768::Ciphertext,
MlKem768PublicKey,
MlKem768PrivateKey,
MlKem768EncappedKey,
ml_kem_768_from_seed,
);
ml_kem_variant!(
MlKem1024,
"ML-KEM-1024",
0x0042,
1568,
1568,
ml_kem::DecapsulationKey1024,
ml_kem::EncapsulationKey1024,
ml_kem::ml_kem_1024::Ciphertext,
MlKem1024PublicKey,
MlKem1024PrivateKey,
MlKem1024EncappedKey,
ml_kem_1024_from_seed,
);
#[cfg(test)]
mod tests {
use super::*;
use rand_core::{OsRng, TryRngCore as _};
macro_rules! sk_roundtrip_test {
($name:ident, $kem:ty, $seed_len:expr) => {
#[test]
fn $name() {
let mut os_rng = OsRng;
let mut rng = os_rng.unwrap_mut();
let (sk_r, pk_r) = <$kem>::generate(&mut rng).unwrap();
let sk_bytes = <$kem>::sk_to_bytes(&sk_r);
assert_eq!(sk_bytes.len(), $seed_len);
let sk_loaded = <$kem>::sk_from_bytes(&sk_bytes).unwrap();
let (ss_e, enc) = <$kem>::encap(&mut rng, &pk_r).unwrap();
let ss_d = <$kem>::decap(&enc, &sk_loaded).unwrap();
assert_eq!(ss_e.as_ref(), ss_d.as_ref());
}
};
}
sk_roundtrip_test!(xwing_sk_to_bytes_roundtrip, XWingDraft06, 32);
sk_roundtrip_test!(ml_kem_768_sk_to_bytes_roundtrip, MlKem768, 64);
sk_roundtrip_test!(ml_kem_1024_sk_to_bytes_roundtrip, MlKem1024, 64);
#[test]
fn xwing_derive_key_pair_roundtrip() {
let (sk_r, pk_r) = XWingDraft06::derive_key_pair(b"x-wing test ikm").unwrap();
let mut os_rng = OsRng;
let mut rng = os_rng.unwrap_mut();
let (ss_e, enc) = XWingDraft06::encap(&mut rng, &pk_r).unwrap();
assert_eq!(
XWingDraft06::decap(&enc, &sk_r).unwrap().as_ref(),
ss_e.as_ref(),
);
}
#[test]
fn ml_kem_768_and_1024_derive_distinct_keys_from_same_ikm() {
let ikm = [0x5Au8; 64];
let (_, pk_768) = MlKem768::derive_key_pair(&ikm).unwrap();
let (_, pk_1024) = MlKem1024::derive_key_pair(&ikm).unwrap();
let n = pk_768.as_ref().len().min(pk_1024.as_ref().len());
assert_ne!(&pk_768.as_ref()[..n], &pk_1024.as_ref()[..n]);
}
macro_rules! ml_kem_derive_seed_test {
($name:ident, $kem:ty) => {
#[test]
fn $name() {
let ikm: [u8; 64] = core::array::from_fn(|i| u8::try_from(i).unwrap());
let (sk1, pk1) = <$kem>::derive_key_pair(&ikm).unwrap();
let (_, pk2) = <$kem>::derive_key_pair(&ikm).unwrap();
assert_eq!(pk1.as_ref(), pk2.as_ref());
assert_eq!(sk1.seed, ikm);
for bad_len in [0usize, 32, 63, 65] {
assert!(matches!(
<$kem>::derive_key_pair(&vec![0u8; bad_len]),
Err(HpkeError::DeriveKeyPairError)
));
}
}
};
}
ml_kem_derive_seed_test!(ml_kem_768_derive_key_pair_seed_invariants, MlKem768);
ml_kem_derive_seed_test!(ml_kem_1024_derive_key_pair_seed_invariants, MlKem1024);
}