pcapsql_core/stream/
reassembly.rs

1use std::collections::{BTreeMap, HashMap};
2
3use super::Direction;
4
5/// A TCP segment waiting to be reassembled.
6#[derive(Debug, Clone)]
7pub struct Segment {
8    pub seq: u32,
9    pub data: Vec<u8>,
10    pub frame_number: u64,
11    pub timestamp: i64,
12}
13
14/// A gap in the sequence space (missing data).
15#[derive(Debug, Clone)]
16pub struct SequenceGap {
17    pub start_seq: u32,
18    pub end_seq: u32,
19}
20
21/// Buffer for one direction of a TCP stream.
22#[derive(Debug)]
23pub struct StreamBuffer {
24    /// Next expected sequence number.
25    expected_seq: u32,
26    /// Initial sequence number (from SYN).
27    initial_seq: Option<u32>,
28    /// Whether initial_seq was set from a SYN (definitive) vs inferred.
29    initial_seq_from_syn: bool,
30    /// Out-of-order segments waiting to be reassembled.
31    pending: BTreeMap<u32, Segment>,
32    /// Contiguous reassembled data ready for parsing.
33    reassembled: Vec<u8>,
34    /// Detected gaps (missing segments).
35    gaps: Vec<SequenceGap>,
36    /// Statistics.
37    pub segment_count: u32,
38    pub retransmit_count: u32,
39    pub out_of_order_count: u32,
40    /// FIN received.
41    pub fin_received: bool,
42}
43
44impl StreamBuffer {
45    pub fn new() -> Self {
46        Self {
47            expected_seq: 0,
48            initial_seq: None,
49            initial_seq_from_syn: false,
50            pending: BTreeMap::new(),
51            reassembled: Vec::new(),
52            gaps: Vec::new(),
53            segment_count: 0,
54            retransmit_count: 0,
55            out_of_order_count: 0,
56            fin_received: false,
57        }
58    }
59
60    /// Set the initial sequence number (from SYN).
61    pub fn set_initial_seq(&mut self, seq: u32) {
62        self.initial_seq = Some(seq);
63        self.initial_seq_from_syn = true;
64        self.expected_seq = seq.wrapping_add(1); // SYN consumes one seq
65    }
66
67    /// Fast path for in-order segment - avoids intermediate Vec allocation.
68    /// Returns true if the segment was handled (in-order with no pending segments).
69    /// Returns false if the segment needs to be handled by the slow path.
70    ///
71    /// This copies data directly into the reassembled buffer, avoiding the
72    /// intermediate `Segment { data: data.to_vec(), ... }` allocation that
73    /// the slow path requires.
74    #[inline]
75    pub fn add_inorder_data(
76        &mut self,
77        seq: u32,
78        data: &[u8],
79        _frame_number: u64,
80        _timestamp: i64,
81    ) -> bool {
82        // Fast path: segment is in-order and no pending segments exist
83        // This avoids allocating an intermediate Vec for the Segment struct
84        if self.initial_seq.is_some() && seq == self.expected_seq && self.pending.is_empty() {
85            self.segment_count += 1;
86            self.reassembled.extend_from_slice(data);
87            self.expected_seq = seq_add(seq, data.len());
88            true
89        } else {
90            false
91        }
92    }
93
94    /// Add a segment to the buffer (slow path - takes ownership).
95    pub fn add_segment(&mut self, segment: Segment) {
96        self.segment_count += 1;
97
98        // If we haven't seen a SYN, use first segment's seq
99        if self.initial_seq.is_none() {
100            self.initial_seq = Some(segment.seq);
101            self.expected_seq = segment.seq;
102        }
103
104        let seg_end = seq_add(segment.seq, segment.data.len());
105
106        // Check for retransmission (segment starts before expected)
107        if seq_lt(segment.seq, self.expected_seq) {
108            // Special case: mid-stream mode and we received an earlier segment
109            // This means we started in the middle and now got an earlier packet
110            if !self.initial_seq_from_syn && seq_lt(segment.seq, self.initial_seq.unwrap()) {
111                // Move current reassembled data to pending and reset
112                let old_initial = self.initial_seq.unwrap();
113                let old_data = std::mem::take(&mut self.reassembled);
114                if !old_data.is_empty() {
115                    self.pending.insert(
116                        old_initial,
117                        Segment {
118                            seq: old_initial,
119                            data: old_data,
120                            frame_number: 0,
121                            timestamp: 0,
122                        },
123                    );
124                }
125                // Set new initial and process this segment
126                self.initial_seq = Some(segment.seq);
127                self.expected_seq = segment.seq;
128                self.add_segment_inner(segment);
129                return;
130            }
131
132            // Check if it's fully before expected (pure retransmit)
133            if seq_le(seg_end, self.expected_seq) {
134                self.retransmit_count += 1;
135                return;
136            }
137            // Partial overlap - trim the beginning
138            let overlap = self.expected_seq.wrapping_sub(segment.seq) as usize;
139            if overlap < segment.data.len() {
140                let trimmed = Segment {
141                    seq: self.expected_seq,
142                    data: segment.data[overlap..].to_vec(),
143                    frame_number: segment.frame_number,
144                    timestamp: segment.timestamp,
145                };
146                self.add_segment_inner(trimmed);
147            }
148            return;
149        }
150
151        self.add_segment_inner(segment);
152    }
153
154    fn add_segment_inner(&mut self, segment: Segment) {
155        // In-order segment
156        if segment.seq == self.expected_seq {
157            self.reassembled.extend_from_slice(&segment.data);
158            self.expected_seq = seq_add(segment.seq, segment.data.len());
159
160            // Check if pending segments can now be added
161            self.flush_pending();
162        } else if seq_lt(self.expected_seq, segment.seq) {
163            // Out of order - buffer it
164            self.out_of_order_count += 1;
165            self.pending.insert(segment.seq, segment);
166        }
167    }
168
169    /// Try to flush pending segments that are now in order.
170    fn flush_pending(&mut self) {
171        while let Some((&seq, _)) = self.pending.first_key_value() {
172            if seq == self.expected_seq {
173                let segment = self.pending.remove(&seq).unwrap();
174                self.reassembled.extend_from_slice(&segment.data);
175                self.expected_seq = seq_add(segment.seq, segment.data.len());
176            } else if seq_lt(seq, self.expected_seq) {
177                // Retransmit that arrived late, remove it
178                self.pending.remove(&seq);
179            } else {
180                // Gap - can't continue
181                break;
182            }
183        }
184    }
185
186    /// Get contiguous reassembled data.
187    pub fn get_contiguous(&self) -> &[u8] {
188        &self.reassembled
189    }
190
191    /// Consume bytes from the reassembled buffer (after successful parse).
192    pub fn consume(&mut self, bytes: usize) {
193        if bytes > 0 && bytes <= self.reassembled.len() {
194            self.reassembled.drain(..bytes);
195        }
196    }
197
198    /// Check if stream is complete (no gaps, FIN received).
199    pub fn is_complete(&self) -> bool {
200        self.fin_received && self.pending.is_empty()
201    }
202
203    /// Get current gaps in the stream.
204    pub fn gaps(&self) -> &[SequenceGap] {
205        &self.gaps
206    }
207
208    /// Record a gap when we detect missing data.
209    pub fn record_gap(&mut self, start: u32, end: u32) {
210        self.gaps.push(SequenceGap {
211            start_seq: start,
212            end_seq: end,
213        });
214    }
215
216    /// Get number of bytes available for parsing.
217    pub fn available(&self) -> usize {
218        self.reassembled.len()
219    }
220
221    /// Get the number of gaps in the stream.
222    pub fn gap_count(&self) -> u32 {
223        self.gaps.len() as u32
224    }
225
226    /// Get the segment count.
227    pub fn segment_count(&self) -> u32 {
228        self.segment_count
229    }
230
231    /// Get the retransmission count.
232    pub fn retransmit_count(&self) -> u32 {
233        self.retransmit_count
234    }
235
236    /// Get the out-of-order segment count.
237    pub fn out_of_order_count(&self) -> u32 {
238        self.out_of_order_count
239    }
240}
241
242impl Default for StreamBuffer {
243    fn default() -> Self {
244        Self::new()
245    }
246}
247
248/// Key for stream buffer lookup.
249#[derive(Debug, Clone, Hash, Eq, PartialEq)]
250pub struct StreamKey {
251    pub connection_id: u64,
252    pub direction: Direction,
253}
254
255/// TCP stream reassembler.
256pub struct TcpReassembler {
257    streams: HashMap<StreamKey, StreamBuffer>,
258}
259
260impl TcpReassembler {
261    pub fn new() -> Self {
262        Self {
263            streams: HashMap::new(),
264        }
265    }
266
267    /// Get or create a stream buffer.
268    pub fn get_or_create(&mut self, connection_id: u64, direction: Direction) -> &mut StreamBuffer {
269        let key = StreamKey {
270            connection_id,
271            direction,
272        };
273        self.streams.entry(key).or_default()
274    }
275
276    /// Add a segment to the appropriate stream.
277    pub fn add_segment(
278        &mut self,
279        connection_id: u64,
280        direction: Direction,
281        seq: u32,
282        data: &[u8],
283        frame_number: u64,
284        timestamp: i64,
285    ) {
286        if data.is_empty() {
287            return; // No payload
288        }
289
290        let buffer = self.get_or_create(connection_id, direction);
291
292        // Try fast path first (no allocation for in-order segments)
293        if !buffer.add_inorder_data(seq, data, frame_number, timestamp) {
294            // Fall back to slow path with copy
295            buffer.add_segment(Segment {
296                seq,
297                data: data.to_vec(),
298                frame_number,
299                timestamp,
300            });
301        }
302    }
303
304    /// Get contiguous data for a stream.
305    pub fn get_contiguous(&self, connection_id: u64, direction: Direction) -> &[u8] {
306        let key = StreamKey {
307            connection_id,
308            direction,
309        };
310        self.streams
311            .get(&key)
312            .map(|b| b.get_contiguous())
313            .unwrap_or(&[])
314    }
315
316    /// Consume bytes from a stream.
317    pub fn consume(&mut self, connection_id: u64, direction: Direction, bytes: usize) {
318        let key = StreamKey {
319            connection_id,
320            direction,
321        };
322        if let Some(buffer) = self.streams.get_mut(&key) {
323            buffer.consume(bytes);
324        }
325    }
326
327    /// Mark FIN received for a stream.
328    pub fn mark_fin(&mut self, connection_id: u64, direction: Direction) {
329        let key = StreamKey {
330            connection_id,
331            direction,
332        };
333        if let Some(buffer) = self.streams.get_mut(&key) {
334            buffer.fin_received = true;
335        }
336    }
337
338    /// Check if a stream is complete.
339    pub fn is_complete(&self, connection_id: u64, direction: Direction) -> bool {
340        let key = StreamKey {
341            connection_id,
342            direction,
343        };
344        self.streams
345            .get(&key)
346            .map(|b| b.is_complete())
347            .unwrap_or(false)
348    }
349
350    /// Remove a stream (connection closed).
351    pub fn remove(&mut self, connection_id: u64) {
352        self.streams.retain(|k, _| k.connection_id != connection_id);
353    }
354
355    /// Get stream statistics.
356    pub fn stats(&self, connection_id: u64, direction: Direction) -> Option<StreamStats> {
357        let key = StreamKey {
358            connection_id,
359            direction,
360        };
361        self.streams.get(&key).map(|b| StreamStats {
362            segment_count: b.segment_count,
363            retransmit_count: b.retransmit_count,
364            out_of_order_count: b.out_of_order_count,
365            gap_count: b.gaps.len() as u32,
366            bytes_available: b.available(),
367        })
368    }
369}
370
371impl Default for TcpReassembler {
372    fn default() -> Self {
373        Self::new()
374    }
375}
376
377/// Stream statistics.
378#[derive(Debug, Clone)]
379pub struct StreamStats {
380    pub segment_count: u32,
381    pub retransmit_count: u32,
382    pub out_of_order_count: u32,
383    pub gap_count: u32,
384    pub bytes_available: usize,
385}
386
387// Sequence number comparison helpers
388fn seq_lt(a: u32, b: u32) -> bool {
389    (a.wrapping_sub(b) as i32) < 0
390}
391
392fn seq_le(a: u32, b: u32) -> bool {
393    a == b || seq_lt(a, b)
394}
395
396fn seq_add(a: u32, n: usize) -> u32 {
397    a.wrapping_add(n as u32)
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403
404    // Test 1: In-order segment reassembly
405    #[test]
406    fn test_in_order_reassembly() {
407        let mut reassembler = TcpReassembler::new();
408
409        reassembler.add_segment(1, Direction::ToServer, 1000, b"Hello", 1, 0);
410        reassembler.add_segment(1, Direction::ToServer, 1005, b" World", 2, 1);
411
412        let data = reassembler.get_contiguous(1, Direction::ToServer);
413        assert_eq!(data, b"Hello World");
414    }
415
416    // Test 2: Out-of-order segment reordering
417    #[test]
418    fn test_out_of_order_reordering() {
419        let mut reassembler = TcpReassembler::new();
420
421        // Arrive out of order
422        reassembler.add_segment(1, Direction::ToServer, 1005, b" World", 2, 1);
423        reassembler.add_segment(1, Direction::ToServer, 1000, b"Hello", 1, 0);
424
425        let data = reassembler.get_contiguous(1, Direction::ToServer);
426        assert_eq!(data, b"Hello World");
427    }
428
429    // Test 3: Retransmission detection
430    #[test]
431    fn test_retransmission_detection() {
432        let mut reassembler = TcpReassembler::new();
433
434        reassembler.add_segment(1, Direction::ToServer, 1000, b"Hello", 1, 0);
435        reassembler.add_segment(1, Direction::ToServer, 1000, b"Hello", 2, 1); // Retransmit
436
437        let stats = reassembler.stats(1, Direction::ToServer).unwrap();
438        assert_eq!(stats.retransmit_count, 1);
439
440        // Data should appear only once
441        let data = reassembler.get_contiguous(1, Direction::ToServer);
442        assert_eq!(data, b"Hello");
443    }
444
445    // Test 4: Sequence number wraparound
446    #[test]
447    fn test_sequence_wraparound() {
448        let mut reassembler = TcpReassembler::new();
449
450        // Near max u32
451        let near_max = u32::MAX - 2;
452        reassembler.add_segment(1, Direction::ToServer, near_max, b"ABC", 1, 0);
453        reassembler.add_segment(
454            1,
455            Direction::ToServer,
456            near_max.wrapping_add(3),
457            b"DEF",
458            2,
459            1,
460        );
461
462        let data = reassembler.get_contiguous(1, Direction::ToServer);
463        assert_eq!(data, b"ABCDEF");
464    }
465
466    // Test 5: Gap detection
467    #[test]
468    fn test_gap_detection() {
469        let mut reassembler = TcpReassembler::new();
470
471        reassembler.add_segment(1, Direction::ToServer, 1000, b"Hello", 1, 0);
472        // Skip 1005-1009, add segment starting at 1010
473        reassembler.add_segment(1, Direction::ToServer, 1010, b"World", 2, 1);
474
475        // Only "Hello" should be available (gap before "World")
476        let data = reassembler.get_contiguous(1, Direction::ToServer);
477        assert_eq!(data, b"Hello");
478
479        let stats = reassembler.stats(1, Direction::ToServer).unwrap();
480        assert_eq!(stats.out_of_order_count, 1);
481    }
482
483    // Test 6: Overlapping segments (partial retransmit)
484    #[test]
485    fn test_overlapping_segments() {
486        let mut reassembler = TcpReassembler::new();
487
488        reassembler.add_segment(1, Direction::ToServer, 1000, b"Hello", 1, 0);
489        // Overlapping: starts at 1003, overlaps "lo"
490        reassembler.add_segment(1, Direction::ToServer, 1003, b"loWorld", 2, 1);
491
492        let data = reassembler.get_contiguous(1, Direction::ToServer);
493        assert_eq!(data, b"HelloWorld");
494    }
495
496    // Test 7: Zero-length payload
497    #[test]
498    fn test_zero_length_payload() {
499        let mut reassembler = TcpReassembler::new();
500
501        reassembler.add_segment(1, Direction::ToServer, 1000, b"Hello", 1, 0);
502        reassembler.add_segment(1, Direction::ToServer, 1005, b"", 2, 1); // Empty
503        reassembler.add_segment(1, Direction::ToServer, 1005, b"World", 3, 2);
504
505        let data = reassembler.get_contiguous(1, Direction::ToServer);
506        assert_eq!(data, b"HelloWorld");
507    }
508
509    // Test 8: Consume advances buffer
510    #[test]
511    fn test_consume() {
512        let mut reassembler = TcpReassembler::new();
513
514        reassembler.add_segment(1, Direction::ToServer, 1000, b"HelloWorld", 1, 0);
515
516        // Consume "Hello"
517        reassembler.consume(1, Direction::ToServer, 5);
518
519        let data = reassembler.get_contiguous(1, Direction::ToServer);
520        assert_eq!(data, b"World");
521    }
522
523    // Test 9: Get contiguous returns available data
524    #[test]
525    fn test_get_contiguous() {
526        let mut reassembler = TcpReassembler::new();
527
528        // No data yet
529        let data = reassembler.get_contiguous(1, Direction::ToServer);
530        assert!(data.is_empty());
531
532        reassembler.add_segment(1, Direction::ToServer, 1000, b"Test", 1, 0);
533
534        let data = reassembler.get_contiguous(1, Direction::ToServer);
535        assert_eq!(data, b"Test");
536    }
537
538    // Test 10: Multiple streams per connection
539    #[test]
540    fn test_multiple_streams() {
541        let mut reassembler = TcpReassembler::new();
542
543        reassembler.add_segment(1, Direction::ToServer, 1000, b"Request", 1, 0);
544        reassembler.add_segment(1, Direction::ToClient, 2000, b"Response", 2, 1);
545
546        assert_eq!(
547            reassembler.get_contiguous(1, Direction::ToServer),
548            b"Request"
549        );
550        assert_eq!(
551            reassembler.get_contiguous(1, Direction::ToClient),
552            b"Response"
553        );
554    }
555
556    // Test 11: Buffer limits (memory usage)
557    #[test]
558    fn test_stats() {
559        let mut reassembler = TcpReassembler::new();
560
561        reassembler.add_segment(1, Direction::ToServer, 1000, b"Hello", 1, 0);
562        reassembler.add_segment(1, Direction::ToServer, 1010, b"World", 2, 1); // Gap
563
564        let stats = reassembler.stats(1, Direction::ToServer).unwrap();
565        assert_eq!(stats.segment_count, 2);
566        assert_eq!(stats.out_of_order_count, 1);
567        assert_eq!(stats.bytes_available, 5); // Only "Hello"
568    }
569
570    // Test 12: is_complete detection
571    #[test]
572    fn test_is_complete() {
573        let mut reassembler = TcpReassembler::new();
574
575        reassembler.add_segment(1, Direction::ToServer, 1000, b"Hello", 1, 0);
576        assert!(!reassembler.is_complete(1, Direction::ToServer));
577
578        reassembler.mark_fin(1, Direction::ToServer);
579        assert!(reassembler.is_complete(1, Direction::ToServer));
580    }
581
582    // Test 13: Fast path (add_inorder_data) works correctly
583    #[test]
584    fn test_inorder_fast_path() {
585        let mut buffer = StreamBuffer::new();
586
587        // First segment - no initial_seq yet, fast path returns false
588        assert!(!buffer.add_inorder_data(1000, b"Hello", 1, 0));
589
590        // Set up initial state via slow path
591        buffer.add_segment(Segment {
592            seq: 1000,
593            data: b"Hello".to_vec(),
594            frame_number: 1,
595            timestamp: 0,
596        });
597        assert_eq!(buffer.get_contiguous(), b"Hello");
598        assert_eq!(buffer.segment_count, 1);
599
600        // Second segment - should use fast path
601        assert!(buffer.add_inorder_data(1005, b" World", 2, 1));
602        assert_eq!(buffer.get_contiguous(), b"Hello World");
603        assert_eq!(buffer.segment_count, 2);
604
605        // Third segment - also fast path
606        assert!(buffer.add_inorder_data(1011, b"!", 3, 2));
607        assert_eq!(buffer.get_contiguous(), b"Hello World!");
608        assert_eq!(buffer.segment_count, 3);
609    }
610
611    // Test 14: Fast path skipped when pending segments exist
612    #[test]
613    fn test_inorder_fast_path_skipped_with_pending() {
614        let mut buffer = StreamBuffer::new();
615
616        // Set up initial state
617        buffer.add_segment(Segment {
618            seq: 1000,
619            data: b"Hello".to_vec(),
620            frame_number: 1,
621            timestamp: 0,
622        });
623
624        // Add out-of-order segment (creates pending)
625        buffer.add_segment(Segment {
626            seq: 1010,
627            data: b"World".to_vec(),
628            frame_number: 3,
629            timestamp: 2,
630        });
631
632        // Now even if we have in-order data, fast path should return false
633        // because there are pending segments
634        assert!(!buffer.add_inorder_data(1005, b"_____", 2, 1));
635    }
636}