use pqcrypto_kyber::kyber768;
use pqcrypto_traits::kem::{Ciphertext, PublicKey, SecretKey, SharedSecret};
use x25519_dalek::{EphemeralSecret, PublicKey as X25519Public, SharedSecret as X25519Shared};
use crate::crypto::kdf::KeyDerivation;
use crate::error::{CryptoError, SrxError};
#[derive(Clone)]
pub struct HybridPublicKey {
pub kyber_public: Vec<u8>,
pub ecdh_public: [u8; 32],
}
impl HybridPublicKey {
pub fn to_bytes(&self) -> Vec<u8> {
let len = self.kyber_public.len() as u16;
let mut buf = Vec::with_capacity(2 + self.kyber_public.len() + 32);
buf.extend_from_slice(&len.to_be_bytes());
buf.extend_from_slice(&self.kyber_public);
buf.extend_from_slice(&self.ecdh_public);
buf
}
pub fn from_bytes(data: &[u8]) -> crate::error::Result<Self> {
if data.len() < 2 {
return Err(SrxError::Crypto(CryptoError::KeyExchangeFailed(
"public key too short".into(),
)));
}
let kyber_len = u16::from_be_bytes([data[0], data[1]]) as usize;
if data.len() < 2 + kyber_len + 32 {
return Err(SrxError::Crypto(CryptoError::KeyExchangeFailed(
"public key data truncated".into(),
)));
}
let kyber_public = data[2..2 + kyber_len].to_vec();
let mut ecdh_public = [0u8; 32];
ecdh_public.copy_from_slice(&data[2 + kyber_len..2 + kyber_len + 32]);
Ok(Self {
kyber_public,
ecdh_public,
})
}
}
pub struct HybridKeypair {
pub public: HybridPublicKey,
kyber_secret: Vec<u8>,
ecdh_secret: Option<EphemeralSecret>,
}
pub struct EncapsulatedKey {
pub pqc_ciphertext: Vec<u8>,
pub ecdh_public: [u8; 32],
}
impl EncapsulatedKey {
pub fn to_bytes(&self) -> Vec<u8> {
let len = self.pqc_ciphertext.len() as u16;
let mut buf = Vec::with_capacity(2 + self.pqc_ciphertext.len() + 32);
buf.extend_from_slice(&len.to_be_bytes());
buf.extend_from_slice(&self.pqc_ciphertext);
buf.extend_from_slice(&self.ecdh_public);
buf
}
pub fn from_bytes(data: &[u8]) -> crate::error::Result<Self> {
if data.len() < 2 {
return Err(SrxError::Crypto(CryptoError::KeyExchangeFailed(
"encapsulated key too short".into(),
)));
}
let ct_len = u16::from_be_bytes([data[0], data[1]]) as usize;
if data.len() < 2 + ct_len + 32 {
return Err(SrxError::Crypto(CryptoError::KeyExchangeFailed(
"encapsulated key data truncated".into(),
)));
}
let pqc_ciphertext = data[2..2 + ct_len].to_vec();
let mut ecdh_public = [0u8; 32];
ecdh_public.copy_from_slice(&data[2 + ct_len..2 + ct_len + 32]);
Ok(Self {
pqc_ciphertext,
ecdh_public,
})
}
}
pub struct HybridKem;
impl HybridKem {
pub fn generate_keypair() -> HybridKeypair {
let (kyber_pk, kyber_sk) = kyber768::keypair();
let ecdh_secret = EphemeralSecret::random();
let ecdh_public = X25519Public::from(&ecdh_secret);
HybridKeypair {
public: HybridPublicKey {
kyber_public: kyber_pk.as_bytes().to_vec(),
ecdh_public: ecdh_public.to_bytes(),
},
kyber_secret: kyber_sk.as_bytes().to_vec(),
ecdh_secret: Some(ecdh_secret),
}
}
pub fn encapsulate(
peer_public: &HybridPublicKey,
) -> crate::error::Result<(EncapsulatedKey, [u8; 32])> {
let kyber_pk = kyber768::PublicKey::from_bytes(&peer_public.kyber_public)
.map_err(|e| SrxError::Crypto(CryptoError::PqcFailed(format!("{e:?}"))))?;
let (pqc_ss, pqc_ct) = kyber768::encapsulate(&kyber_pk);
let ecdh_secret = EphemeralSecret::random();
let ecdh_public = X25519Public::from(&ecdh_secret);
let peer_ecdh = X25519Public::from(peer_public.ecdh_public);
let ecdh_ss: X25519Shared = ecdh_secret.diffie_hellman(&peer_ecdh);
let master = KeyDerivation::combine_secrets(pqc_ss.as_bytes(), ecdh_ss.as_bytes())?;
let encap = EncapsulatedKey {
pqc_ciphertext: pqc_ct.as_bytes().to_vec(),
ecdh_public: ecdh_public.to_bytes(),
};
Ok((encap, master))
}
pub fn decapsulate(
keypair: &mut HybridKeypair,
encapsulated: &EncapsulatedKey,
) -> crate::error::Result<[u8; 32]> {
let kyber_sk = kyber768::SecretKey::from_bytes(&keypair.kyber_secret)
.map_err(|e| SrxError::Crypto(CryptoError::PqcFailed(format!("{e:?}"))))?;
let kyber_ct = kyber768::Ciphertext::from_bytes(&encapsulated.pqc_ciphertext)
.map_err(|e| SrxError::Crypto(CryptoError::PqcFailed(format!("{e:?}"))))?;
let pqc_ss = kyber768::decapsulate(&kyber_ct, &kyber_sk);
let ecdh_secret = keypair.ecdh_secret.take().ok_or_else(|| {
SrxError::Crypto(CryptoError::KeyExchangeFailed(
"ECDH secret already consumed".into(),
))
})?;
let peer_ecdh = X25519Public::from(encapsulated.ecdh_public);
let ecdh_ss = ecdh_secret.diffie_hellman(&peer_ecdh);
KeyDerivation::combine_secrets(pqc_ss.as_bytes(), ecdh_ss.as_bytes())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_keypair_generation() {
let kp = HybridKem::generate_keypair();
assert!(!kp.public.kyber_public.is_empty());
assert_ne!(kp.public.ecdh_public, [0u8; 32]);
}
#[test]
fn test_encapsulate_decapsulate() {
let mut server_kp = HybridKem::generate_keypair();
let (encap, client_master) = HybridKem::encapsulate(&server_kp.public).unwrap();
let server_master = HybridKem::decapsulate(&mut server_kp, &encap).unwrap();
assert_eq!(client_master, server_master);
}
#[test]
fn test_different_keypairs_different_secrets() {
let mut server_kp1 = HybridKem::generate_keypair();
let mut server_kp2 = HybridKem::generate_keypair();
let (encap1, master1) = HybridKem::encapsulate(&server_kp1.public).unwrap();
let (encap2, master2) = HybridKem::encapsulate(&server_kp2.public).unwrap();
let dec1 = HybridKem::decapsulate(&mut server_kp1, &encap1).unwrap();
let dec2 = HybridKem::decapsulate(&mut server_kp2, &encap2).unwrap();
assert_eq!(master1, dec1);
assert_eq!(master2, dec2);
assert_ne!(master1, master2);
}
#[test]
fn test_public_key_serialization() {
let kp = HybridKem::generate_keypair();
let bytes = kp.public.to_bytes();
let restored = HybridPublicKey::from_bytes(&bytes).unwrap();
assert_eq!(kp.public.kyber_public, restored.kyber_public);
assert_eq!(kp.public.ecdh_public, restored.ecdh_public);
}
#[test]
fn test_encapsulated_key_serialization() {
let server_kp = HybridKem::generate_keypair();
let (encap, _) = HybridKem::encapsulate(&server_kp.public).unwrap();
let bytes = encap.to_bytes();
let restored = EncapsulatedKey::from_bytes(&bytes).unwrap();
assert_eq!(encap.pqc_ciphertext, restored.pqc_ciphertext);
assert_eq!(encap.ecdh_public, restored.ecdh_public);
}
#[test]
fn test_ecdh_secret_consumed_once() {
let mut server_kp = HybridKem::generate_keypair();
let (encap, _) = HybridKem::encapsulate(&server_kp.public).unwrap();
assert!(HybridKem::decapsulate(&mut server_kp, &encap).is_ok());
assert!(HybridKem::decapsulate(&mut server_kp, &encap).is_err());
}
}