#[cfg(feature = "alloc")]
use alloc::vec::Vec;
use super::hmac_sm3;
use super::DIGEST_LEN;
#[cfg(feature = "alloc")]
use crate::error::Error;
pub fn hkdf_extract(salt: Option<&[u8]>, ikm: &[u8]) -> [u8; DIGEST_LEN] {
let zeros = [0u8; DIGEST_LEN];
let salt = salt.unwrap_or(&zeros);
hmac_sm3(salt, ikm)
}
#[cfg(feature = "alloc")]
pub fn hkdf_expand(prk: &[u8; DIGEST_LEN], info: &[u8], len: usize) -> Result<Vec<u8>, Error> {
const MAX_LEN: usize = 255 * DIGEST_LEN;
if len > MAX_LEN {
return Err(Error::InvalidInputLength);
}
let mut okm = Vec::with_capacity(len + DIGEST_LEN);
let mut t_prev = [0u8; DIGEST_LEN]; let mut t_prev_len = 0usize;
let rounds = len.div_ceil(DIGEST_LEN);
for i in 1u8..=(rounds as u8) {
let mut input = [0u8; DIGEST_LEN + 255 + 1]; let info_len = info.len().min(255);
input[..t_prev_len].copy_from_slice(&t_prev[..t_prev_len]);
input[t_prev_len..t_prev_len + info_len].copy_from_slice(&info[..info_len]);
input[t_prev_len + info_len] = i;
let t_i = hmac_sm3(prk, &input[..t_prev_len + info_len + 1]);
okm.extend_from_slice(&t_i);
t_prev = t_i;
t_prev_len = DIGEST_LEN; }
okm.truncate(len);
Ok(okm)
}
#[cfg(feature = "alloc")]
pub fn hkdf(salt: Option<&[u8]>, ikm: &[u8], info: &[u8], len: usize) -> Result<Vec<u8>, Error> {
let prk = hkdf_extract(salt, ikm);
hkdf_expand(&prk, info, len)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hkdf_extract_deterministic() {
let salt = b"test-salt";
let ikm = b"input-key-material";
let prk1 = hkdf_extract(Some(salt), ikm);
let prk2 = hkdf_extract(Some(salt), ikm);
assert_eq!(prk1, prk2);
assert_eq!(prk1.len(), 32);
}
#[test]
fn test_hkdf_extract_none_salt_equals_zero_salt() {
let ikm = b"some ikm";
let zeros = [0u8; 32];
let prk_none = hkdf_extract(None, ikm);
let prk_zero = hkdf_extract(Some(&zeros), ikm);
assert_eq!(prk_none, prk_zero);
}
#[cfg(feature = "alloc")]
#[test]
fn test_hkdf_expand_length() {
let prk = [0x42u8; 32];
let info = b"test-info";
assert_eq!(hkdf_expand(&prk, info, 16).unwrap().len(), 16);
assert_eq!(hkdf_expand(&prk, info, 32).unwrap().len(), 32);
assert_eq!(hkdf_expand(&prk, info, 48).unwrap().len(), 48);
}
#[cfg(feature = "alloc")]
#[test]
fn test_hkdf_expand_deterministic() {
let prk = [0x11u8; 32];
let info = b"ctx";
let out1 = hkdf_expand(&prk, info, 32).unwrap();
let out2 = hkdf_expand(&prk, info, 32).unwrap();
assert_eq!(out1, out2);
}
#[cfg(feature = "alloc")]
#[test]
fn test_hkdf_expand_prefix_consistency() {
let prk = [0x22u8; 32];
let info = b"prefix-test";
let short = hkdf_expand(&prk, info, 32).unwrap();
let long = hkdf_expand(&prk, info, 64).unwrap();
assert_eq!(&long[..32], &short[..]);
}
#[cfg(feature = "alloc")]
#[test]
fn test_hkdf_expand_max_len_rejected() {
let prk = [0u8; 32];
let result = hkdf_expand(&prk, b"", 255 * 32 + 1);
assert!(result.is_err());
}
#[cfg(feature = "alloc")]
#[test]
fn test_hkdf_expand_max_len_accepted() {
let prk = [0u8; 32];
let result = hkdf_expand(&prk, b"", 255 * 32);
assert!(result.is_ok());
assert_eq!(result.unwrap().len(), 255 * 32);
}
#[cfg(feature = "alloc")]
#[test]
fn test_hkdf_different_info_different_output() {
let prk = [0x33u8; 32];
let out1 = hkdf_expand(&prk, b"info-a", 32).unwrap();
let out2 = hkdf_expand(&prk, b"info-b", 32).unwrap();
assert_ne!(out1, out2);
}
#[cfg(feature = "alloc")]
#[test]
fn test_hkdf_roundtrip_salt_info() {
let salt = b"tls13-early-secret-salt";
let ikm = b"shared-secret-from-key-exchange";
let prk = hkdf_extract(Some(salt), ikm);
let key1 = hkdf_expand(&prk, b"tls13 key", 16).unwrap();
let key2 = hkdf_expand(&prk, b"tls13 iv", 12).unwrap();
assert_eq!(key1.len(), 16);
assert_eq!(key2.len(), 12);
assert_ne!(&key1[..12], &key2[..]);
}
}