1use hkdf::Hkdf;
5use sha2::{Sha256, Sha384, Sha512};
6use sha3::{Sha3_256, Sha3_384, Sha3_512};
7use serde::{Deserialize, Serialize};
8use crate::{PqcError, Result};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
12pub enum KdfAlgorithm {
13 HkdfSha256,
15 HkdfSha384,
17 HkdfSha512,
19 HkdfSha3_256,
21 HkdfSha3_384,
23 HkdfSha3_512,
25 Blake3Kdf,
27}
28
29impl Default for KdfAlgorithm {
30 fn default() -> Self {
31 Self::HkdfSha384 }
33}
34
35pub trait Kdf {
37 fn extract(&self, salt: Option<&[u8]>, ikm: &[u8]) -> Vec<u8>;
39
40 fn expand(&self, prk: &[u8], info: &[u8], okm_len: usize) -> Result<Vec<u8>>;
42
43 fn derive(&self, salt: Option<&[u8]>, ikm: &[u8], info: &[u8], okm_len: usize) -> Result<Vec<u8>>;
45}
46
47pub struct HkdfKdf {
49 algorithm: KdfAlgorithm,
50}
51
52impl HkdfKdf {
53 pub fn new(algorithm: KdfAlgorithm) -> Self {
54 Self { algorithm }
55 }
56}
57
58impl Kdf for HkdfKdf {
59 fn extract(&self, salt: Option<&[u8]>, ikm: &[u8]) -> Vec<u8> {
60 match self.algorithm {
61 KdfAlgorithm::HkdfSha256 => {
62 let (prk, _) = Hkdf::<Sha256>::extract(salt, ikm);
63 prk.to_vec()
64 }
65 KdfAlgorithm::HkdfSha384 => {
66 let (prk, _) = Hkdf::<Sha384>::extract(salt, ikm);
67 prk.to_vec()
68 }
69 KdfAlgorithm::HkdfSha512 => {
70 let (prk, _) = Hkdf::<Sha512>::extract(salt, ikm);
71 prk.to_vec()
72 }
73 KdfAlgorithm::HkdfSha3_256 => {
74 let (prk, _) = Hkdf::<Sha3_256>::extract(salt, ikm);
75 prk.to_vec()
76 }
77 KdfAlgorithm::HkdfSha3_384 => {
78 let (prk, _) = Hkdf::<Sha3_384>::extract(salt, ikm);
79 prk.to_vec()
80 }
81 KdfAlgorithm::HkdfSha3_512 => {
82 let (prk, _) = Hkdf::<Sha3_512>::extract(salt, ikm);
83 prk.to_vec()
84 }
85 KdfAlgorithm::Blake3Kdf => {
86 let key = blake3::derive_key(
88 salt.map(|s| std::str::from_utf8(s).unwrap_or("hanzo-pqc")).unwrap_or("hanzo-pqc"),
89 ikm,
90 );
91 key.to_vec()
92 }
93 }
94 }
95
96 fn expand(&self, prk: &[u8], info: &[u8], okm_len: usize) -> Result<Vec<u8>> {
97 let mut okm = vec![0u8; okm_len];
98
99 match self.algorithm {
100 KdfAlgorithm::HkdfSha256 => {
101 let hk = Hkdf::<Sha256>::from_prk(prk)
102 .map_err(|_| PqcError::KdfError("Invalid PRK length for SHA256".into()))?;
103 hk.expand(info, &mut okm)
104 .map_err(|_| PqcError::KdfError("HKDF expand failed".into()))?;
105 }
106 KdfAlgorithm::HkdfSha384 => {
107 let hk = Hkdf::<Sha384>::from_prk(prk)
108 .map_err(|_| PqcError::KdfError("Invalid PRK length for SHA384".into()))?;
109 hk.expand(info, &mut okm)
110 .map_err(|_| PqcError::KdfError("HKDF expand failed".into()))?;
111 }
112 KdfAlgorithm::HkdfSha512 => {
113 let hk = Hkdf::<Sha512>::from_prk(prk)
114 .map_err(|_| PqcError::KdfError("Invalid PRK length for SHA512".into()))?;
115 hk.expand(info, &mut okm)
116 .map_err(|_| PqcError::KdfError("HKDF expand failed".into()))?;
117 }
118 KdfAlgorithm::HkdfSha3_256 => {
119 let hk = Hkdf::<Sha3_256>::from_prk(prk)
120 .map_err(|_| PqcError::KdfError("Invalid PRK length for SHA3-256".into()))?;
121 hk.expand(info, &mut okm)
122 .map_err(|_| PqcError::KdfError("HKDF expand failed".into()))?;
123 }
124 KdfAlgorithm::HkdfSha3_384 => {
125 let hk = Hkdf::<Sha3_384>::from_prk(prk)
126 .map_err(|_| PqcError::KdfError("Invalid PRK length for SHA3-384".into()))?;
127 hk.expand(info, &mut okm)
128 .map_err(|_| PqcError::KdfError("HKDF expand failed".into()))?;
129 }
130 KdfAlgorithm::HkdfSha3_512 => {
131 let hk = Hkdf::<Sha3_512>::from_prk(prk)
132 .map_err(|_| PqcError::KdfError("Invalid PRK length for SHA3-512".into()))?;
133 hk.expand(info, &mut okm)
134 .map_err(|_| PqcError::KdfError("HKDF expand failed".into()))?;
135 }
136 KdfAlgorithm::Blake3Kdf => {
137 let mut hasher = blake3::Hasher::new_keyed(
139 &<[u8; 32]>::try_from(&prk[..32])
140 .map_err(|_| PqcError::KdfError("BLAKE3 requires 32-byte key".into()))?
141 );
142 hasher.update(info);
143 let mut output = hasher.finalize_xof();
144 output.fill(&mut okm);
145 }
146 }
147
148 Ok(okm)
149 }
150
151 fn derive(&self, salt: Option<&[u8]>, ikm: &[u8], info: &[u8], okm_len: usize) -> Result<Vec<u8>> {
152 let prk = self.extract(salt, ikm);
153 self.expand(&prk, info, okm_len)
154 }
155}
156
157pub fn combine_shared_secrets(
160 kdf: &impl Kdf,
161 secrets: &[&[u8]],
162 context: &[u8],
163 output_len: usize,
164) -> Result<Vec<u8>> {
165 let mut combined = Vec::new();
167 for secret in secrets {
168 combined.extend_from_slice(&(secret.len() as u32).to_be_bytes());
169 combined.extend_from_slice(secret);
170 }
171
172 kdf.derive(None, &combined, context, output_len)
174}
175
176pub fn domain_separate(
178 kdf: &impl Kdf,
179 key_material: &[u8],
180 domain: &str,
181 output_len: usize,
182) -> Result<Vec<u8>> {
183 let info = format!("hanzo-pqc-v1|{domain}");
184 kdf.expand(key_material, info.as_bytes(), output_len)
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190
191 #[test]
192 fn test_hkdf_sha384() {
193 let kdf = HkdfKdf::new(KdfAlgorithm::HkdfSha384);
194
195 let ikm = b"input keying material";
196 let salt = b"salt";
197 let info = b"hanzo-test-v1";
198
199 let okm = kdf.derive(Some(salt), ikm, info, 64).unwrap();
200 assert_eq!(okm.len(), 64);
201
202 let okm2 = kdf.derive(Some(salt), ikm, info, 64).unwrap();
204 assert_eq!(okm, okm2);
205 }
206
207 #[test]
208 fn test_combine_secrets() {
209 let kdf = HkdfKdf::new(KdfAlgorithm::HkdfSha384);
210
211 let secret1 = vec![1u8; 32]; let secret2 = vec![2u8; 32]; let combined = combine_shared_secrets(
215 &kdf,
216 &[&secret1, &secret2],
217 b"hanzo-hybrid-v1",
218 48,
219 ).unwrap();
220
221 assert_eq!(combined.len(), 48);
222 }
223}