1use alloc::vec::Vec;
2use core::fmt;
3
4use crate::hmac::hmac_sha256;
5
6#[derive(Debug, PartialEq)]
7pub enum HkdfError {
8 InvalidLength,
9 EmptyInput,
10}
11
12impl fmt::Display for HkdfError {
13 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
14 match self {
15 HkdfError::InvalidLength => write!(f, "Invalid output key length"),
16 HkdfError::EmptyInput => write!(f, "Cannot derive key from empty input material"),
17 }
18 }
19}
20
21pub fn hkdf(
24 length: usize,
25 derive_from: &[u8],
26 salt: Option<&[u8]>,
27 context: Option<&[u8]>,
28) -> Result<Vec<u8>, HkdfError> {
29 let hash_len: usize = 32;
30
31 if length < 1 {
32 return Err(HkdfError::InvalidLength);
33 }
34
35 if derive_from.is_empty() {
36 return Err(HkdfError::EmptyInput);
37 }
38
39 let salt = match salt {
40 Some(s) if !s.is_empty() => s.to_vec(),
41 _ => alloc::vec![0u8; hash_len],
42 };
43
44 let context = context.unwrap_or(b"");
45
46 let prk = hmac_sha256(&salt, derive_from);
48
49 let mut block: Vec<u8> = Vec::new();
51 let mut derived = Vec::with_capacity(length);
52
53 let iterations = (length + hash_len - 1) / hash_len;
54 for i in 0..iterations {
55 let mut input = Vec::new();
56 input.extend_from_slice(&block);
57 input.extend_from_slice(context);
58 input.push(((i + 1) % 256) as u8);
59
60 block = hmac_sha256(&prk, &input).to_vec();
61 derived.extend_from_slice(&block);
62 }
63
64 derived.truncate(length);
65 Ok(derived)
66}
67
68#[cfg(test)]
69mod tests {
70 use super::*;
71
72 #[test]
73 fn test_hkdf_32bytes() {
74 let ikm = b"input key material";
75 let salt = b"salt value";
76 let result = hkdf(32, ikm, Some(salt), None).unwrap();
77 assert_eq!(result.len(), 32);
78 }
79
80 #[test]
81 fn test_hkdf_64bytes() {
82 let ikm = b"input key material";
83 let salt = b"salt value";
84 let result = hkdf(64, ikm, Some(salt), None).unwrap();
85 assert_eq!(result.len(), 64);
86 }
87
88 #[test]
89 fn test_hkdf_with_context() {
90 let ikm = b"input key material";
91 let salt = b"salt";
92 let ctx = b"context info";
93 let result = hkdf(32, ikm, Some(salt), Some(ctx)).unwrap();
94 assert_eq!(result.len(), 32);
95 let result2 = hkdf(32, ikm, Some(salt), None).unwrap();
97 assert_ne!(result, result2);
98 }
99
100 #[test]
101 fn test_hkdf_none_salt() {
102 let ikm = b"input key material";
103 let result = hkdf(32, ikm, None, None).unwrap();
104 assert_eq!(result.len(), 32);
105 }
106
107 #[test]
108 fn test_hkdf_empty_salt() {
109 let ikm = b"input key material";
110 let result1 = hkdf(32, ikm, Some(b""), None).unwrap();
111 let result2 = hkdf(32, ikm, None, None).unwrap();
112 assert_eq!(result1, result2);
114 }
115
116 #[test]
117 fn test_hkdf_invalid_length() {
118 assert_eq!(hkdf(0, b"ikm", None, None), Err(HkdfError::InvalidLength));
119 }
120
121 #[test]
122 fn test_hkdf_empty_ikm() {
123 assert_eq!(hkdf(32, b"", None, None), Err(HkdfError::EmptyInput));
124 }
125}