Skip to main content

hanzo_crypto/
kdf.rs

1//! Key Derivation Functions (KDF)
2//! SP 800-56C compliant HKDF and SP 800-108 compliant KDF
3
4use crate::{PqcError, Result};
5use hkdf::Hkdf;
6use serde::{Deserialize, Serialize};
7use sha2::{Sha256, Sha384, Sha512};
8use sha3::{Sha3_256, Sha3_384, Sha3_512};
9
10/// KDF algorithms (SP 800-56C and SP 800-108 compliant)
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
12pub enum KdfAlgorithm {
13    /// HKDF with SHA-256 (128-bit security)
14    HkdfSha256,
15    /// HKDF with SHA-384 (192-bit security) - RECOMMENDED for ML-KEM-768
16    HkdfSha384,
17    /// HKDF with SHA-512 (256-bit security)
18    HkdfSha512,
19    /// HKDF with SHA3-256 (128-bit security)
20    HkdfSha3_256,
21    /// HKDF with SHA3-384 (192-bit security)
22    HkdfSha3_384,
23    /// HKDF with SHA3-512 (256-bit security)
24    HkdfSha3_512,
25    /// BLAKE3 KDF (256-bit security)
26    Blake3Kdf,
27}
28
29impl Default for KdfAlgorithm {
30    fn default() -> Self {
31        Self::HkdfSha384 // Matches ML-KEM-768 security level
32    }
33}
34
35/// KDF trait for key derivation operations
36pub trait Kdf {
37    /// Extract a pseudorandom key from input keying material
38    fn extract(&self, salt: Option<&[u8]>, ikm: &[u8]) -> Vec<u8>;
39
40    /// Expand a pseudorandom key to desired length
41    fn expand(&self, prk: &[u8], info: &[u8], okm_len: usize) -> Result<Vec<u8>>;
42
43    /// Combined extract-and-expand operation
44    fn derive(
45        &self,
46        salt: Option<&[u8]>,
47        ikm: &[u8],
48        info: &[u8],
49        okm_len: usize,
50    ) -> Result<Vec<u8>>;
51}
52
53/// Generic HKDF implementation
54pub struct HkdfKdf {
55    algorithm: KdfAlgorithm,
56}
57
58impl HkdfKdf {
59    pub fn new(algorithm: KdfAlgorithm) -> Self {
60        Self { algorithm }
61    }
62}
63
64impl Kdf for HkdfKdf {
65    fn extract(&self, salt: Option<&[u8]>, ikm: &[u8]) -> Vec<u8> {
66        match self.algorithm {
67            KdfAlgorithm::HkdfSha256 => {
68                let (prk, _) = Hkdf::<Sha256>::extract(salt, ikm);
69                prk.to_vec()
70            }
71            KdfAlgorithm::HkdfSha384 => {
72                let (prk, _) = Hkdf::<Sha384>::extract(salt, ikm);
73                prk.to_vec()
74            }
75            KdfAlgorithm::HkdfSha512 => {
76                let (prk, _) = Hkdf::<Sha512>::extract(salt, ikm);
77                prk.to_vec()
78            }
79            KdfAlgorithm::HkdfSha3_256 => {
80                let (prk, _) = Hkdf::<Sha3_256>::extract(salt, ikm);
81                prk.to_vec()
82            }
83            KdfAlgorithm::HkdfSha3_384 => {
84                let (prk, _) = Hkdf::<Sha3_384>::extract(salt, ikm);
85                prk.to_vec()
86            }
87            KdfAlgorithm::HkdfSha3_512 => {
88                let (prk, _) = Hkdf::<Sha3_512>::extract(salt, ikm);
89                prk.to_vec()
90            }
91            KdfAlgorithm::Blake3Kdf => {
92                // BLAKE3 has its own KDF mode
93                let key = blake3::derive_key(
94                    salt.map(|s| std::str::from_utf8(s).unwrap_or("hanzo-pqc"))
95                        .unwrap_or("hanzo-pqc"),
96                    ikm,
97                );
98                key.to_vec()
99            }
100        }
101    }
102
103    fn expand(&self, prk: &[u8], info: &[u8], okm_len: usize) -> Result<Vec<u8>> {
104        let mut okm = vec![0u8; okm_len];
105
106        match self.algorithm {
107            KdfAlgorithm::HkdfSha256 => {
108                let hk = Hkdf::<Sha256>::from_prk(prk)
109                    .map_err(|_| PqcError::KdfError("Invalid PRK length for SHA256".into()))?;
110                hk.expand(info, &mut okm)
111                    .map_err(|_| PqcError::KdfError("HKDF expand failed".into()))?;
112            }
113            KdfAlgorithm::HkdfSha384 => {
114                let hk = Hkdf::<Sha384>::from_prk(prk)
115                    .map_err(|_| PqcError::KdfError("Invalid PRK length for SHA384".into()))?;
116                hk.expand(info, &mut okm)
117                    .map_err(|_| PqcError::KdfError("HKDF expand failed".into()))?;
118            }
119            KdfAlgorithm::HkdfSha512 => {
120                let hk = Hkdf::<Sha512>::from_prk(prk)
121                    .map_err(|_| PqcError::KdfError("Invalid PRK length for SHA512".into()))?;
122                hk.expand(info, &mut okm)
123                    .map_err(|_| PqcError::KdfError("HKDF expand failed".into()))?;
124            }
125            KdfAlgorithm::HkdfSha3_256 => {
126                let hk = Hkdf::<Sha3_256>::from_prk(prk)
127                    .map_err(|_| PqcError::KdfError("Invalid PRK length for SHA3-256".into()))?;
128                hk.expand(info, &mut okm)
129                    .map_err(|_| PqcError::KdfError("HKDF expand failed".into()))?;
130            }
131            KdfAlgorithm::HkdfSha3_384 => {
132                let hk = Hkdf::<Sha3_384>::from_prk(prk)
133                    .map_err(|_| PqcError::KdfError("Invalid PRK length for SHA3-384".into()))?;
134                hk.expand(info, &mut okm)
135                    .map_err(|_| PqcError::KdfError("HKDF expand failed".into()))?;
136            }
137            KdfAlgorithm::HkdfSha3_512 => {
138                let hk = Hkdf::<Sha3_512>::from_prk(prk)
139                    .map_err(|_| PqcError::KdfError("Invalid PRK length for SHA3-512".into()))?;
140                hk.expand(info, &mut okm)
141                    .map_err(|_| PqcError::KdfError("HKDF expand failed".into()))?;
142            }
143            KdfAlgorithm::Blake3Kdf => {
144                // BLAKE3 XOF mode for expansion
145                let mut hasher = blake3::Hasher::new_keyed(
146                    &<[u8; 32]>::try_from(&prk[..32])
147                        .map_err(|_| PqcError::KdfError("BLAKE3 requires 32-byte key".into()))?,
148                );
149                hasher.update(info);
150                let mut output = hasher.finalize_xof();
151                output.fill(&mut okm);
152            }
153        }
154
155        Ok(okm)
156    }
157
158    fn derive(
159        &self,
160        salt: Option<&[u8]>,
161        ikm: &[u8],
162        info: &[u8],
163        okm_len: usize,
164    ) -> Result<Vec<u8>> {
165        let prk = self.extract(salt, ikm);
166        self.expand(&prk, info, okm_len)
167    }
168}
169
170/// Combine multiple shared secrets (for hybrid mode)
171/// Per SP 800-56C Rev 2, Section 5.9.3
172pub fn combine_shared_secrets(
173    kdf: &impl Kdf,
174    secrets: &[&[u8]],
175    context: &[u8],
176    output_len: usize,
177) -> Result<Vec<u8>> {
178    // Concatenate all secrets with length prefixes
179    let mut combined = Vec::new();
180    for secret in secrets {
181        combined.extend_from_slice(&(secret.len() as u32).to_be_bytes());
182        combined.extend_from_slice(secret);
183    }
184
185    // Derive final key material with context
186    kdf.derive(None, &combined, context, output_len)
187}
188
189/// Domain separation for different protocol contexts
190pub fn domain_separate(
191    kdf: &impl Kdf,
192    key_material: &[u8],
193    domain: &str,
194    output_len: usize,
195) -> Result<Vec<u8>> {
196    let info = format!("hanzo-pqc-v1|{}", domain);
197    kdf.expand(key_material, info.as_bytes(), output_len)
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    #[test]
205    fn test_hkdf_sha384() {
206        let kdf = HkdfKdf::new(KdfAlgorithm::HkdfSha384);
207
208        let ikm = b"input keying material";
209        let salt = b"salt";
210        let info = b"hanzo-test-v1";
211
212        let okm = kdf.derive(Some(salt), ikm, info, 64).unwrap();
213        assert_eq!(okm.len(), 64);
214
215        // Verify deterministic
216        let okm2 = kdf.derive(Some(salt), ikm, info, 64).unwrap();
217        assert_eq!(okm, okm2);
218    }
219
220    #[test]
221    fn test_combine_secrets() {
222        let kdf = HkdfKdf::new(KdfAlgorithm::HkdfSha384);
223
224        let secret1 = vec![1u8; 32]; // ML-KEM shared secret
225        let secret2 = vec![2u8; 32]; // X25519 shared secret
226
227        let combined =
228            combine_shared_secrets(&kdf, &[&secret1, &secret2], b"hanzo-hybrid-v1", 48).unwrap();
229
230        assert_eq!(combined.len(), 48);
231    }
232}