use cypher::Digest;
use crate::ChainingKey;
fn hmac_hash<D: Digest>(
key: impl AsRef<[u8]>,
inputs: impl IntoIterator<Item = impl AsRef<[u8]>>,
) -> D::Output {
let mut ipad = [0x36u8; 128];
let mut opad = [0x5cu8; 128];
let mut iengine = D::new();
let mut oengine = D::new();
let key = key.as_ref();
if key.len() > D::BLOCK_LEN {
let hash = D::digest(key);
for (b_i, b_h) in ipad.iter_mut().zip(hash.as_ref()) {
*b_i ^= *b_h;
}
for (b_o, b_h) in opad.iter_mut().zip(hash.as_ref()) {
*b_o ^= *b_h;
}
} else {
for (b_i, b_h) in ipad.iter_mut().zip(key) {
*b_i ^= *b_h;
}
for (b_o, b_h) in opad.iter_mut().zip(key) {
*b_o ^= *b_h;
}
};
iengine.input(&ipad[..D::BLOCK_LEN]);
oengine.input(&opad[..D::BLOCK_LEN]);
for buf in inputs {
iengine.input(buf);
}
let ihash = iengine.finalize();
oengine.input(ihash.as_ref());
oengine.finalize()
}
fn _hkdf<D: Digest>(
chaining_key: ChainingKey<D>,
input_material: impl AsRef<[u8]>,
) -> (D::Output, D::Output, D::Output) {
let temp_key = hmac_hash::<D>(chaining_key, [input_material]);
let output1 = hmac_hash::<D>(temp_key.as_ref(), [&[1]]);
let output2 = hmac_hash::<D>(temp_key.as_ref(), [output1.as_ref(), &[2][..]]);
(temp_key, output1, output2)
}
pub(crate) fn hkdf_2<D: Digest>(
chaining_key: ChainingKey<D>,
input_material: impl AsRef<[u8]>,
) -> (D::Output, D::Output) {
let (_, output1, output2) = _hkdf::<D>(chaining_key, input_material);
(output1, output2)
}
pub(crate) fn hkdf_3<D: Digest>(
chaining_key: ChainingKey<D>,
input_material: impl AsRef<[u8]>,
) -> (D::Output, D::Output, D::Output) {
let (temp_key, output1, output2) = _hkdf::<D>(chaining_key, input_material);
let output3 = hmac_hash::<D>(temp_key, [output2.as_ref(), &[3][..]]);
(output1, output3, output2)
}
#[cfg(test)]
mod test {
use amplify::hex::FromHex;
use cypher::Sha256;
use super::hkdf_2;
#[test]
fn rfc_5869_test_vector_3() {
let ikm = Vec::<u8>::from_hex("0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b").unwrap();
let (t1, t2) = hkdf_2::<Sha256>([0u8; 32], &ikm);
let mut calculated_okm = t1.to_vec();
calculated_okm.extend_from_slice(&t2);
calculated_okm.truncate(42);
assert_eq!(calculated_okm, Vec::<u8>::from_hex("8da4e775a563c18f715f802a063c5a31b8a11f5c5ee1879ec3454e5f3c738d2d9d201395faa4b61a96c8").unwrap());
}
}