dcrypt_algorithms/kdf/hkdf/
mod.rs

1//! HMAC-based Key Derivation Function (HKDF)
2//!
3//! This module implements HKDF as defined in RFC 5869.
4//! HKDF is designed to take input keying material (IKM) that is not necessarily
5//! uniform and produce output keying material (OKM) suitable for use in cryptographic
6//! contexts.
7
8use crate::error::{validate, Error, Result};
9use crate::hash::HashFunction;
10use crate::kdf::{KdfAlgorithm, KdfOperation, KeyDerivationFunction, ParamProvider, SecurityLevel};
11use crate::mac::hmac::Hmac;
12use crate::types::salt::HkdfCompatible;
13use crate::types::Salt;
14
15// Import security types from dcrypt-core
16use dcrypt_common::security::{EphemeralSecret, SecureZeroingType};
17
18use rand::{CryptoRng, RngCore};
19use zeroize::{Zeroize, ZeroizeOnDrop, Zeroizing};
20
21#[cfg(not(feature = "std"))]
22use alloc::vec::Vec;
23use std::marker::PhantomData;
24
25/// Type-level constants for HKDF algorithm
26pub enum HkdfAlgorithm<H: HashFunction> {
27    /// Phantom field for the hash function
28    _Hash(PhantomData<H>),
29}
30
31impl<H: HashFunction> KdfAlgorithm for HkdfAlgorithm<H> {
32    const MIN_SALT_SIZE: usize = 16;
33    const DEFAULT_OUTPUT_SIZE: usize = 32;
34    const ALGORITHM_ID: &'static str = "HKDF";
35
36    fn name() -> String {
37        format!("{}-{}", Self::ALGORITHM_ID, H::name())
38    }
39
40    fn security_level() -> SecurityLevel {
41        match H::output_size() * 8 {
42            bits if bits >= 512 => SecurityLevel::L256,
43            bits if bits >= 384 => SecurityLevel::L192,
44            bits if bits >= 256 => SecurityLevel::L128,
45            bits => SecurityLevel::Custom(bits as u32 / 2),
46        }
47    }
48}
49
50/// Parameters for HKDF
51#[derive(Clone, Debug, Zeroize)]
52pub struct HkdfParams<const S: usize = 16> {
53    /// Optional default salt (can be overridden in derive_key)
54    pub salt: Option<Salt<S>>,
55    /// Optional default info (context, can be overridden in derive_key)
56    pub info: Option<Zeroizing<Vec<u8>>>,
57}
58
59impl<const S: usize> Default for HkdfParams<S> {
60    fn default() -> Self {
61        Self {
62            salt: None,
63            info: None,
64        }
65    }
66}
67
68/// HKDF implementation using any hash function
69#[derive(Clone, Zeroize, ZeroizeOnDrop)]
70pub struct Hkdf<H: HashFunction, const S: usize = 16> {
71    _hash_type: PhantomData<H>,
72    params: HkdfParams<S>,
73}
74
75/// Operation for HKDF operations
76pub struct HkdfOperation<'a, H: HashFunction, const S: usize = 16> {
77    #[allow(dead_code)] // Kept for potential future use and API consistency
78    kdf: &'a Hkdf<H, S>,
79    ikm: Option<&'a [u8]>,
80    salt: Option<&'a [u8]>,
81    info: Option<&'a [u8]>,
82    length: usize,
83}
84
85impl<'a, H: HashFunction + Clone, const S: usize> KdfOperation<'a, HkdfAlgorithm<H>>
86    for HkdfOperation<'a, H, S>
87where
88    Salt<S>: HkdfCompatible,
89{
90    fn with_ikm(mut self, ikm: &'a [u8]) -> Self {
91        self.ikm = Some(ikm);
92        self
93    }
94
95    fn with_salt(mut self, salt: &'a [u8]) -> Self {
96        self.salt = Some(salt);
97        self
98    }
99
100    fn with_info(mut self, info: &'a [u8]) -> Self {
101        self.info = Some(info);
102        self
103    }
104
105    fn with_output_length(mut self, length: usize) -> Self {
106        self.length = length;
107        self
108    }
109
110    fn derive(self) -> Result<Vec<u8>> {
111        let ikm = self
112            .ikm
113            .ok_or_else(|| Error::param("ikm", "Input keying material is required"))?;
114
115        let salt_bytes = self.salt;
116        let info_bytes = self.info;
117
118        // Fix: Convert Zeroizing<Vec<u8>> to Vec<u8>
119        Hkdf::<H, S>::derive(salt_bytes, ikm, info_bytes, self.length).map(|result| result.to_vec())
120    }
121
122    fn derive_array<const N: usize>(self) -> Result<[u8; N]> {
123        // Ensure the requested size matches
124        validate::length("HKDF output", self.length, N)?;
125
126        let vec = self.derive()?;
127
128        // Convert to fixed-size array
129        let mut array = [0u8; N];
130        array.copy_from_slice(&vec);
131        Ok(array)
132    }
133}
134
135impl<H: HashFunction + Clone, const S: usize> Hkdf<H, S>
136where
137    Salt<S>: HkdfCompatible,
138{
139    /// HKDF-Extract
140    pub fn extract(salt: Option<&[u8]>, ikm: &[u8]) -> Result<Zeroizing<Vec<u8>>> {
141        // Convert salt to owned Vec to wrap in EphemeralSecret
142        let salt_vec = salt.unwrap_or(&[]).to_vec();
143        let secure_salt = EphemeralSecret::new(salt_vec);
144
145        // Use HMAC with secure salt
146        let result = Hmac::<H>::mac(&secure_salt, ikm)?;
147        Ok(Zeroizing::new(result))
148    }
149
150    /// HKDF-Expand
151    pub fn expand(prk: &[u8], info: Option<&[u8]>, length: usize) -> Result<Zeroizing<Vec<u8>>> {
152        let hash_len = H::output_size();
153        let max_len = 255 * hash_len;
154
155        // Specified max-length check (length is public)
156        validate::max_length("HKDF-Expand output", length, max_len)?;
157
158        // PRK length check (must be at least one hash block)
159        validate::min_length("PRK for HKDF-Expand", prk.len(), hash_len)?;
160
161        // Number of blocks needed - FIXED: Using div_ceil
162        let n = length.div_ceil(hash_len);
163
164        // Pre-allocate OKM buffer and temporary block buffer
165        let mut okm = Zeroizing::new(vec![0u8; n * hash_len]);
166        let mut t_buf = Zeroizing::new(vec![0u8; hash_len]);
167        let info_bytes = info.unwrap_or(&[]);
168
169        // Convert PRK to owned Vec to wrap in EphemeralSecret
170        let prk_vec = prk.to_vec();
171        let secure_prk = EphemeralSecret::new(prk_vec);
172
173        for i in 1..=n {
174            let mut hmac = Hmac::<H>::new(&secure_prk)?;
175            if i > 1 {
176                // feed previous block for iterations > 1
177                hmac.update(&t_buf)?;
178            }
179            hmac.update(info_bytes)?;
180            hmac.update(&[i as u8])?;
181            let block = hmac.finalize()?;
182            t_buf.copy_from_slice(&block);
183            let start = (i - 1) * hash_len;
184            okm[start..start + hash_len].copy_from_slice(&t_buf);
185        }
186
187        okm.truncate(length);
188        Ok(okm)
189    }
190
191    /// Full HKDF (Extract + Expand) with warm-up
192    pub fn derive(
193        salt: Option<&[u8]>,
194        ikm: &[u8],
195        info: Option<&[u8]>,
196        length: usize,
197    ) -> Result<Zeroizing<Vec<u8>>> {
198        let _ = Hmac::<H>::new(&[])?; // warm-up
199
200        // Extract phase - produces PRK
201        let prk = Self::extract(salt, ikm)?;
202
203        // Expand phase - uses PRK to generate OKM
204        Self::expand(&prk, info, length)
205    }
206}
207
208impl<H: HashFunction, const S: usize> ParamProvider for Hkdf<H, S>
209where
210    Salt<S>: HkdfCompatible,
211{
212    type Params = HkdfParams<S>;
213    fn with_params(params: Self::Params) -> Self {
214        Hkdf {
215            _hash_type: PhantomData,
216            params,
217        }
218    }
219    fn params(&self) -> &Self::Params {
220        &self.params
221    }
222    fn set_params(&mut self, params: Self::Params) {
223        self.params = params;
224    }
225}
226
227impl<H: HashFunction + Clone, const S: usize> KeyDerivationFunction for Hkdf<H, S>
228where
229    Salt<S>: HkdfCompatible,
230{
231    type Algorithm = HkdfAlgorithm<H>;
232    type Salt = Salt<S>;
233
234    fn new() -> Self {
235        Hkdf {
236            _hash_type: PhantomData,
237            params: HkdfParams::default(),
238        }
239    }
240
241    fn derive_key(
242        &self,
243        input: &[u8],
244        salt: Option<&[u8]>,
245        info: Option<&[u8]>,
246        length: usize,
247    ) -> Result<Vec<u8>> {
248        let effective_salt = salt.or_else(|| self.params.salt.as_ref().map(|s| s.as_ref()));
249        let effective_info = info.or_else(|| self.params.info.as_ref().map(|i| i.as_slice()));
250        let result = Self::derive(effective_salt, input, effective_info, length)?;
251        Ok(result.to_vec())
252    }
253
254    // FIXED: Elided lifetime
255    fn builder(&self) -> impl KdfOperation<'_, Self::Algorithm> {
256        HkdfOperation {
257            kdf: self,
258            ikm: None,
259            salt: None,
260            info: None,
261            length: Self::Algorithm::DEFAULT_OUTPUT_SIZE,
262        }
263    }
264
265    fn generate_salt<R: RngCore + CryptoRng>(rng: &mut R) -> Self::Salt {
266        Salt::random_with_size(rng, Self::Algorithm::MIN_SALT_SIZE).expect("Salt generation failed")
267    }
268
269    // Changed from instance method to static method
270    fn security_level() -> SecurityLevel {
271        match H::output_size() * 8 {
272            bits if bits >= 512 => SecurityLevel::L256,
273            bits if bits >= 384 => SecurityLevel::L192,
274            bits if bits >= 256 => SecurityLevel::L128,
275            bits => SecurityLevel::Custom(bits as u32 / 2),
276        }
277    }
278}
279
280impl<H: HashFunction + Clone, const S: usize> SecureZeroingType for Hkdf<H, S>
281where
282    Salt<S>: HkdfCompatible,
283{
284    fn zeroed() -> Self {
285        Self {
286            _hash_type: PhantomData,
287            params: HkdfParams::default(),
288        }
289    }
290
291    fn secure_clone(&self) -> Self {
292        self.clone()
293    }
294}
295
296#[cfg(test)]
297mod tests;