use borsh::{BorshDeserialize, BorshSerialize};
use hkdf::Hkdf;
use ml_kem::array::Array;
use ml_kem::kem::{Decapsulate, Encapsulate};
use ml_kem::{Encoded, EncodedSizeUser, KemCore, MlKem768};
use rand::rngs::OsRng;
use sha2::Sha256;
use std::fmt;
use zeroize::ZeroizeOnDrop;
#[cfg(not(feature = "fips"))]
use x25519_dalek::{PublicKey as X25519PublicKey, StaticSecret};
#[cfg(feature = "fips")]
use aws_lc_rs::{
agreement::{self, agree, EphemeralPrivateKey, PrivateKey, UnparsedPublicKey, ECDH_P256},
rand::SystemRandom,
};
type MlKem768DecapKey = <MlKem768 as KemCore>::DecapsulationKey;
type MlKem768EncapKey = <MlKem768 as KemCore>::EncapsulationKey;
#[cfg(not(feature = "fips"))]
pub const CLASSICAL_PK_BYTES: usize = 32;
#[cfg(feature = "fips")]
pub const CLASSICAL_PK_BYTES: usize = 65;
#[cfg(not(feature = "fips"))]
const COMBINE_LABEL: &[u8] = b"HybridKEM_X25519_Kyber768";
#[cfg(feature = "fips")]
const COMBINE_LABEL: &[u8] = b"HybridKEM_P256_Kyber768";
#[derive(ZeroizeOnDrop)]
pub struct HybridSecretKey {
#[cfg(not(feature = "fips"))]
pub classical_sk: StaticSecret,
#[cfg(feature = "fips")]
#[zeroize(skip)] pub classical_sk: PrivateKey,
#[zeroize(skip)] pub ml_kem_dk: Box<MlKem768DecapKey>,
}
impl HybridSecretKey {
pub fn generate() -> (Self, HybridKeyPackage) {
let mut rng = OsRng;
#[cfg(not(feature = "fips"))]
let (classical_sk, classical_pk_bytes) = {
let sk = StaticSecret::random_from_rng(rng);
let pk = X25519PublicKey::from(&sk);
(sk, *pk.as_bytes())
};
#[cfg(feature = "fips")]
let (classical_sk, classical_pk_bytes) = {
#[allow(clippy::expect_used)]
let sk = PrivateKey::generate(&ECDH_P256)
.expect("aws-lc-rs ECDH-P-256 generate must succeed");
#[allow(clippy::expect_used)]
let pk = sk
.compute_public_key()
.expect("aws-lc-rs ECDH-P-256 compute_public_key must succeed");
let mut bytes = [0u8; CLASSICAL_PK_BYTES];
bytes.copy_from_slice(pk.as_ref());
(sk, bytes)
};
let (dk, ek) = MlKem768::generate(&mut rng);
let secret_key = HybridSecretKey {
classical_sk,
ml_kem_dk: Box::new(dk),
};
let key_package = HybridKeyPackage {
classical_pk: classical_pk_bytes,
ml_kem_pk: ek.as_bytes().to_vec(),
};
(secret_key, key_package)
}
pub fn decapsulate(&self, ciphertext: &HybridCiphertext) -> Result<[u8; 32], anyhow::Error> {
#[cfg(not(feature = "fips"))]
let classical_shared: [u8; 32] = {
let peer = X25519PublicKey::from(ciphertext.classical_pk);
let s = self.classical_sk.diffie_hellman(&peer);
*s.as_bytes()
};
#[cfg(feature = "fips")]
let classical_shared: [u8; 32] = {
let peer = UnparsedPublicKey::new(&ECDH_P256, &ciphertext.classical_pk[..]);
agree(
&self.classical_sk,
peer,
anyhow::anyhow!("aws-lc-rs ECDH-P-256 agree failed (peer key parse)"),
|km| -> Result<[u8; 32], anyhow::Error> {
let mut out = [0u8; 32];
out.copy_from_slice(km);
Ok(out)
},
)?
};
let ct_array = decode_ml_kem_ciphertext(&ciphertext.ml_kem_ct)
.ok_or_else(|| anyhow::anyhow!("invalid ML-KEM-768 ciphertext length"))?;
let ml_kem_shared = self
.ml_kem_dk
.decapsulate(&ct_array)
.map_err(|e| anyhow::anyhow!("ML-KEM decapsulation failed: {:?}", e))?;
Self::combine_secrets(&classical_shared, ml_kem_shared.as_slice())
}
pub(crate) fn combine_secrets(
ecc_secret: &[u8],
pq_secret: &[u8],
) -> Result<[u8; 32], anyhow::Error> {
let ikm = zeroize::Zeroizing::new([ecc_secret, pq_secret].concat());
let hkdf = Hkdf::<Sha256>::new(None, &ikm);
let mut okm = [0u8; 32];
hkdf.expand(COMBINE_LABEL, &mut okm)
.map_err(|_| anyhow::anyhow!("HKDF expansion failed"))?;
Ok(okm)
}
}
impl fmt::Debug for HybridSecretKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("HybridSecretKey")
.field("classical_sk", &"REDACTED")
.field("ml_kem_dk", &"REDACTED")
.finish()
}
}
#[derive(BorshSerialize, BorshDeserialize, Debug, Clone)]
pub struct HybridKeyPackage {
pub classical_pk: [u8; CLASSICAL_PK_BYTES],
pub ml_kem_pk: Vec<u8>,
}
impl HybridKeyPackage {
pub fn encapsulate(&self) -> Result<([u8; 32], HybridCiphertext), anyhow::Error> {
let mut rng = OsRng;
#[cfg(not(feature = "fips"))]
let (eph_pk_bytes, classical_shared) = {
let eph_sk = StaticSecret::random_from_rng(rng);
let eph_pk = X25519PublicKey::from(&eph_sk);
let peer = X25519PublicKey::from(self.classical_pk);
let shared = eph_sk.diffie_hellman(&peer);
(*eph_pk.as_bytes(), *shared.as_bytes())
};
#[cfg(feature = "fips")]
let (eph_pk_bytes, classical_shared): ([u8; CLASSICAL_PK_BYTES], [u8; 32]) = {
let aws_rng = SystemRandom::new();
let eph_sk = EphemeralPrivateKey::generate(&ECDH_P256, &aws_rng)
.map_err(|e| anyhow::anyhow!("aws-lc-rs ECDH-P-256 ephemeral generate: {:?}", e))?;
let eph_pk = eph_sk
.compute_public_key()
.map_err(|e| anyhow::anyhow!("compute_public_key: {:?}", e))?;
let mut pk_bytes = [0u8; CLASSICAL_PK_BYTES];
pk_bytes.copy_from_slice(eph_pk.as_ref());
let peer = UnparsedPublicKey::new(&ECDH_P256, &self.classical_pk[..]);
let shared = agreement::agree_ephemeral(
eph_sk,
peer,
anyhow::anyhow!("aws-lc-rs ECDH-P-256 agree_ephemeral failed (peer parse)"),
|km| -> Result<[u8; 32], anyhow::Error> {
let mut o = [0u8; 32];
o.copy_from_slice(km);
Ok(o)
},
)?;
(pk_bytes, shared)
};
let ek_array = decode_ml_kem_encap_key(&self.ml_kem_pk)
.ok_or_else(|| anyhow::anyhow!("invalid ML-KEM-768 public key length"))?;
let ek = MlKem768EncapKey::from_bytes(&ek_array);
let (ct, ml_kem_shared) = ek
.encapsulate(&mut rng)
.map_err(|e| anyhow::anyhow!("ML-KEM encapsulation failed: {:?}", e))?;
let shared_secret =
HybridSecretKey::combine_secrets(&classical_shared, ml_kem_shared.as_slice())?;
let ciphertext = HybridCiphertext {
classical_pk: eph_pk_bytes,
ml_kem_ct: ct.as_slice().to_vec(),
};
Ok((shared_secret, ciphertext))
}
}
#[derive(BorshSerialize, BorshDeserialize, Debug, Clone)]
pub struct HybridCiphertext {
pub classical_pk: [u8; CLASSICAL_PK_BYTES],
pub ml_kem_ct: Vec<u8>,
}
fn decode_ml_kem_encap_key(bytes: &[u8]) -> Option<Encoded<MlKem768EncapKey>> {
Encoded::<MlKem768EncapKey>::try_from(bytes).ok()
}
fn decode_ml_kem_ciphertext(
bytes: &[u8],
) -> Option<Array<u8, <MlKem768 as KemCore>::CiphertextSize>> {
Array::<u8, <MlKem768 as KemCore>::CiphertextSize>::try_from(bytes).ok()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn hybrid_kem_round_trip() {
let (sk, pk) = HybridSecretKey::generate();
let (ss_send, ct) = pk.encapsulate().expect("encap");
let ss_recv = sk.decapsulate(&ct).expect("decap");
assert_eq!(
ss_send, ss_recv,
"encap/decap must agree on the shared secret"
);
}
#[test]
fn hybrid_kem_two_handshakes_yield_distinct_secrets() {
let (_sk, pk) = HybridSecretKey::generate();
let (ss1, _ct1) = pk.encapsulate().expect("first encap");
let (ss2, _ct2) = pk.encapsulate().expect("second encap");
assert_ne!(ss1, ss2);
}
#[test]
fn ml_kem_ciphertext_size_matches_fips_203() {
let (_sk, pk) = HybridSecretKey::generate();
let (_ss, ct) = pk.encapsulate().expect("encap");
assert_eq!(ct.ml_kem_ct.len(), 1088);
}
#[test]
fn ml_kem_public_key_size_matches_fips_203() {
let (_sk, pk) = HybridSecretKey::generate();
assert_eq!(pk.ml_kem_pk.len(), 1184);
}
#[test]
fn hybrid_kem_two_secrets_distinct_under_same_recipient_key() {
let (sk, pk) = HybridSecretKey::generate();
let (ss1, ct1) = pk.encapsulate().expect("encap1");
let (_ss2, _ct2) = pk.encapsulate().expect("encap2");
let pt1 = sk.decapsulate(&ct1).expect("decap1");
assert_eq!(pt1, ss1);
}
#[test]
fn classical_public_key_size_matches_backend() {
let (_sk, pk) = HybridSecretKey::generate();
assert_eq!(pk.classical_pk.len(), CLASSICAL_PK_BYTES);
#[cfg(not(feature = "fips"))]
assert_eq!(CLASSICAL_PK_BYTES, 32, "X25519 public key is 32 bytes");
#[cfg(feature = "fips")]
assert_eq!(
CLASSICAL_PK_BYTES, 65,
"ECDH-P-256 uncompressed SEC1 public key is 65 bytes"
);
}
#[cfg(feature = "fips")]
#[test]
fn fips_classical_public_key_is_uncompressed_sec1() {
let (_sk, pk) = HybridSecretKey::generate();
assert_eq!(
pk.classical_pk[0], 0x04,
"uncompressed SEC1 P-256 key must lead with 0x04"
);
}
}