1use hkdf::Hkdf;
6use sha2::Sha256;
7use zeroize::{Zeroize, ZeroizeOnDrop};
8
9use crate::{CryptoError, Result, CipherSuite, DataChannelKey};
10
11#[derive(ZeroizeOnDrop)]
13pub struct KeyMaterial {
14 pub client_write_key: [u8; 32],
16 pub server_write_key: [u8; 32],
18 pub client_hmac_key: [u8; 32],
20 pub server_hmac_key: [u8; 32],
22}
23
24#[inline]
32pub fn derive_keys(
33 shared_secret: &[u8],
34 client_random: &[u8; 32],
35 server_random: &[u8; 32],
36 info: &[u8],
37) -> Result<KeyMaterial> {
38 let mut salt = [0u8; 64];
40 salt[..32].copy_from_slice(client_random);
41 salt[32..].copy_from_slice(server_random);
42
43 let hkdf = Hkdf::<Sha256>::new(Some(&salt), shared_secret);
44
45 let mut okm = [0u8; 128];
47 hkdf.expand(info, &mut okm)
48 .map_err(|_| CryptoError::KeyDerivationFailed("HKDF expansion failed"))?;
49
50 let mut material = KeyMaterial {
51 client_write_key: [0u8; 32],
52 server_write_key: [0u8; 32],
53 client_hmac_key: [0u8; 32],
54 server_hmac_key: [0u8; 32],
55 };
56
57 material.client_write_key.copy_from_slice(&okm[0..32]);
58 material.server_write_key.copy_from_slice(&okm[32..64]);
59 material.client_hmac_key.copy_from_slice(&okm[64..96]);
60 material.server_hmac_key.copy_from_slice(&okm[96..128]);
61
62 okm.zeroize();
64 salt.zeroize();
65
66 Ok(material)
67}
68
69impl KeyMaterial {
70 pub fn client_data_key(&self, suite: CipherSuite) -> DataChannelKey {
72 DataChannelKey::new(self.client_write_key, suite)
73 }
74
75 pub fn server_data_key(&self, suite: CipherSuite) -> DataChannelKey {
77 DataChannelKey::new(self.server_write_key, suite)
78 }
79}
80
81pub fn derive_single_key(
83 ikm: &[u8],
84 salt: &[u8],
85 info: &[u8],
86) -> Result<[u8; 32]> {
87 let hkdf = Hkdf::<Sha256>::new(Some(salt), ikm);
88 let mut okm = [0u8; 32];
89 hkdf.expand(info, &mut okm)
90 .map_err(|_| CryptoError::KeyDerivationFailed("HKDF expansion failed"))?;
91 Ok(okm)
92}
93
94pub fn openvpn_prf(secret: &[u8], label: &[u8], seed: &[u8], output_len: usize) -> Result<Vec<u8>> {
101 use hmac::{Hmac, Mac};
102
103 type HmacSha256 = Hmac<Sha256>;
104
105 let mut combined_seed = Vec::with_capacity(label.len() + seed.len());
107 combined_seed.extend_from_slice(label);
108 combined_seed.extend_from_slice(seed);
109
110 let mut output = Vec::with_capacity(output_len);
111 let mut a = combined_seed.clone();
112
113 while output.len() < output_len {
114 let mut mac = HmacSha256::new_from_slice(secret)
116 .map_err(|_| CryptoError::KeyDerivationFailed("Invalid HMAC key"))?;
117 mac.update(&a);
118 a = mac.finalize().into_bytes().to_vec();
119
120 let mut mac = HmacSha256::new_from_slice(secret)
122 .map_err(|_| CryptoError::KeyDerivationFailed("Invalid HMAC key"))?;
123 mac.update(&a);
124 mac.update(&combined_seed);
125 output.extend_from_slice(&mac.finalize().into_bytes());
126 }
127
128 output.truncate(output_len);
129 Ok(output)
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135
136 #[test]
137 fn test_derive_keys() {
138 let shared_secret = [0x42u8; 32];
139 let client_random = [0x01u8; 32];
140 let server_random = [0x02u8; 32];
141
142 let keys = derive_keys(&shared_secret, &client_random, &server_random, b"test").unwrap();
143
144 assert_ne!(keys.client_write_key, keys.server_write_key);
146 assert_ne!(keys.client_hmac_key, keys.server_hmac_key);
147 assert_ne!(keys.client_write_key, keys.client_hmac_key);
148 }
149
150 #[test]
151 fn test_derive_keys_deterministic() {
152 let shared_secret = [0x42u8; 32];
153 let client_random = [0x01u8; 32];
154 let server_random = [0x02u8; 32];
155
156 let keys1 = derive_keys(&shared_secret, &client_random, &server_random, b"test").unwrap();
157 let keys2 = derive_keys(&shared_secret, &client_random, &server_random, b"test").unwrap();
158
159 assert_eq!(keys1.client_write_key, keys2.client_write_key);
160 }
161
162 #[test]
163 fn test_openvpn_prf() {
164 let secret = b"test secret";
165 let label = b"test label";
166 let seed = b"test seed";
167
168 let output = openvpn_prf(secret, label, seed, 64).unwrap();
169 assert_eq!(output.len(), 64);
170
171 let output2 = openvpn_prf(secret, label, seed, 64).unwrap();
173 assert_eq!(output, output2);
174 }
175}