use crate::error::{check, len_as_u32, WolfCryptError};
use generic_array::GenericArray;
use typenum::*;
macro_rules! impl_hkdf {
(
$name:ident,
$hash_type:expr,
$output_size:ty,
$cfg_gate:meta
) => {
#[$cfg_gate]
pub struct $name {
prk: GenericArray<u8, $output_size>,
}
#[$cfg_gate]
impl $name {
pub fn new(salt: Option<&[u8]>, ikm: &[u8]) -> Self {
let (_prk, inst) = Self::extract(salt, ikm);
inst
}
pub fn extract(
salt: Option<&[u8]>,
ikm: &[u8],
) -> (GenericArray<u8, $output_size>, Self) {
let mut prk = GenericArray::<u8, $output_size>::default();
let (salt_ptr, salt_len) = match salt {
Some(s) if !s.is_empty() => (s.as_ptr(), len_as_u32(s.len())),
_ => (core::ptr::null(), 0u32),
};
let (ikm_ptr, ikm_len) = if ikm.is_empty() {
(core::ptr::null(), 0u32)
} else {
(ikm.as_ptr(), len_as_u32(ikm.len()))
};
let rc = unsafe {
wolfcrypt_rs::wc_HKDF_Extract(
$hash_type,
salt_ptr,
salt_len,
ikm_ptr,
ikm_len,
prk.as_mut_ptr(),
)
};
assert_eq!(rc, 0, "wc_HKDF_Extract failed (invalid hash type)");
let inst = Self { prk: prk.clone() };
(prk, inst)
}
pub fn from_prk(prk: &[u8]) -> Result<Self, WolfCryptError> {
let hash_len = <$output_size as typenum::Unsigned>::USIZE;
if prk.len() < hash_len {
return Err(WolfCryptError::INVALID_INPUT);
}
let mut arr = GenericArray::<u8, $output_size>::default();
arr.copy_from_slice(&prk[..hash_len]);
Ok(Self { prk: arr })
}
pub fn expand(&self, info: &[u8], okm: &mut [u8]) -> Result<(), WolfCryptError> {
let (info_ptr, info_len) = if info.is_empty() {
(core::ptr::null(), 0u32)
} else {
(info.as_ptr(), len_as_u32(info.len()))
};
let rc = unsafe {
wolfcrypt_rs::wc_HKDF_Expand(
$hash_type,
self.prk.as_ptr(),
len_as_u32(self.prk.len()),
info_ptr,
info_len,
okm.as_mut_ptr(),
len_as_u32(okm.len()),
)
};
check(rc, "wc_HKDF_Expand")
}
}
#[$cfg_gate]
impl Drop for $name {
fn drop(&mut self) {
use zeroize::Zeroize;
self.prk.zeroize();
}
}
};
}
impl_hkdf!(
WolfHkdfSha256,
wolfcrypt_rs::WC_HASH_TYPE_SHA256,
U32,
cfg(wolfssl_hkdf)
);
impl_hkdf!(
WolfHkdfSha384,
wolfcrypt_rs::WC_HASH_TYPE_SHA384,
U48,
cfg(all(wolfssl_hkdf, wolfssl_sha384))
);
impl_hkdf!(
WolfHkdfSha512,
wolfcrypt_rs::WC_HASH_TYPE_SHA512,
U64,
cfg(all(wolfssl_hkdf, wolfssl_sha512))
);