Skip to main content

cryptography/hash/
hkdf.rs

1//! HKDF (RFC 5869) over the crate's digest/HMAC traits.
2//!
3//! This module exposes the two standard stages:
4//! - extract: `PRK = HMAC(salt, IKM)`
5//! - expand: `OKM = T(1) || T(2) || ...`
6//!
7//! The implementation is generic over any fixed-output digest `H` that
8//! implements [`crate::hash::Digest`].
9
10use core::marker::PhantomData;
11
12use super::hmac::Hmac;
13use super::Digest;
14
15/// HKDF key schedule state holding one pseudorandom key (PRK).
16pub struct Hkdf<H: Digest> {
17    prk: Vec<u8>,
18    marker: PhantomData<H>,
19}
20
21impl<H: Digest> Hkdf<H> {
22    /// Extract one pseudorandom key from input keying material.
23    ///
24    /// If `salt` is `None`, RFC 5869 uses a digest-length all-zero salt.
25    #[must_use]
26    pub fn extract(salt: Option<&[u8]>, ikm: &[u8]) -> Self {
27        let zero_salt;
28        let salt = match salt {
29            Some(salt) => salt,
30            None => {
31                zero_salt = vec![0u8; H::OUTPUT_LEN];
32                &zero_salt
33            }
34        };
35        let prk = Hmac::<H>::compute(salt, ikm);
36        Self {
37            prk,
38            marker: PhantomData,
39        }
40    }
41
42    /// Build an HKDF state from a previously extracted PRK.
43    ///
44    /// RFC 5869 defines PRK as one digest-width string.
45    #[must_use]
46    pub fn from_prk(prk: &[u8]) -> Option<Self> {
47        if prk.len() != H::OUTPUT_LEN {
48            return None;
49        }
50        Some(Self {
51            prk: prk.to_vec(),
52            marker: PhantomData,
53        })
54    }
55
56    /// Return the extracted pseudorandom key.
57    #[must_use]
58    pub fn prk(&self) -> &[u8] {
59        &self.prk
60    }
61
62    /// Expand into `out` with caller-supplied context `info`.
63    ///
64    /// Returns `false` if `out` exceeds `255 * H::OUTPUT_LEN`, as required by
65    /// RFC 5869.
66    #[must_use]
67    pub fn expand(&self, info: &[u8], out: &mut [u8]) -> bool {
68        let max = 255usize
69            .checked_mul(H::OUTPUT_LEN)
70            .expect("digest output length should keep HKDF max bounded");
71        if out.len() > max {
72            return false;
73        }
74
75        let mut t = Vec::<u8>::new();
76        let mut generated = 0usize;
77        let mut counter = 1u8;
78
79        while generated < out.len() {
80            let mut data = Vec::with_capacity(t.len() + info.len() + 1);
81            data.extend_from_slice(&t);
82            data.extend_from_slice(info);
83            data.push(counter);
84
85            t = Hmac::<H>::compute(&self.prk, &data);
86            let take = core::cmp::min(out.len() - generated, t.len());
87            out[generated..generated + take].copy_from_slice(&t[..take]);
88            generated += take;
89            counter = counter.wrapping_add(1);
90        }
91
92        crate::ct::zeroize_slice(t.as_mut_slice());
93        true
94    }
95
96    /// Convenience one-shot HKDF (extract + expand).
97    #[must_use]
98    pub fn derive(salt: Option<&[u8]>, ikm: &[u8], info: &[u8], len: usize) -> Option<Vec<u8>> {
99        let hkdf = Self::extract(salt, ikm);
100        let mut out = vec![0u8; len];
101        if !hkdf.expand(info, &mut out) {
102            return None;
103        }
104        Some(out)
105    }
106}
107
108impl<H: Digest> Drop for Hkdf<H> {
109    fn drop(&mut self) {
110        crate::ct::zeroize_slice(self.prk.as_mut_slice());
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::Hkdf;
117    use crate::{Sha1, Sha256};
118
119    fn hex(bytes: &[u8]) -> String {
120        let mut out = String::with_capacity(bytes.len() * 2);
121        for b in bytes {
122            use core::fmt::Write;
123            let _ = write!(&mut out, "{b:02x}");
124        }
125        out
126    }
127
128    fn unhex(input: &str) -> Vec<u8> {
129        let mut out = Vec::with_capacity(input.len() / 2);
130        let bytes = input.as_bytes();
131        let mut i = 0usize;
132        while i + 1 < bytes.len() {
133            let hi = (bytes[i] as char).to_digit(16).expect("hex") as u8;
134            let lo = (bytes[i + 1] as char).to_digit(16).expect("hex") as u8;
135            out.push((hi << 4) | lo);
136            i += 2;
137        }
138        out
139    }
140
141    #[test]
142    fn rfc5869_case_1_sha256() {
143        let ikm = vec![0x0b; 22];
144        let salt = unhex("000102030405060708090a0b0c");
145        let info = unhex("f0f1f2f3f4f5f6f7f8f9");
146
147        let hkdf = Hkdf::<Sha256>::extract(Some(&salt), &ikm);
148        assert_eq!(
149            hex(hkdf.prk()),
150            "077709362c2e32df0ddc3f0dc47bba63".to_owned() + "90b6c73bb50f9c3122ec844ad7c2b3e5"
151        );
152
153        let mut okm = vec![0u8; 42];
154        assert!(hkdf.expand(&info, &mut okm));
155        assert_eq!(
156            hex(&okm),
157            "3cb25f25faacd57a90434f64d0362f2a".to_owned()
158                + "2d2d0a90cf1a5a4c5db02d56ecc4c5bf"
159                + "34007208d5b887185865"
160        );
161    }
162
163    #[test]
164    fn rfc5869_case_2_sha256_long_inputs() {
165        let ikm = unhex(
166            "000102030405060708090a0b0c0d0e0f\
167             101112131415161718191a1b1c1d1e1f\
168             202122232425262728292a2b2c2d2e2f\
169             303132333435363738393a3b3c3d3e3f\
170             404142434445464748494a4b4c4d4e4f",
171        );
172        let salt = unhex(
173            "606162636465666768696a6b6c6d6e6f\
174             707172737475767778797a7b7c7d7e7f\
175             808182838485868788898a8b8c8d8e8f\
176             909192939495969798999a9b9c9d9e9f\
177             a0a1a2a3a4a5a6a7a8a9aaabacadaeaf",
178        );
179        let info = unhex(
180            "b0b1b2b3b4b5b6b7b8b9babbbcbdbebf\
181             c0c1c2c3c4c5c6c7c8c9cacbcccdcecf\
182             d0d1d2d3d4d5d6d7d8d9dadbdcdddedf\
183             e0e1e2e3e4e5e6e7e8e9eaebecedeeef\
184             f0f1f2f3f4f5f6f7f8f9fafbfcfdfeff",
185        );
186
187        let hkdf = Hkdf::<Sha256>::extract(Some(&salt), &ikm);
188        assert_eq!(
189            hex(hkdf.prk()),
190            "06a6b88c5853361a06104c9ceb35b45c".to_owned() + "ef760014904671014a193f40c15fc244"
191        );
192
193        let mut okm = vec![0u8; 82];
194        assert!(hkdf.expand(&info, &mut okm));
195        assert_eq!(
196            hex(&okm),
197            "b11e398dc80327a1c8e7f78c596a4934".to_owned()
198                + "4f012eda2d4efad8a050cc4c19afa97c"
199                + "59045a99cac7827271cb41c65e590e09"
200                + "da3275600c2f09b8367793a9aca3db71"
201                + "cc30c58179ec3e87c14c01d5c1f3434f"
202                + "1d87"
203        );
204    }
205
206    #[test]
207    fn rfc5869_case_3_sha256_zero_salt() {
208        let ikm = vec![0x0b; 22];
209        let info = [];
210        let mut okm = vec![0u8; 42];
211        let hkdf = Hkdf::<Sha256>::extract(None, &ikm);
212        assert!(hkdf.expand(&info, &mut okm));
213        assert_eq!(
214            hex(&okm),
215            "8da4e775a563c18f715f802a063c5a31".to_owned()
216                + "b8a11f5c5ee1879ec3454e5f3c738d2d"
217                + "9d201395faa4b61a96c8"
218        );
219    }
220
221    #[test]
222    fn rfc5869_case_4_sha1() {
223        let ikm = unhex("0b0b0b0b0b0b0b0b0b0b0b");
224        let salt = unhex("000102030405060708090a0b0c");
225        let info = unhex("f0f1f2f3f4f5f6f7f8f9");
226
227        let hkdf = Hkdf::<Sha1>::extract(Some(&salt), &ikm);
228        assert_eq!(hex(hkdf.prk()), "9b6c18c432a7bf8f0e71c8eb88f4b30baa2ba243");
229
230        let mut okm = vec![0u8; 42];
231        assert!(hkdf.expand(&info, &mut okm));
232        assert_eq!(
233            hex(&okm),
234            "085a01ea1b10f36933068b56efa5ad81".to_owned()
235                + "a4f14b822f5b091568a9cdd4f155fda2"
236                + "c22e422478d305f3f896"
237        );
238    }
239
240    #[test]
241    fn rfc5869_case_5_sha1_long_inputs() {
242        let ikm = unhex(
243            "000102030405060708090a0b0c0d0e0f\
244             101112131415161718191a1b1c1d1e1f\
245             202122232425262728292a2b2c2d2e2f\
246             303132333435363738393a3b3c3d3e3f\
247             404142434445464748494a4b4c4d4e4f",
248        );
249        let salt = unhex(
250            "606162636465666768696a6b6c6d6e6f\
251             707172737475767778797a7b7c7d7e7f\
252             808182838485868788898a8b8c8d8e8f\
253             909192939495969798999a9b9c9d9e9f\
254             a0a1a2a3a4a5a6a7a8a9aaabacadaeaf",
255        );
256        let info = unhex(
257            "b0b1b2b3b4b5b6b7b8b9babbbcbdbebf\
258             c0c1c2c3c4c5c6c7c8c9cacbcccdcecf\
259             d0d1d2d3d4d5d6d7d8d9dadbdcdddedf\
260             e0e1e2e3e4e5e6e7e8e9eaebecedeeef\
261             f0f1f2f3f4f5f6f7f8f9fafbfcfdfeff",
262        );
263
264        let hkdf = Hkdf::<Sha1>::extract(Some(&salt), &ikm);
265        assert_eq!(hex(hkdf.prk()), "8adae09a2a307059478d309b26c4115a224cfaf6");
266
267        let mut okm = vec![0u8; 82];
268        assert!(hkdf.expand(&info, &mut okm));
269        assert_eq!(
270            hex(&okm),
271            "0bd770a74d1160f7c9f12cd5912a06eb".to_owned()
272                + "ff6adcae899d92191fe4305673ba2ffe"
273                + "8fa3f1a4e5ad79f3f334b3b202b2173c"
274                + "486ea37ce3d397ed034c7f9dfeb15c5e"
275                + "927336d0441f4c4300e2cff0d0900b52"
276                + "d3b4"
277        );
278    }
279
280    #[test]
281    fn rfc5869_case_6_sha1_zero_salt_info() {
282        let ikm = vec![0x0b; 22];
283        let info = [];
284
285        let hkdf = Hkdf::<Sha1>::extract(Some(&[]), &ikm);
286        assert_eq!(hex(hkdf.prk()), "da8c8a73c7fa77288ec6f5e7c297786aa0d32d01");
287
288        let mut okm = vec![0u8; 42];
289        assert!(hkdf.expand(&info, &mut okm));
290        assert_eq!(
291            hex(&okm),
292            "0ac1af7002b3d761d1e55298da9d0506".to_owned()
293                + "b9ae52057220a306e07b6b87e8df21d0"
294                + "ea00033de03984d34918"
295        );
296    }
297
298    #[test]
299    fn rfc5869_case_7_sha1_no_salt() {
300        let ikm = vec![0x0c; 22];
301        let info = [];
302
303        let hkdf = Hkdf::<Sha1>::extract(None, &ikm);
304        assert_eq!(hex(hkdf.prk()), "2adccada18779e7c2077ad2eb19d3f3e731385dd");
305
306        let mut okm = vec![0u8; 42];
307        assert!(hkdf.expand(&info, &mut okm));
308        assert_eq!(
309            hex(&okm),
310            "2c91117204d745f3500d636a62f64f0a".to_owned()
311                + "b3bae548aa53d423b0d1f27ebba6f5e5"
312                + "673a081d70cce7acfc48"
313        );
314    }
315
316    #[test]
317    fn expand_rejects_overlong_output() {
318        let hkdf = Hkdf::<Sha256>::extract(Some(&[0x01, 0x02]), b"ikm");
319        let mut out = vec![0u8; 255 * 32 + 1];
320        assert!(!hkdf.expand(b"info", &mut out));
321    }
322
323    #[test]
324    fn derive_matches_extract_expand() {
325        let salt = b"salt";
326        let ikm = b"ikm";
327        let info = b"context";
328        let direct = Hkdf::<Sha256>::derive(Some(salt), ikm, info, 48).expect("derive");
329
330        let hkdf = Hkdf::<Sha256>::extract(Some(salt), ikm);
331        let mut manual = vec![0u8; 48];
332        assert!(hkdf.expand(info, &mut manual));
333        assert_eq!(direct, manual);
334    }
335}