1use std::sync::Arc;
7
8use super::decrypt::{DecryptionContext, DecryptionError, Direction, TlsVersion};
9use super::kdf::{derive_tls12_keys, derive_tls13_keys, AeadAlgorithm, KeyDerivationError};
10use super::keylog::{KeyLog, KeyLogEntries};
11use thiserror::Error;
12
13#[derive(Debug, Error)]
15pub enum SessionError {
16 #[error("Key derivation failed: {0}")]
17 KeyDerivation(#[from] KeyDerivationError),
18
19 #[error("Decryption error: {0}")]
20 Decryption(#[from] DecryptionError),
21
22 #[error("Missing key material for client_random")]
23 MissingKeys,
24
25 #[error("Unsupported cipher suite: 0x{0:04x}")]
26 UnsupportedCipherSuite(u16),
27
28 #[error("Session not initialized: handshake incomplete")]
29 NotInitialized,
30
31 #[error("Missing client_random from ClientHello")]
32 MissingClientRandom,
33
34 #[error("Missing server_random from ServerHello")]
35 MissingServerRandom,
36
37 #[error("Missing cipher suite selection from ServerHello")]
38 MissingCipherSuite,
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum SessionState {
44 Initial,
46 ClientHelloReceived,
48 ServerHelloReceived,
50 Tls13HandshakeEncrypted,
52 KeysEstablished,
54 Closed,
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq)]
60pub enum Tls13HandshakePhase {
61 Initial,
63 ServerFinished,
65 Complete,
67}
68
69#[derive(Debug, Clone, Default)]
71pub struct HandshakeData {
72 pub client_random: Option<[u8; 32]>,
74
75 pub server_random: Option<[u8; 32]>,
77
78 pub cipher_suite: Option<u16>,
80
81 pub version: Option<TlsVersion>,
83
84 pub session_id: Option<Vec<u8>>,
86}
87
88impl HandshakeData {
89 pub fn can_derive_keys(&self) -> bool {
91 self.client_random.is_some() && self.server_random.is_some() && self.cipher_suite.is_some()
92 }
93
94 pub fn effective_version(&self) -> Option<TlsVersion> {
96 self.version
97 }
98}
99
100pub struct TlsSession {
109 state: SessionState,
111
112 handshake: HandshakeData,
114
115 keylog: Arc<KeyLog>,
117
118 client_decrypt: Option<DecryptionContext>,
120
121 server_decrypt: Option<DecryptionContext>,
123
124 client_hs_decrypt: Option<DecryptionContext>,
126
127 server_hs_decrypt: Option<DecryptionContext>,
129
130 tls13_hs_phase: Tls13HandshakePhase,
132}
133
134impl TlsSession {
135 pub fn new(keylog: Arc<KeyLog>) -> Self {
137 Self {
138 state: SessionState::Initial,
139 handshake: HandshakeData::default(),
140 keylog,
141 client_decrypt: None,
142 server_decrypt: None,
143 client_hs_decrypt: None,
144 server_hs_decrypt: None,
145 tls13_hs_phase: Tls13HandshakePhase::Initial,
146 }
147 }
148
149 pub fn state(&self) -> SessionState {
151 self.state
152 }
153
154 pub fn handshake(&self) -> &HandshakeData {
156 &self.handshake
157 }
158
159 pub fn process_client_hello(&mut self, client_random: [u8; 32]) {
163 self.handshake.client_random = Some(client_random);
164 self.state = SessionState::ClientHelloReceived;
165 }
166
167 pub fn process_server_hello(
171 &mut self,
172 server_random: [u8; 32],
173 cipher_suite: u16,
174 version: TlsVersion,
175 ) -> Result<(), SessionError> {
176 self.handshake.server_random = Some(server_random);
177 self.handshake.cipher_suite = Some(cipher_suite);
178 self.handshake.version = Some(version);
179 self.state = SessionState::ServerHelloReceived;
180
181 self.try_establish_keys()
183 }
184
185 pub fn try_establish_keys(&mut self) -> Result<(), SessionError> {
192 if self.state == SessionState::KeysEstablished
193 || self.state == SessionState::Tls13HandshakeEncrypted
194 {
195 return Ok(()); }
197
198 let client_random = self
199 .handshake
200 .client_random
201 .ok_or(SessionError::MissingClientRandom)?;
202
203 let server_random = self
204 .handshake
205 .server_random
206 .ok_or(SessionError::MissingServerRandom)?;
207
208 let cipher_suite = self
209 .handshake
210 .cipher_suite
211 .ok_or(SessionError::MissingCipherSuite)?;
212
213 let version = self.handshake.version.unwrap_or(TlsVersion::Tls12);
214
215 let key_entries = self
217 .keylog
218 .lookup(&client_random)
219 .ok_or(SessionError::MissingKeys)?
220 .clone();
221
222 let aead = AeadAlgorithm::from_cipher_suite(cipher_suite)
224 .ok_or(SessionError::UnsupportedCipherSuite(cipher_suite))?;
225
226 match version {
228 TlsVersion::Tls13 => {
229 self.establish_tls13_keys(&key_entries, cipher_suite, aead)?;
230 self.state = SessionState::Tls13HandshakeEncrypted;
232 }
233 _ => {
234 self.establish_tls12_keys(
235 &key_entries,
236 &client_random,
237 &server_random,
238 cipher_suite,
239 aead,
240 )?;
241 self.state = SessionState::KeysEstablished;
242 }
243 }
244
245 Ok(())
246 }
247
248 fn establish_tls12_keys(
250 &mut self,
251 key_entries: &KeyLogEntries,
252 client_random: &[u8; 32],
253 server_random: &[u8; 32],
254 cipher_suite: u16,
255 aead: AeadAlgorithm,
256 ) -> Result<(), SessionError> {
257 let master_secret = key_entries.master_secret.ok_or(SessionError::MissingKeys)?;
258
259 let keys = derive_tls12_keys(&master_secret, client_random, server_random, cipher_suite)?;
260
261 self.client_decrypt = Some(DecryptionContext::new_tls12(
262 &keys,
263 aead,
264 Direction::ClientToServer,
265 )?);
266 self.server_decrypt = Some(DecryptionContext::new_tls12(
267 &keys,
268 aead,
269 Direction::ServerToClient,
270 )?);
271
272 Ok(())
273 }
274
275 fn establish_tls13_keys(
281 &mut self,
282 key_entries: &KeyLogEntries,
283 cipher_suite: u16,
284 aead: AeadAlgorithm,
285 ) -> Result<(), SessionError> {
286 if let (Some(client_hs_secret), Some(server_hs_secret)) = (
288 key_entries.client_handshake_traffic_secret.as_ref(),
289 key_entries.server_handshake_traffic_secret.as_ref(),
290 ) {
291 let client_hs_keys = derive_tls13_keys(client_hs_secret, cipher_suite)?;
292 let server_hs_keys = derive_tls13_keys(server_hs_secret, cipher_suite)?;
293
294 self.client_hs_decrypt = Some(DecryptionContext::new_tls13(&client_hs_keys, aead)?);
295 self.server_hs_decrypt = Some(DecryptionContext::new_tls13(&server_hs_keys, aead)?);
296 }
297
298 let client_secret = key_entries
300 .client_traffic_secret_0
301 .as_ref()
302 .ok_or(SessionError::MissingKeys)?;
303
304 let server_secret = key_entries
305 .server_traffic_secret_0
306 .as_ref()
307 .ok_or(SessionError::MissingKeys)?;
308
309 let client_keys = derive_tls13_keys(client_secret, cipher_suite)?;
310 let server_keys = derive_tls13_keys(server_secret, cipher_suite)?;
311
312 self.client_decrypt = Some(DecryptionContext::new_tls13(&client_keys, aead)?);
313 self.server_decrypt = Some(DecryptionContext::new_tls13(&server_keys, aead)?);
314
315 self.tls13_hs_phase = Tls13HandshakePhase::Initial;
317
318 Ok(())
319 }
320
321 pub fn can_decrypt(&self) -> bool {
323 match self.state {
324 SessionState::KeysEstablished => {
325 self.client_decrypt.is_some() && self.server_decrypt.is_some()
326 }
327 SessionState::Tls13HandshakeEncrypted => {
328 self.client_hs_decrypt.is_some() && self.server_hs_decrypt.is_some()
330 }
331 _ => false,
332 }
333 }
334
335 pub fn is_tls13_handshake_phase(&self) -> bool {
337 self.state == SessionState::Tls13HandshakeEncrypted
338 }
339
340 pub fn tls13_handshake_phase(&self) -> Tls13HandshakePhase {
342 self.tls13_hs_phase
343 }
344
345 pub fn transition_to_application_data(&mut self) {
348 if self.state == SessionState::Tls13HandshakeEncrypted {
349 self.state = SessionState::KeysEstablished;
350 self.tls13_hs_phase = Tls13HandshakePhase::Complete;
351 }
352 }
353
354 pub fn mark_server_finished(&mut self) {
356 if self.tls13_hs_phase == Tls13HandshakePhase::Initial {
357 self.tls13_hs_phase = Tls13HandshakePhase::ServerFinished;
358 }
359 }
360
361 pub fn mark_client_finished(&mut self) {
364 self.transition_to_application_data();
365 }
366
367 pub fn decrypt_record(
372 &mut self,
373 direction: Direction,
374 record_type: u8,
375 ciphertext: &[u8],
376 ) -> Result<Vec<u8>, SessionError> {
377 if !self.can_decrypt() {
378 return Err(SessionError::NotInitialized);
379 }
380
381 let version = self.handshake.version.unwrap_or(TlsVersion::Tls12);
382 let protocol_version = version.to_wire();
383
384 let ctx = if self.state == SessionState::Tls13HandshakeEncrypted {
386 match direction {
388 Direction::ClientToServer => self.client_hs_decrypt.as_mut(),
389 Direction::ServerToClient => self.server_hs_decrypt.as_mut(),
390 }
391 } else {
392 match direction {
394 Direction::ClientToServer => self.client_decrypt.as_mut(),
395 Direction::ServerToClient => self.server_decrypt.as_mut(),
396 }
397 };
398
399 let ctx = ctx.ok_or(SessionError::NotInitialized)?;
400 let plaintext = ctx.decrypt_record(version, record_type, protocol_version, ciphertext)?;
401 Ok(plaintext)
402 }
403
404 pub fn decrypt_handshake_record(
407 &mut self,
408 direction: Direction,
409 record_type: u8,
410 ciphertext: &[u8],
411 ) -> Result<Vec<u8>, SessionError> {
412 let ctx = match direction {
413 Direction::ClientToServer => self.client_hs_decrypt.as_mut(),
414 Direction::ServerToClient => self.server_hs_decrypt.as_mut(),
415 };
416
417 let ctx = ctx.ok_or(SessionError::NotInitialized)?;
418 let version = TlsVersion::Tls13;
419 let protocol_version = version.to_wire();
420
421 let plaintext = ctx.decrypt_record(version, record_type, protocol_version, ciphertext)?;
422 Ok(plaintext)
423 }
424
425 pub fn decrypt_application_record(
428 &mut self,
429 direction: Direction,
430 record_type: u8,
431 ciphertext: &[u8],
432 ) -> Result<Vec<u8>, SessionError> {
433 let ctx = match direction {
434 Direction::ClientToServer => self.client_decrypt.as_mut(),
435 Direction::ServerToClient => self.server_decrypt.as_mut(),
436 };
437
438 let ctx = ctx.ok_or(SessionError::NotInitialized)?;
439 let version = TlsVersion::Tls13;
440 let protocol_version = version.to_wire();
441
442 let plaintext = ctx.decrypt_record(version, record_type, protocol_version, ciphertext)?;
443 Ok(plaintext)
444 }
445
446 pub fn cipher_suite_name(&self) -> Option<&'static str> {
448 self.handshake.cipher_suite.and_then(cipher_suite_name)
449 }
450
451 pub fn client_sequence(&self) -> Option<u64> {
453 self.client_decrypt.as_ref().map(|c| c.sequence_number())
454 }
455
456 pub fn server_sequence(&self) -> Option<u64> {
458 self.server_decrypt.as_ref().map(|c| c.sequence_number())
459 }
460
461 pub fn close(&mut self) {
463 self.state = SessionState::Closed;
464 }
465}
466
467fn cipher_suite_name(id: u16) -> Option<&'static str> {
469 match id {
470 0x1301 => Some("TLS_AES_128_GCM_SHA256"),
472 0x1302 => Some("TLS_AES_256_GCM_SHA384"),
473 0x1303 => Some("TLS_CHACHA20_POLY1305_SHA256"),
474
475 0xC02F => Some("TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"),
477 0xC030 => Some("TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384"),
478 0xCCA8 => Some("TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256"),
479
480 0xC02B => Some("TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"),
482 0xC02C => Some("TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384"),
483 0xCCA9 => Some("TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256"),
484
485 0x009E => Some("TLS_DHE_RSA_WITH_AES_128_GCM_SHA256"),
487 0x009F => Some("TLS_DHE_RSA_WITH_AES_256_GCM_SHA384"),
488 0xCCAA => Some("TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256"),
489
490 0x009C => Some("TLS_RSA_WITH_AES_128_GCM_SHA256"),
492 0x009D => Some("TLS_RSA_WITH_AES_256_GCM_SHA384"),
493
494 _ => None,
495 }
496}
497
498#[cfg(test)]
499mod tests {
500 use super::*;
501
502 fn create_test_keylog() -> Arc<KeyLog> {
503 let content = "CLIENT_RANDOM 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef 000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f";
505 Arc::new(KeyLog::parse(content).unwrap())
506 }
507
508 fn create_test_keylog_tls13() -> Arc<KeyLog> {
509 let content = r#"
510CLIENT_HANDSHAKE_TRAFFIC_SECRET 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef deadbeef00112233445566778899aabbccddeeff00112233445566778899aabb
511SERVER_HANDSHAKE_TRAFFIC_SECRET 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef cafebabe556677889900aabbccddeeff00112233445566778899aabbccddeeff
512CLIENT_TRAFFIC_SECRET_0 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef aabbccdd00112233445566778899aabbccddeeff00112233445566778899aabb
513SERVER_TRAFFIC_SECRET_0 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef 11223344556677889900aabbccddeeff00112233445566778899aabbccddeeff
514"#;
515 Arc::new(KeyLog::parse(content).unwrap())
516 }
517
518 #[test]
519 fn test_session_initial_state() {
520 let keylog = create_test_keylog();
521 let session = TlsSession::new(keylog);
522
523 assert_eq!(session.state(), SessionState::Initial);
524 assert!(!session.can_decrypt());
525 }
526
527 #[test]
528 fn test_session_client_hello() {
529 let keylog = create_test_keylog();
530 let mut session = TlsSession::new(keylog);
531
532 let client_random = [0x42u8; 32];
533 session.process_client_hello(client_random);
534
535 assert_eq!(session.state(), SessionState::ClientHelloReceived);
536 assert_eq!(session.handshake().client_random, Some(client_random));
537 }
538
539 #[test]
540 fn test_session_server_hello_missing_keys() {
541 let keylog = create_test_keylog();
542 let mut session = TlsSession::new(keylog);
543
544 let client_random = [0x42u8; 32];
546 session.process_client_hello(client_random);
547
548 let server_random = [0x43u8; 32];
549 let result = session.process_server_hello(server_random, 0xC02F, TlsVersion::Tls12);
550
551 assert!(matches!(result, Err(SessionError::MissingKeys)));
553 }
554
555 #[test]
556 fn test_session_tls12_key_establishment() {
557 let keylog = create_test_keylog();
558 let mut session = TlsSession::new(keylog);
559
560 let client_random: [u8; 32] = [
562 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0x01, 0x23, 0x45, 0x67, 0x89, 0xab,
563 0xcd, 0xef, 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0x01, 0x23, 0x45, 0x67,
564 0x89, 0xab, 0xcd, 0xef,
565 ];
566 session.process_client_hello(client_random);
567
568 let server_random = [0x43u8; 32];
569 let result = session.process_server_hello(server_random, 0xC02F, TlsVersion::Tls12);
570
571 assert!(result.is_ok());
572 assert_eq!(session.state(), SessionState::KeysEstablished);
573 assert!(session.can_decrypt());
574 assert_eq!(
575 session.cipher_suite_name(),
576 Some("TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256")
577 );
578 }
579
580 #[test]
581 fn test_session_tls13_key_establishment() {
582 let keylog = create_test_keylog_tls13();
583 let mut session = TlsSession::new(keylog);
584
585 let client_random: [u8; 32] = [
587 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0x01, 0x23, 0x45, 0x67, 0x89, 0xab,
588 0xcd, 0xef, 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0x01, 0x23, 0x45, 0x67,
589 0x89, 0xab, 0xcd, 0xef,
590 ];
591 session.process_client_hello(client_random);
592
593 let server_random = [0x43u8; 32];
594 let result = session.process_server_hello(server_random, 0x1301, TlsVersion::Tls13);
595
596 assert!(result.is_ok());
597 assert_eq!(session.state(), SessionState::Tls13HandshakeEncrypted);
599 assert!(session.can_decrypt());
600 assert_eq!(session.cipher_suite_name(), Some("TLS_AES_128_GCM_SHA256"));
601
602 assert!(session.is_tls13_handshake_phase());
604 session.mark_server_finished();
605 assert_eq!(
606 session.tls13_handshake_phase(),
607 Tls13HandshakePhase::ServerFinished
608 );
609 session.mark_client_finished();
610 assert_eq!(session.state(), SessionState::KeysEstablished);
611 assert!(!session.is_tls13_handshake_phase());
612 }
613
614 #[test]
615 fn test_session_unsupported_cipher_suite() {
616 let keylog = create_test_keylog();
617 let mut session = TlsSession::new(keylog);
618
619 let client_random: [u8; 32] = [
620 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0x01, 0x23, 0x45, 0x67, 0x89, 0xab,
621 0xcd, 0xef, 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0x01, 0x23, 0x45, 0x67,
622 0x89, 0xab, 0xcd, 0xef,
623 ];
624 session.process_client_hello(client_random);
625
626 let server_random = [0x43u8; 32];
627 let result = session.process_server_hello(server_random, 0x0000, TlsVersion::Tls12);
629
630 assert!(matches!(
631 result,
632 Err(SessionError::UnsupportedCipherSuite(0x0000))
633 ));
634 }
635
636 #[test]
637 fn test_session_close() {
638 let keylog = create_test_keylog();
639 let mut session = TlsSession::new(keylog);
640
641 session.close();
642 assert_eq!(session.state(), SessionState::Closed);
643 }
644
645 #[test]
646 fn test_decrypt_not_initialized() {
647 let keylog = create_test_keylog();
648 let mut session = TlsSession::new(keylog);
649
650 let result = session.decrypt_record(Direction::ClientToServer, 23, &[0u8; 32]);
651 assert!(matches!(result, Err(SessionError::NotInitialized)));
652 }
653
654 #[test]
655 fn test_handshake_data_can_derive_keys() {
656 let mut data = HandshakeData::default();
657 assert!(!data.can_derive_keys());
658
659 data.client_random = Some([0u8; 32]);
660 assert!(!data.can_derive_keys());
661
662 data.server_random = Some([0u8; 32]);
663 assert!(!data.can_derive_keys());
664
665 data.cipher_suite = Some(0xC02F);
666 assert!(data.can_derive_keys());
667 }
668
669 #[test]
670 fn test_cipher_suite_name() {
671 assert_eq!(cipher_suite_name(0x1301), Some("TLS_AES_128_GCM_SHA256"));
672 assert_eq!(cipher_suite_name(0x1302), Some("TLS_AES_256_GCM_SHA384"));
673 assert_eq!(
674 cipher_suite_name(0xC02F),
675 Some("TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256")
676 );
677 assert_eq!(cipher_suite_name(0x0000), None);
678 }
679}