1use ring::hkdf::{self, KeyType, Prk, HKDF_SHA256, HKDF_SHA384};
8use ring::hmac::{self, Algorithm as HmacAlgorithm, Key as HmacKey};
9use thiserror::Error;
10
11#[derive(Debug, Error)]
13pub enum KeyDerivationError {
14 #[error("Unsupported cipher suite: 0x{0:04x}")]
15 UnsupportedCipherSuite(u16),
16
17 #[error("Invalid key material length: expected {expected}, got {actual}")]
18 InvalidKeyLength { expected: usize, actual: usize },
19
20 #[error("Key derivation failed: {0}")]
21 DerivationFailed(String),
22}
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum HashAlgorithm {
27 Sha256,
28 Sha384,
29}
30
31impl HashAlgorithm {
32 pub fn output_len(&self) -> usize {
34 match self {
35 HashAlgorithm::Sha256 => 32,
36 HashAlgorithm::Sha384 => 48,
37 }
38 }
39
40 fn hmac_algorithm(&self) -> HmacAlgorithm {
42 match self {
43 HashAlgorithm::Sha256 => hmac::HMAC_SHA256,
44 HashAlgorithm::Sha384 => hmac::HMAC_SHA384,
45 }
46 }
47
48 fn hkdf_algorithm(&self) -> hkdf::Algorithm {
50 match self {
51 HashAlgorithm::Sha256 => HKDF_SHA256,
52 HashAlgorithm::Sha384 => HKDF_SHA384,
53 }
54 }
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
59pub enum AeadAlgorithm {
60 Aes128Gcm,
61 Aes256Gcm,
62 Chacha20Poly1305,
63}
64
65impl AeadAlgorithm {
66 pub fn key_len(&self) -> usize {
68 match self {
69 AeadAlgorithm::Aes128Gcm => 16,
70 AeadAlgorithm::Aes256Gcm => 32,
71 AeadAlgorithm::Chacha20Poly1305 => 32,
72 }
73 }
74
75 pub fn iv_len(&self) -> usize {
77 match self {
78 AeadAlgorithm::Aes128Gcm => 12,
79 AeadAlgorithm::Aes256Gcm => 12,
80 AeadAlgorithm::Chacha20Poly1305 => 12,
81 }
82 }
83
84 pub fn tag_len(&self) -> usize {
86 16 }
88
89 pub fn from_cipher_suite(suite_id: u16) -> Option<Self> {
91 match suite_id {
92 0x1301 => Some(AeadAlgorithm::Aes128Gcm), 0x1302 => Some(AeadAlgorithm::Aes256Gcm), 0x1303 => Some(AeadAlgorithm::Chacha20Poly1305), 0xC02F => Some(AeadAlgorithm::Aes128Gcm), 0xC030 => Some(AeadAlgorithm::Aes256Gcm), 0xCCA8 => Some(AeadAlgorithm::Chacha20Poly1305), 0xC02B => Some(AeadAlgorithm::Aes128Gcm), 0xC02C => Some(AeadAlgorithm::Aes256Gcm), 0xCCA9 => Some(AeadAlgorithm::Chacha20Poly1305), 0x009E => Some(AeadAlgorithm::Aes128Gcm), 0x009F => Some(AeadAlgorithm::Aes256Gcm), 0xCCAA => Some(AeadAlgorithm::Chacha20Poly1305), 0x009C => Some(AeadAlgorithm::Aes128Gcm), 0x009D => Some(AeadAlgorithm::Aes256Gcm), _ => None,
117 }
118 }
119}
120
121pub fn hash_for_cipher_suite(suite_id: u16) -> Option<HashAlgorithm> {
123 match suite_id {
124 0x1302 | 0xC030 | 0xC02C | 0x009F | 0x009D => Some(HashAlgorithm::Sha384),
131
132 0x1301 | 0x1303 | 0xC02F | 0xCCA8 | 0xC02B | 0xCCA9 | 0x009E | 0xCCAA | 0x009C => Some(HashAlgorithm::Sha256),
143
144 _ => None,
145 }
146}
147
148#[derive(Debug, Clone)]
150pub struct Tls12KeyMaterial {
151 pub client_write_mac_key: Vec<u8>,
153 pub server_write_mac_key: Vec<u8>,
155 pub client_write_key: Vec<u8>,
157 pub server_write_key: Vec<u8>,
159 pub client_write_iv: Vec<u8>,
161 pub server_write_iv: Vec<u8>,
163}
164
165#[derive(Debug, Clone)]
167pub struct Tls13KeyMaterial {
168 pub key: Vec<u8>,
170 pub iv: Vec<u8>,
172}
173
174pub fn tls12_prf(
188 secret: &[u8],
189 label: &[u8],
190 seed: &[u8],
191 output_len: usize,
192 hash_algo: HashAlgorithm,
193) -> Vec<u8> {
194 let mut label_seed = Vec::with_capacity(label.len() + seed.len());
196 label_seed.extend_from_slice(label);
197 label_seed.extend_from_slice(seed);
198
199 p_hash(secret, &label_seed, output_len, hash_algo)
200}
201
202fn p_hash(secret: &[u8], seed: &[u8], output_len: usize, hash_algo: HashAlgorithm) -> Vec<u8> {
204 let hmac_algo = hash_algo.hmac_algorithm();
205 let key = HmacKey::new(hmac_algo, secret);
206 let hash_len = hash_algo.output_len();
207
208 let mut result = Vec::with_capacity(output_len);
209 let mut a = hmac::sign(&key, seed); while result.len() < output_len {
212 let mut ctx = hmac::Context::with_key(&key);
214 ctx.update(a.as_ref());
215 ctx.update(seed);
216 let p_block = ctx.sign();
217
218 let remaining = output_len - result.len();
220 let take = remaining.min(hash_len);
221 result.extend_from_slice(&p_block.as_ref()[..take]);
222
223 a = hmac::sign(&key, a.as_ref());
225 }
226
227 result
228}
229
230pub fn derive_tls12_keys(
243 master_secret: &[u8; 48],
244 client_random: &[u8; 32],
245 server_random: &[u8; 32],
246 cipher_suite_id: u16,
247) -> Result<Tls12KeyMaterial, KeyDerivationError> {
248 let aead = AeadAlgorithm::from_cipher_suite(cipher_suite_id)
249 .ok_or(KeyDerivationError::UnsupportedCipherSuite(cipher_suite_id))?;
250
251 let hash_algo = hash_for_cipher_suite(cipher_suite_id)
252 .ok_or(KeyDerivationError::UnsupportedCipherSuite(cipher_suite_id))?;
253
254 let mac_key_len = 0;
256 let enc_key_len = aead.key_len();
257 let iv_len = 4; let key_block_len = 2 * mac_key_len + 2 * enc_key_len + 2 * iv_len;
261
262 let mut seed = Vec::with_capacity(64);
264 seed.extend_from_slice(server_random);
265 seed.extend_from_slice(client_random);
266
267 let key_block = tls12_prf(
268 master_secret,
269 b"key expansion",
270 &seed,
271 key_block_len,
272 hash_algo,
273 );
274
275 let mut offset = 0;
277
278 let client_write_mac_key = if mac_key_len > 0 {
279 let k = key_block[offset..offset + mac_key_len].to_vec();
280 offset += mac_key_len;
281 k
282 } else {
283 Vec::new()
284 };
285
286 let server_write_mac_key = if mac_key_len > 0 {
287 let k = key_block[offset..offset + mac_key_len].to_vec();
288 offset += mac_key_len;
289 k
290 } else {
291 Vec::new()
292 };
293
294 let client_write_key = key_block[offset..offset + enc_key_len].to_vec();
295 offset += enc_key_len;
296
297 let server_write_key = key_block[offset..offset + enc_key_len].to_vec();
298 offset += enc_key_len;
299
300 let client_write_iv = key_block[offset..offset + iv_len].to_vec();
301 offset += iv_len;
302
303 let server_write_iv = key_block[offset..offset + iv_len].to_vec();
304
305 Ok(Tls12KeyMaterial {
306 client_write_mac_key,
307 server_write_mac_key,
308 client_write_key,
309 server_write_key,
310 client_write_iv,
311 server_write_iv,
312 })
313}
314
315fn hkdf_expand_label(
330 prk: &Prk,
331 label: &[u8],
332 context: &[u8],
333 output_len: usize,
334) -> Result<Vec<u8>, KeyDerivationError> {
335 let tls13_label = {
337 let mut l = Vec::with_capacity(6 + label.len());
338 l.extend_from_slice(b"tls13 ");
339 l.extend_from_slice(label);
340 l
341 };
342
343 let mut hkdf_label = Vec::with_capacity(2 + 1 + tls13_label.len() + 1 + context.len());
345 hkdf_label.push((output_len >> 8) as u8);
346 hkdf_label.push(output_len as u8);
347 hkdf_label.push(tls13_label.len() as u8);
348 hkdf_label.extend_from_slice(&tls13_label);
349 hkdf_label.push(context.len() as u8);
350 hkdf_label.extend_from_slice(context);
351
352 struct ExpandLen(usize);
354 impl KeyType for ExpandLen {
355 fn len(&self) -> usize {
356 self.0
357 }
358 }
359
360 let info = [hkdf_label.as_slice()];
361 let okm = prk
362 .expand(&info, ExpandLen(output_len))
363 .map_err(|_| KeyDerivationError::DerivationFailed("HKDF expand failed".to_string()))?;
364
365 let mut output = vec![0u8; output_len];
366 okm.fill(&mut output)
367 .map_err(|_| KeyDerivationError::DerivationFailed("HKDF fill failed".to_string()))?;
368
369 Ok(output)
370}
371
372pub fn derive_tls13_keys(
381 traffic_secret: &[u8],
382 cipher_suite_id: u16,
383) -> Result<Tls13KeyMaterial, KeyDerivationError> {
384 let aead = AeadAlgorithm::from_cipher_suite(cipher_suite_id)
385 .ok_or(KeyDerivationError::UnsupportedCipherSuite(cipher_suite_id))?;
386
387 let hash_algo = hash_for_cipher_suite(cipher_suite_id)
388 .ok_or(KeyDerivationError::UnsupportedCipherSuite(cipher_suite_id))?;
389
390 let hkdf_algo = hash_algo.hkdf_algorithm();
391
392 let prk = Prk::new_less_safe(hkdf_algo, traffic_secret);
395
396 let key_len = aead.key_len();
397 let iv_len = aead.iv_len();
398
399 let key = hkdf_expand_label(&prk, b"key", &[], key_len)?;
400 let iv = hkdf_expand_label(&prk, b"iv", &[], iv_len)?;
401
402 Ok(Tls13KeyMaterial { key, iv })
403}
404
405#[cfg(test)]
406mod tests {
407 use super::*;
408
409 #[test]
410 fn test_aead_from_cipher_suite() {
411 assert_eq!(
413 AeadAlgorithm::from_cipher_suite(0x1301),
414 Some(AeadAlgorithm::Aes128Gcm)
415 );
416 assert_eq!(
417 AeadAlgorithm::from_cipher_suite(0x1302),
418 Some(AeadAlgorithm::Aes256Gcm)
419 );
420 assert_eq!(
421 AeadAlgorithm::from_cipher_suite(0x1303),
422 Some(AeadAlgorithm::Chacha20Poly1305)
423 );
424
425 assert_eq!(
427 AeadAlgorithm::from_cipher_suite(0xC02F),
428 Some(AeadAlgorithm::Aes128Gcm)
429 );
430 assert_eq!(
431 AeadAlgorithm::from_cipher_suite(0xC030),
432 Some(AeadAlgorithm::Aes256Gcm)
433 );
434 assert_eq!(
435 AeadAlgorithm::from_cipher_suite(0xCCA8),
436 Some(AeadAlgorithm::Chacha20Poly1305)
437 );
438
439 assert_eq!(AeadAlgorithm::from_cipher_suite(0x0000), None);
441 }
442
443 #[test]
444 fn test_hash_for_cipher_suite() {
445 assert_eq!(hash_for_cipher_suite(0x1301), Some(HashAlgorithm::Sha256));
447 assert_eq!(hash_for_cipher_suite(0xC02F), Some(HashAlgorithm::Sha256));
448
449 assert_eq!(hash_for_cipher_suite(0x1302), Some(HashAlgorithm::Sha384));
451 assert_eq!(hash_for_cipher_suite(0xC030), Some(HashAlgorithm::Sha384));
452
453 assert_eq!(hash_for_cipher_suite(0x0000), None);
455 }
456
457 #[test]
458 fn test_tls12_prf_basic() {
459 let secret = [0x42u8; 48];
461 let label = b"test label";
462 let seed = [0x01u8; 32];
463
464 let result1 = tls12_prf(&secret, label, &seed, 32, HashAlgorithm::Sha256);
465 let result2 = tls12_prf(&secret, label, &seed, 32, HashAlgorithm::Sha256);
466
467 assert_eq!(result1.len(), 32);
468 assert_eq!(result1, result2);
469
470 let result3 = tls12_prf(&secret, b"other label", &seed, 32, HashAlgorithm::Sha256);
472 assert_ne!(result1, result3);
473 }
474
475 #[test]
476 fn test_tls12_prf_sha384() {
477 let secret = [0x42u8; 48];
478 let label = b"test label";
479 let seed = [0x01u8; 32];
480
481 let result_256 = tls12_prf(&secret, label, &seed, 48, HashAlgorithm::Sha256);
482 let result_384 = tls12_prf(&secret, label, &seed, 48, HashAlgorithm::Sha384);
483
484 assert_ne!(result_256, result_384);
486 }
487
488 #[test]
489 fn test_derive_tls12_keys() {
490 let master_secret = [0x42u8; 48];
491 let client_random = [0x01u8; 32];
492 let server_random = [0x02u8; 32];
493
494 let keys = derive_tls12_keys(&master_secret, &client_random, &server_random, 0xC02F)
496 .expect("key derivation should succeed");
497
498 assert_eq!(keys.client_write_key.len(), 16);
500 assert_eq!(keys.server_write_key.len(), 16);
501 assert_eq!(keys.client_write_iv.len(), 4);
502 assert_eq!(keys.server_write_iv.len(), 4);
503 assert!(keys.client_write_mac_key.is_empty());
504 assert!(keys.server_write_mac_key.is_empty());
505
506 assert_ne!(keys.client_write_key, keys.server_write_key);
508 assert_ne!(keys.client_write_iv, keys.server_write_iv);
509 }
510
511 #[test]
512 fn test_derive_tls12_keys_aes256() {
513 let master_secret = [0x42u8; 48];
514 let client_random = [0x01u8; 32];
515 let server_random = [0x02u8; 32];
516
517 let keys = derive_tls12_keys(&master_secret, &client_random, &server_random, 0xC030)
519 .expect("key derivation should succeed");
520
521 assert_eq!(keys.client_write_key.len(), 32);
523 assert_eq!(keys.server_write_key.len(), 32);
524 }
525
526 #[test]
527 fn test_derive_tls12_keys_unsupported() {
528 let master_secret = [0x42u8; 48];
529 let client_random = [0x01u8; 32];
530 let server_random = [0x02u8; 32];
531
532 let result = derive_tls12_keys(&master_secret, &client_random, &server_random, 0x0000);
534 assert!(matches!(
535 result,
536 Err(KeyDerivationError::UnsupportedCipherSuite(0x0000))
537 ));
538 }
539
540 #[test]
541 fn test_derive_tls13_keys() {
542 let traffic_secret = [0x42u8; 32];
544
545 let keys =
547 derive_tls13_keys(&traffic_secret, 0x1301).expect("key derivation should succeed");
548
549 assert_eq!(keys.key.len(), 16);
551 assert_eq!(keys.iv.len(), 12);
552 }
553
554 #[test]
555 fn test_derive_tls13_keys_aes256() {
556 let traffic_secret = [0x42u8; 48];
558
559 let keys =
561 derive_tls13_keys(&traffic_secret, 0x1302).expect("key derivation should succeed");
562
563 assert_eq!(keys.key.len(), 32);
565 assert_eq!(keys.iv.len(), 12);
566 }
567
568 #[test]
569 fn test_derive_tls13_keys_chacha20() {
570 let traffic_secret = [0x42u8; 32];
571
572 let keys =
574 derive_tls13_keys(&traffic_secret, 0x1303).expect("key derivation should succeed");
575
576 assert_eq!(keys.key.len(), 32);
578 assert_eq!(keys.iv.len(), 12);
579 }
580
581 #[test]
582 fn test_derive_tls13_keys_consistency() {
583 let traffic_secret = [0x42u8; 32];
584
585 let keys1 = derive_tls13_keys(&traffic_secret, 0x1301).unwrap();
587 let keys2 = derive_tls13_keys(&traffic_secret, 0x1301).unwrap();
588
589 assert_eq!(keys1.key, keys2.key);
590 assert_eq!(keys1.iv, keys2.iv);
591
592 let other_secret = [0x43u8; 32];
594 let keys3 = derive_tls13_keys(&other_secret, 0x1301).unwrap();
595
596 assert_ne!(keys1.key, keys3.key);
597 assert_ne!(keys1.iv, keys3.iv);
598 }
599
600 #[test]
601 fn test_aead_key_lengths() {
602 assert_eq!(AeadAlgorithm::Aes128Gcm.key_len(), 16);
603 assert_eq!(AeadAlgorithm::Aes256Gcm.key_len(), 32);
604 assert_eq!(AeadAlgorithm::Chacha20Poly1305.key_len(), 32);
605
606 assert_eq!(AeadAlgorithm::Aes128Gcm.iv_len(), 12);
607 assert_eq!(AeadAlgorithm::Aes256Gcm.iv_len(), 12);
608 assert_eq!(AeadAlgorithm::Chacha20Poly1305.iv_len(), 12);
609
610 assert_eq!(AeadAlgorithm::Aes128Gcm.tag_len(), 16);
611 assert_eq!(AeadAlgorithm::Aes256Gcm.tag_len(), 16);
612 assert_eq!(AeadAlgorithm::Chacha20Poly1305.tag_len(), 16);
613 }
614}