1use serde::{Deserialize, Serialize};
28use sha2::{Digest, Sha256};
29use thiserror::Error;
30
31#[derive(Debug, Error, Clone, PartialEq, Eq)]
33pub enum Tls13Error {
34 #[error("Invalid key length: expected {expected}, got {actual}")]
36 InvalidLength { expected: usize, actual: usize },
37
38 #[error("Key schedule not initialized")]
40 NotInitialized,
41
42 #[error("Invalid state: {0}")]
44 InvalidState(String),
45}
46
47pub type Tls13Result<T> = Result<T, Tls13Error>;
49
50#[derive(Clone, Serialize, Deserialize)]
54pub struct Tls13KeySchedule {
55 early_secret: [u8; 32],
57 handshake_secret: Option<[u8; 32]>,
59 master_secret: Option<[u8; 32]>,
61}
62
63impl Tls13KeySchedule {
64 pub fn new(shared_secret: &[u8]) -> Self {
69 let zero_salt = [0u8; 32];
71 let early_secret = hkdf_extract(&zero_salt, &zero_salt);
72
73 let handshake_secret = derive_secret(&early_secret, b"derived", &[]);
75 let handshake_secret = hkdf_extract(&handshake_secret, shared_secret);
76
77 Self {
78 early_secret,
79 handshake_secret: Some(handshake_secret),
80 master_secret: None,
81 }
82 }
83
84 pub fn derive_handshake_secrets(
93 &mut self,
94 client_hello: &[u8],
95 server_hello: &[u8],
96 ) -> ([u8; 32], [u8; 32]) {
97 let handshake_secret = self
98 .handshake_secret
99 .expect("Handshake secret not initialized");
100
101 let mut hasher = Sha256::new();
103 hasher.update(client_hello);
104 hasher.update(server_hello);
105 let transcript_hash = hasher.finalize();
106
107 let client_hs_traffic_secret =
109 derive_secret(&handshake_secret, b"c hs traffic", &transcript_hash);
110
111 let server_hs_traffic_secret =
113 derive_secret(&handshake_secret, b"s hs traffic", &transcript_hash);
114
115 let derived = derive_secret(&handshake_secret, b"derived", &[]);
117 let master_secret = hkdf_extract(&derived, &[0u8; 32]);
118 self.master_secret = Some(master_secret);
119
120 (client_hs_traffic_secret, server_hs_traffic_secret)
121 }
122
123 pub fn derive_application_secrets(&self) -> Tls13Result<([u8; 32], [u8; 32])> {
128 let master_secret = self.master_secret.ok_or(Tls13Error::NotInitialized)?;
129
130 let empty_hash = Sha256::digest([]);
132
133 let client_app_traffic_secret = derive_secret(&master_secret, b"c ap traffic", &empty_hash);
135
136 let server_app_traffic_secret = derive_secret(&master_secret, b"s ap traffic", &empty_hash);
138
139 Ok((client_app_traffic_secret, server_app_traffic_secret))
140 }
141
142 pub fn derive_exporter_secret(&self) -> Tls13Result<[u8; 32]> {
146 let master_secret = self.master_secret.ok_or(Tls13Error::NotInitialized)?;
147
148 let empty_hash = Sha256::digest([]);
149 Ok(derive_secret(&master_secret, b"exp master", &empty_hash))
150 }
151
152 pub fn derive_resumption_secret(&self, transcript_hash: &[u8]) -> Tls13Result<[u8; 32]> {
156 let master_secret = self.master_secret.ok_or(Tls13Error::NotInitialized)?;
157
158 Ok(derive_secret(
159 &master_secret,
160 b"res master",
161 transcript_hash,
162 ))
163 }
164
165 pub fn update_traffic_secret(current_secret: &[u8; 32]) -> [u8; 32] {
169 derive_secret(current_secret, b"traffic upd", &[])
170 }
171}
172
173fn hkdf_extract(salt: &[u8], ikm: &[u8]) -> [u8; 32] {
175 use hmac::digest::KeyInit;
176 use hmac::{Hmac, Mac};
177 type HmacSha256 = Hmac<Sha256>;
178
179 let mut mac =
180 <HmacSha256 as KeyInit>::new_from_slice(salt).expect("HMAC can take key of any size");
181 mac.update(ikm);
182 let result = mac.finalize();
183 let bytes = result.into_bytes();
184
185 let mut output = [0u8; 32];
186 output.copy_from_slice(&bytes);
187 output
188}
189
190fn hkdf_expand_label(secret: &[u8], label: &[u8], context: &[u8], length: u16) -> Vec<u8> {
192 let mut hkdf_label = Vec::new();
200
201 hkdf_label.extend_from_slice(&length.to_be_bytes());
203
204 let full_label = [b"tls13 ", label].concat();
206 hkdf_label.push(full_label.len() as u8);
207 hkdf_label.extend_from_slice(&full_label);
208
209 hkdf_label.push(context.len() as u8);
211 hkdf_label.extend_from_slice(context);
212
213 hkdf_expand(secret, &hkdf_label, length as usize)
215}
216
217fn hkdf_expand(prk: &[u8], info: &[u8], length: usize) -> Vec<u8> {
219 use hmac::digest::KeyInit;
220 use hmac::{Hmac, Mac};
221 type HmacSha256 = Hmac<Sha256>;
222
223 let mut output = Vec::with_capacity(length);
224 let mut t = Vec::new();
225 let mut counter = 1u8;
226
227 while output.len() < length {
228 let mut mac =
229 <HmacSha256 as KeyInit>::new_from_slice(prk).expect("HMAC can take key of any size");
230 mac.update(&t);
231 mac.update(info);
232 mac.update(&[counter]);
233
234 t = mac.finalize().into_bytes().to_vec();
235 output.extend_from_slice(&t);
236 counter += 1;
237 }
238
239 output.truncate(length);
240 output
241}
242
243fn derive_secret(secret: &[u8], label: &[u8], messages: &[u8]) -> [u8; 32] {
245 let transcript_hash = if messages.is_empty() {
247 Sha256::digest([]).to_vec()
248 } else {
249 messages.to_vec()
250 };
251
252 let expanded = hkdf_expand_label(secret, label, &transcript_hash, 32);
253 let mut output = [0u8; 32];
254 output.copy_from_slice(&expanded[..32]);
255 output
256}
257
258pub fn derive_traffic_keys(traffic_secret: &[u8; 32]) -> ([u8; 32], [u8; 12]) {
263 let key_bytes = hkdf_expand_label(traffic_secret, b"key", &[], 32);
265 let mut key = [0u8; 32];
266 key.copy_from_slice(&key_bytes[..32]);
267
268 let iv_bytes = hkdf_expand_label(traffic_secret, b"iv", &[], 12);
270 let mut iv = [0u8; 12];
271 iv.copy_from_slice(&iv_bytes[..12]);
272
273 (key, iv)
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279
280 #[test]
281 fn test_key_schedule_creation() {
282 let shared_secret = [0x42u8; 32];
283 let schedule = Tls13KeySchedule::new(&shared_secret);
284
285 assert!(schedule.handshake_secret.is_some());
286 assert!(schedule.master_secret.is_none());
287 }
288
289 #[test]
290 fn test_handshake_secrets_derivation() {
291 let shared_secret = [0x42u8; 32];
292 let mut schedule = Tls13KeySchedule::new(&shared_secret);
293
294 let client_hello = b"client hello message";
295 let server_hello = b"server hello message";
296
297 let (client_hs, server_hs) = schedule.derive_handshake_secrets(client_hello, server_hello);
298
299 assert_ne!(client_hs, server_hs);
301
302 assert!(schedule.master_secret.is_some());
304 }
305
306 #[test]
307 fn test_application_secrets_derivation() {
308 let shared_secret = [0x42u8; 32];
309 let mut schedule = Tls13KeySchedule::new(&shared_secret);
310
311 let client_hello = b"client hello";
313 let server_hello = b"server hello";
314 schedule.derive_handshake_secrets(client_hello, server_hello);
315
316 let result = schedule.derive_application_secrets();
318 assert!(result.is_ok());
319
320 let (client_app, server_app) = result.unwrap();
321 assert_ne!(client_app, server_app);
322 }
323
324 #[test]
325 fn test_application_secrets_before_handshake() {
326 let shared_secret = [0x42u8; 32];
327 let schedule = Tls13KeySchedule::new(&shared_secret);
328
329 let result = schedule.derive_application_secrets();
331 assert!(result.is_err());
332 }
333
334 #[test]
335 fn test_exporter_secret() {
336 let shared_secret = [0x42u8; 32];
337 let mut schedule = Tls13KeySchedule::new(&shared_secret);
338
339 schedule.derive_handshake_secrets(b"client hello", b"server hello");
340
341 let exporter_secret = schedule.derive_exporter_secret();
342 assert!(exporter_secret.is_ok());
343 assert_eq!(exporter_secret.unwrap().len(), 32);
344 }
345
346 #[test]
347 fn test_resumption_secret() {
348 let shared_secret = [0x42u8; 32];
349 let mut schedule = Tls13KeySchedule::new(&shared_secret);
350
351 schedule.derive_handshake_secrets(b"client hello", b"server hello");
352
353 let transcript = Sha256::digest(b"full handshake transcript");
354 let resumption_secret = schedule.derive_resumption_secret(&transcript);
355 assert!(resumption_secret.is_ok());
356 assert_eq!(resumption_secret.unwrap().len(), 32);
357 }
358
359 #[test]
360 fn test_traffic_key_update() {
361 let current_secret = [0x42u8; 32];
362 let new_secret = Tls13KeySchedule::update_traffic_secret(¤t_secret);
363
364 assert_ne!(current_secret, new_secret);
366 }
367
368 #[test]
369 fn test_derive_traffic_keys() {
370 let traffic_secret = [0x42u8; 32];
371 let (key, iv) = derive_traffic_keys(&traffic_secret);
372
373 assert_eq!(key.len(), 32);
374 assert_eq!(iv.len(), 12);
375 }
376
377 #[test]
378 fn test_hkdf_extract() {
379 let salt = [0x01u8; 32];
380 let ikm = [0x02u8; 32];
381
382 let prk = hkdf_extract(&salt, &ikm);
383 assert_eq!(prk.len(), 32);
384
385 let prk2 = hkdf_extract(&salt, &ikm);
387 assert_eq!(prk, prk2);
388 }
389
390 #[test]
391 fn test_hkdf_expand() {
392 let prk = [0x42u8; 32];
393 let info = b"test info";
394
395 let okm = hkdf_expand(&prk, info, 64);
396 assert_eq!(okm.len(), 64);
397
398 let okm2 = hkdf_expand(&prk, info, 64);
400 assert_eq!(okm, okm2);
401 }
402
403 #[test]
404 fn test_hkdf_expand_label() {
405 let secret = [0x42u8; 32];
406 let label = b"test label";
407 let context = b"test context";
408
409 let output = hkdf_expand_label(&secret, label, context, 32);
410 assert_eq!(output.len(), 32);
411
412 let output2 = hkdf_expand_label(&secret, label, context, 32);
414 assert_eq!(output, output2);
415 }
416
417 #[test]
418 fn test_derive_secret() {
419 let secret = [0x42u8; 32];
420 let label = b"test";
421 let messages = b"messages";
422
423 let derived = derive_secret(&secret, label, messages);
424 assert_eq!(derived.len(), 32);
425
426 let derived2 = derive_secret(&secret, label, messages);
428 assert_eq!(derived, derived2);
429 }
430
431 #[test]
432 fn test_serialization() {
433 let shared_secret = [0x42u8; 32];
434 let schedule = Tls13KeySchedule::new(&shared_secret);
435
436 let serialized = crate::codec::encode(&schedule).unwrap();
437 let deserialized: Tls13KeySchedule = crate::codec::decode(&serialized).unwrap();
438
439 assert_eq!(deserialized.early_secret, schedule.early_secret);
440 assert_eq!(deserialized.handshake_secret, schedule.handshake_secret);
441 }
442}