use oxicrypto_core::{CryptoError, KeyAgreement, SecretVec};
use p256::elliptic_curve::sec1::ToSec1Point;
use x25519_dalek::{x25519, X25519_BASEPOINT_BYTES};
use super::ids::{i2osp, kem_suite_id, KemId};
use super::labeled::HpkeKdf;
use crate::{EcdhP256, X25519};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct DhKem {
kem: KemId,
}
impl DhKem {
#[must_use]
pub const fn new(kem: KemId) -> Self {
Self { kem }
}
#[must_use]
pub const fn kem_id(self) -> KemId {
self.kem
}
#[inline]
const fn kdf(self) -> HpkeKdf {
HpkeKdf::HkdfSha256
}
pub fn serialize_public_key(self, pk: &[u8]) -> Result<Vec<u8>, CryptoError> {
match self.kem {
KemId::DhkemX25519HkdfSha256 => {
if pk.len() != 32 {
return Err(CryptoError::InvalidKey);
}
Ok(pk.to_vec())
}
KemId::DhkemP256HkdfSha256 => {
let canonical = p256_uncompressed_from_serialized(pk)?;
Ok(canonical)
}
}
}
pub fn deserialize_public_key(self, enc: &[u8]) -> Result<Vec<u8>, CryptoError> {
self.serialize_public_key(enc)
}
pub fn public_key_from_secret(self, sk: &[u8]) -> Result<Vec<u8>, CryptoError> {
match self.kem {
KemId::DhkemX25519HkdfSha256 => {
let scalar: [u8; 32] = sk.try_into().map_err(|_| CryptoError::InvalidKey)?;
let pk = x25519(scalar, X25519_BASEPOINT_BYTES);
Ok(pk.to_vec())
}
KemId::DhkemP256HkdfSha256 => {
let fb: [u8; 32] = sk.try_into().map_err(|_| CryptoError::InvalidKey)?;
let secret = p256::SecretKey::from_bytes(&p256::FieldBytes::from(fb))
.map_err(|_| CryptoError::InvalidKey)?;
let pk = secret.public_key();
Ok(pk.to_sec1_point(false).as_bytes().to_vec())
}
}
}
fn dh(self, sk: &[u8], pk: &[u8]) -> Result<Vec<u8>, CryptoError> {
let mut out = vec![0u8; 32];
match self.kem {
KemId::DhkemX25519HkdfSha256 => X25519.agree(sk, pk, &mut out)?,
KemId::DhkemP256HkdfSha256 => EcdhP256.agree(sk, pk, &mut out)?,
}
Ok(out)
}
pub fn derive_key_pair(self, ikm: &[u8]) -> Result<(SecretVec, Vec<u8>), CryptoError> {
let suite = kem_suite_id(self.kem);
let dkp_prk = self.kdf().labeled_extract(&suite, b"", b"dkp_prk", ikm);
match self.kem {
KemId::DhkemX25519HkdfSha256 => {
let sk = self
.kdf()
.labeled_expand(&suite, &dkp_prk, b"sk", b"", 32)?;
let pk = self.public_key_from_secret(&sk)?;
Ok((SecretVec::new(sk), pk))
}
KemId::DhkemP256HkdfSha256 => {
for counter in 0u16..=255 {
let mut bytes = self.kdf().labeled_expand(
&suite,
&dkp_prk,
b"candidate",
&i2osp(counter as u128, 1),
32,
)?;
if let Some(first) = bytes.first_mut() {
*first &= 0xff;
}
let fb: [u8; 32] = bytes
.as_slice()
.try_into()
.map_err(|_| CryptoError::InvalidKey)?;
if let Ok(secret) = p256::SecretKey::from_bytes(&p256::FieldBytes::from(fb)) {
let pk = secret.public_key().to_sec1_point(false).as_bytes().to_vec();
return Ok((SecretVec::new(bytes), pk));
}
}
Err(CryptoError::Internal(
"HPKE DeriveKeyPair: rejection sampling exhausted",
))
}
}
}
fn extract_and_expand(self, dh: &[u8], kem_context: &[u8]) -> Result<Vec<u8>, CryptoError> {
let suite = kem_suite_id(self.kem);
let eae_prk = self.kdf().labeled_extract(&suite, b"", b"eae_prk", dh);
self.kdf().labeled_expand(
&suite,
&eae_prk,
b"shared_secret",
kem_context,
self.kem.n_secret(),
)
}
pub(crate) fn encap_with_ikm(
self,
pk_r: &[u8],
ikm_e: &[u8],
) -> Result<(SecretVec, Vec<u8>), CryptoError> {
let pk_rm = self.deserialize_public_key(pk_r)?;
let (sk_e, enc) = self.derive_key_pair(ikm_e)?;
let dh = self.dh(sk_e.as_bytes(), &pk_rm)?;
let mut kem_context = Vec::with_capacity(enc.len() + pk_rm.len());
kem_context.extend_from_slice(&enc);
kem_context.extend_from_slice(&pk_rm);
let shared_secret = self.extract_and_expand(&dh, &kem_context)?;
Ok((SecretVec::new(shared_secret), enc))
}
pub fn decap(self, enc: &[u8], sk_r: &[u8]) -> Result<SecretVec, CryptoError> {
let pk_e = self.deserialize_public_key(enc)?;
let dh = self.dh(sk_r, &pk_e)?;
let pk_rm = self.public_key_from_secret(sk_r)?;
let mut kem_context = Vec::with_capacity(enc.len() + pk_rm.len());
kem_context.extend_from_slice(enc);
kem_context.extend_from_slice(&pk_rm);
let shared_secret = self.extract_and_expand(&dh, &kem_context)?;
Ok(SecretVec::new(shared_secret))
}
pub(crate) fn auth_encap_with_ikm(
self,
pk_r: &[u8],
sk_s: &[u8],
ikm_e: &[u8],
) -> Result<(SecretVec, Vec<u8>), CryptoError> {
let pk_rm = self.deserialize_public_key(pk_r)?;
let (sk_e, enc) = self.derive_key_pair(ikm_e)?;
let mut dh = self.dh(sk_e.as_bytes(), &pk_rm)?;
let dh2 = self.dh(sk_s, &pk_rm)?;
dh.extend_from_slice(&dh2);
let pk_sm = self.public_key_from_secret(sk_s)?;
let mut kem_context = Vec::with_capacity(enc.len() + pk_rm.len() + pk_sm.len());
kem_context.extend_from_slice(&enc);
kem_context.extend_from_slice(&pk_rm);
kem_context.extend_from_slice(&pk_sm);
let shared_secret = self.extract_and_expand(&dh, &kem_context)?;
Ok((SecretVec::new(shared_secret), enc))
}
pub fn auth_decap(
self,
enc: &[u8],
sk_r: &[u8],
pk_s: &[u8],
) -> Result<SecretVec, CryptoError> {
let pk_e = self.deserialize_public_key(enc)?;
let pk_sm = self.deserialize_public_key(pk_s)?;
let mut dh = self.dh(sk_r, &pk_e)?;
let dh2 = self.dh(sk_r, &pk_sm)?;
dh.extend_from_slice(&dh2);
let pk_rm = self.public_key_from_secret(sk_r)?;
let mut kem_context = Vec::with_capacity(enc.len() + pk_rm.len() + pk_sm.len());
kem_context.extend_from_slice(enc);
kem_context.extend_from_slice(&pk_rm);
kem_context.extend_from_slice(&pk_sm);
let shared_secret = self.extract_and_expand(&dh, &kem_context)?;
Ok(SecretVec::new(shared_secret))
}
}
fn p256_uncompressed_from_serialized(enc: &[u8]) -> Result<Vec<u8>, CryptoError> {
if enc.len() != 65 || enc.first() != Some(&0x04) {
return Err(CryptoError::InvalidKey);
}
let pk = p256::PublicKey::from_sec1_bytes(enc).map_err(|_| CryptoError::InvalidKey)?;
Ok(pk.to_sec1_point(false).as_bytes().to_vec())
}
#[cfg(test)]
mod kem_tests {
use super::*;
fn hx(s: &str) -> Vec<u8> {
hex::decode(s).expect("valid hex in test vector")
}
#[test]
fn derive_key_pair_x25519_a_1_1() {
let kem = DhKem::new(KemId::DhkemX25519HkdfSha256);
let ikm_e = hx("7268600d403fce431561aef583ee1613527cff655c1343f29812e66706df3234");
let sk_em = hx("52c4a758a802cd8b936eceea314432798d5baf2d7e9235dc084ab1b9cfa2f736");
let pk_em = hx("37fda3567bdbd628e88668c3c8d7e97d1d1253b6d4ea6d44c150f741f1bf4431");
let (sk, pk) = kem.derive_key_pair(&ikm_e).expect("derive eph");
assert_eq!(sk.as_bytes(), sk_em.as_slice(), "X25519 skEm mismatch");
assert_eq!(pk, pk_em, "X25519 pkEm mismatch");
let ikm_r = hx("6db9df30aa07dd42ee5e8181afdb977e538f5e1fec8a06223f33f7013e525037");
let sk_rm = hx("4612c550263fc8ad58375df3f557aac531d26850903e55a9f23f21d8534e8ac8");
let pk_rm = hx("3948cfe0ad1ddb695d780e59077195da6c56506b027329794ab02bca80815c4d");
let (skr, pkr) = kem.derive_key_pair(&ikm_r).expect("derive recip");
assert_eq!(skr.as_bytes(), sk_rm.as_slice(), "X25519 skRm mismatch");
assert_eq!(pkr, pk_rm, "X25519 pkRm mismatch");
}
#[test]
fn derive_key_pair_p256_a_3_1() {
let kem = DhKem::new(KemId::DhkemP256HkdfSha256);
let ikm_e = hx("4270e54ffd08d79d5928020af4686d8f6b7d35dbe470265f1f5aa22816ce860e");
let sk_em = hx("4995788ef4b9d6132b249ce59a77281493eb39af373d236a1fe415cb0c2d7beb");
let pk_em = hx("04a92719c6195d5085104f469a8b9814d5838ff72b60501e2c4466e5e67b325ac98536d7b61a1af4b78e5b7f951c0900be863c403ce65c9bfcb9382657222d18c4");
let (sk, pk) = kem.derive_key_pair(&ikm_e).expect("derive eph");
assert_eq!(sk.as_bytes(), sk_em.as_slice(), "P-256 skEm mismatch");
assert_eq!(pk.len(), 65, "P-256 pk must be uncompressed (65 bytes)");
assert_eq!(pk.first(), Some(&0x04));
assert_eq!(pk, pk_em, "P-256 pkEm mismatch");
let ikm_r = hx("668b37171f1072f3cf12ea8a236a45df23fc13b82af3609ad1e354f6ef817550");
let sk_rm = hx("f3ce7fdae57e1a310d87f1ebbde6f328be0a99cdbcadf4d6589cf29de4b8ffd2");
let pk_rm = hx("04fe8c19ce0905191ebc298a9245792531f26f0cece2460639e8bc39cb7f706a826a779b4cf969b8a0e539c7f62fb3d30ad6aa8f80e30f1d128aafd68a2ce72ea0");
let (skr, pkr) = kem.derive_key_pair(&ikm_r).expect("derive recip");
assert_eq!(skr.as_bytes(), sk_rm.as_slice(), "P-256 skRm mismatch");
assert_eq!(pkr, pk_rm, "P-256 pkRm mismatch");
}
#[test]
fn encap_x25519_a_1_1() {
let kem = DhKem::new(KemId::DhkemX25519HkdfSha256);
let ikm_e = hx("7268600d403fce431561aef583ee1613527cff655c1343f29812e66706df3234");
let pk_rm = hx("3948cfe0ad1ddb695d780e59077195da6c56506b027329794ab02bca80815c4d");
let enc_expected = hx("37fda3567bdbd628e88668c3c8d7e97d1d1253b6d4ea6d44c150f741f1bf4431");
let ss_expected = hx("fe0e18c9f024ce43799ae393c7e8fe8fce9d218875e8227b0187c04e7d2ea1fc");
let (ss, enc) = kem.encap_with_ikm(&pk_rm, &ikm_e).expect("encap");
assert_eq!(enc, enc_expected, "enc mismatch");
assert_eq!(
ss.as_bytes(),
ss_expected.as_slice(),
"shared_secret mismatch"
);
let sk_rm = hx("4612c550263fc8ad58375df3f557aac531d26850903e55a9f23f21d8534e8ac8");
let ss_dec = kem.decap(&enc, &sk_rm).expect("decap");
assert_eq!(ss_dec.as_bytes(), ss_expected.as_slice());
}
#[test]
fn encap_p256_a_3_1() {
let kem = DhKem::new(KemId::DhkemP256HkdfSha256);
let ikm_e = hx("4270e54ffd08d79d5928020af4686d8f6b7d35dbe470265f1f5aa22816ce860e");
let pk_rm = hx("04fe8c19ce0905191ebc298a9245792531f26f0cece2460639e8bc39cb7f706a826a779b4cf969b8a0e539c7f62fb3d30ad6aa8f80e30f1d128aafd68a2ce72ea0");
let enc_expected = hx("04a92719c6195d5085104f469a8b9814d5838ff72b60501e2c4466e5e67b325ac98536d7b61a1af4b78e5b7f951c0900be863c403ce65c9bfcb9382657222d18c4");
let ss_expected = hx("c0d26aeab536609a572b07695d933b589dcf363ff9d93c93adea537aeabb8cb8");
let (ss, enc) = kem.encap_with_ikm(&pk_rm, &ikm_e).expect("encap");
assert_eq!(enc, enc_expected, "P-256 enc mismatch");
assert_eq!(
ss.as_bytes(),
ss_expected.as_slice(),
"P-256 shared_secret mismatch"
);
let sk_rm = hx("f3ce7fdae57e1a310d87f1ebbde6f328be0a99cdbcadf4d6589cf29de4b8ffd2");
let ss_dec = kem.decap(&enc, &sk_rm).expect("decap");
assert_eq!(ss_dec.as_bytes(), ss_expected.as_slice());
}
#[test]
fn p256_rejects_compressed_enc() {
let kem = DhKem::new(KemId::DhkemP256HkdfSha256);
let pk_rm = hx("04fe8c19ce0905191ebc298a9245792531f26f0cece2460639e8bc39cb7f706a826a779b4cf969b8a0e539c7f62fb3d30ad6aa8f80e30f1d128aafd68a2ce72ea0");
let compressed = p256::PublicKey::from_sec1_bytes(&pk_rm)
.expect("valid")
.to_sec1_point(true)
.as_bytes()
.to_vec();
assert_eq!(compressed.len(), 33);
assert_eq!(
kem.deserialize_public_key(&compressed),
Err(CryptoError::InvalidKey)
);
assert_eq!(
kem.deserialize_public_key(&pk_rm[..64]),
Err(CryptoError::InvalidKey)
);
}
#[test]
fn auth_round_trip_x25519() {
let kem = DhKem::new(KemId::DhkemX25519HkdfSha256);
let ikm_e = hx("7268600d403fce431561aef583ee1613527cff655c1343f29812e66706df3234");
let (sk_r, pk_r) = kem
.derive_key_pair(&hx(
"6db9df30aa07dd42ee5e8181afdb977e538f5e1fec8a06223f33f7013e525037",
))
.expect("recip");
let (sk_s, pk_s) = kem
.derive_key_pair(&hx(
"1111111111111111111111111111111111111111111111111111111111111111",
))
.expect("sender");
let (ss_enc, enc) = kem
.auth_encap_with_ikm(&pk_r, sk_s.as_bytes(), &ikm_e)
.expect("auth_encap");
let ss_dec = kem
.auth_decap(&enc, sk_r.as_bytes(), &pk_s)
.expect("auth_decap");
assert_eq!(ss_enc.as_bytes(), ss_dec.as_bytes());
}
#[test]
fn auth_round_trip_p256() {
let kem = DhKem::new(KemId::DhkemP256HkdfSha256);
let ikm_e = hx("4270e54ffd08d79d5928020af4686d8f6b7d35dbe470265f1f5aa22816ce860e");
let (sk_r, pk_r) = kem
.derive_key_pair(&hx(
"668b37171f1072f3cf12ea8a236a45df23fc13b82af3609ad1e354f6ef817550",
))
.expect("recip");
let (sk_s, pk_s) = kem
.derive_key_pair(&hx(
"2222222222222222222222222222222222222222222222222222222222222222",
))
.expect("sender");
let (ss_enc, enc) = kem
.auth_encap_with_ikm(&pk_r, sk_s.as_bytes(), &ikm_e)
.expect("auth_encap");
let ss_dec = kem
.auth_decap(&enc, sk_r.as_bytes(), &pk_s)
.expect("auth_decap");
assert_eq!(ss_enc.as_bytes(), ss_dec.as_bytes());
}
}