use hkdf::Hkdf;
use sha3::Sha3_512;
use zeroize::Zeroizing;
use rand_chacha::ChaCha20Rng;
use rand_core::SeedableRng;
use frodo_kem::{
Algorithm, Ciphertext as FrodoCiphertext, DecryptionKey as FrodoDecryptionKey,
EncryptionKey as FrodoEncryptionKey,
};
use ml_kem::{
Ciphertext, MlKem1024, Seed as MlKemSeed,
array::Array,
kem::{Decapsulate, KeyExport},
};
use x25519_dalek::{PublicKey, StaticSecret};
const K_WING_OKM_CONTEXT: &'static [u8; 64] = &[
23, 18, 198, 136, 205, 78, 247, 102, 135, 178, 234, 65, 223, 184, 208, 126, 20, 210, 94, 166,
168, 92, 94, 241, 48, 209, 96, 164, 56, 106, 245, 205, 94, 113, 223, 88, 245, 94, 152, 82, 1,
243, 111, 55, 252, 234, 237, 104, 244, 74, 251, 49, 208, 140, 49, 164, 217, 58, 35, 189, 66, 7,
225, 167,
];
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum Error {
EncapsulateError,
DecapsulateError,
LowEntropyKey,
InvalidFormat,
}
impl core::fmt::Display for Error {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Error::EncapsulateError => write!(f, "Encapsulation failed"),
Error::DecapsulateError => write!(f, "Decapsulation failed"),
Error::LowEntropyKey => write!(f, "Low entropy or non-contributory key detected"),
Error::InvalidFormat => write!(f, "Invalid format or size"),
}
}
}
fn derive_key(
dh_ss: Zeroizing<[u8; 32]>,
ml_kem_ss: Zeroizing<[u8; 32]>,
frodo_ss: Zeroizing<[u8; 32]>,
salt: &[u8; 32],
dh_eph_pub: &[u8; 32],
ml_kem_ct: &[u8; 1568],
frodo_ct: &[u8],
ek: &[u8],
) -> Result<[u8; 64], Error> {
let mut ikm = Zeroizing::new([0u8; 96]);
ikm[0..32].copy_from_slice(&*dh_ss);
ikm[32..64].copy_from_slice(&*ml_kem_ss);
ikm[64..96].copy_from_slice(&*frodo_ss);
drop((dh_ss, ml_kem_ss, frodo_ss));
let hkdf = Hkdf::<Sha3_512>::new(Some(salt), &*ikm);
drop(ikm);
let mut okm_info = Vec::with_capacity(
32 + ml_kem_ct.len() + frodo_ct.len() + ek.len() + K_WING_OKM_CONTEXT.len(),
);
okm_info.extend_from_slice(dh_eph_pub);
okm_info.extend_from_slice(ml_kem_ct);
okm_info.extend_from_slice(frodo_ct);
okm_info.extend_from_slice(ek);
okm_info.extend_from_slice(K_WING_OKM_CONTEXT);
let mut okm = [0u8; 64];
hkdf.expand(&okm_info, &mut okm)
.map_err(|_| Error::InvalidFormat)?;
Ok(okm)
}
pub struct KWing {
dh_secret: Zeroizing<StaticSecret>,
ml_kem_dk: ml_kem::DecapsulationKey<MlKem1024>,
frodo_sk: FrodoDecryptionKey,
composite_pk: Vec<u8>,
}
impl KWing {
pub const ENCAPSULATION_KEY_SIZE: usize = 23120;
pub const CIPHERTEXT_SIZE: usize = 23328;
pub fn from_seed(secret_seed: &[u8; 128]) -> Result<Self, Error> {
let dh_secret = Zeroizing::new(StaticSecret::from(
<[u8; 32]>::try_from(&secret_seed[0..32]).map_err(|_| Error::InvalidFormat)?,
));
let dh_pub = PublicKey::from(&*dh_secret);
let ml_kem_d = Zeroizing::new(
<[u8; 32]>::try_from(&secret_seed[32..64]).map_err(|_| Error::InvalidFormat)?,
);
let ml_kem_z = Zeroizing::new(
<[u8; 32]>::try_from(&secret_seed[64..96]).map_err(|_| Error::InvalidFormat)?,
);
let mut ml_kem_seed = MlKemSeed::default();
ml_kem_seed[..32].copy_from_slice(&*ml_kem_d);
ml_kem_seed[32..].copy_from_slice(&*ml_kem_z);
let ml_kem_dk = ml_kem::DecapsulationKey::<MlKem1024>::from_seed(ml_kem_seed);
let ml_ek = ml_kem_dk.encapsulation_key();
let frodo_seed =
<[u8; 32]>::try_from(&secret_seed[96..128]).map_err(|_| Error::InvalidFormat)?;
let mut frodo_rng = ChaCha20Rng::from_seed(frodo_seed);
let frodo = Algorithm::FrodoKem1344Shake;
let (frodo_pk, frodo_sk) = frodo.generate_keypair(&mut frodo_rng);
let mut composite_pk = Vec::with_capacity(Self::ENCAPSULATION_KEY_SIZE);
composite_pk.extend_from_slice(dh_pub.as_bytes());
composite_pk.extend_from_slice(&ml_ek.to_bytes());
composite_pk.extend_from_slice(frodo_pk.value());
Ok(Self {
dh_secret,
ml_kem_dk,
frodo_sk,
composite_pk,
})
}
#[must_use]
pub fn get_pub_key(&self) -> &[u8] {
&self.composite_pk
}
pub fn encapsulate(encaps_seed: &[u8; 128], ek: &[u8]) -> Result<(Vec<u8>, [u8; 64]), Error> {
if ek.len() != Self::ENCAPSULATION_KEY_SIZE {
return Err(Error::InvalidFormat);
}
let frodo = Algorithm::FrodoKem1344Shake;
let dh_pub =
PublicKey::from(<[u8; 32]>::try_from(&ek[0..32]).map_err(|_| Error::InvalidFormat)?);
let ml_kem_ek_bytes: ml_kem::kem::Key<ml_kem::EncapsulationKey<MlKem1024>> = Array(
ek[32..1600].try_into().map_err(|_| Error::InvalidFormat)?,
);
let ml_kem_ek = ml_kem::EncapsulationKey::<MlKem1024>::new(&ml_kem_ek_bytes)
.map_err(|_| Error::InvalidFormat)?;
let frodo_pk =
FrodoEncryptionKey::from_bytes(frodo, &ek[1600..]).map_err(|_| Error::InvalidFormat)?;
let dh_eph_secret = Zeroizing::new(StaticSecret::from(
<[u8; 32]>::try_from(&encaps_seed[0..32]).map_err(|_| Error::InvalidFormat)?,
));
let ml_kem_m = Zeroizing::new(
<[u8; 32]>::try_from(&encaps_seed[32..64]).map_err(|_| Error::InvalidFormat)?,
);
let frodo_rng_seed =
<[u8; 32]>::try_from(&encaps_seed[64..96]).map_err(|_| Error::InvalidFormat)?;
let salt = <[u8; 32]>::try_from(&encaps_seed[96..128]).map_err(|_| Error::InvalidFormat)?;
let dh_eph_pub = PublicKey::from(&*dh_eph_secret);
let dh_eph_pub_bytes = dh_eph_pub.as_bytes();
let dh_ss = Zeroizing::new(dh_eph_secret.diffie_hellman(&dh_pub));
if !dh_ss.was_contributory() {
return Err(Error::LowEntropyKey);
}
let (ml_kem_ct, ml_kem_ss) = ml_kem_ek
.encapsulate_deterministic(&Array(*ml_kem_m));
let ml_kem_ss: Zeroizing<[u8; 32]> = Zeroizing::new(ml_kem_ss.into());
let mut frodo_rng = ChaCha20Rng::from_seed(frodo_rng_seed);
let (frodo_ct, frodo_ss) = frodo
.encapsulate_with_rng(&frodo_pk, &mut frodo_rng)
.map_err(|_| Error::EncapsulateError)?;
let frodo_ss: Zeroizing<[u8; 32]> = Zeroizing::new(
frodo_ss
.value()
.try_into()
.map_err(|_| Error::EncapsulateError)?,
);
let frodo_ct_arr = frodo_ct.value().to_vec();
let ml_kem_ct_bytes: [u8; 1568] = ml_kem_ct.0
.try_into()
.map_err(|_| Error::EncapsulateError)?;
let okm = derive_key(
Zeroizing::new(dh_ss.to_bytes()),
ml_kem_ss,
frodo_ss,
&salt,
dh_eph_pub_bytes,
&ml_kem_ct_bytes,
&frodo_ct_arr,
ek,
)?;
let mut ciphertext = Vec::with_capacity(Self::CIPHERTEXT_SIZE);
ciphertext.extend_from_slice(dh_eph_pub_bytes);
ciphertext.extend_from_slice(&salt);
ciphertext.extend_from_slice(&ml_kem_ct_bytes);
ciphertext.extend_from_slice(&frodo_ct_arr);
Ok((ciphertext, okm))
}
pub fn decapsulate(&self, ct: &[u8]) -> Result<[u8; 64], Error> {
if ct.len() != Self::CIPHERTEXT_SIZE {
return Err(Error::InvalidFormat);
}
let frodo = Algorithm::FrodoKem1344Shake;
let dh_eph_pub =
PublicKey::from(<[u8; 32]>::try_from(&ct[0..32]).map_err(|_| Error::InvalidFormat)?);
let salt: [u8; 32] = ct[32..64].try_into().map_err(|_| Error::InvalidFormat)?;
let ml_kem_ct: Ciphertext<MlKem1024> = Array(ct[64..1632].try_into().map_err(|_| Error::InvalidFormat)?);
let frodo_ct_bytes = &ct[1632..];
let frodo_ct = FrodoCiphertext::from_bytes(frodo, &frodo_ct_bytes)
.map_err(|_| Error::InvalidFormat)?;
let dh_ss = Zeroizing::new(self.dh_secret.diffie_hellman(&dh_eph_pub));
if !dh_ss.was_contributory() {
return Err(Error::LowEntropyKey);
}
let ml_kem_ss: Zeroizing<[u8; 32]> = Zeroizing::new(
self.ml_kem_dk
.decapsulate(&ml_kem_ct)
.into(),
);
let frodo_ss: Zeroizing<[u8; 32]> = Zeroizing::new(
frodo
.decapsulate(&self.frodo_sk, &frodo_ct)
.map_err(|_| Error::DecapsulateError)?
.0
.value()
.try_into()
.map_err(|_| Error::DecapsulateError)?,
);
let ml_kem_ct_bytes: [u8; 1568] = ml_kem_ct.0
.try_into()
.map_err(|_| Error::DecapsulateError)?;
let okm = derive_key(
Zeroizing::new(dh_ss.to_bytes()),
ml_kem_ss,
frodo_ss,
&salt,
dh_eph_pub.as_bytes(),
&ml_kem_ct_bytes,
frodo_ct_bytes,
&self.get_pub_key(),
)?;
Ok(okm)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::LazyLock;
static SECRET_SEED: [u8; 128] = [0x42; 128];
static ENCAPS_SEED: [u8; 128] = [0x84; 128];
static K_WING: LazyLock<KWing> = LazyLock::new(|| KWing::from_seed(&SECRET_SEED).unwrap());
static ENCAPS_RESULT: LazyLock<(Vec<u8>, [u8; 64])> =
LazyLock::new(|| KWing::encapsulate(&ENCAPS_SEED, K_WING.get_pub_key()).unwrap());
#[test]
fn test_happy_path_round_trip() {
let pk = K_WING.get_pub_key();
assert_eq!(pk.len(), KWing::ENCAPSULATION_KEY_SIZE);
let (ct, okm_encapsulated) = &*ENCAPS_RESULT;
assert_eq!(ct.len(), KWing::CIPHERTEXT_SIZE);
let okm_decapsulated = K_WING
.decapsulate(ct)
.expect("Decapsulation should succeed");
assert_eq!(
okm_encapsulated, &okm_decapsulated,
"Decapsulated OKM must exactly match the Encapsulated OKM"
);
}
#[test]
fn test_strict_determinism() {
let binding = KWing::from_seed(&SECRET_SEED).unwrap();
let pk2 = binding.get_pub_key();
assert_eq!(
K_WING.get_pub_key(),
pk2,
"Public keys must be identical for the same seed"
);
let (ct2, okm2) = KWing::encapsulate(&ENCAPS_SEED, pk2).unwrap();
assert_eq!(
ENCAPS_RESULT.0, ct2,
"Ciphertexts must be identical for the same seeds"
);
assert_eq!(
ENCAPS_RESULT.1, okm2,
"OKMs must be identical for the same seeds"
);
}
#[test]
fn test_invalid_public_key_length() {
let bad_pk = vec![0u8; KWing::ENCAPSULATION_KEY_SIZE - 1];
let result = KWing::encapsulate(&ENCAPS_SEED, &bad_pk);
assert_eq!(
result,
Err(Error::InvalidFormat),
"Encapsulate must reject invalid public key lengths immediately"
);
}
#[test]
fn test_invalid_ciphertext_length() {
let bad_ct = vec![0u8; KWing::CIPHERTEXT_SIZE + 5];
let result = K_WING.decapsulate(&bad_ct);
assert_eq!(
result,
Err(Error::InvalidFormat),
"Decapsulate must reject invalid ciphertext lengths immediately"
);
}
#[test]
fn test_low_entropy_key_encapsulate_rejection() {
let mut pk = K_WING.get_pub_key().to_vec();
pk[0..32].fill(0);
let result = KWing::encapsulate(&ENCAPS_SEED, &pk);
assert_eq!(
result,
Err(Error::LowEntropyKey),
"Encapsulate must reject mathematical weak points (all-zero DH shared secret)"
);
}
#[test]
fn test_low_entropy_key_decapsulate_rejection() {
let mut ct = ENCAPS_RESULT.0.clone();
ct[0..32].fill(0);
let result = K_WING.decapsulate(&ct);
assert_eq!(
result,
Err(Error::LowEntropyKey),
"Decapsulate must reject mathematical weak points injected by an attacker"
);
}
}