use crate::hash::Algorithm as HashAlgorithm;
use rustls::crypto::hash::Hash as _;
use rustls::crypto::hmac::{Hmac as _, Tag};
use rustls::crypto::tls13::{
Hkdf as RustlsHkdf, HkdfExpander as RustlsHkdfExpander, OkmBlock, OutputLengthError,
};
use windows::core::{Owned, PCWSTR};
use windows::Win32::Foundation::NTSTATUS;
use windows::Win32::Security::Cryptography::{
BCryptBuffer, BCryptBufferDesc, BCryptGenerateSymmetricKey, BCryptKeyDerivation,
BCryptSetProperty as Win32BCryptSetProperty, BCRYPTBUFFER_VERSION, BCRYPT_HANDLE,
BCRYPT_HKDF_ALG_HANDLE, BCRYPT_HKDF_HASH_ALGORITHM, BCRYPT_HKDF_PRK_AND_FINALIZE,
BCRYPT_HKDF_SALT_AND_FINALIZE, BCRYPT_KEY_HANDLE, KDF_HKDF_INFO,
};
#[derive(Debug, Copy, Clone)]
pub(crate) struct Hkdf<const HASH_SIZE: usize>(pub(crate) HashAlgorithm<HASH_SIZE>);
struct HkdfExpander<const HASH_SIZE: usize> {
key_handle: Owned<BCRYPT_KEY_HANDLE>,
hash: HashAlgorithm<HASH_SIZE>,
}
unsafe impl<const HASH_SIZE: usize> Send for HkdfExpander<HASH_SIZE> {}
unsafe impl<const HASH_SIZE: usize> Sync for HkdfExpander<HASH_SIZE> {}
impl<const HASH_SIZE: usize> RustlsHkdf for Hkdf<HASH_SIZE> {
fn extract_from_zero_ikm(&self, salt: Option<&[u8]>) -> Box<dyn RustlsHkdfExpander> {
let secret = [0u8; HASH_SIZE];
self.extract_from_secret(salt, &secret)
}
fn extract_from_secret(
&self,
salt: Option<&[u8]>,
secret: &[u8],
) -> Box<dyn RustlsHkdfExpander> {
let mut key_handle = Owned::default();
unsafe {
BCryptGenerateSymmetricKey(BCRYPT_HKDF_ALG_HANDLE, &mut *key_handle, None, secret, 0)
.ok()
.unwrap();
let bcrypt_handle = BCRYPT_HANDLE(&mut *key_handle.0);
Win32BCryptSetProperty(
bcrypt_handle,
BCRYPT_HKDF_HASH_ALGORITHM,
self.0.id_bytes,
0,
)
.ok()
.unwrap();
Win32BCryptSetProperty(
bcrypt_handle,
BCRYPT_HKDF_SALT_AND_FINALIZE,
salt.unwrap_or(&[0u8; HASH_SIZE]),
0,
)
.ok()
.unwrap();
};
Box::new(HkdfExpander::<HASH_SIZE> {
key_handle,
hash: self.0,
})
}
fn expander_for_okm(&self, okm: &OkmBlock) -> Box<dyn RustlsHkdfExpander> {
let okm = okm.as_ref();
let mut key_handle = Owned::default();
unsafe {
BCryptGenerateSymmetricKey(BCRYPT_HKDF_ALG_HANDLE, &mut *key_handle, None, okm, 0)
.ok()
.unwrap();
let bcrypt_handle = BCRYPT_HANDLE(&mut *key_handle.0);
Win32BCryptSetProperty(
bcrypt_handle,
BCRYPT_HKDF_HASH_ALGORITHM,
self.0.id_bytes,
0,
)
.ok()
.unwrap();
BCryptSetProperty(
bcrypt_handle,
BCRYPT_HKDF_PRK_AND_FINALIZE,
std::ptr::null(),
0,
0,
)
.ok()
.unwrap();
};
Box::new(HkdfExpander::<HASH_SIZE> {
key_handle,
hash: self.0,
})
}
fn hmac_sign(&self, key: &OkmBlock, message: &[u8]) -> Tag {
self.0.with_key(key.as_ref()).sign(&[message])
}
fn fips(&self) -> bool {
crate::fips::enabled()
}
}
extern "system" {
fn BCryptSetProperty(
hobject: BCRYPT_HANDLE,
pszproperty: PCWSTR,
pbinput: *const u8,
cbinput: u32,
dwflags: u32,
) -> NTSTATUS;
}
impl<const HASH_SIZE: usize> RustlsHkdfExpander for HkdfExpander<HASH_SIZE> {
fn expand_slice(&self, info: &[&[u8]], output: &mut [u8]) -> Result<(), OutputLengthError> {
let info = info
.iter()
.flat_map(|info| info.iter())
.copied()
.collect::<Vec<u8>>();
let mut buffer = BCryptBuffer {
cbBuffer: info.len() as u32,
BufferType: KDF_HKDF_INFO,
pvBuffer: info.as_ptr() as *mut _,
};
let params = if info.is_empty() {
BCryptBufferDesc::default()
} else {
BCryptBufferDesc {
ulVersion: BCRYPTBUFFER_VERSION,
cBuffers: 1u32,
pBuffers: &mut buffer,
}
};
let mut size = 0u32;
unsafe {
BCryptKeyDerivation(*self.key_handle, Some(¶ms), output, &mut size, 0)
.ok()
.map_err(|e| {
dbg!(e);
OutputLengthError
})?;
};
if size != output.len() as u32 {
return Err(OutputLengthError);
}
Ok(())
}
fn expand_block(&self, info: &[&[u8]]) -> OkmBlock {
let mut output = [0u8; HASH_SIZE];
self.expand_slice(info, &mut output)
.expect("HDKF-Expand failed");
OkmBlock::new(&output)
}
fn hash_len(&self) -> usize {
self.hash.output_len()
}
}
#[cfg(test)]
mod test {
use rustls::crypto::tls13::Hkdf as _;
use wycheproof::{hkdf::TestName, TestResult};
use super::Hkdf;
fn test_hkdf<const HASH_SIZE: usize>(hkdf: Hkdf<HASH_SIZE>, test_name: TestName) {
let test_set = wycheproof::hkdf::TestSet::load(test_name).unwrap();
for test_group in test_set.test_groups {
for test in test_group.tests {
dbg!(&test);
let prk_expander = hkdf.extract_from_secret(Some(&test.salt), &test.ikm);
let mut okm = vec![0; test.size];
let infos = test.info.chunks(2).collect::<Vec<_>>();
let res = prk_expander.expand_slice(&infos, &mut okm);
match &test.result {
TestResult::Acceptable | TestResult::Valid => {
assert!(res.is_ok());
assert_eq!(okm[..], test.okm[..], "Failed test: {}", test.comment);
}
TestResult::Invalid => {
dbg!(&res);
assert!(res.is_err(), "Failed test: {}", test.comment);
}
}
}
}
}
#[test]
fn hkdf_sha256() {
let hkdf = Hkdf(crate::hash::SHA256);
test_hkdf(hkdf, TestName::HkdfSha256);
}
#[test]
fn hkdf_sha384() {
let hkdf = Hkdf(crate::hash::SHA384);
test_hkdf(hkdf, TestName::HkdfSha384);
}
}