1use std::net::IpAddr;
2use std::sync::Arc;
3
4use crate::error::Error;
5use crate::tls::KeyLog;
6
7use super::{
8 parsers::DecryptingTlsStreamParser, Connection, ConnectionTracker, Direction, ParsedMessage,
9 StreamContext, StreamParseResult, StreamRegistry, TcpFlags, TcpReassembler,
10};
11
12#[derive(Debug, Clone)]
14pub struct StreamConfig {
15 pub max_connection_buffer: usize,
17 pub max_total_memory: usize,
19 pub connection_timeout_us: i64,
21}
22
23impl Default for StreamConfig {
24 fn default() -> Self {
25 Self {
26 max_connection_buffer: 16 * 1024 * 1024, max_total_memory: 1024 * 1024 * 1024, connection_timeout_us: 300_000_000, }
30 }
31}
32
33pub struct StreamManager {
35 connections: ConnectionTracker,
36 reassembler: TcpReassembler,
37 stream_registry: StreamRegistry,
38 config: StreamConfig,
39 total_memory: usize,
41 keylog: Option<Arc<KeyLog>>,
43}
44
45impl StreamManager {
46 pub fn new(config: StreamConfig) -> Self {
47 Self {
48 connections: ConnectionTracker::new(),
49 reassembler: TcpReassembler::new(),
50 stream_registry: StreamRegistry::new(),
51 config,
52 total_memory: 0,
53 keylog: None,
54 }
55 }
56
57 pub fn with_defaults() -> Self {
59 Self::new(StreamConfig::default())
60 }
61
62 pub fn with_keylog(mut self, keylog: KeyLog) -> Self {
81 let keylog = Arc::new(keylog);
82 self.keylog = Some(Arc::clone(&keylog));
83
84 let parser = DecryptingTlsStreamParser::with_keylog(keylog);
86 self.stream_registry.register(parser);
87
88 self
89 }
90
91 pub fn has_keylog(&self) -> bool {
93 self.keylog.is_some()
94 }
95
96 pub fn keylog(&self) -> Option<&KeyLog> {
98 self.keylog.as_ref().map(|k| k.as_ref())
99 }
100
101 pub fn registry_mut(&mut self) -> &mut StreamRegistry {
103 &mut self.stream_registry
104 }
105
106 #[allow(clippy::too_many_arguments)]
110 pub fn process_segment(
111 &mut self,
112 src_ip: IpAddr,
113 dst_ip: IpAddr,
114 src_port: u16,
115 dst_port: u16,
116 seq: u32,
117 _ack: u32,
118 flags: TcpFlags,
119 payload: &[u8],
120 frame_number: u64,
121 timestamp: i64,
122 ) -> Result<Vec<ParsedMessage>, Error> {
123 let mut messages = Vec::new();
124
125 let (conn, direction) = self.connections.get_or_create(
127 src_ip,
128 src_port,
129 dst_ip,
130 dst_port,
131 flags,
132 seq,
133 frame_number,
134 timestamp,
135 );
136 let connection_id = conn.id;
137
138 ConnectionTracker::update_state(conn, flags, direction, seq);
140
141 if flags.syn {
143 let buffer = self.reassembler.get_or_create(connection_id, direction);
144 buffer.set_initial_seq(seq);
145 }
146
147 if !payload.is_empty() {
149 ConnectionTracker::add_bytes(conn, direction, payload.len());
150 self.reassembler.add_segment(
151 connection_id,
152 direction,
153 seq,
154 payload,
155 frame_number,
156 timestamp,
157 );
158 self.total_memory += payload.len();
159 }
160
161 self.try_parse(connection_id, direction, frame_number, &mut messages)?;
163
164 if flags.fin {
166 self.reassembler.mark_fin(connection_id, direction);
167 }
168
169 if flags.rst || (flags.fin && self.is_fully_closed(connection_id)) {
171 self.finalize_connection(connection_id, &mut messages)?;
172 }
173
174 Ok(messages)
175 }
176
177 fn try_parse(
179 &mut self,
180 connection_id: u64,
181 direction: Direction,
182 frame_number: u64,
183 messages: &mut Vec<ParsedMessage>,
184 ) -> Result<(), Error> {
185 loop {
186 let data = self.reassembler.get_contiguous(connection_id, direction);
187 if data.is_empty() {
188 break;
189 }
190
191 let context = self.build_context(connection_id, direction);
193
194 let parser = match self.stream_registry.find_parser(&context) {
196 Some(p) => p,
197 None => break, };
199
200 let data_copy = data.to_vec();
202
203 match parser.parse_stream(&data_copy, &context) {
205 StreamParseResult::Complete {
206 messages: msgs,
207 bytes_consumed,
208 } => {
209 for mut msg in msgs {
210 msg.frame_number = frame_number;
211 messages.push(msg);
212 }
213 self.reassembler
214 .consume(connection_id, direction, bytes_consumed);
215 self.total_memory = self.total_memory.saturating_sub(bytes_consumed);
216 }
218
219 StreamParseResult::Transform {
220 child_protocol,
221 child_data,
222 bytes_consumed,
223 metadata,
224 } => {
225 if let Some(mut meta) = metadata {
226 meta.frame_number = frame_number;
227 messages.push(meta);
228 }
229 self.reassembler
230 .consume(connection_id, direction, bytes_consumed);
231 self.total_memory = self.total_memory.saturating_sub(bytes_consumed);
232
233 self.parse_transformed(
235 connection_id,
236 direction,
237 child_protocol,
238 &child_data,
239 frame_number,
240 messages,
241 )?;
242 }
243
244 StreamParseResult::NeedMore { .. } => {
245 break; }
247
248 StreamParseResult::NotThisProtocol => {
249 break; }
251
252 StreamParseResult::Error { skip_bytes, .. } => {
253 if let Some(skip) = skip_bytes {
254 self.reassembler.consume(connection_id, direction, skip);
255 self.total_memory = self.total_memory.saturating_sub(skip);
256 } else {
257 break;
258 }
259 }
260 }
261 }
262
263 Ok(())
264 }
265
266 fn parse_transformed(
268 &self,
269 connection_id: u64,
270 direction: Direction,
271 child_protocol: &str,
272 data: &[u8],
273 frame_number: u64,
274 messages: &mut Vec<ParsedMessage>,
275 ) -> Result<(), Error> {
276 let parser = match self.stream_registry.get_parser(child_protocol) {
277 Some(p) => p,
278 None => return Ok(()), };
280
281 let context = self.build_context(connection_id, direction);
282
283 if let StreamParseResult::Complete { messages: msgs, .. } =
284 parser.parse_stream(data, &context)
285 {
286 for mut msg in msgs {
287 msg.frame_number = frame_number;
288 messages.push(msg);
289 }
290 }
291
292 Ok(())
293 }
294
295 fn build_context(&self, connection_id: u64, direction: Direction) -> StreamContext {
297 let conn = self
298 .connections
299 .connections()
300 .find(|c| c.id == connection_id);
301
302 if let Some(conn) = conn {
303 let (src_ip, dst_ip, src_port, dst_port) = match direction {
304 Direction::ToServer => (
305 conn.client_ip(),
306 conn.server_ip(),
307 conn.client_port(),
308 conn.server_port(),
309 ),
310 Direction::ToClient => (
311 conn.server_ip(),
312 conn.client_ip(),
313 conn.server_port(),
314 conn.client_port(),
315 ),
316 };
317
318 StreamContext {
319 connection_id,
320 direction,
321 src_ip,
322 dst_ip,
323 src_port,
324 dst_port,
325 bytes_parsed: 0, messages_parsed: 0,
327 alpn: None, }
329 } else {
330 StreamContext {
332 connection_id,
333 direction,
334 src_ip: std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED),
335 dst_ip: std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED),
336 src_port: 0,
337 dst_port: 0,
338 bytes_parsed: 0,
339 messages_parsed: 0,
340 alpn: None,
341 }
342 }
343 }
344
345 fn is_fully_closed(&self, connection_id: u64) -> bool {
347 self.reassembler
348 .is_complete(connection_id, Direction::ToServer)
349 && self
350 .reassembler
351 .is_complete(connection_id, Direction::ToClient)
352 }
353
354 #[allow(clippy::ptr_arg)]
356 fn finalize_connection(
357 &mut self,
358 connection_id: u64,
359 _messages: &mut Vec<ParsedMessage>,
360 ) -> Result<(), Error> {
361 self.reassembler.remove(connection_id);
363 Ok(())
364 }
365
366 pub fn cleanup_timeout(&mut self, current_time: i64) -> Vec<Connection> {
368 let removed = self
369 .connections
370 .cleanup_timeout(current_time, self.config.connection_timeout_us);
371
372 for conn in &removed {
373 self.reassembler.remove(conn.id);
374 }
375
376 removed
377 }
378
379 pub fn connections(&self) -> impl Iterator<Item = &Connection> {
381 self.connections.connections()
382 }
383
384 pub fn total_memory(&self) -> usize {
386 self.total_memory
387 }
388
389 pub fn memory_limit_exceeded(&self) -> bool {
391 self.total_memory > self.config.max_total_memory
392 }
393}
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398 use std::net::Ipv4Addr;
399
400 fn ip(a: u8, b: u8, c: u8, d: u8) -> IpAddr {
401 IpAddr::V4(Ipv4Addr::new(a, b, c, d))
402 }
403
404 #[test]
406 fn test_process_segment() {
407 let mut manager = StreamManager::with_defaults();
408
409 let flags = TcpFlags {
410 syn: true,
411 ..Default::default()
412 };
413 let result = manager.process_segment(
414 ip(192, 168, 1, 1),
415 ip(192, 168, 1, 2),
416 54321,
417 80,
418 1000,
419 0,
420 flags,
421 b"",
422 1,
423 0,
424 );
425
426 assert!(result.is_ok());
427 assert_eq!(manager.connections().count(), 1);
428 }
429
430 #[test]
432 fn test_connection_on_syn() {
433 let mut manager = StreamManager::with_defaults();
434
435 let syn = TcpFlags {
436 syn: true,
437 ..Default::default()
438 };
439 manager
440 .process_segment(
441 ip(192, 168, 1, 1),
442 ip(192, 168, 1, 2),
443 54321,
444 80,
445 1000,
446 0,
447 syn,
448 b"",
449 1,
450 0,
451 )
452 .unwrap();
453
454 let conn = manager.connections().next().unwrap();
455 assert_eq!(conn.client_port(), 54321);
456 assert_eq!(conn.server_port(), 80);
457 }
458
459 #[test]
461 fn test_reassembly_triggers_parse() {
462 let mut manager = StreamManager::with_defaults();
463
464 let ack = TcpFlags {
466 ack: true,
467 ..Default::default()
468 };
469 manager
470 .process_segment(
471 ip(192, 168, 1, 1),
472 ip(192, 168, 1, 2),
473 54321,
474 80,
475 1000,
476 0,
477 ack,
478 b"GET / HTTP/1.1\r\n",
479 1,
480 0,
481 )
482 .unwrap();
483
484 assert!(manager.total_memory() > 0);
486 }
487
488 #[test]
490 fn test_need_more_handling() {
491 let manager = StreamManager::with_defaults();
494 assert_eq!(manager.connections().count(), 0);
495 }
496
497 #[test]
499 fn test_complete_handling() {
500 let manager = StreamManager::with_defaults();
502 assert!(manager.total_memory() == 0);
503 }
504
505 #[test]
507 fn test_memory_tracking() {
508 let config = StreamConfig {
509 max_total_memory: 1000,
510 ..Default::default()
511 };
512 let mut manager = StreamManager::new(config);
513
514 let ack = TcpFlags {
515 ack: true,
516 ..Default::default()
517 };
518
519 manager
521 .process_segment(
522 ip(192, 168, 1, 1),
523 ip(192, 168, 1, 2),
524 54321,
525 80,
526 1000,
527 0,
528 ack,
529 &[0u8; 500],
530 1,
531 0,
532 )
533 .unwrap();
534
535 assert_eq!(manager.total_memory(), 500);
536 assert!(!manager.memory_limit_exceeded());
537
538 manager
540 .process_segment(
541 ip(192, 168, 1, 1),
542 ip(192, 168, 1, 2),
543 54321,
544 80,
545 1500,
546 0,
547 ack,
548 &[0u8; 600],
549 2,
550 1,
551 )
552 .unwrap();
553
554 assert!(manager.memory_limit_exceeded());
555 }
556
557 #[test]
559 fn test_connection_cleanup() {
560 let config = StreamConfig {
561 connection_timeout_us: 1000,
562 ..Default::default()
563 };
564 let mut manager = StreamManager::new(config);
565
566 let syn = TcpFlags {
567 syn: true,
568 ..Default::default()
569 };
570 manager
571 .process_segment(
572 ip(192, 168, 1, 1),
573 ip(192, 168, 1, 2),
574 54321,
575 80,
576 1000,
577 0,
578 syn,
579 b"",
580 1,
581 0,
582 )
583 .unwrap();
584
585 assert_eq!(manager.connections().count(), 1);
586
587 let removed = manager.cleanup_timeout(10000);
589 assert_eq!(removed.len(), 1);
590 assert_eq!(manager.connections().count(), 0);
591 }
592
593 #[test]
595 fn test_multiple_connections() {
596 let mut manager = StreamManager::with_defaults();
597
598 let syn = TcpFlags {
599 syn: true,
600 ..Default::default()
601 };
602
603 manager
605 .process_segment(
606 ip(192, 168, 1, 1),
607 ip(192, 168, 1, 2),
608 54321,
609 80,
610 1000,
611 0,
612 syn,
613 b"",
614 1,
615 0,
616 )
617 .unwrap();
618
619 manager
621 .process_segment(
622 ip(192, 168, 1, 3),
623 ip(192, 168, 1, 4),
624 54322,
625 443,
626 2000,
627 0,
628 syn,
629 b"",
630 2,
631 1,
632 )
633 .unwrap();
634
635 assert_eq!(manager.connections().count(), 2);
636 }
637
638 #[test]
640 fn test_with_keylog() {
641 let keylog = KeyLog::new();
642 let manager = StreamManager::new(StreamConfig::default()).with_keylog(keylog);
643
644 assert!(manager.has_keylog());
645 assert!(manager.keylog().is_some());
646
647 let parser_names: Vec<_> = manager.stream_registry.parser_names().into_iter().collect();
649 assert!(parser_names.contains(&"tls_decrypt"));
650 }
651
652 #[test]
654 fn test_without_keylog() {
655 let manager = StreamManager::with_defaults();
656
657 assert!(!manager.has_keylog());
658 assert!(manager.keylog().is_none());
659 }
660
661 #[test]
663 fn test_tls_parser_registered() {
664 let keylog = KeyLog::new();
665 let manager = StreamManager::new(StreamConfig::default()).with_keylog(keylog);
666
667 let ctx = StreamContext {
669 connection_id: 1,
670 direction: Direction::ToServer,
671 src_ip: ip(192, 168, 1, 1),
672 dst_ip: ip(192, 168, 1, 2),
673 src_port: 54321,
674 dst_port: 443,
675 bytes_parsed: 0,
676 messages_parsed: 0,
677 alpn: None,
678 };
679
680 let parser = manager.stream_registry.find_parser(&ctx);
682 assert!(parser.is_some());
683 assert_eq!(parser.unwrap().name(), "tls_decrypt");
684 }
685}