parse_tcp/
stream.rs

1use std::collections::BinaryHeap;
2use std::ops::Range;
3
4use kinesin_rdt::common::ring_buffer::RingBufSlice;
5use kinesin_rdt::stream::inbound::{ReceiveSegmentResult, StreamInboundState};
6use tracing::{debug, trace, warn};
7
8use crate::PacketExtra;
9
10/// size of the sequence number sliding window
11pub const SEQ_WINDOW_SIZE: u32 = 1024 << 20; // MB
12/// threshold for advancing the sequence number window
13pub const SEQ_WINDOW_ADVANCE_THRESHOLD: u32 = 512 << 20;
14/// how much to advance the sequence number window by
15pub const SEQ_WINDOW_ADVANCE_BY: u32 = 256 << 20;
16/// max allowed size of stream buffer
17pub const MAX_ALLOWED_BUFFER_SIZE: u64 = 128 << 20;
18/// max size of segments_info in eleemnts
19pub const MAX_SEGMENTS_INFO_COUNT: usize = 128 << 10;
20/// how far forward to allow reset packets
21pub const RESET_MAX_LOOKAHEAD: u32 = 16 << 20;
22/// how far back to allow reset packets
23pub const RESET_MAX_LOOKBEHIND: u32 = 256 << 10;
24
25// TODO: track segments so we can have metadata in a heap or something
26/// unidirectional stream of a connection
27pub struct Stream {
28    /// initial sequence number
29    pub initial_sequence_number: u32,
30    /// offset from packet sequence number to absolute stream offset
31    pub seq_offset: SeqOffset,
32    /// window scale
33    pub window_scale: u8,
34    /// if the window scale was captured (if not, try to estimate)
35    pub got_window_scale: bool,
36    /// stream state
37    pub state: StreamInboundState,
38    /// lowest acceptable TCP sequence number (used to disambiguate absolute offset)
39    pub seq_window_start: u32,
40    /// highest acceptable TCP sequence number plus one
41    pub seq_window_end: u32,
42    /// highest offset at which we have received an ack
43    pub highest_acked: u64,
44    /// highest acked offset of opposite stream
45    pub reverse_acked: u64,
46
47    /// whether a reset happened in this direction
48    pub had_reset: bool,
49    /// true if the FIN for this stream was acked
50    pub has_ended: bool,
51
52    /// count of bytes skipped due to gaps
53    pub gaps_length: u64,
54    /// detected retransmission count
55    pub retransmit_count: usize,
56    /// segment metadata
57    pub segments_info: BinaryHeap<SegmentInfo>,
58    /// number of packets not written to segments_info because it was full
59    pub segments_info_dropped: usize,
60}
61
62impl Stream {
63    /// create new instance
64    pub fn new() -> Self {
65        Stream {
66            initial_sequence_number: 0,
67            seq_offset: SeqOffset::Initial(0),
68            window_scale: 0,
69            got_window_scale: false,
70            state: StreamInboundState::new(0, true),
71            seq_window_start: 0,
72            seq_window_end: 0,
73            highest_acked: 0,
74            reverse_acked: 0,
75            had_reset: false,
76            has_ended: false,
77            gaps_length: 0,
78            retransmit_count: 0,
79            segments_info: BinaryHeap::new(),
80            segments_info_dropped: 0,
81        }
82    }
83
84    /// return the number of bytes currently buffered and readable
85    pub fn readable_buffered_length(&self) -> usize {
86        if let Some(highest_readable) = self.state.max_contiguous_offset() {
87            (highest_readable - self.state.buffer_offset) as usize
88        } else {
89            0
90        }
91    }
92
93    /// return the total length of the buffer, including segments not yet
94    /// readable
95    pub fn total_buffered_length(&self) -> usize {
96        self.state.buffer.len()
97    }
98
99    /// get offset of head of internal buffer
100    pub fn buffer_start(&self) -> u64 {
101        self.state.buffer_offset
102    }
103
104    /// set the window scale option
105    pub fn set_window_scale(&mut self, window_scale: u8) -> bool {
106        if window_scale > 14 {
107            // max value is 14
108            warn!("rejected oversized window_scale value: {window_scale}");
109            false
110        } else {
111            self.window_scale = window_scale;
112            self.got_window_scale = true;
113            true
114        }
115    }
116
117    /// if window scale was not received, try to estimate it
118    pub fn estimate_window_scale(&mut self, fit_end_offset: u64) -> bool {
119        debug_assert!(fit_end_offset > self.state.window_limit);
120        let window_available = self.state.window_limit - self.highest_acked;
121        trace!("available window: {window_available}");
122        if window_available < 8 {
123            // not enough space to estimate
124            debug!("cannot estimate window scale (available window: {window_available})");
125            return false;
126        }
127        let mut try_scale = self.window_scale;
128        let unscaled = window_available >> self.window_scale;
129        if unscaled == 0 {
130            debug!("cannot estimate window scale: unscaled window size is 0");
131            return false;
132        }
133        let mut new_limit = self.highest_acked + (unscaled << try_scale);
134        loop {
135            if try_scale >= 14 {
136                debug!("cannot estimate window scale: scale is too large");
137                return false;
138            }
139            if new_limit < fit_end_offset {
140                try_scale += 1;
141                new_limit = self.highest_acked + (unscaled << try_scale);
142            } else {
143                debug!("estimating window scale to be {try_scale}");
144                self.window_scale = try_scale;
145                self.state.set_limit(new_limit);
146                return true;
147            }
148        }
149    }
150
151    /// set initial sequence number
152    pub fn set_isn(&mut self, isn: u32, window_size: u16) {
153        self.initial_sequence_number = isn;
154        self.seq_offset = SeqOffset::Initial(isn);
155        // set seq window to sane initial values
156        self.seq_window_start = isn;
157        self.seq_window_end = self.seq_window_start.wrapping_add(SEQ_WINDOW_SIZE);
158        // update expected receive window
159        let window_size = (window_size as u64) << self.window_scale as u64;
160        if window_size < MAX_ALLOWED_BUFFER_SIZE {
161            trace!("got initial window size from handshake: {window_size}");
162            self.state.set_limit(window_size);
163        } else {
164            warn!("received window size in handshake is too large: {window_size}");
165            self.state.set_limit(MAX_ALLOWED_BUFFER_SIZE);
166        }
167    }
168
169    /// update seq_window and seq_offset based on current window, return whether
170    /// the value was in the current window and the absolute stream offset
171    pub fn update_offset(&mut self, number: u32, should_advance: bool) -> Option<u64> {
172        // ensure in range
173        if self.seq_window_start < self.seq_window_end {
174            // does not wrap
175            if !(number >= self.seq_window_start && number < self.seq_window_end) {
176                None
177            } else {
178                if should_advance && number - self.seq_window_start > SEQ_WINDOW_ADVANCE_THRESHOLD {
179                    // advance window
180                    let old_start = self.seq_window_start;
181                    self.seq_window_start = number - SEQ_WINDOW_ADVANCE_BY;
182                    self.seq_window_end = self.seq_window_start.wrapping_add(SEQ_WINDOW_SIZE);
183                    trace!(
184                        "advance seq_window {} -> {} (received seq {})",
185                        old_start,
186                        self.seq_window_start,
187                        number
188                    );
189                }
190                Some(self.seq_offset.compute_absolute(number))
191            }
192        } else if number < self.seq_window_start && number >= self.seq_window_end {
193            // does wrap, out of range
194            None
195        } else if number >= self.seq_window_start {
196            // at high section of window
197            if should_advance && number - self.seq_window_start > SEQ_WINDOW_ADVANCE_THRESHOLD {
198                // advance window
199                let old_start = self.seq_window_start;
200                self.seq_window_start = number - SEQ_WINDOW_ADVANCE_BY;
201                self.seq_window_end = self.seq_window_start.wrapping_add(SEQ_WINDOW_SIZE);
202                trace!(
203                    "advance seq_window {} -> {} (received seq {})",
204                    old_start,
205                    self.seq_window_start,
206                    number
207                );
208            }
209            Some(self.seq_offset.compute_absolute(number))
210        } else {
211            // at low section of window (sequence number has rolled over)
212            let bytes_from_start = number.wrapping_sub(self.seq_window_start);
213            // offset object to use for rolled over values
214            let rollover_offset = match self.seq_offset {
215                SeqOffset::Initial(isn) => SeqOffset::Subsequent((1 << 32) - isn as u64),
216                SeqOffset::Subsequent(off) => SeqOffset::Subsequent(off + (1 << 32)),
217            };
218            if should_advance && bytes_from_start > SEQ_WINDOW_ADVANCE_THRESHOLD {
219                // advance window
220                let old_start = self.seq_window_start;
221                self.seq_window_start = number.wrapping_sub(SEQ_WINDOW_ADVANCE_BY);
222                self.seq_window_end = self.seq_window_start.wrapping_add(SEQ_WINDOW_SIZE);
223                trace!(
224                    "advance seq_window {} -> {} (received seq {})",
225                    old_start,
226                    self.seq_window_start,
227                    number
228                );
229
230                if self.seq_window_start < self.seq_window_end {
231                    // seq_window rollover done, update seq_offset
232                    self.seq_offset = rollover_offset.clone();
233                    trace!("seq_window rollover over, advance seq_offset");
234                }
235            }
236            let offset = rollover_offset.compute_absolute(number);
237            Some(offset)
238        }
239    }
240
241    /// handle data packet in the forward direction
242    pub fn handle_data_packet(
243        &mut self,
244        sequence_number: u32,
245        mut data: &[u8],
246        extra: &PacketExtra,
247    ) -> bool {
248        let Some(offset) = self.update_offset(sequence_number, true) else {
249            warn!(
250                "received seq number {} outside of window ({} - {})",
251                sequence_number, self.seq_window_start, self.seq_window_end
252            );
253            return false;
254        };
255
256        let packet_end_offset = offset + data.len() as u64;
257        if packet_end_offset > self.state.window_limit {
258            // might have lost a packet or never got window_scale
259            debug!(
260                "got packet exceeding the original receiver's window limit: \
261                    seq: {}, offset: {}, len: {}, original window limit: {}",
262                sequence_number,
263                offset,
264                data.len(),
265                self.state.window_limit
266            );
267            // try to extend the window limit
268            if packet_end_offset - self.state.buffer_offset < MAX_ALLOWED_BUFFER_SIZE {
269                if !self.got_window_scale {
270                    if self.estimate_window_scale(packet_end_offset) {
271                        debug_assert!(self.state.window_limit >= packet_end_offset);
272                    } else {
273                        self.state.set_limit(packet_end_offset);
274                    }
275                } else {
276                    trace!("extending window limit due to out-of-window packet");
277                    self.state.set_limit(packet_end_offset);
278                }
279            } else {
280                let max_offset = self.state.buffer_offset + MAX_ALLOWED_BUFFER_SIZE;
281                let max_len = max_offset.saturating_sub(offset) as usize;
282                if max_len > 0 {
283                    warn!(
284                        "packet exceeds max buffer, dropping {} bytes",
285                        data.len() - max_len
286                    );
287                    data = &data[..max_len];
288                } else {
289                    warn!("packet exceeds max buffer, dropping packet");
290                    return false;
291                }
292            }
293        }
294
295        // read in the packet
296        let mut is_retransmit = false;
297        match self.state.receive_segment(offset, data) {
298            ReceiveSegmentResult::Duplicate => {
299                // probably a retransmit
300                self.retransmit_count += 1;
301                is_retransmit = true;
302                trace!(
303                    "handle_data_packet: got retransmit of {} bytes at seq {}, offset {}",
304                    data.len(),
305                    sequence_number,
306                    offset
307                );
308            }
309            ReceiveSegmentResult::ExceedsWindow => {
310                // should not happen, window limit is guarded
311                unreachable!();
312            }
313            ReceiveSegmentResult::Received => {
314                // all is well, probably
315                trace!(
316                    "handle_data_packet: got {} bytes at seq {}, offset {}",
317                    data.len(),
318                    sequence_number,
319                    offset
320                );
321            }
322        }
323
324        self.add_segment_info(SegmentInfo {
325            offset,
326            reverse_acked: self.reverse_acked,
327            extra: extra.clone(),
328            data: SegmentType::Data {
329                len: data.len(),
330                is_retransmit,
331            },
332        });
333
334        true
335    }
336
337    /// handle ack packet in the reverse direction
338    pub fn handle_ack_packet(
339        &mut self,
340        acknowledgment_number: u32,
341        window_size: u16,
342        extra: &PacketExtra,
343    ) -> bool {
344        let Some(offset) = self.update_offset(acknowledgment_number, true) else {
345            warn!(
346                "received ack number {} outside of window ({} - {})",
347                acknowledgment_number, self.seq_window_start, self.seq_window_end
348            );
349            return false;
350        };
351
352        if offset > self.highest_acked {
353            self.highest_acked = offset;
354            trace!("handle_ack_packet: highest ack is {offset}");
355        }
356
357        if let Some(final_seq) = self.state.final_offset {
358            // check if final data packet was acked
359            if self.highest_acked > final_seq {
360                self.has_ended = true;
361                debug!("handle_ack_packet: fin (offset {final_seq}) got ack (offset {offset})");
362            }
363        }
364
365        // set expected window limit
366        let real_window = (window_size as u32) << (self.window_scale as u32);
367        let limit = offset + real_window as u64;
368        trace!(
369            "handle_ack_packet: ack: {}, offset {}, win {}",
370            acknowledgment_number,
371            offset,
372            real_window
373        );
374
375        if limit > self.state.window_limit {
376            let new_buffer_size = limit - self.state.buffer_offset;
377            if new_buffer_size > MAX_ALLOWED_BUFFER_SIZE {
378                // would make buffer too large, either window too large (DoS?)
379                // or the buffer is not getting drained properly
380                warn!(
381                    "received ack packet which would result in a buffer size \
382                        exceeding the maximum allowed buffer size: \
383                        ack: {}, win: {}, win scale: {}, absolute window limit: {}",
384                    acknowledgment_number, window_size, self.window_scale, limit
385                );
386                self.state
387                    .set_limit(self.state.buffer_offset + MAX_ALLOWED_BUFFER_SIZE);
388            } else {
389                trace!(
390                    "received window increase: {} -> {} ({} bytes)",
391                    offset,
392                    limit,
393                    real_window
394                );
395                self.state.set_limit(limit);
396            }
397        }
398
399        self.add_segment_info(SegmentInfo {
400            offset,
401            reverse_acked: self.reverse_acked,
402            extra: extra.clone(),
403            data: SegmentType::Ack {
404                window: real_window as usize,
405            },
406        });
407
408        true
409    }
410
411    /// handle FIN packet
412    pub fn handle_fin_packet(
413        &mut self,
414        sequence_number: u32,
415        data_len: usize,
416        extra: &PacketExtra,
417    ) -> bool {
418        let Some(offset) = self.update_offset(sequence_number, true) else {
419            warn!(
420                "received fin with seq number {} outside of window ({} - {})",
421                sequence_number, self.seq_window_start, self.seq_window_end
422            );
423            return false;
424        };
425        let fin_offset = offset + data_len as u64;
426
427        match self.state.final_offset {
428            None => {
429                self.state.set_final_offset(fin_offset);
430                debug!(
431                    "handle_fin_packet: seq: {}, len: {}, final offset: {}",
432                    sequence_number,
433                    data_len,
434                    fin_offset
435                );
436            }
437            Some(prev_fin) => {
438                if fin_offset != prev_fin {
439                    warn!(
440                        "received duplicate FIN different from previous: prev: {}, now: {}",
441                        prev_fin, fin_offset
442                    );
443                }
444                trace!("handle_fin_packet: detected retransmitted FIN");
445                // otherwise it is just retransmit
446            }
447        }
448
449        self.add_segment_info(SegmentInfo {
450            offset,
451            reverse_acked: self.reverse_acked,
452            extra: extra.clone(),
453            data: SegmentType::Fin {
454                end_offset: fin_offset,
455            },
456        });
457        true
458    }
459
460    /// handle reset packet in established state
461    pub fn handle_rst_packet(&mut self, sequence_number: u32, extra: &PacketExtra) -> bool {
462        // we send reset packets to the aligned stream (i.e. if the packet is sent in
463        // the forward direction, then it is sent to the forward stream).
464        // to validate, compare sequence number of reset to highest_acked.
465        // do not update seq_window, as some middleboxes will generate reset packets
466        // with incorrect sequence numbers.
467        let Some(offset) = self.update_offset(sequence_number, false) else {
468            warn!(
469                "received reset with seq number {} outside of window ({} - {})",
470                sequence_number, self.seq_window_start, self.seq_window_end
471            );
472            return false;
473        };
474
475        if offset >= self.highest_acked.saturating_sub(RESET_MAX_LOOKBEHIND as u64)
476            && offset < self.highest_acked.saturating_add(RESET_MAX_LOOKAHEAD as u64)
477        {
478            debug!("handle_rst_packet: got reset at offset {offset}");
479            self.add_segment_info(SegmentInfo {
480                offset,
481                reverse_acked: self.reverse_acked,
482                extra: extra.clone(),
483                data: SegmentType::Rst,
484            });
485            true
486        } else {
487            warn!(
488                "got likely invalid reset packet at offset {} (highest acked {}, seq {})",
489                offset, self.highest_acked, sequence_number
490            );
491            false
492        }
493    }
494
495    /// add an info object to segments_info
496    pub fn add_segment_info(&mut self, info: SegmentInfo) -> bool {
497        if self.segments_info.len() < MAX_SEGMENTS_INFO_COUNT {
498            self.segments_info.push(info);
499            true
500        } else {
501            self.segments_info_dropped += 1;
502            false
503        }
504    }
505
506    /// pop and read segment info until offset, adding to vec.
507    /// if `end_offset` is None, read everything
508    pub fn read_segments_until(&mut self, end_offset: Option<u64>, in_segments: &mut Vec<SegmentInfo>) {
509        loop {
510            let Some(info_peek) = self.segments_info.peek() else {
511                break;
512            };
513            if let Some(end_offset) = end_offset {
514                if info_peek.offset >= end_offset {
515                    break;
516                }
517            }
518
519            in_segments.push(self.segments_info.pop().unwrap());
520        }
521    }
522
523    /// read gaps in buffer in a given range, adding to vec and accounting in gaps_length
524    pub fn read_gaps(&mut self, range: Range<u64>, in_gaps: &mut Vec<Range<u64>>) {
525        for gap in self.state.received.range_complement(range) {
526            trace!("read_gaps: gap: {} .. {}", gap.start, gap.end);
527            in_gaps.push(gap.clone());
528            self.gaps_length += gap.end - gap.start;
529        }
530    }
531
532    /// read state until offset
533    pub fn read_next<T>(
534        &mut self,
535        end_offset: u64,
536        in_segments: &mut Vec<SegmentInfo>,
537        in_gaps: &mut Vec<Range<u64>>,
538        read_fn: impl FnOnce(RingBufSlice<'_, u8>) -> T,
539    ) -> Option<T> {
540        let start_offset = self.state.buffer_offset;
541        if end_offset < start_offset {
542            warn!("requested read of range that no longer exists");
543            return None;
544        }
545        if end_offset == start_offset {
546            // don't return zero-length reads
547            return None;
548        }
549        if (end_offset - start_offset) as usize > self.state.buffer.len() {
550            warn!("requested read of range past end of buffer");
551            return None;
552        }
553        self.read_segments_until(Some(end_offset), in_segments);
554        self.read_gaps(start_offset..end_offset, in_gaps);
555        // assume gaps don't exist
556        self.state.received.insert_range(start_offset..end_offset);
557        // acquire slice
558        let Some(slice) = self.state.read_segment(start_offset..end_offset) else {
559            panic!("InboundStreamState says range is not available");
560        };
561        let ret = read_fn(slice);
562        // advance backing buffer
563        self.state.advance_buffer(end_offset);
564        Some(ret)
565    }
566}
567
568impl Default for Stream {
569    fn default() -> Self {
570        Self::new()
571    }
572}
573
574/// determine if `(base - before) <= value <= (base + after)` in GF(2^32)
575pub fn in_range_wrapping(base: u32, before: u32, after: u32, value: u32) -> bool {
576    let (begin, begin_wrap) = base.overflowing_sub(before);
577    let (end, end_wrap) = base.overflowing_add(after);
578    if begin_wrap && end_wrap {
579        panic!("requested range too large");
580    }
581
582    if begin <= end {
583        begin <= value && value <= end
584    } else {
585        begin <= value || value <= end
586    }
587}
588
589/// information on each segment received
590#[derive(Clone)]
591pub struct SegmentInfo {
592    /// offset into stream of this segment
593    pub offset: u64,
594    /// highest acked offset of opposite stream
595    pub reverse_acked: u64,
596    /// extra metadata from packet
597    pub extra: PacketExtra,
598    /// segment type and type-specific info
599    pub data: SegmentType,
600}
601
602/// type-specific information for each segment
603#[derive(Clone)]
604pub enum SegmentType {
605    Data { len: usize, is_retransmit: bool },
606    Ack { window: usize },
607    Fin { end_offset: u64 },
608    Rst,
609}
610
611impl Ord for SegmentInfo {
612    /// reversed compare of offset (we want pop to get the smallest offset)
613    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
614        use std::cmp::Ordering;
615        match self.offset.cmp(&other.offset) {
616            Ordering::Less => Ordering::Greater,
617            Ordering::Equal => match self.reverse_acked.cmp(&other.reverse_acked) {
618                // sort by reverse_acked if equal
619                Ordering::Less => Ordering::Greater,
620                Ordering::Equal => Ordering::Equal,
621                Ordering::Greater => Ordering::Less,
622            },
623            Ordering::Greater => Ordering::Less,
624        }
625    }
626}
627
628impl PartialOrd for SegmentInfo {
629    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
630        Some(self.cmp(other))
631    }
632}
633
634impl PartialEq for SegmentInfo {
635    fn eq(&self, other: &Self) -> bool {
636        self.offset == other.offset && self.reverse_acked == other.reverse_acked
637    }
638}
639
640impl Eq for SegmentInfo {}
641
642/// represents offset from packet sequence number to absolute offset
643#[derive(Clone)]
644pub enum SeqOffset {
645    /// negative offset due to initial sequence number
646    Initial(u32),
647    /// positive offset after rollover
648    Subsequent(u64),
649}
650
651impl SeqOffset {
652    pub fn compute_absolute(&self, number: u32) -> u64 {
653        match self {
654            SeqOffset::Initial(isn) => {
655                debug_assert!(number >= *isn);
656                (number - isn) as u64
657            }
658            SeqOffset::Subsequent(offset) => number as u64 + offset,
659        }
660    }
661}