1use serde::{Deserialize, Serialize};
28use sha2::{Digest, Sha256};
29use thiserror::Error;
30
31#[derive(Debug, Error, Clone, PartialEq, Eq)]
33pub enum Tls13Error {
34 #[error("Invalid key length: expected {expected}, got {actual}")]
36 InvalidLength { expected: usize, actual: usize },
37
38 #[error("Key schedule not initialized")]
40 NotInitialized,
41
42 #[error("Invalid state: {0}")]
44 InvalidState(String),
45}
46
47pub type Tls13Result<T> = Result<T, Tls13Error>;
49
50#[derive(Clone, Serialize, Deserialize)]
54pub struct Tls13KeySchedule {
55 early_secret: [u8; 32],
57 handshake_secret: Option<[u8; 32]>,
59 master_secret: Option<[u8; 32]>,
61}
62
63impl Tls13KeySchedule {
64 pub fn new(shared_secret: &[u8]) -> Self {
69 let zero_salt = [0u8; 32];
71 let early_secret = hkdf_extract(&zero_salt, &zero_salt);
72
73 let handshake_secret = derive_secret(&early_secret, b"derived", &[]);
75 let handshake_secret = hkdf_extract(&handshake_secret, shared_secret);
76
77 Self {
78 early_secret,
79 handshake_secret: Some(handshake_secret),
80 master_secret: None,
81 }
82 }
83
84 pub fn derive_handshake_secrets(
93 &mut self,
94 client_hello: &[u8],
95 server_hello: &[u8],
96 ) -> ([u8; 32], [u8; 32]) {
97 let handshake_secret = self
98 .handshake_secret
99 .expect("Handshake secret not initialized");
100
101 let mut hasher = Sha256::new();
103 hasher.update(client_hello);
104 hasher.update(server_hello);
105 let transcript_hash = hasher.finalize();
106
107 let client_hs_traffic_secret =
109 derive_secret(&handshake_secret, b"c hs traffic", &transcript_hash);
110
111 let server_hs_traffic_secret =
113 derive_secret(&handshake_secret, b"s hs traffic", &transcript_hash);
114
115 let derived = derive_secret(&handshake_secret, b"derived", &[]);
117 let master_secret = hkdf_extract(&derived, &[0u8; 32]);
118 self.master_secret = Some(master_secret);
119
120 (client_hs_traffic_secret, server_hs_traffic_secret)
121 }
122
123 pub fn derive_application_secrets(&self) -> Tls13Result<([u8; 32], [u8; 32])> {
128 let master_secret = self.master_secret.ok_or(Tls13Error::NotInitialized)?;
129
130 let empty_hash = Sha256::digest([]);
132
133 let client_app_traffic_secret = derive_secret(&master_secret, b"c ap traffic", &empty_hash);
135
136 let server_app_traffic_secret = derive_secret(&master_secret, b"s ap traffic", &empty_hash);
138
139 Ok((client_app_traffic_secret, server_app_traffic_secret))
140 }
141
142 pub fn derive_exporter_secret(&self) -> Tls13Result<[u8; 32]> {
146 let master_secret = self.master_secret.ok_or(Tls13Error::NotInitialized)?;
147
148 let empty_hash = Sha256::digest([]);
149 Ok(derive_secret(&master_secret, b"exp master", &empty_hash))
150 }
151
152 pub fn derive_resumption_secret(&self, transcript_hash: &[u8]) -> Tls13Result<[u8; 32]> {
156 let master_secret = self.master_secret.ok_or(Tls13Error::NotInitialized)?;
157
158 Ok(derive_secret(
159 &master_secret,
160 b"res master",
161 transcript_hash,
162 ))
163 }
164
165 pub fn update_traffic_secret(current_secret: &[u8; 32]) -> [u8; 32] {
169 derive_secret(current_secret, b"traffic upd", &[])
170 }
171}
172
173fn hkdf_extract(salt: &[u8], ikm: &[u8]) -> [u8; 32] {
175 use hmac::{Hmac, Mac};
176 type HmacSha256 = Hmac<Sha256>;
177
178 let mut mac = HmacSha256::new_from_slice(salt).expect("HMAC can take key of any size");
179 mac.update(ikm);
180 let result = mac.finalize();
181 let bytes = result.into_bytes();
182
183 let mut output = [0u8; 32];
184 output.copy_from_slice(&bytes);
185 output
186}
187
188fn hkdf_expand_label(secret: &[u8], label: &[u8], context: &[u8], length: u16) -> Vec<u8> {
190 let mut hkdf_label = Vec::new();
198
199 hkdf_label.extend_from_slice(&length.to_be_bytes());
201
202 let full_label = [b"tls13 ", label].concat();
204 hkdf_label.push(full_label.len() as u8);
205 hkdf_label.extend_from_slice(&full_label);
206
207 hkdf_label.push(context.len() as u8);
209 hkdf_label.extend_from_slice(context);
210
211 hkdf_expand(secret, &hkdf_label, length as usize)
213}
214
215fn hkdf_expand(prk: &[u8], info: &[u8], length: usize) -> Vec<u8> {
217 use hmac::{Hmac, Mac};
218 type HmacSha256 = Hmac<Sha256>;
219
220 let mut output = Vec::with_capacity(length);
221 let mut t = Vec::new();
222 let mut counter = 1u8;
223
224 while output.len() < length {
225 let mut mac = HmacSha256::new_from_slice(prk).expect("HMAC can take key of any size");
226 mac.update(&t);
227 mac.update(info);
228 mac.update(&[counter]);
229
230 t = mac.finalize().into_bytes().to_vec();
231 output.extend_from_slice(&t);
232 counter += 1;
233 }
234
235 output.truncate(length);
236 output
237}
238
239fn derive_secret(secret: &[u8], label: &[u8], messages: &[u8]) -> [u8; 32] {
241 let transcript_hash = if messages.is_empty() {
243 Sha256::digest([]).to_vec()
244 } else {
245 messages.to_vec()
246 };
247
248 let expanded = hkdf_expand_label(secret, label, &transcript_hash, 32);
249 let mut output = [0u8; 32];
250 output.copy_from_slice(&expanded[..32]);
251 output
252}
253
254pub fn derive_traffic_keys(traffic_secret: &[u8; 32]) -> ([u8; 32], [u8; 12]) {
259 let key_bytes = hkdf_expand_label(traffic_secret, b"key", &[], 32);
261 let mut key = [0u8; 32];
262 key.copy_from_slice(&key_bytes[..32]);
263
264 let iv_bytes = hkdf_expand_label(traffic_secret, b"iv", &[], 12);
266 let mut iv = [0u8; 12];
267 iv.copy_from_slice(&iv_bytes[..12]);
268
269 (key, iv)
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275
276 #[test]
277 fn test_key_schedule_creation() {
278 let shared_secret = [0x42u8; 32];
279 let schedule = Tls13KeySchedule::new(&shared_secret);
280
281 assert!(schedule.handshake_secret.is_some());
282 assert!(schedule.master_secret.is_none());
283 }
284
285 #[test]
286 fn test_handshake_secrets_derivation() {
287 let shared_secret = [0x42u8; 32];
288 let mut schedule = Tls13KeySchedule::new(&shared_secret);
289
290 let client_hello = b"client hello message";
291 let server_hello = b"server hello message";
292
293 let (client_hs, server_hs) = schedule.derive_handshake_secrets(client_hello, server_hello);
294
295 assert_ne!(client_hs, server_hs);
297
298 assert!(schedule.master_secret.is_some());
300 }
301
302 #[test]
303 fn test_application_secrets_derivation() {
304 let shared_secret = [0x42u8; 32];
305 let mut schedule = Tls13KeySchedule::new(&shared_secret);
306
307 let client_hello = b"client hello";
309 let server_hello = b"server hello";
310 schedule.derive_handshake_secrets(client_hello, server_hello);
311
312 let result = schedule.derive_application_secrets();
314 assert!(result.is_ok());
315
316 let (client_app, server_app) = result.unwrap();
317 assert_ne!(client_app, server_app);
318 }
319
320 #[test]
321 fn test_application_secrets_before_handshake() {
322 let shared_secret = [0x42u8; 32];
323 let schedule = Tls13KeySchedule::new(&shared_secret);
324
325 let result = schedule.derive_application_secrets();
327 assert!(result.is_err());
328 }
329
330 #[test]
331 fn test_exporter_secret() {
332 let shared_secret = [0x42u8; 32];
333 let mut schedule = Tls13KeySchedule::new(&shared_secret);
334
335 schedule.derive_handshake_secrets(b"client hello", b"server hello");
336
337 let exporter_secret = schedule.derive_exporter_secret();
338 assert!(exporter_secret.is_ok());
339 assert_eq!(exporter_secret.unwrap().len(), 32);
340 }
341
342 #[test]
343 fn test_resumption_secret() {
344 let shared_secret = [0x42u8; 32];
345 let mut schedule = Tls13KeySchedule::new(&shared_secret);
346
347 schedule.derive_handshake_secrets(b"client hello", b"server hello");
348
349 let transcript = Sha256::digest(b"full handshake transcript");
350 let resumption_secret = schedule.derive_resumption_secret(&transcript);
351 assert!(resumption_secret.is_ok());
352 assert_eq!(resumption_secret.unwrap().len(), 32);
353 }
354
355 #[test]
356 fn test_traffic_key_update() {
357 let current_secret = [0x42u8; 32];
358 let new_secret = Tls13KeySchedule::update_traffic_secret(¤t_secret);
359
360 assert_ne!(current_secret, new_secret);
362 }
363
364 #[test]
365 fn test_derive_traffic_keys() {
366 let traffic_secret = [0x42u8; 32];
367 let (key, iv) = derive_traffic_keys(&traffic_secret);
368
369 assert_eq!(key.len(), 32);
370 assert_eq!(iv.len(), 12);
371 }
372
373 #[test]
374 fn test_hkdf_extract() {
375 let salt = [0x01u8; 32];
376 let ikm = [0x02u8; 32];
377
378 let prk = hkdf_extract(&salt, &ikm);
379 assert_eq!(prk.len(), 32);
380
381 let prk2 = hkdf_extract(&salt, &ikm);
383 assert_eq!(prk, prk2);
384 }
385
386 #[test]
387 fn test_hkdf_expand() {
388 let prk = [0x42u8; 32];
389 let info = b"test info";
390
391 let okm = hkdf_expand(&prk, info, 64);
392 assert_eq!(okm.len(), 64);
393
394 let okm2 = hkdf_expand(&prk, info, 64);
396 assert_eq!(okm, okm2);
397 }
398
399 #[test]
400 fn test_hkdf_expand_label() {
401 let secret = [0x42u8; 32];
402 let label = b"test label";
403 let context = b"test context";
404
405 let output = hkdf_expand_label(&secret, label, context, 32);
406 assert_eq!(output.len(), 32);
407
408 let output2 = hkdf_expand_label(&secret, label, context, 32);
410 assert_eq!(output, output2);
411 }
412
413 #[test]
414 fn test_derive_secret() {
415 let secret = [0x42u8; 32];
416 let label = b"test";
417 let messages = b"messages";
418
419 let derived = derive_secret(&secret, label, messages);
420 assert_eq!(derived.len(), 32);
421
422 let derived2 = derive_secret(&secret, label, messages);
424 assert_eq!(derived, derived2);
425 }
426
427 #[test]
428 fn test_serialization() {
429 let shared_secret = [0x42u8; 32];
430 let schedule = Tls13KeySchedule::new(&shared_secret);
431
432 let serialized = crate::codec::encode(&schedule).unwrap();
433 let deserialized: Tls13KeySchedule = crate::codec::decode(&serialized).unwrap();
434
435 assert_eq!(deserialized.early_secret, schedule.early_secret);
436 assert_eq!(deserialized.handshake_secret, schedule.handshake_secret);
437 }
438}