1use crate::{PqcError, Result};
5use hkdf::Hkdf;
6use serde::{Deserialize, Serialize};
7use sha2::{Sha256, Sha384, Sha512};
8use sha3::{Sha3_256, Sha3_384, Sha3_512};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
12pub enum KdfAlgorithm {
13 HkdfSha256,
15 HkdfSha384,
17 HkdfSha512,
19 HkdfSha3_256,
21 HkdfSha3_384,
23 HkdfSha3_512,
25 Blake3Kdf,
27}
28
29impl Default for KdfAlgorithm {
30 fn default() -> Self {
31 Self::HkdfSha384 }
33}
34
35pub trait Kdf {
37 fn extract(&self, salt: Option<&[u8]>, ikm: &[u8]) -> Vec<u8>;
39
40 fn expand(&self, prk: &[u8], info: &[u8], okm_len: usize) -> Result<Vec<u8>>;
42
43 fn derive(
45 &self,
46 salt: Option<&[u8]>,
47 ikm: &[u8],
48 info: &[u8],
49 okm_len: usize,
50 ) -> Result<Vec<u8>>;
51}
52
53pub struct HkdfKdf {
55 algorithm: KdfAlgorithm,
56}
57
58impl HkdfKdf {
59 pub fn new(algorithm: KdfAlgorithm) -> Self {
60 Self { algorithm }
61 }
62}
63
64impl Kdf for HkdfKdf {
65 fn extract(&self, salt: Option<&[u8]>, ikm: &[u8]) -> Vec<u8> {
66 match self.algorithm {
67 KdfAlgorithm::HkdfSha256 => {
68 let (prk, _) = Hkdf::<Sha256>::extract(salt, ikm);
69 prk.to_vec()
70 }
71 KdfAlgorithm::HkdfSha384 => {
72 let (prk, _) = Hkdf::<Sha384>::extract(salt, ikm);
73 prk.to_vec()
74 }
75 KdfAlgorithm::HkdfSha512 => {
76 let (prk, _) = Hkdf::<Sha512>::extract(salt, ikm);
77 prk.to_vec()
78 }
79 KdfAlgorithm::HkdfSha3_256 => {
80 let (prk, _) = Hkdf::<Sha3_256>::extract(salt, ikm);
81 prk.to_vec()
82 }
83 KdfAlgorithm::HkdfSha3_384 => {
84 let (prk, _) = Hkdf::<Sha3_384>::extract(salt, ikm);
85 prk.to_vec()
86 }
87 KdfAlgorithm::HkdfSha3_512 => {
88 let (prk, _) = Hkdf::<Sha3_512>::extract(salt, ikm);
89 prk.to_vec()
90 }
91 KdfAlgorithm::Blake3Kdf => {
92 let key = blake3::derive_key(
94 salt.map(|s| std::str::from_utf8(s).unwrap_or("hanzo-pqc"))
95 .unwrap_or("hanzo-pqc"),
96 ikm,
97 );
98 key.to_vec()
99 }
100 }
101 }
102
103 fn expand(&self, prk: &[u8], info: &[u8], okm_len: usize) -> Result<Vec<u8>> {
104 let mut okm = vec![0u8; okm_len];
105
106 match self.algorithm {
107 KdfAlgorithm::HkdfSha256 => {
108 let hk = Hkdf::<Sha256>::from_prk(prk)
109 .map_err(|_| PqcError::KdfError("Invalid PRK length for SHA256".into()))?;
110 hk.expand(info, &mut okm)
111 .map_err(|_| PqcError::KdfError("HKDF expand failed".into()))?;
112 }
113 KdfAlgorithm::HkdfSha384 => {
114 let hk = Hkdf::<Sha384>::from_prk(prk)
115 .map_err(|_| PqcError::KdfError("Invalid PRK length for SHA384".into()))?;
116 hk.expand(info, &mut okm)
117 .map_err(|_| PqcError::KdfError("HKDF expand failed".into()))?;
118 }
119 KdfAlgorithm::HkdfSha512 => {
120 let hk = Hkdf::<Sha512>::from_prk(prk)
121 .map_err(|_| PqcError::KdfError("Invalid PRK length for SHA512".into()))?;
122 hk.expand(info, &mut okm)
123 .map_err(|_| PqcError::KdfError("HKDF expand failed".into()))?;
124 }
125 KdfAlgorithm::HkdfSha3_256 => {
126 let hk = Hkdf::<Sha3_256>::from_prk(prk)
127 .map_err(|_| PqcError::KdfError("Invalid PRK length for SHA3-256".into()))?;
128 hk.expand(info, &mut okm)
129 .map_err(|_| PqcError::KdfError("HKDF expand failed".into()))?;
130 }
131 KdfAlgorithm::HkdfSha3_384 => {
132 let hk = Hkdf::<Sha3_384>::from_prk(prk)
133 .map_err(|_| PqcError::KdfError("Invalid PRK length for SHA3-384".into()))?;
134 hk.expand(info, &mut okm)
135 .map_err(|_| PqcError::KdfError("HKDF expand failed".into()))?;
136 }
137 KdfAlgorithm::HkdfSha3_512 => {
138 let hk = Hkdf::<Sha3_512>::from_prk(prk)
139 .map_err(|_| PqcError::KdfError("Invalid PRK length for SHA3-512".into()))?;
140 hk.expand(info, &mut okm)
141 .map_err(|_| PqcError::KdfError("HKDF expand failed".into()))?;
142 }
143 KdfAlgorithm::Blake3Kdf => {
144 let mut hasher = blake3::Hasher::new_keyed(
146 &<[u8; 32]>::try_from(&prk[..32])
147 .map_err(|_| PqcError::KdfError("BLAKE3 requires 32-byte key".into()))?,
148 );
149 hasher.update(info);
150 let mut output = hasher.finalize_xof();
151 output.fill(&mut okm);
152 }
153 }
154
155 Ok(okm)
156 }
157
158 fn derive(
159 &self,
160 salt: Option<&[u8]>,
161 ikm: &[u8],
162 info: &[u8],
163 okm_len: usize,
164 ) -> Result<Vec<u8>> {
165 let prk = self.extract(salt, ikm);
166 self.expand(&prk, info, okm_len)
167 }
168}
169
170pub fn combine_shared_secrets(
173 kdf: &impl Kdf,
174 secrets: &[&[u8]],
175 context: &[u8],
176 output_len: usize,
177) -> Result<Vec<u8>> {
178 let mut combined = Vec::new();
180 for secret in secrets {
181 combined.extend_from_slice(&(secret.len() as u32).to_be_bytes());
182 combined.extend_from_slice(secret);
183 }
184
185 kdf.derive(None, &combined, context, output_len)
187}
188
189pub fn domain_separate(
191 kdf: &impl Kdf,
192 key_material: &[u8],
193 domain: &str,
194 output_len: usize,
195) -> Result<Vec<u8>> {
196 let info = format!("hanzo-pqc-v1|{}", domain);
197 kdf.expand(key_material, info.as_bytes(), output_len)
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 #[test]
205 fn test_hkdf_sha384() {
206 let kdf = HkdfKdf::new(KdfAlgorithm::HkdfSha384);
207
208 let ikm = b"input keying material";
209 let salt = b"salt";
210 let info = b"hanzo-test-v1";
211
212 let okm = kdf.derive(Some(salt), ikm, info, 64).unwrap();
213 assert_eq!(okm.len(), 64);
214
215 let okm2 = kdf.derive(Some(salt), ikm, info, 64).unwrap();
217 assert_eq!(okm, okm2);
218 }
219
220 #[test]
221 fn test_combine_secrets() {
222 let kdf = HkdfKdf::new(KdfAlgorithm::HkdfSha384);
223
224 let secret1 = vec![1u8; 32]; let secret2 = vec![2u8; 32]; let combined =
228 combine_shared_secrets(&kdf, &[&secret1, &secret2], b"hanzo-hybrid-v1", 48).unwrap();
229
230 assert_eq!(combined.len(), 48);
231 }
232}