pcapsql_core/stream/parsers/
tls.rs

1use std::collections::HashMap;
2
3use compact_str::CompactString;
4
5use crate::protocol::{FieldValue, OwnedFieldValue};
6use crate::schema::{DataKind, FieldDescriptor};
7use crate::stream::{ParsedMessage, StreamContext, StreamParseResult, StreamParser};
8
9/// TLS content types.
10mod content_type {
11    pub const CHANGE_CIPHER_SPEC: u8 = 20;
12    pub const ALERT: u8 = 21;
13    pub const HANDSHAKE: u8 = 22;
14    pub const APPLICATION_DATA: u8 = 23;
15}
16
17/// TLS handshake types.
18mod handshake_type {
19    pub const CLIENT_HELLO: u8 = 1;
20    pub const SERVER_HELLO: u8 = 2;
21    #[allow(dead_code)] // RFC 5246 constant
22    pub const CERTIFICATE: u8 = 11;
23    #[allow(dead_code)] // RFC 5246 constant
24    pub const SERVER_KEY_EXCHANGE: u8 = 12;
25    #[allow(dead_code)] // RFC 5246 constant
26    pub const CERTIFICATE_REQUEST: u8 = 13;
27    #[allow(dead_code)] // RFC 5246 constant
28    pub const SERVER_HELLO_DONE: u8 = 14;
29    #[allow(dead_code)] // RFC 5246 constant
30    pub const CERTIFICATE_VERIFY: u8 = 15;
31    #[allow(dead_code)] // RFC 5246 constant
32    pub const CLIENT_KEY_EXCHANGE: u8 = 16;
33    #[allow(dead_code)] // RFC 5246 constant
34    pub const FINISHED: u8 = 20;
35}
36
37/// TLS stream parser (metadata extraction only, no decryption).
38#[derive(Debug, Clone, Copy, Default)]
39pub struct TlsStreamParser;
40
41impl TlsStreamParser {
42    pub fn new() -> Self {
43        Self
44    }
45
46    /// Parse a TLS record header.
47    fn parse_record_header(data: &[u8]) -> Option<(u8, u16, u16)> {
48        if data.len() < 5 {
49            return None;
50        }
51        let content_type = data[0];
52        let version = u16::from_be_bytes([data[1], data[2]]);
53        let length = u16::from_be_bytes([data[3], data[4]]);
54        Some((content_type, version, length))
55    }
56
57    /// Extract SNI from ClientHello extension.
58    fn extract_sni(extensions: &[u8]) -> Option<String> {
59        let mut pos = 0;
60        while pos + 4 <= extensions.len() {
61            let ext_type = u16::from_be_bytes([extensions[pos], extensions[pos + 1]]);
62            let ext_len = u16::from_be_bytes([extensions[pos + 2], extensions[pos + 3]]) as usize;
63            pos += 4;
64
65            if pos + ext_len > extensions.len() {
66                break;
67            }
68
69            if ext_type == 0 {
70                // SNI extension
71                let ext_data = &extensions[pos..pos + ext_len];
72                if ext_data.len() >= 5 {
73                    let name_len = u16::from_be_bytes([ext_data[3], ext_data[4]]) as usize;
74                    if ext_data.len() >= 5 + name_len {
75                        if let Ok(sni) = std::str::from_utf8(&ext_data[5..5 + name_len]) {
76                            return Some(sni.to_string());
77                        }
78                    }
79                }
80            }
81
82            pos += ext_len;
83        }
84        None
85    }
86
87    /// Extract ALPN from extensions.
88    fn extract_alpn(extensions: &[u8]) -> Option<String> {
89        let mut pos = 0;
90        while pos + 4 <= extensions.len() {
91            let ext_type = u16::from_be_bytes([extensions[pos], extensions[pos + 1]]);
92            let ext_len = u16::from_be_bytes([extensions[pos + 2], extensions[pos + 3]]) as usize;
93            pos += 4;
94
95            if pos + ext_len > extensions.len() {
96                break;
97            }
98
99            if ext_type == 16 {
100                // ALPN extension
101                let ext_data = &extensions[pos..pos + ext_len];
102                if ext_data.len() >= 3 {
103                    let proto_len = ext_data[2] as usize;
104                    if ext_data.len() >= 3 + proto_len {
105                        if let Ok(alpn) = std::str::from_utf8(&ext_data[3..3 + proto_len]) {
106                            return Some(alpn.to_string());
107                        }
108                    }
109                }
110            }
111
112            pos += ext_len;
113        }
114        None
115    }
116
117    /// Parse ClientHello message.
118    fn parse_client_hello(&self, data: &[u8]) -> HashMap<&'static str, OwnedFieldValue> {
119        let mut fields = HashMap::new();
120        fields.insert("handshake_type", FieldValue::Str("ClientHello"));
121
122        if data.len() < 38 {
123            return fields;
124        }
125
126        // Client version (2 bytes)
127        let version = u16::from_be_bytes([data[0], data[1]]);
128        fields.insert("client_version", FieldValue::UInt16(version));
129
130        // Skip random (32 bytes) and session ID
131        let mut pos = 34;
132        if pos >= data.len() {
133            return fields;
134        }
135        let session_id_len = data[pos] as usize;
136        pos += 1 + session_id_len;
137
138        // Cipher suites
139        if pos + 2 > data.len() {
140            return fields;
141        }
142        let cipher_suites_len = u16::from_be_bytes([data[pos], data[pos + 1]]) as usize;
143        pos += 2;
144
145        if pos + cipher_suites_len > data.len() {
146            return fields;
147        }
148        let cipher_count = cipher_suites_len / 2;
149        fields.insert(
150            "cipher_suite_count",
151            FieldValue::UInt16(cipher_count as u16),
152        );
153        pos += cipher_suites_len;
154
155        // Skip compression methods
156        if pos >= data.len() {
157            return fields;
158        }
159        let comp_len = data[pos] as usize;
160        pos += 1 + comp_len;
161
162        // Extensions
163        if pos + 2 > data.len() {
164            return fields;
165        }
166        let ext_len = u16::from_be_bytes([data[pos], data[pos + 1]]) as usize;
167        pos += 2;
168
169        if pos + ext_len <= data.len() {
170            let extensions = &data[pos..pos + ext_len];
171            if let Some(sni) = Self::extract_sni(extensions) {
172                fields.insert("sni", FieldValue::OwnedString(CompactString::new(sni)));
173            }
174            if let Some(alpn) = Self::extract_alpn(extensions) {
175                fields.insert("alpn", FieldValue::OwnedString(CompactString::new(alpn)));
176            }
177        }
178
179        fields
180    }
181
182    /// Parse ServerHello message.
183    fn parse_server_hello(&self, data: &[u8]) -> HashMap<&'static str, OwnedFieldValue> {
184        let mut fields = HashMap::new();
185        fields.insert("handshake_type", FieldValue::Str("ServerHello"));
186
187        if data.len() < 38 {
188            return fields;
189        }
190
191        // Server version
192        let version = u16::from_be_bytes([data[0], data[1]]);
193        fields.insert("server_version", FieldValue::UInt16(version));
194
195        // Skip random (32 bytes) and session ID
196        let mut pos = 34;
197        if pos >= data.len() {
198            return fields;
199        }
200        let session_id_len = data[pos] as usize;
201        pos += 1 + session_id_len;
202
203        // Selected cipher suite
204        if pos + 2 <= data.len() {
205            let cipher = u16::from_be_bytes([data[pos], data[pos + 1]]);
206            fields.insert("cipher_suite", FieldValue::UInt16(cipher));
207            fields.insert(
208                "cipher_suite_name",
209                FieldValue::OwnedString(CompactString::new(cipher_suite_name(cipher))),
210            );
211        }
212
213        fields
214    }
215
216    /// Get TLS version name.
217    fn version_name(version: u16) -> &'static str {
218        match version {
219            0x0300 => "SSL 3.0",
220            0x0301 => "TLS 1.0",
221            0x0302 => "TLS 1.1",
222            0x0303 => "TLS 1.2",
223            0x0304 => "TLS 1.3",
224            _ => "Unknown",
225        }
226    }
227}
228
229fn cipher_suite_name(id: u16) -> String {
230    match id {
231        0x1301 => "TLS_AES_128_GCM_SHA256".to_string(),
232        0x1302 => "TLS_AES_256_GCM_SHA384".to_string(),
233        0x1303 => "TLS_CHACHA20_POLY1305_SHA256".to_string(),
234        0xc02f => "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256".to_string(),
235        0xc030 => "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384".to_string(),
236        _ => format!("0x{id:04x}"),
237    }
238}
239
240impl StreamParser for TlsStreamParser {
241    fn name(&self) -> &'static str {
242        "tls"
243    }
244
245    fn display_name(&self) -> &'static str {
246        "TLS"
247    }
248
249    fn can_parse_stream(&self, context: &StreamContext) -> bool {
250        context.dst_port == 443 || context.src_port == 443
251    }
252
253    fn parse_stream(&self, data: &[u8], context: &StreamContext) -> StreamParseResult {
254        // Parse TLS record header
255        let (content_type, version, length) = match Self::parse_record_header(data) {
256            Some(header) => header,
257            None => {
258                return StreamParseResult::NeedMore {
259                    minimum_bytes: Some(5),
260                }
261            }
262        };
263
264        let record_len = 5 + length as usize;
265        if data.len() < record_len {
266            return StreamParseResult::NeedMore {
267                minimum_bytes: Some(record_len),
268            };
269        }
270
271        let mut fields = HashMap::new();
272        fields.insert("version", FieldValue::Str(Self::version_name(version)));
273        fields.insert("version_raw", FieldValue::UInt16(version));
274
275        match content_type {
276            content_type::HANDSHAKE => {
277                let handshake_data = &data[5..record_len];
278                if handshake_data.len() >= 4 {
279                    let hs_type = handshake_data[0];
280                    let hs_len = ((handshake_data[1] as usize) << 16)
281                        | ((handshake_data[2] as usize) << 8)
282                        | (handshake_data[3] as usize);
283
284                    if handshake_data.len() >= 4 + hs_len {
285                        let hs_body = &handshake_data[4..4 + hs_len];
286
287                        let hs_fields = match hs_type {
288                            handshake_type::CLIENT_HELLO => self.parse_client_hello(hs_body),
289                            handshake_type::SERVER_HELLO => self.parse_server_hello(hs_body),
290                            _ => {
291                                let mut f = HashMap::new();
292                                f.insert("handshake_type_id", FieldValue::UInt8(hs_type));
293                                f
294                            }
295                        };
296
297                        fields.extend(hs_fields);
298                    }
299                }
300
301                fields.insert("record_type", FieldValue::Str("Handshake"));
302            }
303
304            content_type::APPLICATION_DATA => {
305                fields.insert("record_type", FieldValue::Str("ApplicationData"));
306                fields.insert("encrypted_length", FieldValue::UInt16(length));
307            }
308
309            content_type::ALERT => {
310                fields.insert("record_type", FieldValue::Str("Alert"));
311            }
312
313            content_type::CHANGE_CIPHER_SPEC => {
314                fields.insert("record_type", FieldValue::Str("ChangeCipherSpec"));
315            }
316
317            _ => {
318                return StreamParseResult::NotThisProtocol;
319            }
320        }
321
322        let message = ParsedMessage {
323            protocol: "tls",
324            connection_id: context.connection_id,
325            message_id: context.messages_parsed as u32,
326            direction: context.direction,
327            frame_number: 0,
328            fields,
329        };
330
331        StreamParseResult::Complete {
332            messages: vec![message],
333            bytes_consumed: record_len,
334        }
335    }
336
337    fn message_schema(&self) -> Vec<FieldDescriptor> {
338        vec![
339            FieldDescriptor::new("connection_id", DataKind::UInt64),
340            FieldDescriptor::new("record_type", DataKind::String).set_nullable(true),
341            FieldDescriptor::new("version", DataKind::String).set_nullable(true),
342            FieldDescriptor::new("handshake_type", DataKind::String).set_nullable(true),
343            FieldDescriptor::new("sni", DataKind::String).set_nullable(true),
344            FieldDescriptor::new("alpn", DataKind::String).set_nullable(true),
345            FieldDescriptor::new("cipher_suite", DataKind::UInt16).set_nullable(true),
346            FieldDescriptor::new("cipher_suite_name", DataKind::String).set_nullable(true),
347        ]
348    }
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354    use crate::stream::Direction;
355    use std::net::Ipv4Addr;
356
357    fn test_context() -> StreamContext {
358        StreamContext {
359            connection_id: 1,
360            direction: Direction::ToServer,
361            src_ip: std::net::IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
362            dst_ip: std::net::IpAddr::V4(Ipv4Addr::new(192, 168, 1, 2)),
363            src_port: 54321,
364            dst_port: 443,
365            bytes_parsed: 0,
366            messages_parsed: 0,
367            alpn: None,
368        }
369    }
370
371    // Test 1: TLS record parsing
372    #[test]
373    fn test_record_header() {
374        let header = TlsStreamParser::parse_record_header(&[22, 3, 3, 0, 5]);
375        assert_eq!(header, Some((22, 0x0303, 5)));
376    }
377
378    // Test 2: ClientHello parsing (simplified)
379    #[test]
380    fn test_client_hello_parsing() {
381        let parser = TlsStreamParser::new();
382
383        // Build ClientHello body first to calculate lengths
384        let mut hs_body = Vec::new();
385        hs_body.extend_from_slice(&[3, 3]); // Version
386        hs_body.extend_from_slice(&[0u8; 32]); // Random
387        hs_body.push(0); // Session ID length
388        hs_body.extend_from_slice(&[0, 2, 0, 0]); // Cipher suites length (2) + 1 suite
389        hs_body.push(1); // Compression methods length
390        hs_body.push(0); // null compression
391        hs_body.extend_from_slice(&[0, 0]); // Extensions length
392
393        let hs_len = hs_body.len();
394        let record_len = 1 + 3 + hs_len; // type + length + body
395
396        let mut record = vec![
397            22, // Handshake
398            3,
399            3,                         // TLS 1.2
400            (record_len >> 8) as u8,   // Length high
401            (record_len & 0xff) as u8, // Length low
402            1,                         // ClientHello
403            0,                         // Handshake length high
404            (hs_len >> 8) as u8,       // Handshake length mid
405            (hs_len & 0xff) as u8,     // Handshake length low
406        ];
407        record.extend_from_slice(&hs_body);
408
409        let result = parser.parse_stream(&record, &test_context());
410        match result {
411            StreamParseResult::Complete { messages, .. } => {
412                assert!(messages[0].fields.contains_key("handshake_type"));
413            }
414            _ => panic!("Expected Complete"),
415        }
416    }
417
418    // Test 3: ServerHello parsing
419    #[test]
420    fn test_server_hello() {
421        let parser = TlsStreamParser::new();
422
423        // Build ServerHello body first to calculate lengths
424        let mut hs_body = Vec::new();
425        hs_body.extend_from_slice(&[3, 3]); // Version
426        hs_body.extend_from_slice(&[0u8; 32]); // Random
427        hs_body.push(0); // Session ID length
428        hs_body.extend_from_slice(&[0xc0, 0x2f]); // Cipher suite
429        hs_body.push(0); // Compression
430
431        let hs_len = hs_body.len();
432        let record_len = 1 + 3 + hs_len; // type + length + body
433
434        let mut record = vec![
435            22, // Handshake
436            3,
437            3,                         // TLS 1.2
438            (record_len >> 8) as u8,   // Length high
439            (record_len & 0xff) as u8, // Length low
440            2,                         // ServerHello
441            0,                         // Handshake length high
442            (hs_len >> 8) as u8,       // Handshake length mid
443            (hs_len & 0xff) as u8,     // Handshake length low
444        ];
445        record.extend_from_slice(&hs_body);
446
447        let mut ctx = test_context();
448        ctx.direction = Direction::ToClient;
449
450        let result = parser.parse_stream(&record, &ctx);
451        match result {
452            StreamParseResult::Complete { messages, .. } => {
453                assert!(messages[0].fields.contains_key("cipher_suite"));
454            }
455            _ => panic!("Expected Complete"),
456        }
457    }
458
459    // Test 4: Certificate record
460    #[test]
461    fn test_certificate_record() {
462        let parser = TlsStreamParser::new();
463
464        let record = vec![
465            22, 3, 3, 0, 4,  // Handshake record, 4 bytes
466            11, // Certificate type
467            0, 0, 0, // Length 0 (empty cert for test)
468        ];
469
470        let result = parser.parse_stream(&record, &test_context());
471        match result {
472            StreamParseResult::Complete { .. } => {}
473            _ => panic!("Expected Complete"),
474        }
475    }
476
477    // Test 5: Incomplete record (NeedMore)
478    #[test]
479    fn test_incomplete_record() {
480        let parser = TlsStreamParser::new();
481
482        // Record says 100 bytes but we only have 10
483        let record = vec![22, 3, 3, 0, 100, 1, 2, 3, 4, 5];
484
485        let result = parser.parse_stream(&record, &test_context());
486        match result {
487            StreamParseResult::NeedMore { minimum_bytes } => {
488                assert_eq!(minimum_bytes, Some(105)); // 5 header + 100 payload
489            }
490            _ => panic!("Expected NeedMore"),
491        }
492    }
493
494    // Test 6: Application data record
495    #[test]
496    fn test_application_data() {
497        let parser = TlsStreamParser::new();
498
499        let record = vec![
500            23, // ApplicationData
501            3, 3, // TLS 1.2
502            0, 10, // Length
503            0, 1, 2, 3, 4, 5, 6, 7, 8, 9, // Encrypted data
504        ];
505
506        let result = parser.parse_stream(&record, &test_context());
507        match result {
508            StreamParseResult::Complete {
509                messages,
510                bytes_consumed,
511            } => {
512                assert_eq!(bytes_consumed, 15);
513                assert_eq!(
514                    messages[0].fields.get("record_type"),
515                    Some(&FieldValue::Str("ApplicationData"))
516                );
517            }
518            _ => panic!("Expected Complete"),
519        }
520    }
521}