use crate::mac::{ HmacMd2, HmacMd4, HmacMd5, HmacSm3, HmacSha1, HmacSha224, HmacSha256, HmacSha384, HmacSha512, };
macro_rules! impl_hkdf_with_hmac {
($name:tt, $hmac:tt) => {
#[derive(Clone)]
pub struct $name {
prk: [u8; Self::TAG_LEN],
}
impl $name {
pub const BLOCK_LEN: usize = $hmac::BLOCK_LEN;
pub const TAG_LEN: usize = $hmac::TAG_LEN;
pub fn new(salt: &[u8], ikm: &[u8]) -> Self {
let prk = if salt.is_empty() {
let salt = [0u8; Self::TAG_LEN];
$hmac::oneshot(&salt, ikm)
} else {
$hmac::oneshot(salt, ikm)
};
Self { prk }
}
pub fn prk(&self) -> &[u8; Self::TAG_LEN] {
&self.prk
}
pub fn from_prk(prk_in: &[u8]) -> Self {
assert_eq!(prk_in.len(), Self::TAG_LEN);
let mut prk = [0u8; Self::TAG_LEN];
prk.copy_from_slice(prk_in);
Self { prk }
}
pub fn expand(&self, info: &[u8], okm: &mut [u8]) {
self.expand_multi_info(&[info], okm)
}
pub fn expand_multi_info(&self, info_components: &[&[u8]], okm: &mut [u8]) {
assert!(okm.len() <= Self::TAG_LEN * 255);
let n = okm.len() / Self::TAG_LEN;
let r = okm.len() % Self::TAG_LEN;
if r > 0 {
assert!(n < core::u8::MAX as usize);
} else {
assert!(n <= core::u8::MAX as usize);
}
let mut hmac = $hmac::new(&self.prk);
for info in info_components.iter() {
hmac.update(info);
}
hmac.update(&[1]);
let mut t = hmac.finalize();
let len = core::cmp::min(okm.len(), t.len());
okm[0..len].copy_from_slice(&t[..len]);
for i in 1u8..n as u8 {
let mut hmac = $hmac::new(&self.prk);
hmac.update(&t);
for info in info_components.iter() {
hmac.update(info);
}
hmac.update(&[i + 1]);
t = hmac.finalize();
let offset = i as usize * Self::TAG_LEN;
okm[offset..offset + Self::TAG_LEN].copy_from_slice(&t);
}
if n > 0 && r > 0 {
let mut hmac = $hmac::new(&self.prk);
hmac.update(&t);
for info in info_components.iter() {
hmac.update(info);
}
hmac.update(&[n as u8 + 1]);
t = hmac.finalize();
let last_okm = &mut okm[n * Self::TAG_LEN..];
let len = core::cmp::min(last_okm.len(), Self::TAG_LEN);
last_okm[..len].copy_from_slice(&t[..len]);
}
}
pub fn oneshot(salt: &[u8], ikm: &[u8], info: &[u8], okm: &mut [u8]) {
let hkdf = Self::new(salt, ikm);
hkdf.expand(info, okm);
}
}
}
}
impl_hkdf_with_hmac!(HkdfMd2, HmacMd2);
impl_hkdf_with_hmac!(HkdfMd4, HmacMd4);
impl_hkdf_with_hmac!(HkdfMd5, HmacMd5);
impl_hkdf_with_hmac!(HkdfSm3, HmacSm3);
impl_hkdf_with_hmac!(HkdfSha1, HmacSha1);
impl_hkdf_with_hmac!(HkdfSha224, HmacSha224);
impl_hkdf_with_hmac!(HkdfSha256, HmacSha256);
impl_hkdf_with_hmac!(HkdfSha384, HmacSha384);
impl_hkdf_with_hmac!(HkdfSha512, HmacSha512);
#[cfg(test)]
fn hexdecode(s: &str) -> Vec<u8> {
let h = s.replace("0x", "").replace(" ", "").replace("\n", "").replace("\r", "");
hex::decode(&h).unwrap()
}
#[test]
fn test_hkdf() {
let ikm = hexdecode("0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b");
let salt = hexdecode("000102030405060708090a0b0c");
let info = hexdecode("f0f1f2f3f4f5f6f7f8f9");
let len = 42usize;
assert_eq!(ikm.len(), 22);
assert_eq!(salt.len(), 13);
assert_eq!(info.len(), 10);
let hkdf = HkdfSha256::new(&salt, &ikm);
assert_eq!(&hkdf.prk()[..], &hex::decode("077709362c2e32df0ddc3f0dc47bba6390b6c73bb50f9c3122ec844ad7c2b3e5").unwrap()[..]);
let mut okm = vec![0u8; len];
hkdf.expand(&info, &mut okm);
assert_eq!(&okm[..],
&hex::decode("3cb25f25faacd57a90434f64d0362f2a2d2d0a90cf1a5a4c5db02d56ecc4c5bf34007208d5b887185865").unwrap()[..]);
let ikm = hexdecode("0x000102030405060708090a0b0c0d0e0f\
101112131415161718191a1b1c1d1e1f\
202122232425262728292a2b2c2d2e2f\
303132333435363738393a3b3c3d3e3f\
404142434445464748494a4b4c4d4e4f\
");
let salt = hexdecode("0x606162636465666768696a6b6c6d6e6f\
707172737475767778797a7b7c7d7e7f\
808182838485868788898a8b8c8d8e8f\
909192939495969798999a9b9c9d9e9f\
a0a1a2a3a4a5a6a7a8a9aaabacadaeaf\
");
let info = hexdecode("0xb0b1b2b3b4b5b6b7b8b9babbbcbdbebf\
c0c1c2c3c4c5c6c7c8c9cacbcccdcecf\
d0d1d2d3d4d5d6d7d8d9dadbdcdddedf\
e0e1e2e3e4e5e6e7e8e9eaebecedeeef\
f0f1f2f3f4f5f6f7f8f9fafbfcfdfeff\
");
let len = 82usize;
let mut okm = vec![0u8; len];
HkdfSha256::oneshot(&salt, &ikm, &info, &mut okm);
assert_eq!(&hex::encode(&okm), "b11e398dc80327a1c8e7f78c596a4934\
4f012eda2d4efad8a050cc4c19afa97c\
59045a99cac7827271cb41c65e590e09\
da3275600c2f09b8367793a9aca3db71\
cc30c58179ec3e87c14c01d5c1f3434f\
1d87\
");
let ikm = hexdecode("0x0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b");
let salt = [];
let info = [];
let len = 42usize;
let mut okm = vec![0u8; len];
HkdfSha256::oneshot(&salt, &ikm, &info, &mut okm);
assert_eq!(&hex::encode(&okm), "8da4e775a563c18f715f802a063c5a31\
b8a11f5c5ee1879ec3454e5f3c738d2d\
9d201395faa4b61a96c8\
");
let ikm = hexdecode("0x0b0b0b0b0b0b0b0b0b0b0b");
let salt = hexdecode("0x000102030405060708090a0b0c");
let info = hexdecode("0xf0f1f2f3f4f5f6f7f8f9");
let len = 42usize;
let mut okm = vec![0u8; len];
HkdfSha1::oneshot(&salt, &ikm, &info, &mut okm);
assert_eq!(&hex::encode(&okm), "085a01ea1b10f36933068b56efa5ad81\
a4f14b822f5b091568a9cdd4f155fda2\
c22e422478d305f3f896\
");
let ikm = hexdecode("0x000102030405060708090a0b0c0d0e0f\
101112131415161718191a1b1c1d1e1f\
202122232425262728292a2b2c2d2e2f\
303132333435363738393a3b3c3d3e3f\
404142434445464748494a4b4c4d4e4f\
");
let salt = hexdecode("0x606162636465666768696a6b6c6d6e6f\
707172737475767778797a7b7c7d7e7f\
808182838485868788898a8b8c8d8e8f\
909192939495969798999a9b9c9d9e9f\
a0a1a2a3a4a5a6a7a8a9aaabacadaeaf\
");
let info = hexdecode("0xb0b1b2b3b4b5b6b7b8b9babbbcbdbebf\
c0c1c2c3c4c5c6c7c8c9cacbcccdcecf\
d0d1d2d3d4d5d6d7d8d9dadbdcdddedf\
e0e1e2e3e4e5e6e7e8e9eaebecedeeef\
f0f1f2f3f4f5f6f7f8f9fafbfcfdfeff\
");
let len = 82usize;
let mut okm = vec![0u8; len];
HkdfSha1::oneshot(&salt, &ikm, &info, &mut okm);
assert_eq!(&hex::encode(&okm), "0bd770a74d1160f7c9f12cd5912a06eb\
ff6adcae899d92191fe4305673ba2ffe\
8fa3f1a4e5ad79f3f334b3b202b2173c\
486ea37ce3d397ed034c7f9dfeb15c5e\
927336d0441f4c4300e2cff0d0900b52\
d3b4\
");
let ikm = hexdecode("0x0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b");
let salt = [];
let info = [];
let len = 42usize;
let mut okm = vec![0u8; len];
HkdfSha1::oneshot(&salt, &ikm, &info, &mut okm);
assert_eq!(&hex::encode(&okm), "0ac1af7002b3d761d1e55298da9d0506\
b9ae52057220a306e07b6b87e8df21d0\
ea00033de03984d34918\
");
let ikm = hexdecode("0x0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c");
let salt = [];
let info = [];
let len = 42usize;
let mut okm = vec![0u8; len];
HkdfSha1::oneshot(&salt, &ikm, &info, &mut okm);
assert_eq!(&hex::encode(&okm), "2c91117204d745f3500d636a62f64f0a\
b3bae548aa53d423b0d1f27ebba6f5e5\
673a081d70cce7acfc48\
");
}