use ml_kem::array::sizes::U64;
use ml_kem::array::Array;
use ml_kem::kem::Decapsulate;
use ml_kem::{DecapsulationKey, EncapsulationKey, MlKem768};
use sha3::digest::{ExtendableOutput, Update as _, XofReader};
use sha3::{Digest, Sha3_256, Shake256};
use subtle::ConstantTimeEq;
use thiserror::Error;
use x25519_dalek::x25519;
use zeroize::{Zeroize, ZeroizeOnDrop};
pub const MLKEM768X25519_PUBLIC_KEY_LENGTH: usize = 1216;
pub const MLKEM768X25519_ENC_LENGTH: usize = 1120;
pub const MLKEM768X25519_SHARED_SECRET_LENGTH: usize = 32;
pub const MLKEM768X25519_SK_SEED_LENGTH: usize = 32;
pub const MLKEM768X25519_ESEED_LENGTH: usize = 64;
pub const X25519_KEY_LENGTH: usize = 32;
const MLKEM_EK_LENGTH: usize = 1184;
const MLKEM_CT_LENGTH: usize = 1088;
const XWING_EXPANDED_SEED_LENGTH: usize = 96;
const XWING_COMBINER_LABEL: &[u8] = &[0x5c, 0x2e, 0x2f, 0x2f, 0x5e, 0x5c];
#[derive(Debug, Clone, PartialEq, Eq, Error)]
pub enum KemError {
#[error("x25519 ECDH rejected: peer public key is a small-order point")]
X25519LowOrderPoint,
#[error("x25519 key must be 32 bytes, got {0}")]
InvalidX25519KeyLength(usize),
#[error("mlkem768x25519 public key must be 1216 bytes, got {0}")]
InvalidPublicKeyLength(usize),
#[error("mlkem768x25519 enc must be 1120 bytes, got {0}")]
InvalidEncLength(usize),
#[error("mlkem768x25519 secret seed must be 32 bytes, got {0}")]
InvalidSecretSeedLength(usize),
#[error("mlkem768x25519 eseed must be 64 bytes, got {0}")]
InvalidEseedLength(usize),
#[error("mlkem768x25519 public key contains an invalid ML-KEM-768 encapsulation key")]
InvalidMlKemEncapsulationKey,
}
pub fn x25519_public_key(secret_key: &[u8]) -> Result<[u8; 32], KemError> {
let scalar =
to_array32(secret_key).ok_or(KemError::InvalidX25519KeyLength(secret_key.len()))?;
Ok(x25519(scalar, x25519_dalek::X25519_BASEPOINT_BYTES))
}
pub fn x25519_ecdh(secret_key: &[u8], their_public_key: &[u8]) -> Result<[u8; 32], KemError> {
let scalar =
to_array32(secret_key).ok_or(KemError::InvalidX25519KeyLength(secret_key.len()))?;
let point = to_array32(their_public_key)
.ok_or(KemError::InvalidX25519KeyLength(their_public_key.len()))?;
let shared = x25519(scalar, point);
if shared.ct_eq(&[0u8; 32]).into() {
return Err(KemError::X25519LowOrderPoint);
}
Ok(shared)
}
#[derive(Clone, ZeroizeOnDrop)]
pub struct Mlkem768X25519Encapsulation {
#[zeroize(skip)]
pub enc: [u8; MLKEM768X25519_ENC_LENGTH],
pub ss: [u8; MLKEM768X25519_SHARED_SECRET_LENGTH],
}
pub fn mlkem768x25519_encapsulate(
public_key: &[u8],
eseed: &[u8],
) -> Result<Mlkem768X25519Encapsulation, KemError> {
if public_key.len() != MLKEM768X25519_PUBLIC_KEY_LENGTH {
return Err(KemError::InvalidPublicKeyLength(public_key.len()));
}
if eseed.len() != MLKEM768X25519_ESEED_LENGTH {
return Err(KemError::InvalidEseedLength(eseed.len()));
}
let ek_mlkem = &public_key[..MLKEM_EK_LENGTH];
let pk_x25519 = &public_key[MLKEM_EK_LENGTH..];
let mlkem_message = &eseed[..32];
let mut x_ephemeral_scalar = to_array32(&eseed[32..]).expect("eseed tail is exactly 32 bytes");
let ek_key = <&Array<u8, _>>::try_from(ek_mlkem)
.expect("the ML-KEM encapsulation key slice is exactly 1184 bytes");
let ek = EncapsulationKey::<MlKem768>::new(ek_key)
.map_err(|_| KemError::InvalidMlKemEncapsulationKey)?;
let m: Array<u8, _> =
Array::try_from(mlkem_message).expect("the ML-KEM message slice is exactly 32 bytes");
let (ct_mlkem, mlkem_ss_array) = ek.encapsulate_deterministic(&m);
let mut ss_mlkem =
to_array32(mlkem_ss_array.as_slice()).expect("ML-KEM shared key is 32 bytes");
let mut mlkem_ss_array = mlkem_ss_array;
mlkem_ss_array.as_mut_slice().zeroize();
let ct_x25519 = x25519(x_ephemeral_scalar, x25519_dalek::X25519_BASEPOINT_BYTES);
let mut ss_x25519 = x25519(
x_ephemeral_scalar,
to_array32(pk_x25519).expect("pk tail is 32 bytes"),
);
x_ephemeral_scalar.zeroize();
let mut enc = [0u8; MLKEM768X25519_ENC_LENGTH];
enc[..MLKEM_CT_LENGTH].copy_from_slice(ct_mlkem.as_slice());
enc[MLKEM_CT_LENGTH..].copy_from_slice(&ct_x25519);
let ss = xwing_combine(&ss_mlkem, &ss_x25519, &ct_x25519, pk_x25519);
ss_mlkem.zeroize();
ss_x25519.zeroize();
Ok(Mlkem768X25519Encapsulation { enc, ss })
}
pub fn mlkem768x25519_decapsulate(secret_seed: &[u8], enc: &[u8]) -> Result<[u8; 32], KemError> {
if secret_seed.len() != MLKEM768X25519_SK_SEED_LENGTH {
return Err(KemError::InvalidSecretSeedLength(secret_seed.len()));
}
if enc.len() != MLKEM768X25519_ENC_LENGTH {
return Err(KemError::InvalidEncLength(enc.len()));
}
let seed = to_array32(secret_seed).expect("secret seed length checked above");
let mut expanded = expand_xwing_seed(&seed);
let mlkem_seed: Array<u8, U64> =
Array::try_from(&expanded[0..64]).expect("the expansion yields a 64-byte ML-KEM seed");
let dk = DecapsulationKey::<MlKem768>::from_seed(mlkem_seed);
let mut x_scalar = to_array32(&expanded[64..96]).expect("the expansion tail is 32 bytes");
let pk_x25519 = x25519(x_scalar, x25519_dalek::X25519_BASEPOINT_BYTES);
let ct_mlkem = &enc[..MLKEM_CT_LENGTH];
let ct_x25519 = &enc[MLKEM_CT_LENGTH..];
let mlkem_ss_array = dk
.decapsulate_slice(ct_mlkem)
.expect("the ML-KEM ciphertext slice is exactly 1088 bytes");
let mut ss_mlkem =
to_array32(mlkem_ss_array.as_slice()).expect("ML-KEM shared key is 32 bytes");
let mut mlkem_ss_array = mlkem_ss_array;
mlkem_ss_array.as_mut_slice().zeroize();
let mut ss_x25519 = x25519(
x_scalar,
to_array32(ct_x25519).expect("ct tail is 32 bytes"),
);
let ss = xwing_combine(&ss_mlkem, &ss_x25519, ct_x25519, &pk_x25519);
ss_mlkem.zeroize();
ss_x25519.zeroize();
expanded.zeroize();
x_scalar.zeroize();
Ok(ss)
}
fn xwing_combine(
ss_mlkem: &[u8],
ss_x25519: &[u8],
ct_x25519: &[u8],
pk_x25519: &[u8],
) -> [u8; 32] {
let mut hasher = Sha3_256::new();
Digest::update(&mut hasher, ss_mlkem);
Digest::update(&mut hasher, ss_x25519);
Digest::update(&mut hasher, ct_x25519);
Digest::update(&mut hasher, pk_x25519);
Digest::update(&mut hasher, XWING_COMBINER_LABEL);
hasher.finalize().into()
}
fn expand_xwing_seed(seed: &[u8; 32]) -> [u8; XWING_EXPANDED_SEED_LENGTH] {
let mut hasher = Shake256::default();
hasher.update(seed);
let mut reader = hasher.finalize_xof();
let mut expanded = [0u8; XWING_EXPANDED_SEED_LENGTH];
reader.read(&mut expanded);
expanded
}
fn to_array32(bytes: &[u8]) -> Option<[u8; 32]> {
bytes.try_into().ok()
}
#[cfg(test)]
mod tests {
use super::*;
use ml_kem::KeyExport;
#[test]
fn x25519_rejects_the_all_zero_point() {
let secret = [9u8; 32];
assert_eq!(
x25519_ecdh(&secret, &[0u8; 32]),
Err(KemError::X25519LowOrderPoint),
);
}
#[test]
fn x25519_rejects_wrong_length_keys() {
assert_eq!(
x25519_public_key(&[0u8; 31]),
Err(KemError::InvalidX25519KeyLength(31))
);
assert_eq!(
x25519_ecdh(&[0u8; 33], &[0u8; 32]),
Err(KemError::InvalidX25519KeyLength(33)),
);
}
#[test]
fn x25519_roundtrips_between_two_parties() {
let alice = [1u8; 32];
let bob = [2u8; 32];
let alice_pub = x25519_public_key(&alice).unwrap();
let bob_pub = x25519_public_key(&bob).unwrap();
let from_alice = x25519_ecdh(&alice, &bob_pub).unwrap();
let from_bob = x25519_ecdh(&bob, &alice_pub).unwrap();
assert_eq!(from_alice, from_bob);
}
#[test]
fn xwing_encaps_decaps_agree() {
let seed = [42u8; 32];
let expanded = expand_xwing_seed(&seed);
let mlkem_seed: Array<u8, U64> = Array::try_from(&expanded[0..64]).unwrap();
let dk = DecapsulationKey::<MlKem768>::from_seed(mlkem_seed);
let ek_bytes = dk.encapsulation_key().to_bytes();
let pk_x25519 = x25519_public_key(&expanded[64..96]).unwrap();
let mut public_key = [0u8; MLKEM768X25519_PUBLIC_KEY_LENGTH];
public_key[..MLKEM_EK_LENGTH].copy_from_slice(ek_bytes.as_slice());
public_key[MLKEM_EK_LENGTH..].copy_from_slice(&pk_x25519);
let eseed = [7u8; 64];
let encaps = mlkem768x25519_encapsulate(&public_key, &eseed).unwrap();
let recovered = mlkem768x25519_decapsulate(&seed, &encaps.enc).unwrap();
assert_eq!(recovered, encaps.ss);
}
#[test]
fn xwing_rejects_wrong_length_inputs() {
assert_eq!(
mlkem768x25519_encapsulate(&[0u8; 1215], &[0u8; 64]).err(),
Some(KemError::InvalidPublicKeyLength(1215)),
);
assert_eq!(
mlkem768x25519_encapsulate(&[0u8; 1216], &[0u8; 63]).err(),
Some(KemError::InvalidEseedLength(63)),
);
assert_eq!(
mlkem768x25519_decapsulate(&[0u8; 31], &[0u8; 1120]),
Err(KemError::InvalidSecretSeedLength(31)),
);
assert_eq!(
mlkem768x25519_decapsulate(&[0u8; 32], &[0u8; 1119]),
Err(KemError::InvalidEncLength(1119)),
);
}
}