pcapsql_core/stream/
manager.rs

1use std::net::IpAddr;
2use std::sync::Arc;
3
4use crate::error::Error;
5use crate::tls::KeyLog;
6
7use super::{
8    parsers::DecryptingTlsStreamParser, Connection, ConnectionTracker, Direction, ParsedMessage,
9    StreamContext, StreamParseResult, StreamRegistry, TcpFlags, TcpReassembler,
10};
11
12/// Configuration for the StreamManager.
13#[derive(Debug, Clone)]
14pub struct StreamConfig {
15    /// Maximum memory per connection (bytes).
16    pub max_connection_buffer: usize,
17    /// Maximum total memory for all streams.
18    pub max_total_memory: usize,
19    /// Connection timeout (microseconds).
20    pub connection_timeout_us: i64,
21}
22
23impl Default for StreamConfig {
24    fn default() -> Self {
25        Self {
26            max_connection_buffer: 16 * 1024 * 1024, // 16 MB per connection
27            max_total_memory: 1024 * 1024 * 1024,    // 1 GB total
28            connection_timeout_us: 300_000_000,      // 5 minutes
29        }
30    }
31}
32
33/// Central orchestrator for TCP stream processing.
34pub struct StreamManager {
35    connections: ConnectionTracker,
36    reassembler: TcpReassembler,
37    stream_registry: StreamRegistry,
38    config: StreamConfig,
39    /// Current total memory usage.
40    total_memory: usize,
41    /// Optional keylog for TLS decryption.
42    keylog: Option<Arc<KeyLog>>,
43}
44
45impl StreamManager {
46    pub fn new(config: StreamConfig) -> Self {
47        Self {
48            connections: ConnectionTracker::new(),
49            reassembler: TcpReassembler::new(),
50            stream_registry: StreamRegistry::new(),
51            config,
52            total_memory: 0,
53            keylog: None,
54        }
55    }
56
57    /// Create with default config and register default parsers.
58    pub fn with_defaults() -> Self {
59        Self::new(StreamConfig::default())
60    }
61
62    /// Enable TLS decryption with the provided keylog.
63    ///
64    /// This registers a `DecryptingTlsStreamParser` that will attempt to
65    /// decrypt TLS application data when matching keys are found in the keylog.
66    ///
67    /// The keylog should be in SSLKEYLOGFILE format, as used by Wireshark
68    /// and browsers.
69    ///
70    /// # Example
71    ///
72    /// ```rust,no_run
73    /// use pcapsql_core::stream::{StreamConfig, StreamManager};
74    /// use pcapsql_core::tls::KeyLog;
75    ///
76    /// let keylog = KeyLog::from_file("sslkeylog.txt").unwrap();
77    /// let manager = StreamManager::new(StreamConfig::default())
78    ///     .with_keylog(keylog);
79    /// ```
80    pub fn with_keylog(mut self, keylog: KeyLog) -> Self {
81        let keylog = Arc::new(keylog);
82        self.keylog = Some(Arc::clone(&keylog));
83
84        // Register the decrypting TLS parser (before any other TLS parser)
85        let parser = DecryptingTlsStreamParser::with_keylog(keylog);
86        self.stream_registry.register(parser);
87
88        self
89    }
90
91    /// Check if TLS decryption is enabled.
92    pub fn has_keylog(&self) -> bool {
93        self.keylog.is_some()
94    }
95
96    /// Get the keylog if available.
97    pub fn keylog(&self) -> Option<&KeyLog> {
98        self.keylog.as_ref().map(|k| k.as_ref())
99    }
100
101    /// Get mutable access to the stream registry for parser registration.
102    pub fn registry_mut(&mut self) -> &mut StreamRegistry {
103        &mut self.stream_registry
104    }
105
106    /// Process a TCP segment.
107    ///
108    /// Returns any parsed messages.
109    #[allow(clippy::too_many_arguments)]
110    pub fn process_segment(
111        &mut self,
112        src_ip: IpAddr,
113        dst_ip: IpAddr,
114        src_port: u16,
115        dst_port: u16,
116        seq: u32,
117        _ack: u32,
118        flags: TcpFlags,
119        payload: &[u8],
120        frame_number: u64,
121        timestamp: i64,
122    ) -> Result<Vec<ParsedMessage>, Error> {
123        let mut messages = Vec::new();
124
125        // 1. Get or create connection
126        let (conn, direction) = self.connections.get_or_create(
127            src_ip,
128            src_port,
129            dst_ip,
130            dst_port,
131            flags,
132            seq,
133            frame_number,
134            timestamp,
135        );
136        let connection_id = conn.id;
137
138        // 2. Update connection state
139        ConnectionTracker::update_state(conn, flags, direction, seq);
140
141        // 3. Handle SYN (initial sequence number)
142        if flags.syn {
143            let buffer = self.reassembler.get_or_create(connection_id, direction);
144            buffer.set_initial_seq(seq);
145        }
146
147        // 4. Add payload to reassembler
148        if !payload.is_empty() {
149            ConnectionTracker::add_bytes(conn, direction, payload.len());
150            self.reassembler.add_segment(
151                connection_id,
152                direction,
153                seq,
154                payload,
155                frame_number,
156                timestamp,
157            );
158            self.total_memory += payload.len();
159        }
160
161        // 5. Try to parse reassembled data
162        self.try_parse(connection_id, direction, frame_number, &mut messages)?;
163
164        // 6. Handle FIN
165        if flags.fin {
166            self.reassembler.mark_fin(connection_id, direction);
167        }
168
169        // 7. Handle connection termination
170        if flags.rst || (flags.fin && self.is_fully_closed(connection_id)) {
171            self.finalize_connection(connection_id, &mut messages)?;
172        }
173
174        Ok(messages)
175    }
176
177    /// Try to parse data from a stream.
178    fn try_parse(
179        &mut self,
180        connection_id: u64,
181        direction: Direction,
182        frame_number: u64,
183        messages: &mut Vec<ParsedMessage>,
184    ) -> Result<(), Error> {
185        loop {
186            let data = self.reassembler.get_contiguous(connection_id, direction);
187            if data.is_empty() {
188                break;
189            }
190
191            // Build context for parser
192            let context = self.build_context(connection_id, direction);
193
194            // Find parser
195            let parser = match self.stream_registry.find_parser(&context) {
196                Some(p) => p,
197                None => break, // No parser for this stream
198            };
199
200            // We need to copy data because we can't hold borrow across mutable ops
201            let data_copy = data.to_vec();
202
203            // Parse
204            match parser.parse_stream(&data_copy, &context) {
205                StreamParseResult::Complete {
206                    messages: msgs,
207                    bytes_consumed,
208                } => {
209                    for mut msg in msgs {
210                        msg.frame_number = frame_number;
211                        messages.push(msg);
212                    }
213                    self.reassembler
214                        .consume(connection_id, direction, bytes_consumed);
215                    self.total_memory = self.total_memory.saturating_sub(bytes_consumed);
216                    // Continue loop - might be more messages
217                }
218
219                StreamParseResult::Transform {
220                    child_protocol,
221                    child_data,
222                    bytes_consumed,
223                    metadata,
224                } => {
225                    if let Some(mut meta) = metadata {
226                        meta.frame_number = frame_number;
227                        messages.push(meta);
228                    }
229                    self.reassembler
230                        .consume(connection_id, direction, bytes_consumed);
231                    self.total_memory = self.total_memory.saturating_sub(bytes_consumed);
232
233                    // Recursively parse transformed data
234                    self.parse_transformed(
235                        connection_id,
236                        direction,
237                        child_protocol,
238                        &child_data,
239                        frame_number,
240                        messages,
241                    )?;
242                }
243
244                StreamParseResult::NeedMore { .. } => {
245                    break; // Wait for more data
246                }
247
248                StreamParseResult::NotThisProtocol => {
249                    break; // Can't parse this stream
250                }
251
252                StreamParseResult::Error { skip_bytes, .. } => {
253                    if let Some(skip) = skip_bytes {
254                        self.reassembler.consume(connection_id, direction, skip);
255                        self.total_memory = self.total_memory.saturating_sub(skip);
256                    } else {
257                        break;
258                    }
259                }
260            }
261        }
262
263        Ok(())
264    }
265
266    /// Parse transformed/decrypted data with a child parser.
267    fn parse_transformed(
268        &self,
269        connection_id: u64,
270        direction: Direction,
271        child_protocol: &str,
272        data: &[u8],
273        frame_number: u64,
274        messages: &mut Vec<ParsedMessage>,
275    ) -> Result<(), Error> {
276        let parser = match self.stream_registry.get_parser(child_protocol) {
277            Some(p) => p,
278            None => return Ok(()), // No parser for child protocol
279        };
280
281        let context = self.build_context(connection_id, direction);
282
283        if let StreamParseResult::Complete { messages: msgs, .. } =
284            parser.parse_stream(data, &context)
285        {
286            for mut msg in msgs {
287                msg.frame_number = frame_number;
288                messages.push(msg);
289            }
290        }
291
292        Ok(())
293    }
294
295    /// Build a StreamContext for the given connection.
296    fn build_context(&self, connection_id: u64, direction: Direction) -> StreamContext {
297        let conn = self
298            .connections
299            .connections()
300            .find(|c| c.id == connection_id);
301
302        if let Some(conn) = conn {
303            let (src_ip, dst_ip, src_port, dst_port) = match direction {
304                Direction::ToServer => (
305                    conn.client_ip(),
306                    conn.server_ip(),
307                    conn.client_port(),
308                    conn.server_port(),
309                ),
310                Direction::ToClient => (
311                    conn.server_ip(),
312                    conn.client_ip(),
313                    conn.server_port(),
314                    conn.client_port(),
315                ),
316            };
317
318            StreamContext {
319                connection_id,
320                direction,
321                src_ip,
322                dst_ip,
323                src_port,
324                dst_port,
325                bytes_parsed: 0, // Could track this
326                messages_parsed: 0,
327                alpn: None, // Set by TLS parser
328            }
329        } else {
330            // Fallback - shouldn't happen
331            StreamContext {
332                connection_id,
333                direction,
334                src_ip: std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED),
335                dst_ip: std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED),
336                src_port: 0,
337                dst_port: 0,
338                bytes_parsed: 0,
339                messages_parsed: 0,
340                alpn: None,
341            }
342        }
343    }
344
345    /// Check if connection is fully closed (both sides FIN'd).
346    fn is_fully_closed(&self, connection_id: u64) -> bool {
347        self.reassembler
348            .is_complete(connection_id, Direction::ToServer)
349            && self
350                .reassembler
351                .is_complete(connection_id, Direction::ToClient)
352    }
353
354    /// Finalize a closed connection.
355    #[allow(clippy::ptr_arg)]
356    fn finalize_connection(
357        &mut self,
358        connection_id: u64,
359        _messages: &mut Vec<ParsedMessage>,
360    ) -> Result<(), Error> {
361        // Clean up reassembly buffers
362        self.reassembler.remove(connection_id);
363        Ok(())
364    }
365
366    /// Cleanup timed-out connections.
367    pub fn cleanup_timeout(&mut self, current_time: i64) -> Vec<Connection> {
368        let removed = self
369            .connections
370            .cleanup_timeout(current_time, self.config.connection_timeout_us);
371
372        for conn in &removed {
373            self.reassembler.remove(conn.id);
374        }
375
376        removed
377    }
378
379    /// Get all tracked connections.
380    pub fn connections(&self) -> impl Iterator<Item = &Connection> {
381        self.connections.connections()
382    }
383
384    /// Get total memory usage.
385    pub fn total_memory(&self) -> usize {
386        self.total_memory
387    }
388
389    /// Check if memory limit is exceeded.
390    pub fn memory_limit_exceeded(&self) -> bool {
391        self.total_memory > self.config.max_total_memory
392    }
393}
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398    use std::net::Ipv4Addr;
399
400    fn ip(a: u8, b: u8, c: u8, d: u8) -> IpAddr {
401        IpAddr::V4(Ipv4Addr::new(a, b, c, d))
402    }
403
404    // Test 1: Basic segment processing
405    #[test]
406    fn test_process_segment() {
407        let mut manager = StreamManager::with_defaults();
408
409        let flags = TcpFlags {
410            syn: true,
411            ..Default::default()
412        };
413        let result = manager.process_segment(
414            ip(192, 168, 1, 1),
415            ip(192, 168, 1, 2),
416            54321,
417            80,
418            1000,
419            0,
420            flags,
421            b"",
422            1,
423            0,
424        );
425
426        assert!(result.is_ok());
427        assert_eq!(manager.connections().count(), 1);
428    }
429
430    // Test 2: Connection creation on SYN
431    #[test]
432    fn test_connection_on_syn() {
433        let mut manager = StreamManager::with_defaults();
434
435        let syn = TcpFlags {
436            syn: true,
437            ..Default::default()
438        };
439        manager
440            .process_segment(
441                ip(192, 168, 1, 1),
442                ip(192, 168, 1, 2),
443                54321,
444                80,
445                1000,
446                0,
447                syn,
448                b"",
449                1,
450                0,
451            )
452            .unwrap();
453
454        let conn = manager.connections().next().unwrap();
455        assert_eq!(conn.client_port(), 54321);
456        assert_eq!(conn.server_port(), 80);
457    }
458
459    // Test 3: Reassembly triggers parser (mock test - no real parser)
460    #[test]
461    fn test_reassembly_triggers_parse() {
462        let mut manager = StreamManager::with_defaults();
463
464        // No parser registered, but data should be buffered
465        let ack = TcpFlags {
466            ack: true,
467            ..Default::default()
468        };
469        manager
470            .process_segment(
471                ip(192, 168, 1, 1),
472                ip(192, 168, 1, 2),
473                54321,
474                80,
475                1000,
476                0,
477                ack,
478                b"GET / HTTP/1.1\r\n",
479                1,
480                0,
481            )
482            .unwrap();
483
484        // Data should be in buffer (no parser to consume it)
485        assert!(manager.total_memory() > 0);
486    }
487
488    // Test 4: Parser NeedMore handling
489    #[test]
490    fn test_need_more_handling() {
491        // This test verifies the loop exits on NeedMore
492        // Would need mock parser for full test
493        let manager = StreamManager::with_defaults();
494        assert_eq!(manager.connections().count(), 0);
495    }
496
497    // Test 5: Parser Complete handling
498    #[test]
499    fn test_complete_handling() {
500        // Would need mock parser
501        let manager = StreamManager::with_defaults();
502        assert!(manager.total_memory() == 0);
503    }
504
505    // Test 6: Memory limit tracking
506    #[test]
507    fn test_memory_tracking() {
508        let config = StreamConfig {
509            max_total_memory: 1000,
510            ..Default::default()
511        };
512        let mut manager = StreamManager::new(config);
513
514        let ack = TcpFlags {
515            ack: true,
516            ..Default::default()
517        };
518
519        // Add data
520        manager
521            .process_segment(
522                ip(192, 168, 1, 1),
523                ip(192, 168, 1, 2),
524                54321,
525                80,
526                1000,
527                0,
528                ack,
529                &[0u8; 500],
530                1,
531                0,
532            )
533            .unwrap();
534
535        assert_eq!(manager.total_memory(), 500);
536        assert!(!manager.memory_limit_exceeded());
537
538        // Exceed limit
539        manager
540            .process_segment(
541                ip(192, 168, 1, 1),
542                ip(192, 168, 1, 2),
543                54321,
544                80,
545                1500,
546                0,
547                ack,
548                &[0u8; 600],
549                2,
550                1,
551            )
552            .unwrap();
553
554        assert!(manager.memory_limit_exceeded());
555    }
556
557    // Test 7: Connection cleanup
558    #[test]
559    fn test_connection_cleanup() {
560        let config = StreamConfig {
561            connection_timeout_us: 1000,
562            ..Default::default()
563        };
564        let mut manager = StreamManager::new(config);
565
566        let syn = TcpFlags {
567            syn: true,
568            ..Default::default()
569        };
570        manager
571            .process_segment(
572                ip(192, 168, 1, 1),
573                ip(192, 168, 1, 2),
574                54321,
575                80,
576                1000,
577                0,
578                syn,
579                b"",
580                1,
581                0,
582            )
583            .unwrap();
584
585        assert_eq!(manager.connections().count(), 1);
586
587        // Cleanup after timeout
588        let removed = manager.cleanup_timeout(10000);
589        assert_eq!(removed.len(), 1);
590        assert_eq!(manager.connections().count(), 0);
591    }
592
593    // Test 8: Multiple concurrent connections
594    #[test]
595    fn test_multiple_connections() {
596        let mut manager = StreamManager::with_defaults();
597
598        let syn = TcpFlags {
599            syn: true,
600            ..Default::default()
601        };
602
603        // Connection 1
604        manager
605            .process_segment(
606                ip(192, 168, 1, 1),
607                ip(192, 168, 1, 2),
608                54321,
609                80,
610                1000,
611                0,
612                syn,
613                b"",
614                1,
615                0,
616            )
617            .unwrap();
618
619        // Connection 2
620        manager
621            .process_segment(
622                ip(192, 168, 1, 3),
623                ip(192, 168, 1, 4),
624                54322,
625                443,
626                2000,
627                0,
628                syn,
629                b"",
630                2,
631                1,
632            )
633            .unwrap();
634
635        assert_eq!(manager.connections().count(), 2);
636    }
637
638    // Test 9: StreamManager with keylog
639    #[test]
640    fn test_with_keylog() {
641        let keylog = KeyLog::new();
642        let manager = StreamManager::new(StreamConfig::default()).with_keylog(keylog);
643
644        assert!(manager.has_keylog());
645        assert!(manager.keylog().is_some());
646
647        // Should have the decrypting TLS parser registered
648        let parser_names: Vec<_> = manager.stream_registry.parser_names().into_iter().collect();
649        assert!(parser_names.contains(&"tls_decrypt"));
650    }
651
652    // Test 10: StreamManager without keylog
653    #[test]
654    fn test_without_keylog() {
655        let manager = StreamManager::with_defaults();
656
657        assert!(!manager.has_keylog());
658        assert!(manager.keylog().is_none());
659    }
660
661    // Test 11: TLS decryption parser is registered and prioritized
662    #[test]
663    fn test_tls_parser_registered() {
664        let keylog = KeyLog::new();
665        let manager = StreamManager::new(StreamConfig::default()).with_keylog(keylog);
666
667        // Create a context for port 443
668        let ctx = StreamContext {
669            connection_id: 1,
670            direction: Direction::ToServer,
671            src_ip: ip(192, 168, 1, 1),
672            dst_ip: ip(192, 168, 1, 2),
673            src_port: 54321,
674            dst_port: 443,
675            bytes_parsed: 0,
676            messages_parsed: 0,
677            alpn: None,
678        };
679
680        // Should find the decrypting TLS parser
681        let parser = manager.stream_registry.find_parser(&ctx);
682        assert!(parser.is_some());
683        assert_eq!(parser.unwrap().name(), "tls_decrypt");
684    }
685}