1use std::collections::{BTreeMap, HashMap};
2
3use super::Direction;
4
5#[derive(Debug, Clone)]
7pub struct Segment {
8 pub seq: u32,
9 pub data: Vec<u8>,
10 pub frame_number: u64,
11 pub timestamp: i64,
12}
13
14#[derive(Debug, Clone)]
16pub struct SequenceGap {
17 pub start_seq: u32,
18 pub end_seq: u32,
19}
20
21#[derive(Debug)]
23pub struct StreamBuffer {
24 expected_seq: u32,
26 initial_seq: Option<u32>,
28 initial_seq_from_syn: bool,
30 pending: BTreeMap<u32, Segment>,
32 reassembled: Vec<u8>,
34 gaps: Vec<SequenceGap>,
36 pub segment_count: u32,
38 pub retransmit_count: u32,
39 pub out_of_order_count: u32,
40 pub fin_received: bool,
42}
43
44impl StreamBuffer {
45 pub fn new() -> Self {
46 Self {
47 expected_seq: 0,
48 initial_seq: None,
49 initial_seq_from_syn: false,
50 pending: BTreeMap::new(),
51 reassembled: Vec::new(),
52 gaps: Vec::new(),
53 segment_count: 0,
54 retransmit_count: 0,
55 out_of_order_count: 0,
56 fin_received: false,
57 }
58 }
59
60 pub fn set_initial_seq(&mut self, seq: u32) {
62 self.initial_seq = Some(seq);
63 self.initial_seq_from_syn = true;
64 self.expected_seq = seq.wrapping_add(1); }
66
67 #[inline]
75 pub fn add_inorder_data(
76 &mut self,
77 seq: u32,
78 data: &[u8],
79 _frame_number: u64,
80 _timestamp: i64,
81 ) -> bool {
82 if self.initial_seq.is_some() && seq == self.expected_seq && self.pending.is_empty() {
85 self.segment_count += 1;
86 self.reassembled.extend_from_slice(data);
87 self.expected_seq = seq_add(seq, data.len());
88 true
89 } else {
90 false
91 }
92 }
93
94 pub fn add_segment(&mut self, segment: Segment) {
96 self.segment_count += 1;
97
98 if self.initial_seq.is_none() {
100 self.initial_seq = Some(segment.seq);
101 self.expected_seq = segment.seq;
102 }
103
104 let seg_end = seq_add(segment.seq, segment.data.len());
105
106 if seq_lt(segment.seq, self.expected_seq) {
108 if !self.initial_seq_from_syn && seq_lt(segment.seq, self.initial_seq.unwrap()) {
111 let old_initial = self.initial_seq.unwrap();
113 let old_data = std::mem::take(&mut self.reassembled);
114 if !old_data.is_empty() {
115 self.pending.insert(
116 old_initial,
117 Segment {
118 seq: old_initial,
119 data: old_data,
120 frame_number: 0,
121 timestamp: 0,
122 },
123 );
124 }
125 self.initial_seq = Some(segment.seq);
127 self.expected_seq = segment.seq;
128 self.add_segment_inner(segment);
129 return;
130 }
131
132 if seq_le(seg_end, self.expected_seq) {
134 self.retransmit_count += 1;
135 return;
136 }
137 let overlap = self.expected_seq.wrapping_sub(segment.seq) as usize;
139 if overlap < segment.data.len() {
140 let trimmed = Segment {
141 seq: self.expected_seq,
142 data: segment.data[overlap..].to_vec(),
143 frame_number: segment.frame_number,
144 timestamp: segment.timestamp,
145 };
146 self.add_segment_inner(trimmed);
147 }
148 return;
149 }
150
151 self.add_segment_inner(segment);
152 }
153
154 fn add_segment_inner(&mut self, segment: Segment) {
155 if segment.seq == self.expected_seq {
157 self.reassembled.extend_from_slice(&segment.data);
158 self.expected_seq = seq_add(segment.seq, segment.data.len());
159
160 self.flush_pending();
162 } else if seq_lt(self.expected_seq, segment.seq) {
163 self.out_of_order_count += 1;
165 self.pending.insert(segment.seq, segment);
166 }
167 }
168
169 fn flush_pending(&mut self) {
171 while let Some((&seq, _)) = self.pending.first_key_value() {
172 if seq == self.expected_seq {
173 let segment = self.pending.remove(&seq).unwrap();
174 self.reassembled.extend_from_slice(&segment.data);
175 self.expected_seq = seq_add(segment.seq, segment.data.len());
176 } else if seq_lt(seq, self.expected_seq) {
177 self.pending.remove(&seq);
179 } else {
180 break;
182 }
183 }
184 }
185
186 pub fn get_contiguous(&self) -> &[u8] {
188 &self.reassembled
189 }
190
191 pub fn consume(&mut self, bytes: usize) {
193 if bytes > 0 && bytes <= self.reassembled.len() {
194 self.reassembled.drain(..bytes);
195 }
196 }
197
198 pub fn is_complete(&self) -> bool {
200 self.fin_received && self.pending.is_empty()
201 }
202
203 pub fn gaps(&self) -> &[SequenceGap] {
205 &self.gaps
206 }
207
208 pub fn record_gap(&mut self, start: u32, end: u32) {
210 self.gaps.push(SequenceGap {
211 start_seq: start,
212 end_seq: end,
213 });
214 }
215
216 pub fn available(&self) -> usize {
218 self.reassembled.len()
219 }
220
221 pub fn gap_count(&self) -> u32 {
223 self.gaps.len() as u32
224 }
225
226 pub fn segment_count(&self) -> u32 {
228 self.segment_count
229 }
230
231 pub fn retransmit_count(&self) -> u32 {
233 self.retransmit_count
234 }
235
236 pub fn out_of_order_count(&self) -> u32 {
238 self.out_of_order_count
239 }
240}
241
242impl Default for StreamBuffer {
243 fn default() -> Self {
244 Self::new()
245 }
246}
247
248#[derive(Debug, Clone, Hash, Eq, PartialEq)]
250pub struct StreamKey {
251 pub connection_id: u64,
252 pub direction: Direction,
253}
254
255pub struct TcpReassembler {
257 streams: HashMap<StreamKey, StreamBuffer>,
258}
259
260impl TcpReassembler {
261 pub fn new() -> Self {
262 Self {
263 streams: HashMap::new(),
264 }
265 }
266
267 pub fn get_or_create(&mut self, connection_id: u64, direction: Direction) -> &mut StreamBuffer {
269 let key = StreamKey {
270 connection_id,
271 direction,
272 };
273 self.streams.entry(key).or_default()
274 }
275
276 pub fn add_segment(
278 &mut self,
279 connection_id: u64,
280 direction: Direction,
281 seq: u32,
282 data: &[u8],
283 frame_number: u64,
284 timestamp: i64,
285 ) {
286 if data.is_empty() {
287 return; }
289
290 let buffer = self.get_or_create(connection_id, direction);
291
292 if !buffer.add_inorder_data(seq, data, frame_number, timestamp) {
294 buffer.add_segment(Segment {
296 seq,
297 data: data.to_vec(),
298 frame_number,
299 timestamp,
300 });
301 }
302 }
303
304 pub fn get_contiguous(&self, connection_id: u64, direction: Direction) -> &[u8] {
306 let key = StreamKey {
307 connection_id,
308 direction,
309 };
310 self.streams
311 .get(&key)
312 .map(|b| b.get_contiguous())
313 .unwrap_or(&[])
314 }
315
316 pub fn consume(&mut self, connection_id: u64, direction: Direction, bytes: usize) {
318 let key = StreamKey {
319 connection_id,
320 direction,
321 };
322 if let Some(buffer) = self.streams.get_mut(&key) {
323 buffer.consume(bytes);
324 }
325 }
326
327 pub fn mark_fin(&mut self, connection_id: u64, direction: Direction) {
329 let key = StreamKey {
330 connection_id,
331 direction,
332 };
333 if let Some(buffer) = self.streams.get_mut(&key) {
334 buffer.fin_received = true;
335 }
336 }
337
338 pub fn is_complete(&self, connection_id: u64, direction: Direction) -> bool {
340 let key = StreamKey {
341 connection_id,
342 direction,
343 };
344 self.streams
345 .get(&key)
346 .map(|b| b.is_complete())
347 .unwrap_or(false)
348 }
349
350 pub fn remove(&mut self, connection_id: u64) {
352 self.streams.retain(|k, _| k.connection_id != connection_id);
353 }
354
355 pub fn stats(&self, connection_id: u64, direction: Direction) -> Option<StreamStats> {
357 let key = StreamKey {
358 connection_id,
359 direction,
360 };
361 self.streams.get(&key).map(|b| StreamStats {
362 segment_count: b.segment_count,
363 retransmit_count: b.retransmit_count,
364 out_of_order_count: b.out_of_order_count,
365 gap_count: b.gaps.len() as u32,
366 bytes_available: b.available(),
367 })
368 }
369}
370
371impl Default for TcpReassembler {
372 fn default() -> Self {
373 Self::new()
374 }
375}
376
377#[derive(Debug, Clone)]
379pub struct StreamStats {
380 pub segment_count: u32,
381 pub retransmit_count: u32,
382 pub out_of_order_count: u32,
383 pub gap_count: u32,
384 pub bytes_available: usize,
385}
386
387fn seq_lt(a: u32, b: u32) -> bool {
389 (a.wrapping_sub(b) as i32) < 0
390}
391
392fn seq_le(a: u32, b: u32) -> bool {
393 a == b || seq_lt(a, b)
394}
395
396fn seq_add(a: u32, n: usize) -> u32 {
397 a.wrapping_add(n as u32)
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403
404 #[test]
406 fn test_in_order_reassembly() {
407 let mut reassembler = TcpReassembler::new();
408
409 reassembler.add_segment(1, Direction::ToServer, 1000, b"Hello", 1, 0);
410 reassembler.add_segment(1, Direction::ToServer, 1005, b" World", 2, 1);
411
412 let data = reassembler.get_contiguous(1, Direction::ToServer);
413 assert_eq!(data, b"Hello World");
414 }
415
416 #[test]
418 fn test_out_of_order_reordering() {
419 let mut reassembler = TcpReassembler::new();
420
421 reassembler.add_segment(1, Direction::ToServer, 1005, b" World", 2, 1);
423 reassembler.add_segment(1, Direction::ToServer, 1000, b"Hello", 1, 0);
424
425 let data = reassembler.get_contiguous(1, Direction::ToServer);
426 assert_eq!(data, b"Hello World");
427 }
428
429 #[test]
431 fn test_retransmission_detection() {
432 let mut reassembler = TcpReassembler::new();
433
434 reassembler.add_segment(1, Direction::ToServer, 1000, b"Hello", 1, 0);
435 reassembler.add_segment(1, Direction::ToServer, 1000, b"Hello", 2, 1); let stats = reassembler.stats(1, Direction::ToServer).unwrap();
438 assert_eq!(stats.retransmit_count, 1);
439
440 let data = reassembler.get_contiguous(1, Direction::ToServer);
442 assert_eq!(data, b"Hello");
443 }
444
445 #[test]
447 fn test_sequence_wraparound() {
448 let mut reassembler = TcpReassembler::new();
449
450 let near_max = u32::MAX - 2;
452 reassembler.add_segment(1, Direction::ToServer, near_max, b"ABC", 1, 0);
453 reassembler.add_segment(
454 1,
455 Direction::ToServer,
456 near_max.wrapping_add(3),
457 b"DEF",
458 2,
459 1,
460 );
461
462 let data = reassembler.get_contiguous(1, Direction::ToServer);
463 assert_eq!(data, b"ABCDEF");
464 }
465
466 #[test]
468 fn test_gap_detection() {
469 let mut reassembler = TcpReassembler::new();
470
471 reassembler.add_segment(1, Direction::ToServer, 1000, b"Hello", 1, 0);
472 reassembler.add_segment(1, Direction::ToServer, 1010, b"World", 2, 1);
474
475 let data = reassembler.get_contiguous(1, Direction::ToServer);
477 assert_eq!(data, b"Hello");
478
479 let stats = reassembler.stats(1, Direction::ToServer).unwrap();
480 assert_eq!(stats.out_of_order_count, 1);
481 }
482
483 #[test]
485 fn test_overlapping_segments() {
486 let mut reassembler = TcpReassembler::new();
487
488 reassembler.add_segment(1, Direction::ToServer, 1000, b"Hello", 1, 0);
489 reassembler.add_segment(1, Direction::ToServer, 1003, b"loWorld", 2, 1);
491
492 let data = reassembler.get_contiguous(1, Direction::ToServer);
493 assert_eq!(data, b"HelloWorld");
494 }
495
496 #[test]
498 fn test_zero_length_payload() {
499 let mut reassembler = TcpReassembler::new();
500
501 reassembler.add_segment(1, Direction::ToServer, 1000, b"Hello", 1, 0);
502 reassembler.add_segment(1, Direction::ToServer, 1005, b"", 2, 1); reassembler.add_segment(1, Direction::ToServer, 1005, b"World", 3, 2);
504
505 let data = reassembler.get_contiguous(1, Direction::ToServer);
506 assert_eq!(data, b"HelloWorld");
507 }
508
509 #[test]
511 fn test_consume() {
512 let mut reassembler = TcpReassembler::new();
513
514 reassembler.add_segment(1, Direction::ToServer, 1000, b"HelloWorld", 1, 0);
515
516 reassembler.consume(1, Direction::ToServer, 5);
518
519 let data = reassembler.get_contiguous(1, Direction::ToServer);
520 assert_eq!(data, b"World");
521 }
522
523 #[test]
525 fn test_get_contiguous() {
526 let mut reassembler = TcpReassembler::new();
527
528 let data = reassembler.get_contiguous(1, Direction::ToServer);
530 assert!(data.is_empty());
531
532 reassembler.add_segment(1, Direction::ToServer, 1000, b"Test", 1, 0);
533
534 let data = reassembler.get_contiguous(1, Direction::ToServer);
535 assert_eq!(data, b"Test");
536 }
537
538 #[test]
540 fn test_multiple_streams() {
541 let mut reassembler = TcpReassembler::new();
542
543 reassembler.add_segment(1, Direction::ToServer, 1000, b"Request", 1, 0);
544 reassembler.add_segment(1, Direction::ToClient, 2000, b"Response", 2, 1);
545
546 assert_eq!(
547 reassembler.get_contiguous(1, Direction::ToServer),
548 b"Request"
549 );
550 assert_eq!(
551 reassembler.get_contiguous(1, Direction::ToClient),
552 b"Response"
553 );
554 }
555
556 #[test]
558 fn test_stats() {
559 let mut reassembler = TcpReassembler::new();
560
561 reassembler.add_segment(1, Direction::ToServer, 1000, b"Hello", 1, 0);
562 reassembler.add_segment(1, Direction::ToServer, 1010, b"World", 2, 1); let stats = reassembler.stats(1, Direction::ToServer).unwrap();
565 assert_eq!(stats.segment_count, 2);
566 assert_eq!(stats.out_of_order_count, 1);
567 assert_eq!(stats.bytes_available, 5); }
569
570 #[test]
572 fn test_is_complete() {
573 let mut reassembler = TcpReassembler::new();
574
575 reassembler.add_segment(1, Direction::ToServer, 1000, b"Hello", 1, 0);
576 assert!(!reassembler.is_complete(1, Direction::ToServer));
577
578 reassembler.mark_fin(1, Direction::ToServer);
579 assert!(reassembler.is_complete(1, Direction::ToServer));
580 }
581
582 #[test]
584 fn test_inorder_fast_path() {
585 let mut buffer = StreamBuffer::new();
586
587 assert!(!buffer.add_inorder_data(1000, b"Hello", 1, 0));
589
590 buffer.add_segment(Segment {
592 seq: 1000,
593 data: b"Hello".to_vec(),
594 frame_number: 1,
595 timestamp: 0,
596 });
597 assert_eq!(buffer.get_contiguous(), b"Hello");
598 assert_eq!(buffer.segment_count, 1);
599
600 assert!(buffer.add_inorder_data(1005, b" World", 2, 1));
602 assert_eq!(buffer.get_contiguous(), b"Hello World");
603 assert_eq!(buffer.segment_count, 2);
604
605 assert!(buffer.add_inorder_data(1011, b"!", 3, 2));
607 assert_eq!(buffer.get_contiguous(), b"Hello World!");
608 assert_eq!(buffer.segment_count, 3);
609 }
610
611 #[test]
613 fn test_inorder_fast_path_skipped_with_pending() {
614 let mut buffer = StreamBuffer::new();
615
616 buffer.add_segment(Segment {
618 seq: 1000,
619 data: b"Hello".to_vec(),
620 frame_number: 1,
621 timestamp: 0,
622 });
623
624 buffer.add_segment(Segment {
626 seq: 1010,
627 data: b"World".to_vec(),
628 frame_number: 3,
629 timestamp: 2,
630 });
631
632 assert!(!buffer.add_inorder_data(1005, b"_____", 2, 1));
635 }
636}