use alloc::{string::ToString, vec::Vec};
use hpke_rs_crypto::{error::Error, types::KemAlgorithm, HpkeCrypto};
use crate::util::*;
use crate::{
kdf::{labeled_expand, labeled_extract},
kem::*,
};
fn extract_and_expand<Crypto: HpkeCrypto>(
alg: KemAlgorithm,
pk: PublicKey,
kem_context: &[u8],
suite_id: &[u8],
) -> Result<Vec<u8>, Error> {
let prk = labeled_extract::<Crypto>(alg.into(), &[], suite_id, "eae_prk", &pk)?;
labeled_expand::<Crypto>(
alg.into(),
&prk,
suite_id,
"shared_secret",
kem_context,
alg.shared_secret_len(),
)
}
#[inline(always)]
pub(super) fn serialize(pk: &[u8]) -> Vec<u8> {
pk.to_vec()
}
#[inline(always)]
pub(super) fn deserialize(enc: &[u8]) -> Vec<u8> {
enc.to_vec()
}
pub(super) fn key_gen<Crypto: HpkeCrypto>(
alg: KemAlgorithm,
prng: &mut Crypto::HpkePrng,
) -> Result<(PrivateKey, Vec<u8>), Error> {
let (pk, sk) = Crypto::kem_key_gen(alg, prng)?;
Ok((PrivateKey(sk), pk))
}
pub(super) fn derive_key_pair<Crypto: HpkeCrypto>(
alg: KemAlgorithm,
suite_id: &[u8],
ikm: &[u8],
) -> Result<(PublicKey, PrivateKey), Error> {
let dkp_prk = labeled_extract::<Crypto>(alg.into(), &[], suite_id, "dkp_prk", ikm)?;
let sk = match alg {
KemAlgorithm::DhKem25519 => PrivateKey(labeled_expand::<Crypto>(
alg.into(),
&dkp_prk,
suite_id,
"sk",
&[],
alg.private_key_len(),
)?),
KemAlgorithm::DhKemP256 | KemAlgorithm::DhKemK256 => {
let mut ctr = 0u8;
loop {
let candidate = labeled_expand::<Crypto>(
alg.into(),
&dkp_prk,
suite_id,
"candidate",
&ctr.to_be_bytes(),
alg.private_key_len(),
);
if let Ok(sk) = &candidate {
if let Ok(sk) = Crypto::dh_validate_sk(alg, sk) {
break PrivateKey(sk);
}
}
if ctr == u8::MAX {
return Err(Error::CryptoLibraryError(
"Unable to generate a valid private key".to_string(),
));
}
ctr += 1;
}
}
_ => {
panic!("This should be unreachable. Only x25519, P256, and K256 KEMs are implemented")
}
};
Ok((Crypto::secret_to_public(alg, &sk.0)?, sk))
}
pub(super) fn encaps<Crypto: HpkeCrypto>(
alg: KemAlgorithm,
pk_r: &[u8],
suite_id: &[u8],
randomness: &[u8],
) -> Result<(Vec<u8>, Vec<u8>), Error> {
debug_assert_eq!(randomness.len(), alg.private_key_len());
let (pk_e, sk_e) = derive_key_pair::<Crypto>(alg, suite_id, randomness)?;
let dh_pk = Crypto::dh(alg, pk_r, &sk_e.0)?;
let enc = serialize(&pk_e);
let pk_rm = serialize(pk_r);
let kem_context = concat(&[&enc, &pk_rm]);
let zz = extract_and_expand::<Crypto>(alg, dh_pk, &kem_context, suite_id)?;
Ok((zz, enc))
}
pub(super) fn decaps<Crypto: HpkeCrypto>(
alg: KemAlgorithm,
enc: &[u8],
sk_r: &[u8],
suite_id: &[u8],
) -> Result<Vec<u8>, Error> {
let pk_e = deserialize(enc);
let dh_pk = Crypto::dh(alg, &pk_e, sk_r)?;
let pk_rm = serialize(&Crypto::secret_to_public(alg, sk_r)?);
let kem_context = concat(&[enc, &pk_rm]);
extract_and_expand::<Crypto>(alg, dh_pk, &kem_context, suite_id)
}
pub(super) fn auth_encaps<Crypto: HpkeCrypto>(
alg: KemAlgorithm,
pk_r: &[u8],
sk_s: &[u8],
suite_id: &[u8],
randomness: &[u8],
) -> Result<(Vec<u8>, Vec<u8>), Error> {
debug_assert_eq!(randomness.len(), alg.private_key_len());
let (pk_e, sk_e) = derive_key_pair::<Crypto>(alg, suite_id, randomness)?;
let dh_pk = concat(&[
&Crypto::dh(alg, pk_r, &sk_e.0)?,
&Crypto::dh(alg, pk_r, sk_s)?,
]);
let enc = serialize(&pk_e);
let pk_rm = serialize(pk_r);
let pk_sm = serialize(&Crypto::secret_to_public(alg, sk_s)?);
let kem_context = concat(&[&enc, &pk_rm, &pk_sm]);
let zz = extract_and_expand::<Crypto>(alg, dh_pk, &kem_context, suite_id)?;
Ok((zz, enc))
}
pub(super) fn auth_decaps<Crypto: HpkeCrypto>(
alg: KemAlgorithm,
enc: &[u8],
sk_r: &[u8],
pk_s: &[u8],
suite_id: &[u8],
) -> Result<Vec<u8>, Error> {
let pk_e = deserialize(enc);
let dh_pk = concat(&[
&Crypto::dh(alg, &pk_e, sk_r)?,
&Crypto::dh(alg, pk_s, sk_r)?,
]);
let pk_rm = serialize(&Crypto::secret_to_public(alg, sk_r)?);
let pk_sm = serialize(pk_s);
let kem_context = concat(&[enc, &pk_rm, &pk_sm]);
extract_and_expand::<Crypto>(alg, dh_pk, &kem_context, suite_id)
}