use hkdf::Hkdf;
use sha2::Sha256;
use thiserror::Error;
const HASH_LEN: usize = 32;
const MAX_OUTPUT_LEN: usize = 255 * HASH_LEN;
#[derive(Debug, Clone, PartialEq, Eq, Error)]
pub enum HkdfError {
#[error("hkdf-sha256: requested length {0} exceeds the {MAX_OUTPUT_LEN}-byte maximum")]
OutputTooLong(usize),
}
pub fn hkdf_sha256(
ikm: &[u8],
salt: &[u8],
info: &[u8],
length: usize,
) -> Result<Vec<u8>, HkdfError> {
let prk = extract(salt, ikm);
expand(&prk, info, length)
}
#[must_use]
pub fn extract(salt: &[u8], ikm: &[u8]) -> [u8; HASH_LEN] {
let (prk, _hk) = Hkdf::<Sha256>::extract(Some(salt), ikm);
prk.into()
}
pub fn expand(prk: &[u8; HASH_LEN], info: &[u8], length: usize) -> Result<Vec<u8>, HkdfError> {
if length > MAX_OUTPUT_LEN {
return Err(HkdfError::OutputTooLong(length));
}
let hk = Hkdf::<Sha256>::from_prk(prk).expect("a 32-byte PRK is always a valid SHA-256 PRK");
let mut okm = vec![0u8; length];
hk.expand(info, &mut okm)
.expect("length is bounded above by the RFC 5869 maximum checked above");
Ok(okm)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extract_then_expand_equals_combined() {
let ikm = b"the quick brown fox";
let salt = b"some salt";
let info = b"context label";
let combined = hkdf_sha256(ikm, salt, info, 40).unwrap();
let prk = extract(salt, ikm);
let staged = expand(&prk, info, 40).unwrap();
assert_eq!(combined, staged);
}
#[test]
fn empty_salt_is_the_zero_salt_case() {
let ikm = b"input keying material";
let info = b"info";
let with_empty = hkdf_sha256(ikm, &[], info, 32).unwrap();
let with_zeros = hkdf_sha256(ikm, &[0u8; 32], info, 32).unwrap();
assert_eq!(with_empty, with_zeros);
}
#[test]
fn output_length_is_respected() {
let okm = hkdf_sha256(b"ikm", b"salt", b"info", 7).unwrap();
assert_eq!(okm.len(), 7);
let okm = hkdf_sha256(b"ikm", b"salt", b"info", 0).unwrap();
assert_eq!(okm.len(), 0);
}
#[test]
fn rejects_over_long_output() {
let too_long = MAX_OUTPUT_LEN + 1;
assert_eq!(
hkdf_sha256(b"ikm", b"salt", b"info", too_long),
Err(HkdfError::OutputTooLong(too_long)),
);
assert!(hkdf_sha256(b"ikm", b"salt", b"info", MAX_OUTPUT_LEN).is_ok());
}
}