corevpn_crypto/
kdf.rs

1//! Key Derivation Functions
2//!
3//! Uses HKDF-SHA256 for deriving encryption keys from shared secrets.
4
5use hkdf::Hkdf;
6use sha2::Sha256;
7use zeroize::{Zeroize, ZeroizeOnDrop};
8
9use crate::{CryptoError, Result, CipherSuite, DataChannelKey};
10
11/// OpenVPN-style key material derived from TLS session
12#[derive(ZeroizeOnDrop)]
13pub struct KeyMaterial {
14    /// Client -> Server encryption key
15    pub client_write_key: [u8; 32],
16    /// Server -> Client encryption key
17    pub server_write_key: [u8; 32],
18    /// Client -> Server HMAC key (for tls-auth)
19    pub client_hmac_key: [u8; 32],
20    /// Server -> Client HMAC key (for tls-auth)
21    pub server_hmac_key: [u8; 32],
22}
23
24/// Derive key material from a shared secret
25///
26/// # Arguments
27/// * `shared_secret` - The raw shared secret from DH exchange
28/// * `client_random` - Client's random value (from TLS handshake)
29/// * `server_random` - Server's random value (from TLS handshake)
30/// * `info` - Optional context info (e.g., "OpenVPN data channel")
31#[inline]
32pub fn derive_keys(
33    shared_secret: &[u8],
34    client_random: &[u8; 32],
35    server_random: &[u8; 32],
36    info: &[u8],
37) -> Result<KeyMaterial> {
38    // Combine randoms as salt
39    let mut salt = [0u8; 64];
40    salt[..32].copy_from_slice(client_random);
41    salt[32..].copy_from_slice(server_random);
42
43    let hkdf = Hkdf::<Sha256>::new(Some(&salt), shared_secret);
44
45    // Derive 128 bytes (4 x 32-byte keys)
46    let mut okm = [0u8; 128];
47    hkdf.expand(info, &mut okm)
48        .map_err(|_| CryptoError::KeyDerivationFailed("HKDF expansion failed"))?;
49
50    let mut material = KeyMaterial {
51        client_write_key: [0u8; 32],
52        server_write_key: [0u8; 32],
53        client_hmac_key: [0u8; 32],
54        server_hmac_key: [0u8; 32],
55    };
56
57    material.client_write_key.copy_from_slice(&okm[0..32]);
58    material.server_write_key.copy_from_slice(&okm[32..64]);
59    material.client_hmac_key.copy_from_slice(&okm[64..96]);
60    material.server_hmac_key.copy_from_slice(&okm[96..128]);
61
62    // Zeroize intermediate values
63    okm.zeroize();
64    salt.zeroize();
65
66    Ok(material)
67}
68
69impl KeyMaterial {
70    /// Create data channel keys for the client side
71    pub fn client_data_key(&self, suite: CipherSuite) -> DataChannelKey {
72        DataChannelKey::new(self.client_write_key, suite)
73    }
74
75    /// Create data channel keys for the server side
76    pub fn server_data_key(&self, suite: CipherSuite) -> DataChannelKey {
77        DataChannelKey::new(self.server_write_key, suite)
78    }
79}
80
81/// Derive a single key from input key material
82pub fn derive_single_key(
83    ikm: &[u8],
84    salt: &[u8],
85    info: &[u8],
86) -> Result<[u8; 32]> {
87    let hkdf = Hkdf::<Sha256>::new(Some(salt), ikm);
88    let mut okm = [0u8; 32];
89    hkdf.expand(info, &mut okm)
90        .map_err(|_| CryptoError::KeyDerivationFailed("HKDF expansion failed"))?;
91    Ok(okm)
92}
93
94/// PRF for OpenVPN TLS key expansion
95///
96/// Compatible with OpenVPN's PRF which uses:
97/// P_SHA256(secret, seed) = HMAC_SHA256(secret, A(1) + seed) +
98///                          HMAC_SHA256(secret, A(2) + seed) + ...
99/// where A(0) = seed, A(i) = HMAC_SHA256(secret, A(i-1))
100pub fn openvpn_prf(secret: &[u8], label: &[u8], seed: &[u8], output_len: usize) -> Result<Vec<u8>> {
101    use hmac::{Hmac, Mac};
102
103    type HmacSha256 = Hmac<Sha256>;
104
105    // Combine label and seed
106    let mut combined_seed = Vec::with_capacity(label.len() + seed.len());
107    combined_seed.extend_from_slice(label);
108    combined_seed.extend_from_slice(seed);
109
110    let mut output = Vec::with_capacity(output_len);
111    let mut a = combined_seed.clone();
112
113    while output.len() < output_len {
114        // A(i) = HMAC(secret, A(i-1))
115        let mut mac = HmacSha256::new_from_slice(secret)
116            .map_err(|_| CryptoError::KeyDerivationFailed("Invalid HMAC key"))?;
117        mac.update(&a);
118        a = mac.finalize().into_bytes().to_vec();
119
120        // P_hash = HMAC(secret, A(i) + seed)
121        let mut mac = HmacSha256::new_from_slice(secret)
122            .map_err(|_| CryptoError::KeyDerivationFailed("Invalid HMAC key"))?;
123        mac.update(&a);
124        mac.update(&combined_seed);
125        output.extend_from_slice(&mac.finalize().into_bytes());
126    }
127
128    output.truncate(output_len);
129    Ok(output)
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135
136    #[test]
137    fn test_derive_keys() {
138        let shared_secret = [0x42u8; 32];
139        let client_random = [0x01u8; 32];
140        let server_random = [0x02u8; 32];
141
142        let keys = derive_keys(&shared_secret, &client_random, &server_random, b"test").unwrap();
143
144        // Keys should be different from each other
145        assert_ne!(keys.client_write_key, keys.server_write_key);
146        assert_ne!(keys.client_hmac_key, keys.server_hmac_key);
147        assert_ne!(keys.client_write_key, keys.client_hmac_key);
148    }
149
150    #[test]
151    fn test_derive_keys_deterministic() {
152        let shared_secret = [0x42u8; 32];
153        let client_random = [0x01u8; 32];
154        let server_random = [0x02u8; 32];
155
156        let keys1 = derive_keys(&shared_secret, &client_random, &server_random, b"test").unwrap();
157        let keys2 = derive_keys(&shared_secret, &client_random, &server_random, b"test").unwrap();
158
159        assert_eq!(keys1.client_write_key, keys2.client_write_key);
160    }
161
162    #[test]
163    fn test_openvpn_prf() {
164        let secret = b"test secret";
165        let label = b"test label";
166        let seed = b"test seed";
167
168        let output = openvpn_prf(secret, label, seed, 64).unwrap();
169        assert_eq!(output.len(), 64);
170
171        // Should be deterministic
172        let output2 = openvpn_prf(secret, label, seed, 64).unwrap();
173        assert_eq!(output, output2);
174    }
175}