1use crate::aws_lc::{HKDF_expand, HKDF};
41use crate::error::Unspecified;
42use crate::fips::indicator_check;
43use crate::{digest, hmac};
44use alloc::sync::Arc;
45use core::fmt;
46use zeroize::Zeroize;
47
48#[derive(Clone, Copy, Debug, Eq, PartialEq)]
50pub struct Algorithm(hmac::Algorithm);
51
52impl Algorithm {
53 #[inline]
55 #[must_use]
56 pub fn hmac_algorithm(&self) -> hmac::Algorithm {
57 self.0
58 }
59}
60
61pub const HKDF_SHA1_FOR_LEGACY_USE_ONLY: Algorithm = Algorithm(hmac::HMAC_SHA1_FOR_LEGACY_USE_ONLY);
63
64pub const HKDF_SHA256: Algorithm = Algorithm(hmac::HMAC_SHA256);
66
67pub const HKDF_SHA384: Algorithm = Algorithm(hmac::HMAC_SHA384);
69
70pub const HKDF_SHA512: Algorithm = Algorithm(hmac::HMAC_SHA512);
72
73const HKDF_INFO_DEFAULT_CAPACITY_LEN: usize = 80;
77
78const MAX_HKDF_PRK_LEN: usize = digest::MAX_OUTPUT_LEN;
81
82impl KeyType for Algorithm {
83 fn len(&self) -> usize {
84 self.0.digest_algorithm().output_len
85 }
86}
87
88pub struct Salt {
90 algorithm: Algorithm,
91 bytes: Arc<[u8]>,
92}
93
94#[allow(clippy::missing_fields_in_debug)]
95impl fmt::Debug for Salt {
96 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
97 f.debug_struct("hkdf::Salt")
98 .field("algorithm", &self.algorithm.0)
99 .finish()
100 }
101}
102
103impl Salt {
104 #[must_use]
122 pub fn new(algorithm: Algorithm, value: &[u8]) -> Self {
123 Self {
124 algorithm,
125 bytes: Arc::from(value),
126 }
127 }
128
129 #[inline]
136 #[must_use]
137 pub fn extract(&self, secret: &[u8]) -> Prk {
138 Prk {
139 algorithm: self.algorithm,
140 mode: PrkMode::ExtractExpand {
141 secret: Arc::new(ZeroizeBoxSlice::from(secret)),
142 salt: Arc::clone(&self.bytes),
143 },
144 }
145 }
146
147 #[inline]
149 #[must_use]
150 pub fn algorithm(&self) -> Algorithm {
151 Algorithm(self.algorithm.hmac_algorithm())
152 }
153}
154
155impl From<Okm<'_, Algorithm>> for Salt {
156 fn from(okm: Okm<'_, Algorithm>) -> Self {
157 let algorithm = okm.prk.algorithm;
158 let salt_len = okm.len().len();
159 let mut salt_bytes = vec![0u8; salt_len];
160 okm.fill(&mut salt_bytes).unwrap();
161 Self {
162 algorithm,
163 bytes: Arc::from(salt_bytes.as_slice()),
164 }
165 }
166}
167
168#[allow(clippy::len_without_is_empty)]
170pub trait KeyType {
171 fn len(&self) -> usize;
173}
174
175#[derive(Clone)]
176enum PrkMode {
177 Expand {
178 key_bytes: [u8; MAX_HKDF_PRK_LEN],
179 key_len: usize,
180 },
181 ExtractExpand {
182 secret: Arc<ZeroizeBoxSlice<u8>>,
183 salt: Arc<[u8]>,
184 },
185}
186
187impl PrkMode {
188 fn fill(&self, algorithm: Algorithm, out: &mut [u8], info: &[u8]) -> Result<(), Unspecified> {
189 let digest = digest::match_digest_type(&algorithm.0.digest_algorithm().id).as_const_ptr();
190
191 match &self {
192 PrkMode::Expand { key_bytes, key_len } => unsafe {
193 if 1 != indicator_check!(HKDF_expand(
194 out.as_mut_ptr(),
195 out.len(),
196 digest,
197 key_bytes.as_ptr(),
198 *key_len,
199 info.as_ptr(),
200 info.len(),
201 )) {
202 return Err(Unspecified);
203 }
204 },
205 PrkMode::ExtractExpand { secret, salt } => {
206 if 1 != indicator_check!(unsafe {
207 HKDF(
208 out.as_mut_ptr(),
209 out.len(),
210 digest,
211 secret.as_ptr(),
212 secret.len(),
213 salt.as_ptr(),
214 salt.len(),
215 info.as_ptr(),
216 info.len(),
217 )
218 }) {
219 return Err(Unspecified);
220 }
221 }
222 }
223
224 Ok(())
225 }
226}
227
228impl fmt::Debug for PrkMode {
229 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
230 match self {
231 Self::Expand { .. } => f.debug_struct("Expand").finish_non_exhaustive(),
232 Self::ExtractExpand { .. } => f.debug_struct("ExtractExpand").finish_non_exhaustive(),
233 }
234 }
235}
236
237struct ZeroizeBoxSlice<T: Zeroize>(Box<[T]>);
238
239impl<T: Zeroize> core::ops::Deref for ZeroizeBoxSlice<T> {
240 type Target = [T];
241
242 fn deref(&self) -> &Self::Target {
243 &self.0
244 }
245}
246
247impl<T: Clone + Zeroize> From<&[T]> for ZeroizeBoxSlice<T> {
248 fn from(value: &[T]) -> Self {
249 Self(Vec::from(value).into_boxed_slice())
250 }
251}
252
253impl<T: Zeroize> Drop for ZeroizeBoxSlice<T> {
254 fn drop(&mut self) {
255 self.0.zeroize();
256 }
257}
258
259#[derive(Clone)]
261pub struct Prk {
262 algorithm: Algorithm,
263 mode: PrkMode,
264}
265
266impl Drop for Prk {
267 fn drop(&mut self) {
268 if let PrkMode::Expand {
269 ref mut key_bytes, ..
270 } = self.mode
271 {
272 key_bytes.zeroize();
273 }
274 }
275}
276
277#[allow(clippy::missing_fields_in_debug)]
278impl fmt::Debug for Prk {
279 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
280 f.debug_struct("hkdf::Prk")
281 .field("algorithm", &self.algorithm.0)
282 .field("mode", &self.mode)
283 .finish()
284 }
285}
286
287impl Prk {
288 #[must_use]
306 pub fn new_less_safe(algorithm: Algorithm, value: &[u8]) -> Self {
307 Prk::try_new_less_safe(algorithm, value).expect("Prk length limit exceeded.")
308 }
309
310 fn try_new_less_safe(algorithm: Algorithm, value: &[u8]) -> Result<Prk, Unspecified> {
311 let key_len = value.len();
312 if key_len > MAX_HKDF_PRK_LEN {
313 return Err(Unspecified);
314 }
315 let mut key_bytes = [0u8; MAX_HKDF_PRK_LEN];
316 key_bytes[0..key_len].copy_from_slice(value);
317 Ok(Self {
318 algorithm,
319 mode: PrkMode::Expand { key_bytes, key_len },
320 })
321 }
322
323 #[inline]
336 pub fn expand<'a, L: KeyType>(
337 &'a self,
338 info: &'a [&'a [u8]],
339 len: L,
340 ) -> Result<Okm<'a, L>, Unspecified> {
341 let len_cached = len.len();
342 if len_cached > 255 * self.algorithm.0.digest_algorithm().output_len {
343 return Err(Unspecified);
344 }
345 Ok(Okm {
346 prk: self,
347 info,
348 len,
349 })
350 }
351}
352
353impl From<Okm<'_, Algorithm>> for Prk {
354 fn from(okm: Okm<Algorithm>) -> Self {
355 let algorithm = okm.len;
356 let key_len = okm.len.len();
357 let mut key_bytes = [0u8; MAX_HKDF_PRK_LEN];
358 okm.fill(&mut key_bytes[0..key_len]).unwrap();
359
360 Self {
361 algorithm,
362 mode: PrkMode::Expand { key_bytes, key_len },
363 }
364 }
365}
366
367pub struct Okm<'a, L: KeyType> {
372 prk: &'a Prk,
373 info: &'a [&'a [u8]],
374 len: L,
375}
376
377impl<L: KeyType> fmt::Debug for Okm<'_, L> {
378 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
379 f.debug_struct("hkdf::Okm").field("prk", &self.prk).finish()
380 }
381}
382
383#[inline]
387fn concatenate_info<F, R>(info: &[&[u8]], f: F) -> R
388where
389 F: FnOnce(&[u8]) -> R,
390{
391 let info_len: usize = info.iter().map(|s| s.len()).sum();
392
393 if info_len <= HKDF_INFO_DEFAULT_CAPACITY_LEN {
395 let mut stack_buf = [0u8; HKDF_INFO_DEFAULT_CAPACITY_LEN];
397 let mut pos = 0;
398 for &slice in info {
399 stack_buf[pos..pos + slice.len()].copy_from_slice(slice);
400 pos += slice.len();
401 }
402
403 f(&stack_buf[..info_len])
404 } else {
405 let mut heap_buf = Vec::with_capacity(info_len);
407 for &slice in info {
408 heap_buf.extend_from_slice(slice);
409 }
410
411 f(&heap_buf)
412 }
413}
414
415impl<L: KeyType> Okm<'_, L> {
416 #[inline]
418 pub fn len(&self) -> &L {
419 &self.len
420 }
421
422 #[inline]
440 pub fn fill(self, out: &mut [u8]) -> Result<(), Unspecified> {
441 if out.len() != self.len.len() {
442 return Err(Unspecified);
443 }
444
445 concatenate_info(self.info, |info_bytes| {
446 self.prk.mode.fill(self.prk.algorithm, out, info_bytes)
447 })
448 }
449}
450
451#[cfg(test)]
452mod tests {
453 use crate::hkdf::{Salt, HKDF_SHA256, HKDF_SHA384};
454
455 #[cfg(feature = "fips")]
456 mod fips;
457
458 #[test]
459 fn hkdf_coverage() {
460 assert_ne!(HKDF_SHA256, HKDF_SHA384);
463 assert_eq!("Algorithm(Algorithm(SHA256))", format!("{HKDF_SHA256:?}"));
464 }
465
466 #[test]
467 fn test_debug() {
468 const SALT: &[u8; 32] = &[
469 29, 113, 120, 243, 11, 202, 39, 222, 206, 81, 163, 184, 122, 153, 52, 192, 98, 195,
470 240, 32, 34, 19, 160, 128, 178, 111, 97, 232, 113, 101, 221, 143,
471 ];
472 const SECRET1: &[u8; 32] = &[
473 157, 191, 36, 107, 110, 131, 193, 6, 175, 226, 193, 3, 168, 133, 165, 181, 65, 120,
474 194, 152, 31, 92, 37, 191, 73, 222, 41, 112, 207, 236, 196, 174,
475 ];
476
477 const INFO1: &[&[u8]] = &[
478 &[
479 2, 130, 61, 83, 192, 248, 63, 60, 211, 73, 169, 66, 101, 160, 196, 212, 250, 113,
480 ],
481 &[
482 80, 46, 248, 123, 78, 204, 171, 178, 67, 204, 96, 27, 131, 24,
483 ],
484 ];
485
486 let alg = HKDF_SHA256;
487 let salt = Salt::new(alg, SALT);
488 let prk = salt.extract(SECRET1);
489 let okm = prk.expand(INFO1, alg).unwrap();
490
491 assert_eq!(
492 "hkdf::Salt { algorithm: Algorithm(SHA256) }",
493 format!("{salt:?}")
494 );
495 assert_eq!(
496 "hkdf::Prk { algorithm: Algorithm(SHA256), mode: ExtractExpand { .. } }",
497 format!("{prk:?}")
498 );
499 assert_eq!(
500 "hkdf::Okm { prk: hkdf::Prk { algorithm: Algorithm(SHA256), mode: ExtractExpand { .. } } }",
501 format!("{okm:?}")
502 );
503 }
504
505 #[test]
506 fn test_long_salt() {
507 let long_salt = vec![0x42u8; 100];
509
510 let salt = Salt::new(HKDF_SHA256, &long_salt);
512
513 let secret = b"test secret key material";
515 let prk = salt.extract(secret);
516
517 let info_data = b"test context info";
519 let info = [info_data.as_slice()];
520 let okm = prk.expand(&info, HKDF_SHA256).unwrap();
521
522 let mut output = [0u8; 32];
524 okm.fill(&mut output).unwrap();
525
526 let very_long_salt = vec![0x55u8; 500];
528 let very_long_salt_obj = Salt::new(HKDF_SHA256, &very_long_salt);
529 let prk2 = very_long_salt_obj.extract(secret);
530 let okm2 = prk2.expand(&info, HKDF_SHA256).unwrap();
531 let mut output2 = [0u8; 32];
532 okm2.fill(&mut output2).unwrap();
533
534 assert_ne!(output, output2);
536 }
537}