1use std::collections::HashMap;
2
3use compact_str::CompactString;
4
5use crate::protocol::{FieldValue, OwnedFieldValue};
6use crate::schema::{DataKind, FieldDescriptor};
7use crate::stream::{ParsedMessage, StreamContext, StreamParseResult, StreamParser};
8
9mod content_type {
11 pub const CHANGE_CIPHER_SPEC: u8 = 20;
12 pub const ALERT: u8 = 21;
13 pub const HANDSHAKE: u8 = 22;
14 pub const APPLICATION_DATA: u8 = 23;
15}
16
17mod handshake_type {
19 pub const CLIENT_HELLO: u8 = 1;
20 pub const SERVER_HELLO: u8 = 2;
21 #[allow(dead_code)] pub const CERTIFICATE: u8 = 11;
23 #[allow(dead_code)] pub const SERVER_KEY_EXCHANGE: u8 = 12;
25 #[allow(dead_code)] pub const CERTIFICATE_REQUEST: u8 = 13;
27 #[allow(dead_code)] pub const SERVER_HELLO_DONE: u8 = 14;
29 #[allow(dead_code)] pub const CERTIFICATE_VERIFY: u8 = 15;
31 #[allow(dead_code)] pub const CLIENT_KEY_EXCHANGE: u8 = 16;
33 #[allow(dead_code)] pub const FINISHED: u8 = 20;
35}
36
37#[derive(Debug, Clone, Copy, Default)]
39pub struct TlsStreamParser;
40
41impl TlsStreamParser {
42 pub fn new() -> Self {
43 Self
44 }
45
46 fn parse_record_header(data: &[u8]) -> Option<(u8, u16, u16)> {
48 if data.len() < 5 {
49 return None;
50 }
51 let content_type = data[0];
52 let version = u16::from_be_bytes([data[1], data[2]]);
53 let length = u16::from_be_bytes([data[3], data[4]]);
54 Some((content_type, version, length))
55 }
56
57 fn extract_sni(extensions: &[u8]) -> Option<String> {
59 let mut pos = 0;
60 while pos + 4 <= extensions.len() {
61 let ext_type = u16::from_be_bytes([extensions[pos], extensions[pos + 1]]);
62 let ext_len = u16::from_be_bytes([extensions[pos + 2], extensions[pos + 3]]) as usize;
63 pos += 4;
64
65 if pos + ext_len > extensions.len() {
66 break;
67 }
68
69 if ext_type == 0 {
70 let ext_data = &extensions[pos..pos + ext_len];
72 if ext_data.len() >= 5 {
73 let name_len = u16::from_be_bytes([ext_data[3], ext_data[4]]) as usize;
74 if ext_data.len() >= 5 + name_len {
75 if let Ok(sni) = std::str::from_utf8(&ext_data[5..5 + name_len]) {
76 return Some(sni.to_string());
77 }
78 }
79 }
80 }
81
82 pos += ext_len;
83 }
84 None
85 }
86
87 fn extract_alpn(extensions: &[u8]) -> Option<String> {
89 let mut pos = 0;
90 while pos + 4 <= extensions.len() {
91 let ext_type = u16::from_be_bytes([extensions[pos], extensions[pos + 1]]);
92 let ext_len = u16::from_be_bytes([extensions[pos + 2], extensions[pos + 3]]) as usize;
93 pos += 4;
94
95 if pos + ext_len > extensions.len() {
96 break;
97 }
98
99 if ext_type == 16 {
100 let ext_data = &extensions[pos..pos + ext_len];
102 if ext_data.len() >= 3 {
103 let proto_len = ext_data[2] as usize;
104 if ext_data.len() >= 3 + proto_len {
105 if let Ok(alpn) = std::str::from_utf8(&ext_data[3..3 + proto_len]) {
106 return Some(alpn.to_string());
107 }
108 }
109 }
110 }
111
112 pos += ext_len;
113 }
114 None
115 }
116
117 fn parse_client_hello(&self, data: &[u8]) -> HashMap<&'static str, OwnedFieldValue> {
119 let mut fields = HashMap::new();
120 fields.insert("handshake_type", FieldValue::Str("ClientHello"));
121
122 if data.len() < 38 {
123 return fields;
124 }
125
126 let version = u16::from_be_bytes([data[0], data[1]]);
128 fields.insert("client_version", FieldValue::UInt16(version));
129
130 let mut pos = 34;
132 if pos >= data.len() {
133 return fields;
134 }
135 let session_id_len = data[pos] as usize;
136 pos += 1 + session_id_len;
137
138 if pos + 2 > data.len() {
140 return fields;
141 }
142 let cipher_suites_len = u16::from_be_bytes([data[pos], data[pos + 1]]) as usize;
143 pos += 2;
144
145 if pos + cipher_suites_len > data.len() {
146 return fields;
147 }
148 let cipher_count = cipher_suites_len / 2;
149 fields.insert(
150 "cipher_suite_count",
151 FieldValue::UInt16(cipher_count as u16),
152 );
153 pos += cipher_suites_len;
154
155 if pos >= data.len() {
157 return fields;
158 }
159 let comp_len = data[pos] as usize;
160 pos += 1 + comp_len;
161
162 if pos + 2 > data.len() {
164 return fields;
165 }
166 let ext_len = u16::from_be_bytes([data[pos], data[pos + 1]]) as usize;
167 pos += 2;
168
169 if pos + ext_len <= data.len() {
170 let extensions = &data[pos..pos + ext_len];
171 if let Some(sni) = Self::extract_sni(extensions) {
172 fields.insert("sni", FieldValue::OwnedString(CompactString::new(sni)));
173 }
174 if let Some(alpn) = Self::extract_alpn(extensions) {
175 fields.insert("alpn", FieldValue::OwnedString(CompactString::new(alpn)));
176 }
177 }
178
179 fields
180 }
181
182 fn parse_server_hello(&self, data: &[u8]) -> HashMap<&'static str, OwnedFieldValue> {
184 let mut fields = HashMap::new();
185 fields.insert("handshake_type", FieldValue::Str("ServerHello"));
186
187 if data.len() < 38 {
188 return fields;
189 }
190
191 let version = u16::from_be_bytes([data[0], data[1]]);
193 fields.insert("server_version", FieldValue::UInt16(version));
194
195 let mut pos = 34;
197 if pos >= data.len() {
198 return fields;
199 }
200 let session_id_len = data[pos] as usize;
201 pos += 1 + session_id_len;
202
203 if pos + 2 <= data.len() {
205 let cipher = u16::from_be_bytes([data[pos], data[pos + 1]]);
206 fields.insert("cipher_suite", FieldValue::UInt16(cipher));
207 fields.insert(
208 "cipher_suite_name",
209 FieldValue::OwnedString(CompactString::new(cipher_suite_name(cipher))),
210 );
211 }
212
213 fields
214 }
215
216 fn version_name(version: u16) -> &'static str {
218 match version {
219 0x0300 => "SSL 3.0",
220 0x0301 => "TLS 1.0",
221 0x0302 => "TLS 1.1",
222 0x0303 => "TLS 1.2",
223 0x0304 => "TLS 1.3",
224 _ => "Unknown",
225 }
226 }
227}
228
229fn cipher_suite_name(id: u16) -> String {
230 match id {
231 0x1301 => "TLS_AES_128_GCM_SHA256".to_string(),
232 0x1302 => "TLS_AES_256_GCM_SHA384".to_string(),
233 0x1303 => "TLS_CHACHA20_POLY1305_SHA256".to_string(),
234 0xc02f => "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256".to_string(),
235 0xc030 => "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384".to_string(),
236 _ => format!("0x{id:04x}"),
237 }
238}
239
240impl StreamParser for TlsStreamParser {
241 fn name(&self) -> &'static str {
242 "tls"
243 }
244
245 fn display_name(&self) -> &'static str {
246 "TLS"
247 }
248
249 fn can_parse_stream(&self, context: &StreamContext) -> bool {
250 context.dst_port == 443 || context.src_port == 443
251 }
252
253 fn parse_stream(&self, data: &[u8], context: &StreamContext) -> StreamParseResult {
254 let (content_type, version, length) = match Self::parse_record_header(data) {
256 Some(header) => header,
257 None => {
258 return StreamParseResult::NeedMore {
259 minimum_bytes: Some(5),
260 }
261 }
262 };
263
264 let record_len = 5 + length as usize;
265 if data.len() < record_len {
266 return StreamParseResult::NeedMore {
267 minimum_bytes: Some(record_len),
268 };
269 }
270
271 let mut fields = HashMap::new();
272 fields.insert("version", FieldValue::Str(Self::version_name(version)));
273 fields.insert("version_raw", FieldValue::UInt16(version));
274
275 match content_type {
276 content_type::HANDSHAKE => {
277 let handshake_data = &data[5..record_len];
278 if handshake_data.len() >= 4 {
279 let hs_type = handshake_data[0];
280 let hs_len = ((handshake_data[1] as usize) << 16)
281 | ((handshake_data[2] as usize) << 8)
282 | (handshake_data[3] as usize);
283
284 if handshake_data.len() >= 4 + hs_len {
285 let hs_body = &handshake_data[4..4 + hs_len];
286
287 let hs_fields = match hs_type {
288 handshake_type::CLIENT_HELLO => self.parse_client_hello(hs_body),
289 handshake_type::SERVER_HELLO => self.parse_server_hello(hs_body),
290 _ => {
291 let mut f = HashMap::new();
292 f.insert("handshake_type_id", FieldValue::UInt8(hs_type));
293 f
294 }
295 };
296
297 fields.extend(hs_fields);
298 }
299 }
300
301 fields.insert("record_type", FieldValue::Str("Handshake"));
302 }
303
304 content_type::APPLICATION_DATA => {
305 fields.insert("record_type", FieldValue::Str("ApplicationData"));
306 fields.insert("encrypted_length", FieldValue::UInt16(length));
307 }
308
309 content_type::ALERT => {
310 fields.insert("record_type", FieldValue::Str("Alert"));
311 }
312
313 content_type::CHANGE_CIPHER_SPEC => {
314 fields.insert("record_type", FieldValue::Str("ChangeCipherSpec"));
315 }
316
317 _ => {
318 return StreamParseResult::NotThisProtocol;
319 }
320 }
321
322 let message = ParsedMessage {
323 protocol: "tls",
324 connection_id: context.connection_id,
325 message_id: context.messages_parsed as u32,
326 direction: context.direction,
327 frame_number: 0,
328 fields,
329 };
330
331 StreamParseResult::Complete {
332 messages: vec![message],
333 bytes_consumed: record_len,
334 }
335 }
336
337 fn message_schema(&self) -> Vec<FieldDescriptor> {
338 vec![
339 FieldDescriptor::new("connection_id", DataKind::UInt64),
340 FieldDescriptor::new("record_type", DataKind::String).set_nullable(true),
341 FieldDescriptor::new("version", DataKind::String).set_nullable(true),
342 FieldDescriptor::new("handshake_type", DataKind::String).set_nullable(true),
343 FieldDescriptor::new("sni", DataKind::String).set_nullable(true),
344 FieldDescriptor::new("alpn", DataKind::String).set_nullable(true),
345 FieldDescriptor::new("cipher_suite", DataKind::UInt16).set_nullable(true),
346 FieldDescriptor::new("cipher_suite_name", DataKind::String).set_nullable(true),
347 ]
348 }
349}
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354 use crate::stream::Direction;
355 use std::net::Ipv4Addr;
356
357 fn test_context() -> StreamContext {
358 StreamContext {
359 connection_id: 1,
360 direction: Direction::ToServer,
361 src_ip: std::net::IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
362 dst_ip: std::net::IpAddr::V4(Ipv4Addr::new(192, 168, 1, 2)),
363 src_port: 54321,
364 dst_port: 443,
365 bytes_parsed: 0,
366 messages_parsed: 0,
367 alpn: None,
368 }
369 }
370
371 #[test]
373 fn test_record_header() {
374 let header = TlsStreamParser::parse_record_header(&[22, 3, 3, 0, 5]);
375 assert_eq!(header, Some((22, 0x0303, 5)));
376 }
377
378 #[test]
380 fn test_client_hello_parsing() {
381 let parser = TlsStreamParser::new();
382
383 let mut hs_body = Vec::new();
385 hs_body.extend_from_slice(&[3, 3]); hs_body.extend_from_slice(&[0u8; 32]); hs_body.push(0); hs_body.extend_from_slice(&[0, 2, 0, 0]); hs_body.push(1); hs_body.push(0); hs_body.extend_from_slice(&[0, 0]); let hs_len = hs_body.len();
394 let record_len = 1 + 3 + hs_len; let mut record = vec![
397 22, 3,
399 3, (record_len >> 8) as u8, (record_len & 0xff) as u8, 1, 0, (hs_len >> 8) as u8, (hs_len & 0xff) as u8, ];
407 record.extend_from_slice(&hs_body);
408
409 let result = parser.parse_stream(&record, &test_context());
410 match result {
411 StreamParseResult::Complete { messages, .. } => {
412 assert!(messages[0].fields.contains_key("handshake_type"));
413 }
414 _ => panic!("Expected Complete"),
415 }
416 }
417
418 #[test]
420 fn test_server_hello() {
421 let parser = TlsStreamParser::new();
422
423 let mut hs_body = Vec::new();
425 hs_body.extend_from_slice(&[3, 3]); hs_body.extend_from_slice(&[0u8; 32]); hs_body.push(0); hs_body.extend_from_slice(&[0xc0, 0x2f]); hs_body.push(0); let hs_len = hs_body.len();
432 let record_len = 1 + 3 + hs_len; let mut record = vec![
435 22, 3,
437 3, (record_len >> 8) as u8, (record_len & 0xff) as u8, 2, 0, (hs_len >> 8) as u8, (hs_len & 0xff) as u8, ];
445 record.extend_from_slice(&hs_body);
446
447 let mut ctx = test_context();
448 ctx.direction = Direction::ToClient;
449
450 let result = parser.parse_stream(&record, &ctx);
451 match result {
452 StreamParseResult::Complete { messages, .. } => {
453 assert!(messages[0].fields.contains_key("cipher_suite"));
454 }
455 _ => panic!("Expected Complete"),
456 }
457 }
458
459 #[test]
461 fn test_certificate_record() {
462 let parser = TlsStreamParser::new();
463
464 let record = vec![
465 22, 3, 3, 0, 4, 11, 0, 0, 0, ];
469
470 let result = parser.parse_stream(&record, &test_context());
471 match result {
472 StreamParseResult::Complete { .. } => {}
473 _ => panic!("Expected Complete"),
474 }
475 }
476
477 #[test]
479 fn test_incomplete_record() {
480 let parser = TlsStreamParser::new();
481
482 let record = vec![22, 3, 3, 0, 100, 1, 2, 3, 4, 5];
484
485 let result = parser.parse_stream(&record, &test_context());
486 match result {
487 StreamParseResult::NeedMore { minimum_bytes } => {
488 assert_eq!(minimum_bytes, Some(105)); }
490 _ => panic!("Expected NeedMore"),
491 }
492 }
493
494 #[test]
496 fn test_application_data() {
497 let parser = TlsStreamParser::new();
498
499 let record = vec![
500 23, 3, 3, 0, 10, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ];
505
506 let result = parser.parse_stream(&record, &test_context());
507 match result {
508 StreamParseResult::Complete {
509 messages,
510 bytes_consumed,
511 } => {
512 assert_eq!(bytes_consumed, 15);
513 assert_eq!(
514 messages[0].fields.get("record_type"),
515 Some(&FieldValue::Str("ApplicationData"))
516 );
517 }
518 _ => panic!("Expected Complete"),
519 }
520 }
521}