pub use hkdf::InvalidLength;
use core::iter;
use hkdf::{Hkdf, hmac::digest::block_api::EagerHash};
#[cfg(feature = "zeroize")]
use zeroize::Zeroize;
const PREFIXES_MAX: usize = 256;
const LABELED_INPUT_MAX: usize = PREFIXES_MAX + 64;
const HPKE_VERSION_ID: &[u8] = b"HPKE-v1";
const HPKE_SUITE_ID: &[u8] = b"KEM\x00\x10";
#[derive(Debug)]
pub struct Expander<D: EagerHash> {
hkdf: Hkdf<D>,
}
impl<D: EagerHash> Expander<D> {
#[must_use]
pub fn new(salt: &[u8], input_key_material: &[u8]) -> Self {
Self {
hkdf: Hkdf::<D>::new(Some(salt), input_key_material),
}
}
pub fn new_prefixed(
salt: &[u8],
prefixes: &[&[u8]],
input_key_material: &[u8],
) -> Result<Self, InvalidLength> {
let mut labeled_ikm_buf = [0u8; LABELED_INPUT_MAX];
let labeled_ikm = concat_slices(
prefixes
.iter()
.copied()
.chain(iter::once(input_key_material)),
&mut labeled_ikm_buf,
)?;
let ret = Self::new(salt, labeled_ikm);
#[cfg(feature = "zeroize")]
labeled_ikm_buf.zeroize();
Ok(ret)
}
pub fn new_labeled_hpke(
salt: &[u8],
label: &[u8],
input_key_material: &[u8],
) -> Result<Self, InvalidLength> {
Self::new_prefixed(
salt,
&[HPKE_VERSION_ID, HPKE_SUITE_ID, label],
input_key_material,
)
}
pub fn expand(&self, info: &[u8], okm: &mut [u8]) -> Result<(), InvalidLength> {
self.hkdf.expand(info, okm)
}
pub fn expand_multi_info(
&self,
info_components: &[&[u8]],
okm: &mut [u8],
) -> Result<(), InvalidLength> {
self.hkdf.expand_multi_info(info_components, okm)
}
pub fn expand_labeled_hpke(
&self,
label: &[u8],
info: &[u8],
okm: &mut [u8],
) -> Result<(), InvalidLength> {
let okm_len = u16::try_from(okm.len()).map_err(|_| InvalidLength)?;
self.hkdf.expand_multi_info(
&[
&okm_len.to_be_bytes(),
HPKE_VERSION_ID,
HPKE_SUITE_ID,
label,
info,
],
okm,
)
}
}
fn concat_slices<'a, I>(slices: I, out: &mut [u8]) -> Result<&[u8], InvalidLength>
where
I: Iterator<Item = &'a [u8]>,
{
let mut offset = 0usize;
for segment in slices {
let new_offset = offset.checked_add(segment.len()).ok_or(InvalidLength)?;
out.get_mut(offset..new_offset)
.ok_or(InvalidLength)?
.copy_from_slice(segment);
offset = new_offset;
}
Ok(&out[..offset])
}