use crate::util::write_u16_be;
use digest::{core_api::BlockSizeUser, Digest, OutputSizeUser};
use generic_array::GenericArray;
use hmac::SimpleHmac;
use sha2::{Sha256, Sha384, Sha512};
const VERSION_LABEL: &[u8] = b"HPKE-v1";
pub(crate) const MAX_DIGEST_SIZE: usize = 64;
pub trait Kdf {
#[doc(hidden)]
type HashImpl: Clone + Digest + OutputSizeUser + BlockSizeUser;
const KDF_ID: u16;
}
use Kdf as KdfTrait;
pub(crate) type DigestArray<Kdf> =
GenericArray<u8, <<Kdf as KdfTrait>::HashImpl as OutputSizeUser>::OutputSize>;
pub(crate) type SimpleHkdf<Kdf> =
hkdf::Hkdf<<Kdf as KdfTrait>::HashImpl, SimpleHmac<<Kdf as KdfTrait>::HashImpl>>;
type SimpleHkdfExtract<Kdf> =
hkdf::HkdfExtract<<Kdf as KdfTrait>::HashImpl, SimpleHmac<<Kdf as KdfTrait>::HashImpl>>;
pub struct HkdfSha256 {}
impl KdfTrait for HkdfSha256 {
#[doc(hidden)]
type HashImpl = Sha256;
const KDF_ID: u16 = 0x0001;
}
pub struct HkdfSha384 {}
impl KdfTrait for HkdfSha384 {
#[doc(hidden)]
type HashImpl = Sha384;
const KDF_ID: u16 = 0x0002;
}
pub struct HkdfSha512 {}
impl KdfTrait for HkdfSha512 {
#[doc(hidden)]
type HashImpl = Sha512;
const KDF_ID: u16 = 0x0003;
}
#[doc(hidden)]
pub fn extract_and_expand<Kdf: KdfTrait>(
ikm: &[u8],
suite_id: &[u8],
info: &[u8],
out: &mut [u8],
) -> Result<(), hkdf::InvalidLength> {
let (_, hkdf_ctx) = labeled_extract::<Kdf>(&[], suite_id, b"eae_prk", ikm);
hkdf_ctx.labeled_expand(suite_id, b"shared_secret", info, out)
}
#[doc(hidden)]
pub fn labeled_extract<Kdf: KdfTrait>(
salt: &[u8],
suite_id: &[u8],
label: &[u8],
ikm: &[u8],
) -> (DigestArray<Kdf>, SimpleHkdf<Kdf>) {
let mut extract_ctx = SimpleHkdfExtract::<Kdf>::new(Some(salt));
extract_ctx.input_ikm(VERSION_LABEL);
extract_ctx.input_ikm(suite_id);
extract_ctx.input_ikm(label);
extract_ctx.input_ikm(ikm);
extract_ctx.finalize()
}
#[doc(hidden)]
pub trait LabeledExpand {
fn labeled_expand(
&self,
suite_id: &[u8],
label: &[u8],
info: &[u8],
out: &mut [u8],
) -> Result<(), hkdf::InvalidLength>;
}
impl<D> LabeledExpand for hkdf::Hkdf<D, SimpleHmac<D>>
where
D: Clone + OutputSizeUser + Digest + BlockSizeUser,
{
fn labeled_expand(
&self,
suite_id: &[u8],
label: &[u8],
info: &[u8],
out: &mut [u8],
) -> Result<(), hkdf::InvalidLength> {
if out.len() > u16::MAX as usize {
return Err(hkdf::InvalidLength);
}
let mut len_buf = [0u8; 2];
write_u16_be(&mut len_buf, out.len() as u16);
let labeled_info = [&len_buf, VERSION_LABEL, suite_id, label, info];
self.expand_multi_info(&labeled_info, out)
}
}