libcrux_hkdf/
hkdf.rs

1//! # HKDF
2//!
3//! This crate implements HKDF ([RFC 5869](https://tools.ietf.org/html/rfc5869)) on SHA2-256, SHA2-384, and SHA2-512.
4//! The implementation is based on code extracted from verified crypto code from the [HACL* project](https://hacl-star.github.io).
5//!
6//! ## Examples
7//!
8//! ### Using the typed SHA2-256 API
9//!
10//! ```
11//! use libcrux_hkdf::{Hkdf, Sha2_256};
12//! use libcrux_secrets::{U8, Classify, ClassifyRef, DeclassifyRef};
13//!
14//! // Input key material and salt
15//! let ikm = &[0x0b; 22].classify(); // 22 bytes of 0x0b
16//! let salt = b"salt".classify_ref();
17//!
18//! // Extract phase: derive pseudorandom key
19//! let mut prk = [0u8; 32].classify(); // SHA2-256 output length
20//! Hkdf::<Sha2_256>::extract(&mut prk, salt, ikm).unwrap();
21//!
22//! // Expand phase: derive keys for different purposes
23//! let mut encrypt_key = [0u8; 16].classify();
24//! let mut mac_key = [0u8; 16].classify();
25//!
26//! Hkdf::<Sha2_256>::expand(&mut encrypt_key, &prk, b"encrypt").unwrap();
27//! Hkdf::<Sha2_256>::expand(&mut mac_key, &prk, b"mac").unwrap();
28//! ```
29//!
30//! ### Using the dynamic API
31//!
32//! ```
33//! use libcrux_hkdf::{extract, expand, Algorithm};
34//! use libcrux_secrets::{U8, Classify, ClassifyRef, DeclassifyRef};
35//!
36//! // Input key material and salt
37//! let ikm = &[0x0b; 22].classify();
38//! let salt = b"salt".classify_ref();
39//!
40//! // Extract phase using SHA2-512
41//! let mut prk = [0u8; 64].classify(); // SHA2-512 output length
42//! extract(Algorithm::Sha512, &mut prk, salt, ikm).unwrap();
43//!
44//! // Expand phase: derive keys for different purposes
45//! let mut encrypt_key = [0u8; 32].classify();
46//! let mut mac_key = [0u8; 32].classify();
47//!
48//! expand(Algorithm::Sha512, &mut encrypt_key, &prk, b"encrypt").unwrap();
49//! expand(Algorithm::Sha512, &mut mac_key, &prk, b"mac").unwrap();
50//! ```
51#![no_std]
52
53use core::marker::PhantomData;
54
55use libcrux_secrets::{Classify, DeclassifyRef, DeclassifyRefMut, U8};
56
57pub mod hacl;
58
59/// The HKDF algorithm defining the used hash function. Only needed for the functions with dynamic
60/// algorithm selection.
61#[derive(Copy, Clone, Debug, PartialEq)]
62pub enum Algorithm {
63    Sha256,
64    Sha384,
65    Sha512,
66}
67
68/// HKDF extract using the `salt` and the input key material `ikm`.
69/// The result is written to `prk`.
70/// The `algo` argument is used for dynamic algorithm selection.
71///
72/// Returns nothing on success.
73/// Returns [`ExtractError::ArgumentTooLong`] if one of `ikm` or `salt` is longer than [`u32::MAX`]
74/// bytes.
75/// Returns [`ExtractError::PrkTooShort`] if `prk` is shorter than the hash length.
76pub fn extract(
77    algo: Algorithm,
78    prk: &mut [U8],
79    salt: &[U8],
80    ikm: &[U8],
81) -> Result<(), ExtractError> {
82    match algo {
83        Algorithm::Sha256 => sha2_256::extract(prk, salt, ikm),
84        Algorithm::Sha384 => sha2_384::extract(prk, salt, ikm),
85        Algorithm::Sha512 => sha2_512::extract(prk, salt, ikm),
86    }
87}
88
89/// HKDF expand. The argument names match the specification.
90/// The result is written to `okm`.
91/// The `algo` argument is used for dynamic algorithm selection.
92///
93/// Returns nothing on success.
94/// Returns [`ExpandError::ArgumentTooLong`] if one of `prk` or `info` is longer than
95/// [`u32::MAX`] bytes.
96/// Returns [`ExpandError::PrkTooShort`] if `okm` is shorter than hash length.
97/// Returns [`ExpandError::OutputTooLong`] if `okm` is longer than 255 times the respective hash
98/// length.
99pub fn expand(algo: Algorithm, okm: &mut [U8], prk: &[U8], info: &[u8]) -> Result<(), ExpandError> {
100    match algo {
101        Algorithm::Sha256 => sha2_256::expand(okm, prk, info),
102        Algorithm::Sha384 => sha2_384::expand(okm, prk, info),
103        Algorithm::Sha512 => sha2_512::expand(okm, prk, info),
104    }
105}
106
107/// Full HKDF, i.e. both extract and expand, using the `salt` and the input key material `ikm`.
108/// The argument names match the specification. The result is written to `okm`.
109/// The `algo` argument is used for dynamic algorithm selection.
110///
111/// Returns nothing on success.
112/// Returns [`ExpandError::ArgumentTooLong`] if one of `prk` or `info` is longer than
113/// [`u32::MAX`] bytes.
114/// Returns [`ExpandError::PrkTooShort`] if `okm` is shorter than hash length.
115/// Returns [`ExpandError::OutputTooLong`] if `okm` is longer than 255 times the respective hash
116/// length.
117pub fn hkdf(
118    algo: Algorithm,
119    okm: &mut [U8],
120    salt: &[U8],
121    ikm: &[U8],
122    info: &[u8],
123) -> Result<(), ExpandError> {
124    match algo {
125        Algorithm::Sha256 => sha2_256::hkdf(okm, salt, ikm, info),
126        Algorithm::Sha384 => sha2_384::hkdf(okm, salt, ikm, info),
127        Algorithm::Sha512 => sha2_512::hkdf(okm, salt, ikm, info),
128    }
129}
130
131/// Type marker for SHA2-256 hash algorithm.
132///
133/// This struct is used as a type parameter for [`Hkdf<Sha2_256>`] to provide
134/// compile-time selection of the SHA2-256 algorithm for HKDF operations.
135/// SHA2-256 produces 32-byte (256-bit) hash outputs.
136pub struct Sha2_256;
137
138/// Type marker for SHA2-384 hash algorithm.
139///
140/// This struct is used as a type parameter for [`Hkdf<Sha2_384>`] to provide
141/// compile-time selection of the SHA2-384 algorithm for HKDF operations.
142/// SHA2-384 produces 48-byte (384-bit) hash outputs.
143pub struct Sha2_384;
144
145/// Type marker for SHA2-512 hash algorithm.
146///
147/// This struct is used as a type parameter for [`Hkdf<Sha2_512>`] to provide
148/// compile-time selection of the SHA2-512 algorithm for HKDF operations.
149/// SHA2-512 produces 64-byte (512-bit) hash outputs.
150pub struct Sha2_512;
151
152/// HKDF implementation with compile-time algorithm selection.
153///
154/// This struct provides type-safe HKDF operations for a specific hash algorithm
155/// determined at compile time. The algorithm is specified using type markers
156/// like [`Sha2_256`], [`Sha2_384`], or [`Sha2_512`].
157///
158/// The implementation follows RFC 5869 and uses verified cryptographic code
159/// from the HACL* project.
160///
161/// # Type Parameters
162///
163/// * `Algo` - The hash algorithm type marker (e.g., [`Sha2_256`])
164///
165/// # Examples
166///
167/// ```
168/// use libcrux_hkdf::{Hkdf, Sha2_256};
169/// use libcrux_secrets::{U8, Classify, ClassifyRef};
170///
171/// let ikm = &[0x0b; 22].classify();
172/// let salt = b"salt".classify_ref();
173///
174/// let mut prk = [0u8; 32].classify();
175/// Hkdf::<Sha2_256>::extract(&mut prk, salt, ikm).unwrap();
176/// ```
177pub struct Hkdf<Algo>(PhantomData<Algo>);
178
179impl Algorithm {
180    /// Returns the digest length of the underlying hash function.
181    pub const fn hash_len(self) -> usize {
182        match self {
183            Algorithm::Sha256 => 32,
184            Algorithm::Sha384 => 48,
185            Algorithm::Sha512 => 64,
186        }
187    }
188}
189
190/// Generates HKDF implementation modules for specific hash algorithms.
191///
192/// This macro creates a complete HKDF implementation module for a specific SHA2 algorithm,
193/// including both typed API methods on the [`Hkdf`] struct and standalone module functions.
194/// It generates implementations for extract, expand, and full HKDF operations with both
195/// fixed-size array references and variable-size slice variants.
196///
197/// Parameters:
198///
199/// * `$struct_name` - The path to the hash algorithm type marker (e.g., `crate::Sha2_256`)
200/// * `$name` - The module name to generate (e.g., `sha2_256`)
201/// * `$string_name` - A string literal describing the algorithm (e.g., `"SHA2-256"`)
202/// * `$mode` - The corresponding [`Algorithm`] enum variant (e.g., `Algorithm::Sha256`)
203/// * `$extract` - The name of the HACL extract function (e.g., `extract_sha2_256`)
204/// * `$expand` - The name of the HACL expand function (e.g., `expand_sha2_256`)
205/// * `$hash_len` - The hash output length in bytes as a literal (e.g., `32`)
206///
207///
208/// This generates the `sha2_256` module and implements all HKDF methods for `Hkdf<Sha2_256>`.
209macro_rules! impl_hkdf {
210    ($struct_name:path, $name:ident, $string_name:literal, $mode:path, $extract:ident, $expand:ident,$hash_len:literal) => {
211        #[doc = concat!("HKDF implementation for ", $string_name, ".")]
212        ///
213        /// This module provides HKDF (HMAC-based Key Derivation Function) operations
214        /// specifically for the underlying hash algorithm. It includes both standalone
215        /// functions and methods on the typed [`Hkdf`] struct.
216        ///
217        /// The `_arrayref` variants work with compile-time known PRK sizes for better type safety,
218        /// while the regular variants accept slices and perform runtime validation.
219        pub mod $name {
220            use libcrux_secrets::U8;
221
222            use super::*;
223
224            impl Hkdf<$struct_name> {
225                /// HKDF extract using the `salt` and the input key material `ikm`.
226                /// The result is written to `prk`.
227                ///
228                /// Returns nothing on success.
229                /// Returns [`ExtractError::ArgumentTooLong`] if one of `ikm` or `salt` is longer than
230                /// [`u32::MAX`] bytes.
231                #[inline(always)]
232                pub fn extract_arrayref(
233                    prk: &mut [U8; $hash_len],
234                    salt: &[U8],
235                    ikm: &[U8],
236                ) -> Result<(), ArrayReferenceExtractError> {
237                    extract_arrayref(prk, salt, ikm)
238                }
239
240                /// HKDF extract using the `salt` and the input key material `ikm`.
241                /// The result is written to `prk`.
242                ///
243                /// Returns nothing on success.
244                /// Returns [`ExtractError::ArgumentTooLong`] if one of `ikm` or `salt` is longer than
245                /// [`u32::MAX`] bytes.
246                /// Returns [`ExtractError::PrkTooShort`] if `prk` is shorter than hash length.
247                #[inline(always)]
248                pub fn extract(
249                    prk: &mut [U8],
250                    salt: &[U8],
251                    ikm: &[U8],
252                ) -> Result<(), ExtractError> {
253                    extract(prk, salt, ikm)
254                }
255
256                /// HKDF expand using the pre-key material `prk` and `info`.
257                /// The output is written to `okm`.
258                ///
259                /// Returns nothing on success.
260                /// Returns [`ExpandError::OutputTooLong`] if `okm` is too long (longer than
261                /// `255 * hash_length`)
262                /// Returns [`ExpandError::ArgumentTooLong`] if one of `prk` or `info` is longer than
263                /// [`u32::MAX`] bytes.
264                #[inline(always)]
265                pub fn expand_arrayref(
266                    okm: &mut [U8],
267                    prk: &[U8; $hash_len],
268                    info: &[u8],
269                ) -> Result<(), ArrayReferenceExpandError> {
270                    if okm.len() > 255 * $hash_len {
271                        // Output size is too large. HACL doesn't catch this.
272                        return Err(ArrayReferenceExpandError::OutputTooLong);
273                    }
274
275                    expand_arrayref(okm, prk, info)
276                }
277
278                /// HKDF expand using the pre-key material `prk` and `info`.
279                /// The output is written to `okm`.
280                ///
281                /// Returns nothing on success.
282                /// Returns [`ExpandError::OutputTooLong`] if `okm` is too long (longer than
283                /// `255 * hash_length`)
284                /// Returns [`ExpandError::ArgumentTooLong`] if one of `prk` or `info` is longer than
285                /// [`u32::MAX`] bytes.
286                /// Returns [`ExpandError::PrkTooShort`] if `prk` is shorter than hash length.
287                #[inline(always)]
288                pub fn expand(okm: &mut [U8], prk: &[U8], info: &[u8]) -> Result<(), ExpandError> {
289                    expand(okm, prk, info)
290                }
291
292                /// Full HKDF using the `salt`, input key material `ikm`, `info`.
293                /// The result is written to `okm`.
294                /// The output length is defined through the length of `okm`.
295                /// Calls `extract` and `expand` with the given input.
296                ///
297                /// Returns nothing on success.
298                /// Returns [`ExpandError::OutputTooLong`] if `okm` is too long (longer than
299                /// `255 * hash_length`)
300                /// Returns [`ExpandError::ArgumentTooLong`] if one of `prk` or `info` is longer than
301                /// [`u32::MAX`] bytes.
302                #[inline(always)]
303                pub fn hkdf(
304                    okm: &mut [U8],
305                    salt: &[U8],
306                    ikm: &[U8],
307                    info: &[u8],
308                ) -> Result<(), ExpandError> {
309                    hkdf(okm, salt, ikm, info)
310                }
311            }
312
313            /// HKDF extract using the `salt` and the input key material `ikm`.
314            /// The result is written to `prk`.
315            ///
316            /// Returns nothing on success.
317            /// Returns [`ExtractError::ArgumentTooLong`] if one of `ikm` or `salt` is longer than
318            /// [`u32::MAX`] bytes.
319            #[inline(always)]
320            pub fn extract_arrayref(
321                prk: &mut [U8; $hash_len],
322                salt: &[U8],
323                ikm: &[U8],
324            ) -> Result<(), ArrayReferenceExtractError> {
325                Ok(crate::hacl::$extract(
326                    prk.declassify_ref_mut(),
327                    salt.declassify_ref(),
328                    checked_u32(salt.len())?,
329                    ikm.declassify_ref(),
330                    checked_u32(ikm.len())?,
331                ))
332            }
333
334            /// HKDF extract using the `salt` and the input key material `ikm`.
335            /// The result is written to `prk`.
336            ///
337            /// Returns nothing on success.
338            /// Returns [`ExtractError::ArgumentTooLong`] if one of `ikm` or `salt` is longer than
339            /// [`u32::MAX`] bytes.
340            /// Returns [`ExtractError::PrkTooShort`] if `prk` is shorter than hash length.
341            #[inline(always)]
342            pub fn extract(prk: &mut [U8], salt: &[U8], ikm: &[U8]) -> Result<(), ExtractError> {
343                let (prk, _) = prk
344                    .split_at_mut_checked($hash_len)
345                    .ok_or(ExtractError::PrkTooShort)?;
346                let prk: &mut [U8; $hash_len] =
347                    prk.try_into().map_err(|_| ExtractError::Unknown)?;
348
349                extract_arrayref(prk, salt, ikm).map_err(ExtractError::from)
350            }
351
352            /// HKDF expand using the pre-key material `prk` and `info`.
353            /// The output is written to `okm`.
354            ///
355            /// Returns nothing on success.
356            /// Returns [`ExpandError::OutputTooLong`] if `okm` is too long (longer than
357            /// `255 * hash_length`)
358            /// Returns [`ExpandError::ArgumentTooLong`] if one of `prk` or `info` is longer than
359            /// [`u32::MAX`] bytes.
360            #[inline(always)]
361            pub fn expand_arrayref(
362                mut okm: &mut [U8],
363                prk: &[U8; $hash_len],
364                info: &[u8],
365            ) -> Result<(), ArrayReferenceExpandError> {
366                let okm_len = okm.len();
367                if okm_len > 255 * $hash_len {
368                    // Output size is too large. HACL doesn't catch this.
369                    return Err(ArrayReferenceExpandError::OutputTooLong);
370                }
371
372                Ok(crate::hacl::$expand(
373                    okm.declassify_ref_mut(),
374                    prk.declassify_ref(),
375                    checked_u32(prk.len())?,
376                    info,
377                    checked_u32(info.len())?,
378                    checked_u32(okm_len)?,
379                ))
380            }
381
382            /// HKDF expand using the pre-key material `prk` and `info`.
383            /// The output is written to `okm`.
384            ///
385            /// Returns nothing on success.
386            /// Returns [`ExpandError::OutputTooLong`] if `okm` is too long (longer than
387            /// `255 * hash_length`)
388            /// Returns [`ExpandError::ArgumentTooLong`] if one of `prk` or `info` is longer than
389            /// [`u32::MAX`] bytes.
390            /// Returns [`ExpandError::PrkTooShort`] if `prk` is shorter than hash length.
391            #[inline(always)]
392            pub fn expand(okm: &mut [U8], prk: &[U8], info: &[u8]) -> Result<(), ExpandError> {
393                let (prk, _) = prk
394                    .split_at_checked($hash_len)
395                    .ok_or(ExpandError::PrkTooShort)?;
396                let prk: &[U8; $hash_len] = prk.try_into().map_err(|_| ExpandError::Unknown)?;
397
398                expand_arrayref(okm, prk, info).map_err(ExpandError::from)
399            }
400
401            /// Full HKDF using the `salt`, input key material `ikm`, `info`.
402            /// The result is written to `okm`.
403            /// The output length is defined through the length of `okm`.
404            /// Calls `extract` and `expand` with the given input.
405            ///
406            /// Returns nothing on success.
407            /// Returns [`ExpandError::OutputTooLong`] if `okm` is too long (longer than
408            /// `255 * hash_length`)
409            /// Returns [`ExpandError::ArgumentTooLong`] if one of `prk` or `info` is longer than
410            /// [`u32::MAX`] bytes.
411            #[inline(always)]
412            pub fn hkdf(
413                okm: &mut [U8],
414                salt: &[U8],
415                ikm: &[U8],
416                info: &[u8],
417            ) -> Result<(), ExpandError> {
418                let mut prk = [0u8; $hash_len].classify();
419                extract(&mut prk, salt, ikm)?;
420                expand(okm, &prk, info)
421            }
422        }
423    };
424}
425
426impl_hkdf!(
427    crate::Sha2_256,
428    sha2_256,
429    "SHA2-256",
430    Algorithm::Sha256,
431    extract_sha2_256,
432    expand_sha2_256,
433    32
434);
435
436impl_hkdf!(
437    crate::Sha2_384,
438    sha2_384,
439    "SHA2-384",
440    Algorithm::Sha384,
441    extract_sha2_384,
442    expand_sha2_384,
443    48
444);
445
446impl_hkdf!(
447    crate::Sha2_512,
448    sha2_512,
449    "SHA2-512",
450    Algorithm::Sha512,
451    extract_sha2_512,
452    expand_sha2_512,
453    64
454);
455
456fn checked_u32(num: usize) -> Result<u32, ArgumentsTooLongError> {
457    num.try_into().map_err(|_| ArgumentsTooLongError)
458}
459
460#[derive(Copy, Clone, Debug, PartialEq)]
461pub enum ArrayReferenceExtractError {
462    ArgumentTooLong,
463    Unknown,
464}
465
466#[derive(Copy, Clone, Debug, PartialEq, Eq)]
467pub enum ExtractError {
468    PrkTooShort,
469    ArgumentTooLong,
470    Unknown,
471}
472
473#[derive(Copy, Clone, Debug, PartialEq, Eq)]
474pub enum ArrayReferenceExpandError {
475    OutputTooLong,
476    ArgumentTooLong,
477    Unknown,
478}
479
480#[derive(Copy, Clone, Debug, PartialEq, Eq)]
481pub enum ExpandError {
482    OutputTooLong,
483    PrkTooShort,
484    ArgumentTooLong,
485    Unknown,
486}
487
488#[derive(Copy, Clone, Debug, PartialEq)]
489struct ArgumentsTooLongError;
490
491impl From<ArrayReferenceExtractError> for ExtractError {
492    fn from(err: ArrayReferenceExtractError) -> Self {
493        match err {
494            ArrayReferenceExtractError::ArgumentTooLong => ExtractError::ArgumentTooLong,
495            ArrayReferenceExtractError::Unknown => ExtractError::Unknown,
496        }
497    }
498}
499
500impl From<ArrayReferenceExpandError> for ExpandError {
501    fn from(err: ArrayReferenceExpandError) -> Self {
502        match err {
503            ArrayReferenceExpandError::OutputTooLong => ExpandError::OutputTooLong,
504            ArrayReferenceExpandError::ArgumentTooLong => ExpandError::ArgumentTooLong,
505            ArrayReferenceExpandError::Unknown => ExpandError::Unknown,
506        }
507    }
508}
509
510impl From<ExtractError> for ExpandError {
511    fn from(err: ExtractError) -> Self {
512        match err {
513            ExtractError::PrkTooShort => ExpandError::PrkTooShort,
514            ExtractError::ArgumentTooLong => ExpandError::ArgumentTooLong,
515            ExtractError::Unknown => ExpandError::Unknown,
516        }
517    }
518}
519
520impl From<ArgumentsTooLongError> for ArrayReferenceExtractError {
521    fn from(_: ArgumentsTooLongError) -> Self {
522        ArrayReferenceExtractError::ArgumentTooLong
523    }
524}
525impl From<ArgumentsTooLongError> for ArrayReferenceExpandError {
526    fn from(_: ArgumentsTooLongError) -> Self {
527        ArrayReferenceExpandError::ArgumentTooLong
528    }
529}
530impl From<ArgumentsTooLongError> for ExtractError {
531    fn from(_: ArgumentsTooLongError) -> Self {
532        ExtractError::ArgumentTooLong
533    }
534}
535impl From<ArgumentsTooLongError> for ExpandError {
536    fn from(_: ArgumentsTooLongError) -> Self {
537        ExpandError::ArgumentTooLong
538    }
539}