use std::marker::PhantomData;
use boring::hash::MessageDigest;
use rustls::crypto::tls13::{self, Hkdf as RustlsHkdf};
use zeroize::Zeroizing;
use crate::helper::{cvt, cvt_p};
pub trait BoringHash: Send + Sync {
fn new_hash() -> MessageDigest;
}
pub struct Sha256();
impl BoringHash for Sha256 {
fn new_hash() -> MessageDigest {
MessageDigest::sha256()
}
}
pub struct Sha384();
impl BoringHash for Sha384 {
fn new_hash() -> MessageDigest {
MessageDigest::sha384()
}
}
pub struct Hkdf<T: BoringHash>(PhantomData<T>);
impl<T: BoringHash> Hkdf<T> {
pub const DEFAULT: Self = Self(PhantomData);
}
impl<T: BoringHash> RustlsHkdf for Hkdf<T> {
fn extract_from_zero_ikm(
&self,
salt: Option<&[u8]>,
) -> Box<dyn rustls::crypto::tls13::HkdfExpander> {
let hash_size = T::new_hash().size();
let secret = [0u8; boring_sys::EVP_MAX_MD_SIZE as usize];
let secret_len = hash_size;
self.extract_from_secret(salt, &secret[..secret_len])
}
fn extract_from_secret(
&self,
salt: Option<&[u8]>,
secret: &[u8],
) -> Box<dyn rustls::crypto::tls13::HkdfExpander> {
let digest = T::new_hash();
let hash_size = digest.size();
let mut prk = Zeroizing::new([0u8; boring_sys::EVP_MAX_MD_SIZE as usize]);
let mut prk_len = 0;
let salt_bytes = [0u8; boring_sys::EVP_MAX_MD_SIZE as usize];
let salt = if let Some(salt) = salt {
salt
} else {
&salt_bytes[..hash_size]
};
unsafe {
cvt(boring_sys::HKDF_extract(
prk.as_mut_ptr(),
&mut prk_len,
digest.as_ptr(),
secret.as_ptr(),
secret.len(),
salt.as_ptr(),
salt.len(),
))
.expect("HKDF_extract failed");
}
Box::new(HkdfExpander {
prk,
prk_len,
digest,
})
}
fn expander_for_okm(
&self,
okm: &rustls::crypto::tls13::OkmBlock,
) -> Box<dyn rustls::crypto::tls13::HkdfExpander> {
let okm = okm.as_ref();
let mut prk = Zeroizing::new([0u8; boring_sys::EVP_MAX_MD_SIZE as usize]);
let prk_len = okm.len();
prk[..prk_len].copy_from_slice(okm);
Box::new(HkdfExpander {
prk,
prk_len,
digest: T::new_hash(),
})
}
fn hmac_sign(
&self,
key: &rustls::crypto::tls13::OkmBlock,
message: &[u8],
) -> rustls::crypto::hmac::Tag {
let digest = T::new_hash();
let mut hash = Zeroizing::new([0u8; boring_sys::EVP_MAX_MD_SIZE as usize]);
let mut hash_len = 0u32;
unsafe {
cvt_p(boring_sys::HMAC(
digest.as_ptr(),
key.as_ref().as_ptr() as _,
key.as_ref().len(),
message.as_ptr(),
message.len(),
hash.as_mut_ptr(),
&mut hash_len,
))
.expect("HMAC failed");
}
rustls::crypto::hmac::Tag::new(&hash[..hash_len as usize])
}
fn fips(&self) -> bool {
cfg!(feature = "fips")
}
}
struct HkdfExpander {
prk: Zeroizing<[u8; boring_sys::EVP_MAX_MD_SIZE as usize]>,
prk_len: usize,
digest: MessageDigest,
}
impl tls13::HkdfExpander for HkdfExpander {
fn expand_slice(
&self,
info: &[&[u8]],
output: &mut [u8],
) -> Result<(), tls13::OutputLengthError> {
let max_output_len = self
.hash_len()
.checked_mul(255)
.ok_or(tls13::OutputLengthError)?;
if output.len() > max_output_len {
return Err(tls13::OutputLengthError);
}
let info_concat = info.concat();
unsafe {
cvt(boring_sys::HKDF_expand(
output.as_mut_ptr(),
output.len(),
self.digest.as_ptr(),
self.prk.as_ptr(),
self.prk_len,
info_concat.as_ptr(),
info_concat.len(),
))
.map_err(|_| tls13::OutputLengthError)?;
}
Ok(())
}
fn expand_block(&self, info: &[&[u8]]) -> tls13::OkmBlock {
let mut output = Zeroizing::new([0u8; boring_sys::EVP_MAX_MD_SIZE as usize]);
let output_len = self.hash_len();
self.expand_slice(info, &mut output[..output_len])
.expect("failed hkdf expand");
tls13::OkmBlock::new(&output[..output_len])
}
fn hash_len(&self) -> usize {
self.digest.size()
}
}
#[cfg(test)]
mod tests {
use boring::hash::MessageDigest;
use rustls::crypto::tls13::Hkdf as _;
use super::{Hkdf, Sha256};
#[test]
fn expand_slice_rejects_output_larger_than_rfc_limit() {
let hkdf = Hkdf::<Sha256>::DEFAULT;
let expander = hkdf.extract_from_secret(None, b"ikm");
let hash_len = MessageDigest::sha256().size();
let mut output = vec![0u8; hash_len * 255 + 1];
assert!(expander.expand_slice(&[b"info"], &mut output).is_err());
}
#[test]
fn expand_slice_accepts_output_at_rfc_limit() {
let hkdf = Hkdf::<Sha256>::DEFAULT;
let expander = hkdf.extract_from_secret(None, b"ikm");
let hash_len = MessageDigest::sha256().size();
let mut output = vec![0u8; hash_len * 255];
expander
.expand_slice(&[b"info"], &mut output)
.expect("HKDF expand at RFC limit should succeed");
assert!(output.iter().any(|byte| *byte != 0));
}
}