use hkdf::Hkdf;
use hmac::{Hmac, Mac};
use rand_core::RngCore;
use ring::agreement::{self, EphemeralPrivateKey, UnparsedPublicKey, X25519};
use ring::rand::SystemRandom;
use sha2::Sha384;
use subtle::ConstantTimeEq;
use zeroize::{Zeroize, ZeroizeOnDrop};
use libcrux_ml_kem::mlkem768::{self, MlKem768Ciphertext, MlKem768KeyPair, MlKem768PublicKey};
pub const SESSION_KEY_LEN: usize = 32;
pub const MLKEM_PK_LEN: usize = 1184;
pub const MLKEM_CT_LEN: usize = 1088;
pub const X25519_PK_LEN: usize = 32;
pub const HMAC_LEN: usize = 48;
#[derive(Zeroize, ZeroizeOnDrop)]
pub struct SessionKey(pub [u8; SESSION_KEY_LEN]);
impl SessionKey {
#[allow(dead_code)]
pub fn ct_eq(&self, other: &SessionKey) -> bool {
self.0.ct_eq(&other.0).into()
}
}
impl std::fmt::Debug for SessionKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "SessionKey([REDACTED])")
}
}
#[derive(Zeroize, ZeroizeOnDrop)]
pub struct SecretBytes(pub Vec<u8>);
pub struct X25519KeyPair {
private_key: Option<EphemeralPrivateKey>,
pub public_key_bytes: [u8; X25519_PK_LEN],
}
impl X25519KeyPair {
pub fn generate() -> Result<Self, CryptoError> {
let rng = SystemRandom::new();
let private_key =
EphemeralPrivateKey::generate(&X25519, &rng).map_err(|_| CryptoError::KeyGen)?;
let public_key = private_key.compute_public_key().map_err(|_| CryptoError::KeyGen)?;
let mut pk_bytes = [0u8; X25519_PK_LEN];
pk_bytes.copy_from_slice(public_key.as_ref());
Ok(X25519KeyPair {
private_key: Some(private_key),
public_key_bytes: pk_bytes,
})
}
pub fn diffie_hellman(
mut self,
peer_public_key_bytes: &[u8; X25519_PK_LEN],
) -> Result<SecretBytes, CryptoError> {
let private_key = self.private_key.take().ok_or(CryptoError::AlreadyUsed)?;
let peer_pk = UnparsedPublicKey::new(&X25519, peer_public_key_bytes.as_ref());
let shared = agreement::agree_ephemeral(private_key, &peer_pk, |ss| {
Ok::<SecretBytes, CryptoError>(SecretBytes(ss.to_vec()))
})
.map_err(|_| CryptoError::Ecdh)??;
Ok(shared)
}
}
pub struct MlKemKeyPair {
inner: MlKem768KeyPair,
}
impl MlKemKeyPair {
pub fn generate() -> Result<Self, CryptoError> {
let mut seed = [0u8; 64];
rand_core::OsRng.fill_bytes(&mut seed);
let kp = mlkem768::generate_key_pair(seed);
Ok(MlKemKeyPair { inner: kp })
}
pub fn public_key_bytes(&self) -> Vec<u8> {
self.inner.public_key().as_ref().to_vec()
}
pub fn decapsulate(&self, ciphertext: &[u8]) -> Result<SecretBytes, CryptoError> {
if ciphertext.len() != MLKEM_CT_LEN {
return Err(CryptoError::InvalidCiphertext);
}
let ct_arr: [u8; MLKEM_CT_LEN] = ciphertext.try_into().unwrap();
let ct = MlKem768Ciphertext::from(ct_arr);
let ss = mlkem768::decapsulate(self.inner.private_key(), &ct);
Ok(SecretBytes(ss.as_ref().to_vec()))
}
}
pub fn mlkem_encapsulate(
public_key_bytes: &[u8],
) -> Result<(Vec<u8>, SecretBytes), CryptoError> {
if public_key_bytes.len() != MLKEM_PK_LEN {
return Err(CryptoError::InvalidPublicKey);
}
let pk_arr: [u8; MLKEM_PK_LEN] = public_key_bytes.try_into().unwrap();
let pk = MlKem768PublicKey::from(pk_arr);
let mut rand_bytes = [0u8; 32];
rand_core::OsRng.fill_bytes(&mut rand_bytes);
let (ct, ss) = mlkem768::encapsulate(&pk, rand_bytes);
Ok((ct.as_ref().to_vec(), SecretBytes(ss.as_ref().to_vec())))
}
const HYBRID_SALT: &[u8] = b"WPA-Next-v1-hybrid-salt";
const HYBRID_INFO: &[u8] = b"WPA-Next-v1-session-key";
pub fn derive_session_key(
classical_ss: &SecretBytes,
pq_ss: &SecretBytes,
) -> Result<SessionKey, CryptoError> {
let mut ikm = Vec::with_capacity(classical_ss.0.len() + pq_ss.0.len());
ikm.extend_from_slice(&classical_ss.0);
ikm.extend_from_slice(&pq_ss.0);
let hk = Hkdf::<Sha384>::new(Some(HYBRID_SALT), &ikm);
let mut okm = [0u8; SESSION_KEY_LEN];
hk.expand(HYBRID_INFO, &mut okm)
.map_err(|_| CryptoError::Hkdf)?;
let mut ikm_zeroize = ikm;
ikm_zeroize.zeroize();
Ok(SessionKey(okm))
}
type HmacSha384 = Hmac<Sha384>;
pub fn compute_cookie(
ap_secret: &[u8; 32],
peer_addr: &[u8; 6],
sequence_id: u32,
) -> [u8; HMAC_LEN] {
let mut mac = <HmacSha384 as Mac>::new_from_slice(ap_secret)
.expect("HMAC accepts any key length");
mac.update(peer_addr);
mac.update(&sequence_id.to_be_bytes());
mac.update(b"WPA-Next-cookie-v1");
let result = mac.finalize().into_bytes();
let mut out = [0u8; HMAC_LEN];
out.copy_from_slice(&result);
out
}
pub fn verify_cookie(
ap_secret: &[u8; 32],
peer_addr: &[u8; 6],
sequence_id: u32,
candidate: &[u8; HMAC_LEN],
) -> bool {
let expected = compute_cookie(ap_secret, peer_addr, sequence_id);
expected.ct_eq(candidate).into()
}
#[derive(Debug, thiserror::Error)]
pub enum CryptoError {
#[error("Key generation failed")]
KeyGen,
#[error("X25519 ECDH agreement failed")]
Ecdh,
#[error("HKDF expansion failed (output too long)")]
Hkdf,
#[error("Private key already consumed (single-use ephemeral)")]
AlreadyUsed,
#[error("Invalid ML-KEM ciphertext length")]
InvalidCiphertext,
#[error("Invalid ML-KEM public key length")]
InvalidPublicKey,
}