1use std::collections::HashMap;
7use std::sync::{Arc, Mutex};
8
9use bytes::BytesMut;
10use compact_str::CompactString;
11use tls_parser::{
12 parse_tls_extensions, parse_tls_plaintext, TlsExtension, TlsMessage, TlsMessageHandshake,
13};
14
15use crate::protocol::{FieldValue, OwnedFieldValue};
16use crate::schema::{DataKind, FieldDescriptor};
17use crate::stream::{Direction, ParsedMessage, StreamContext, StreamParseResult, StreamParser};
18use crate::tls::{
19 extract_tls13_inner_content_type, Direction as TlsDirection, KeyLog, TlsSession, TlsVersion,
20};
21
22const TLS_PORTS: &[u16] = &[
24 443, 8443, 993, 995, 465, 636, 853, 5061, 14433, ];
34
35fn is_tls_port(port: u16) -> bool {
37 TLS_PORTS.contains(&port)
38}
39
40struct ConnectionTlsState {
42 session: TlsSession,
44 buffer: BytesMut,
46 handshake_complete: bool,
48 change_cipher_spec_seen: bool,
50}
51
52impl ConnectionTlsState {
53 fn new(keylog: Arc<KeyLog>) -> Self {
54 Self {
55 session: TlsSession::new(keylog),
56 buffer: BytesMut::new(),
57 handshake_complete: false,
58 change_cipher_spec_seen: false,
59 }
60 }
61}
62
63pub struct DecryptingTlsStreamParser {
69 keylog: Arc<KeyLog>,
71 sessions: Arc<Mutex<HashMap<u64, ConnectionTlsState>>>,
73}
74
75impl DecryptingTlsStreamParser {
76 pub fn new(keylog: KeyLog) -> Self {
78 Self {
79 keylog: Arc::new(keylog),
80 sessions: Arc::new(Mutex::new(HashMap::new())),
81 }
82 }
83
84 pub fn with_keylog(keylog: Arc<KeyLog>) -> Self {
86 Self {
87 keylog,
88 sessions: Arc::new(Mutex::new(HashMap::new())),
89 }
90 }
91
92 #[allow(dead_code)]
94 fn get_or_create_state(&self, connection_id: u64) -> ConnectionTlsState {
95 let mut sessions = self.sessions.lock().unwrap();
96 sessions
97 .entry(connection_id)
98 .or_insert_with(|| ConnectionTlsState::new(Arc::clone(&self.keylog)))
99 .clone_state()
100 }
101
102 #[allow(dead_code)]
104 fn update_state(&self, connection_id: u64, state: ConnectionTlsState) {
105 let mut sessions = self.sessions.lock().unwrap();
106 sessions.insert(connection_id, state);
107 }
108
109 pub fn remove_connection(&self, connection_id: u64) {
111 let mut sessions = self.sessions.lock().unwrap();
112 sessions.remove(&connection_id);
113 }
114
115 fn to_tls_direction(direction: Direction) -> TlsDirection {
117 match direction {
118 Direction::ToServer => TlsDirection::ClientToServer,
119 Direction::ToClient => TlsDirection::ServerToClient,
120 }
121 }
122
123 fn process_handshake(
125 state: &mut ConnectionTlsState,
126 handshake: &TlsMessageHandshake,
127 _direction: Direction,
128 fields: &mut HashMap<&'static str, OwnedFieldValue>,
129 ) {
130 match handshake {
131 TlsMessageHandshake::ClientHello(ch) => {
132 let mut client_random = [0u8; 32];
134 client_random.copy_from_slice(ch.random);
135 state.session.process_client_hello(client_random);
136
137 fields.insert("handshake_type", FieldValue::Str("ClientHello"));
138
139 if let Some(ext_data) = &ch.ext {
141 if let Ok((_, extensions)) = parse_tls_extensions(ext_data) {
142 for ext in extensions {
143 match ext {
144 TlsExtension::SNI(sni_list) => {
145 for (name_type, name) in sni_list {
146 if name_type.0 == 0 {
147 if let Ok(sni) = std::str::from_utf8(name) {
149 fields.insert(
150 "sni",
151 FieldValue::OwnedString(CompactString::new(
152 sni,
153 )),
154 );
155 }
156 }
157 }
158 }
159 TlsExtension::ALPN(alpn_list) => {
160 let protocols: Vec<&str> = alpn_list
161 .iter()
162 .filter_map(|p| std::str::from_utf8(p).ok())
163 .collect();
164 if !protocols.is_empty() {
165 fields.insert(
166 "alpn",
167 FieldValue::OwnedString(CompactString::new(
168 protocols.join(","),
169 )),
170 );
171 }
172 }
173 _ => {}
174 }
175 }
176 }
177 }
178 }
179
180 TlsMessageHandshake::ServerHello(sh) => {
181 let mut server_random = [0u8; 32];
183 server_random.copy_from_slice(sh.random);
184
185 let cipher_suite = sh.cipher.0;
186
187 let version = if let Some(ext_data) = &sh.ext {
189 detect_tls13_from_extensions(ext_data).unwrap_or_else(|| {
190 TlsVersion::from_wire(sh.version.0).unwrap_or(TlsVersion::Tls12)
191 })
192 } else {
193 TlsVersion::from_wire(sh.version.0).unwrap_or(TlsVersion::Tls12)
194 };
195
196 let _ = state
198 .session
199 .process_server_hello(server_random, cipher_suite, version);
200
201 fields.insert("handshake_type", FieldValue::Str("ServerHello"));
202 fields.insert("cipher_suite", FieldValue::UInt16(cipher_suite));
203
204 if let Some(name) = state.session.cipher_suite_name() {
205 fields.insert(
206 "cipher_suite_name",
207 FieldValue::OwnedString(CompactString::new(name)),
208 );
209 }
210 }
211
212 _ => {
213 }
215 }
216 }
217
218 fn process_encrypted_handshake(
223 state: &mut ConnectionTlsState,
224 handshake_data: &[u8],
225 direction: Direction,
226 fields: &mut HashMap<&'static str, OwnedFieldValue>,
227 ) {
228 if handshake_data.len() < 4 {
230 return;
231 }
232
233 let hs_type = handshake_data[0];
234 let _hs_len =
235 u32::from_be_bytes([0, handshake_data[1], handshake_data[2], handshake_data[3]])
236 as usize;
237
238 let hs_type_name = match hs_type {
239 1 => "ClientHello",
240 2 => "ServerHello",
241 4 => "NewSessionTicket",
242 8 => "EncryptedExtensions",
243 11 => "Certificate",
244 13 => "CertificateRequest",
245 15 => "CertificateVerify",
246 20 => "Finished",
247 24 => "KeyUpdate",
248 _ => "Unknown",
249 };
250
251 fields.insert(
252 "encrypted_hs_type",
253 FieldValue::OwnedString(CompactString::new(hs_type_name)),
254 );
255
256 if hs_type == 20 {
258 match direction {
260 Direction::ToClient => {
261 state.session.mark_server_finished();
263 fields.insert("hs_finished", FieldValue::Str("server"));
264 }
265 Direction::ToServer => {
266 state.session.mark_client_finished();
268 state.handshake_complete = true;
269 fields.insert("hs_finished", FieldValue::Str("client"));
270 }
271 }
272 }
273 }
274}
275
276fn detect_tls13_from_extensions(ext_data: &[u8]) -> Option<TlsVersion> {
278 if let Ok((_, extensions)) = parse_tls_extensions(ext_data) {
279 for ext in extensions {
280 if let TlsExtension::SupportedVersions(versions) = ext {
281 for v in versions {
283 if v.0 == 0x0304 {
284 return Some(TlsVersion::Tls13);
285 }
286 }
287 }
288 }
289 }
290 None
291}
292
293impl ConnectionTlsState {
294 #[allow(dead_code)]
296 fn clone_state(&self) -> Self {
297 panic!("clone_state should not be called - use Arc<Mutex<>> directly")
301 }
302}
303
304impl StreamParser for DecryptingTlsStreamParser {
305 fn name(&self) -> &'static str {
306 "tls_decrypt"
307 }
308
309 fn display_name(&self) -> &'static str {
310 "TLS (Decrypting)"
311 }
312
313 fn can_parse_stream(&self, context: &StreamContext) -> bool {
314 is_tls_port(context.dst_port) || is_tls_port(context.src_port)
315 }
316
317 fn parse_stream(&self, data: &[u8], context: &StreamContext) -> StreamParseResult {
318 let mut sessions = self.sessions.lock().unwrap();
319 let state = sessions
320 .entry(context.connection_id)
321 .or_insert_with(|| ConnectionTlsState::new(Arc::clone(&self.keylog)));
322
323 state.buffer.extend_from_slice(data);
325
326 let mut messages = Vec::new();
327 let mut decrypted_data = Vec::new();
328 let mut total_consumed = 0;
329
330 loop {
332 if state.buffer.len() < 5 {
333 break; }
335
336 let content_type = state.buffer[0];
338 let version = u16::from_be_bytes([state.buffer[1], state.buffer[2]]);
339 let length = u16::from_be_bytes([state.buffer[3], state.buffer[4]]) as usize;
340 let record_len = 5 + length;
341
342 if state.buffer.len() < record_len {
343 break; }
345
346 let record_data = state.buffer[..record_len].to_vec();
348 let payload = &record_data[5..];
349
350 let mut fields = HashMap::new();
351 fields.insert(
352 "version",
353 FieldValue::OwnedString(CompactString::new(version_name(version))),
354 );
355
356 match content_type {
357 22 => {
358 fields.insert("record_type", FieldValue::Str("Handshake"));
360
361 if let Ok((_, record)) = parse_tls_plaintext(&record_data) {
363 for msg in &record.msg {
364 if let TlsMessage::Handshake(hs) = msg {
365 Self::process_handshake(state, hs, context.direction, &mut fields);
366 }
367 }
368 }
369 }
370
371 23 => {
372 fields.insert("record_type", FieldValue::Str("ApplicationData"));
374
375 if state.session.can_decrypt() {
377 let tls_dir = Self::to_tls_direction(context.direction);
378 match state.session.decrypt_record(tls_dir, content_type, payload) {
379 Ok(plaintext) => {
380 if state.session.is_tls13_handshake_phase() {
382 if let Some((inner_type, inner_data)) =
384 extract_tls13_inner_content_type(&plaintext)
385 {
386 fields.insert(
387 "inner_content_type",
388 FieldValue::UInt8(inner_type),
389 );
390
391 if inner_type == 22 {
392 Self::process_encrypted_handshake(
394 state,
395 inner_data,
396 context.direction,
397 &mut fields,
398 );
399 } else if inner_type == 23 {
400 decrypted_data.extend_from_slice(inner_data);
402 }
403 }
404 } else {
405 let version = state.session.handshake().effective_version();
408 if version == Some(TlsVersion::Tls13) {
409 if let Some((inner_type, inner_data)) =
410 extract_tls13_inner_content_type(&plaintext)
411 {
412 if inner_type == 23 {
413 decrypted_data.extend_from_slice(inner_data);
414 }
415 else if inner_type == 22 {
417 fields.insert(
418 "inner_content_type",
419 FieldValue::UInt8(inner_type),
420 );
421 }
422 }
423 } else {
424 decrypted_data.extend_from_slice(&plaintext);
426 }
427 }
428
429 fields.insert(
430 "decrypted_length",
431 FieldValue::UInt64(plaintext.len() as u64),
432 );
433 }
434 Err(e) => {
435 fields.insert(
437 "decrypt_error",
438 FieldValue::OwnedString(CompactString::new(e.to_string())),
439 );
440 }
441 }
442 } else {
443 fields.insert("encrypted_length", FieldValue::UInt16(length as u16));
444 }
445 }
446
447 20 => {
448 fields.insert("record_type", FieldValue::Str("ChangeCipherSpec"));
450 state.change_cipher_spec_seen = true;
451 }
452
453 21 => {
454 fields.insert("record_type", FieldValue::Str("Alert"));
456 if payload.len() >= 2 {
457 fields.insert("alert_level", FieldValue::UInt8(payload[0]));
458 fields.insert("alert_description", FieldValue::UInt8(payload[1]));
459 }
460 }
461
462 _ => {
463 fields.insert(
465 "record_type",
466 FieldValue::OwnedString(CompactString::new(format!(
467 "Unknown({content_type})"
468 ))),
469 );
470 }
471 }
472
473 let message = ParsedMessage {
475 protocol: "tls",
476 connection_id: context.connection_id,
477 message_id: context.messages_parsed as u32 + messages.len() as u32,
478 direction: context.direction,
479 frame_number: 0, fields,
481 };
482 messages.push(message);
483
484 state.buffer = state.buffer.split_off(record_len);
486 total_consumed += record_len;
487 }
488
489 if !decrypted_data.is_empty() {
491 let child_protocol = "http2"; let metadata = if !messages.is_empty() {
497 Some(messages.remove(0))
498 } else {
499 None
500 };
501
502 StreamParseResult::Transform {
503 child_protocol,
504 child_data: decrypted_data,
505 bytes_consumed: total_consumed,
506 metadata,
507 }
508 } else if !messages.is_empty() {
509 StreamParseResult::Complete {
511 messages,
512 bytes_consumed: total_consumed,
513 }
514 } else if total_consumed == 0 {
515 StreamParseResult::NeedMore {
517 minimum_bytes: Some(5), }
519 } else {
520 StreamParseResult::Complete {
521 messages: vec![],
522 bytes_consumed: total_consumed,
523 }
524 }
525 }
526
527 fn message_schema(&self) -> Vec<FieldDescriptor> {
528 vec![
529 FieldDescriptor::new("connection_id", DataKind::UInt64),
530 FieldDescriptor::new("record_type", DataKind::String).set_nullable(true),
531 FieldDescriptor::new("version", DataKind::String).set_nullable(true),
532 FieldDescriptor::new("handshake_type", DataKind::String).set_nullable(true),
533 FieldDescriptor::new("sni", DataKind::String).set_nullable(true),
534 FieldDescriptor::new("alpn", DataKind::String).set_nullable(true),
535 FieldDescriptor::new("cipher_suite", DataKind::UInt16).set_nullable(true),
536 FieldDescriptor::new("cipher_suite_name", DataKind::String).set_nullable(true),
537 FieldDescriptor::new("decrypted_length", DataKind::UInt64).set_nullable(true),
538 FieldDescriptor::new("encrypted_length", DataKind::UInt16).set_nullable(true),
539 FieldDescriptor::new("decrypt_error", DataKind::String).set_nullable(true),
540 FieldDescriptor::new("alert_level", DataKind::UInt8).set_nullable(true),
541 FieldDescriptor::new("alert_description", DataKind::UInt8).set_nullable(true),
542 ]
543 }
544}
545
546fn version_name(version: u16) -> &'static str {
548 match version {
549 0x0300 => "SSL 3.0",
550 0x0301 => "TLS 1.0",
551 0x0302 => "TLS 1.1",
552 0x0303 => "TLS 1.2",
553 0x0304 => "TLS 1.3",
554 _ => "Unknown",
555 }
556}
557
558#[cfg(test)]
559mod tests {
560 use super::*;
561 use std::net::Ipv4Addr;
562
563 fn test_context() -> StreamContext {
564 StreamContext {
565 connection_id: 1,
566 direction: Direction::ToServer,
567 src_ip: std::net::IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
568 dst_ip: std::net::IpAddr::V4(Ipv4Addr::new(192, 168, 1, 2)),
569 src_port: 54321,
570 dst_port: 443,
571 bytes_parsed: 0,
572 messages_parsed: 0,
573 alpn: None,
574 }
575 }
576
577 fn empty_keylog() -> KeyLog {
578 KeyLog::new()
579 }
580
581 #[test]
582 fn test_is_tls_port() {
583 assert!(is_tls_port(443));
584 assert!(is_tls_port(8443));
585 assert!(is_tls_port(993));
586 assert!(!is_tls_port(80));
587 assert!(!is_tls_port(22));
588 }
589
590 #[test]
591 fn test_can_parse_stream() {
592 let parser = DecryptingTlsStreamParser::new(empty_keylog());
593 let ctx = test_context();
594 assert!(parser.can_parse_stream(&ctx));
595
596 let mut ctx_http = ctx.clone();
597 ctx_http.dst_port = 80;
598 ctx_http.src_port = 54321;
599 assert!(!parser.can_parse_stream(&ctx_http));
600 }
601
602 #[test]
603 fn test_parse_incomplete_record() {
604 let parser = DecryptingTlsStreamParser::new(empty_keylog());
605 let ctx = test_context();
606
607 let data = vec![22, 3, 3];
609 let result = parser.parse_stream(&data, &ctx);
610
611 assert!(matches!(result, StreamParseResult::NeedMore { .. }));
612 }
613
614 #[test]
615 fn test_parse_handshake_record() {
616 let parser = DecryptingTlsStreamParser::new(empty_keylog());
617 let ctx = test_context();
618
619 let mut record = vec![
621 22, 3, 3, 0, 44, 1, 0, 0, 40, 3, 3, ];
628 record.extend_from_slice(&[0u8; 32]);
630 record.push(0);
632 record.extend_from_slice(&[0, 2, 0, 0xff]);
634 record.extend_from_slice(&[1, 0]);
636
637 let result = parser.parse_stream(&record, &ctx);
638
639 match result {
640 StreamParseResult::Complete {
641 messages,
642 bytes_consumed,
643 } => {
644 assert_eq!(bytes_consumed, 49); assert_eq!(messages.len(), 1);
646 assert_eq!(messages[0].protocol, "tls");
647 }
648 _ => panic!("Expected Complete, got {:?}", result),
649 }
650 }
651
652 #[test]
653 fn test_parse_application_data_without_keys() {
654 let parser = DecryptingTlsStreamParser::new(empty_keylog());
655 let ctx = test_context();
656
657 let record = vec![
659 23, 3, 3, 0, 10, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ];
664
665 let result = parser.parse_stream(&record, &ctx);
666
667 match result {
668 StreamParseResult::Complete {
669 messages,
670 bytes_consumed,
671 } => {
672 assert_eq!(bytes_consumed, 15);
673 assert_eq!(messages.len(), 1);
674 assert!(messages[0].fields.contains_key("encrypted_length"));
675 }
676 _ => panic!("Expected Complete, got {:?}", result),
677 }
678 }
679
680 #[test]
681 fn test_parse_alert_record() {
682 let parser = DecryptingTlsStreamParser::new(empty_keylog());
683 let ctx = test_context();
684
685 let record = vec![
687 21, 3, 3, 0, 2, 1, 0, ];
692
693 let result = parser.parse_stream(&record, &ctx);
694
695 match result {
696 StreamParseResult::Complete { messages, .. } => {
697 assert_eq!(messages.len(), 1);
698 assert_eq!(
699 messages[0].fields.get("alert_level"),
700 Some(&FieldValue::UInt8(1))
701 );
702 }
703 _ => panic!("Expected Complete"),
704 }
705 }
706
707 #[test]
708 fn test_multiple_records() {
709 let parser = DecryptingTlsStreamParser::new(empty_keylog());
710 let ctx = test_context();
711
712 let data = vec![
714 20, 3, 3, 0, 1, 1, 23, 3, 3, 0, 5, 1, 2, 3, 4, 5,
717 ];
718
719 let result = parser.parse_stream(&data, &ctx);
720
721 match result {
722 StreamParseResult::Complete {
723 messages,
724 bytes_consumed,
725 } => {
726 assert_eq!(bytes_consumed, 16); assert_eq!(messages.len(), 2);
728 }
729 _ => panic!("Expected Complete"),
730 }
731 }
732
733 #[test]
734 fn test_to_tls_direction() {
735 assert_eq!(
736 DecryptingTlsStreamParser::to_tls_direction(Direction::ToServer),
737 TlsDirection::ClientToServer
738 );
739 assert_eq!(
740 DecryptingTlsStreamParser::to_tls_direction(Direction::ToClient),
741 TlsDirection::ServerToClient
742 );
743 }
744
745 #[test]
746 fn test_version_name() {
747 assert_eq!(version_name(0x0301), "TLS 1.0");
748 assert_eq!(version_name(0x0303), "TLS 1.2");
749 assert_eq!(version_name(0x0304), "TLS 1.3");
750 assert_eq!(version_name(0x0000), "Unknown");
751 }
752}