use alloc::vec;
use alloc::vec::Vec;
use hkdf::Hkdf;
use sha2::{Sha256, Sha512};
use crate::error::{Error, Result};
pub const HKDF_MAX_OUTPUT_SHA256: usize = 255 * 32;
pub const HKDF_MAX_OUTPUT_SHA512: usize = 255 * 64;
pub fn hkdf_sha256(ikm: &[u8], salt: Option<&[u8]>, info: &[u8], len: usize) -> Result<Vec<u8>> {
if len > HKDF_MAX_OUTPUT_SHA256 {
return Err(Error::Kdf("hkdf-sha256 output > 255 * 32 bytes"));
}
let hk = Hkdf::<Sha256>::new(salt, ikm);
let mut out = vec![0u8; len];
hk.expand(info, &mut out)
.map_err(|_| Error::Kdf("hkdf-sha256 expand"))?;
Ok(out)
}
pub fn hkdf_sha512(ikm: &[u8], salt: Option<&[u8]>, info: &[u8], len: usize) -> Result<Vec<u8>> {
if len > HKDF_MAX_OUTPUT_SHA512 {
return Err(Error::Kdf("hkdf-sha512 output > 255 * 64 bytes"));
}
let hk = Hkdf::<Sha512>::new(salt, ikm);
let mut out = vec![0u8; len];
hk.expand(info, &mut out)
.map_err(|_| Error::Kdf("hkdf-sha512 expand"))?;
Ok(out)
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, unused_results)]
mod tests {
use super::*;
fn hex_to_bytes(s: &str) -> Vec<u8> {
hex::decode(s).expect("valid hex")
}
#[test]
fn rfc5869_test_case_1_sha256() {
let ikm = hex_to_bytes("0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b");
let salt = hex_to_bytes("000102030405060708090a0b0c");
let info = hex_to_bytes("f0f1f2f3f4f5f6f7f8f9");
let expected = hex_to_bytes(
"3cb25f25faacd57a90434f64d0362f2a\
2d2d0a90cf1a5a4c5db02d56ecc4c5bf\
34007208d5b887185865",
);
let got = hkdf_sha256(&ikm, Some(&salt), &info, 42).unwrap();
assert_eq!(got, expected);
}
#[test]
fn rfc5869_test_case_3_sha256_no_salt_no_info() {
let ikm = hex_to_bytes("0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b");
let expected = hex_to_bytes(
"8da4e775a563c18f715f802a063c5a31\
b8a11f5c5ee1879ec3454e5f3c738d2d\
9d201395faa4b61a96c8",
);
let got_none = hkdf_sha256(&ikm, None, &[], 42).unwrap();
let got_empty = hkdf_sha256(&ikm, Some(&[]), &[], 42).unwrap();
assert_eq!(got_none, expected);
assert_eq!(got_empty, expected);
}
#[test]
fn hkdf_sha512_matches_upstream() {
let ikm = hex_to_bytes("0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b");
let salt = hex_to_bytes("000102030405060708090a0b0c");
let info = hex_to_bytes("f0f1f2f3f4f5f6f7f8f9");
let got = hkdf_sha512(&ikm, Some(&salt), &info, 64).unwrap();
let hk = Hkdf::<Sha512>::new(Some(&salt), &ikm);
let mut expected = vec![0u8; 64];
hk.expand(&info, &mut expected).unwrap();
assert_eq!(got, expected);
assert_eq!(got.len(), 64);
}
#[test]
fn hkdf_sha256_zero_length_output() {
let out = hkdf_sha256(&[0u8; 32], None, &[], 0).unwrap();
assert!(out.is_empty());
}
#[test]
fn hkdf_sha256_max_length_accepted() {
let out = hkdf_sha256(&[0u8; 32], None, &[], HKDF_MAX_OUTPUT_SHA256).unwrap();
assert_eq!(out.len(), HKDF_MAX_OUTPUT_SHA256);
}
#[test]
fn hkdf_sha256_over_max_rejected() {
let err = hkdf_sha256(&[0u8; 32], None, &[], HKDF_MAX_OUTPUT_SHA256 + 1).unwrap_err();
assert!(matches!(err, Error::Kdf(_)), "{err:?}");
}
#[test]
fn hkdf_sha512_max_length_accepted() {
let out = hkdf_sha512(&[0u8; 32], None, &[], HKDF_MAX_OUTPUT_SHA512).unwrap();
assert_eq!(out.len(), HKDF_MAX_OUTPUT_SHA512);
}
#[test]
fn hkdf_sha512_over_max_rejected() {
let err = hkdf_sha512(&[0u8; 32], None, &[], HKDF_MAX_OUTPUT_SHA512 + 1).unwrap_err();
assert!(matches!(err, Error::Kdf(_)), "{err:?}");
}
#[test]
fn different_info_produces_different_output() {
let master = [0x42u8; 32];
let a = hkdf_sha256(&master, None, b"info:a", 32).unwrap();
let b = hkdf_sha256(&master, None, b"info:b", 32).unwrap();
assert_ne!(a, b);
}
#[test]
fn different_salt_produces_different_output() {
let master = [0x42u8; 32];
let a = hkdf_sha256(&master, Some(b"salt-a"), b"info", 32).unwrap();
let b = hkdf_sha256(&master, Some(b"salt-b"), b"info", 32).unwrap();
assert_ne!(a, b);
}
#[test]
fn deterministic_in_inputs() {
let master = [0x42u8; 32];
let a = hkdf_sha256(&master, Some(b"salt"), b"info", 32).unwrap();
let b = hkdf_sha256(&master, Some(b"salt"), b"info", 32).unwrap();
assert_eq!(a, b);
}
}