use std::collections::BinaryHeap;
use std::ops::Range;
use kinesin_rdt::common::ring_buffer::RingBufSlice;
use kinesin_rdt::stream::inbound::{ReceiveSegmentResult, StreamInboundState};
use tracing::{debug, trace, warn};
use crate::PacketExtra;
pub const SEQ_WINDOW_SIZE: u32 = 1024 << 20; pub const SEQ_WINDOW_ADVANCE_THRESHOLD: u32 = 512 << 20;
pub const SEQ_WINDOW_ADVANCE_BY: u32 = 256 << 20;
pub const MAX_ALLOWED_BUFFER_SIZE: u64 = 128 << 20;
pub const MAX_SEGMENTS_INFO_COUNT: usize = 128 << 10;
pub const RESET_MAX_LOOKAHEAD: u32 = 16 << 20;
pub const RESET_MAX_LOOKBEHIND: u32 = 256 << 10;
pub struct Stream {
pub initial_sequence_number: u32,
pub seq_offset: SeqOffset,
pub window_scale: u8,
pub got_window_scale: bool,
pub state: StreamInboundState,
pub seq_window_start: u32,
pub seq_window_end: u32,
pub highest_acked: u64,
pub reverse_acked: u64,
pub had_reset: bool,
pub has_ended: bool,
pub gaps_length: u64,
pub retransmit_count: usize,
pub segments_info: BinaryHeap<SegmentInfo>,
pub segments_info_dropped: usize,
}
impl Stream {
pub fn new() -> Self {
Stream {
initial_sequence_number: 0,
seq_offset: SeqOffset::Initial(0),
window_scale: 0,
got_window_scale: false,
state: StreamInboundState::new(0, true),
seq_window_start: 0,
seq_window_end: 0,
highest_acked: 0,
reverse_acked: 0,
had_reset: false,
has_ended: false,
gaps_length: 0,
retransmit_count: 0,
segments_info: BinaryHeap::new(),
segments_info_dropped: 0,
}
}
pub fn readable_buffered_length(&self) -> usize {
if let Some(highest_readable) = self.state.max_contiguous_offset() {
(highest_readable - self.state.buffer_offset) as usize
} else {
0
}
}
pub fn total_buffered_length(&self) -> usize {
self.state.buffer.len()
}
pub fn buffer_start(&self) -> u64 {
self.state.buffer_offset
}
pub fn set_window_scale(&mut self, window_scale: u8) -> bool {
if window_scale > 14 {
warn!("rejected oversized window_scale value: {window_scale}");
false
} else {
self.window_scale = window_scale;
self.got_window_scale = true;
true
}
}
pub fn estimate_window_scale(&mut self, fit_end_offset: u64) -> bool {
debug_assert!(fit_end_offset > self.state.window_limit);
let window_available = self.state.window_limit - self.highest_acked;
trace!("available window: {window_available}");
if window_available < 8 {
debug!("cannot estimate window scale (available window: {window_available})");
return false;
}
let mut try_scale = self.window_scale;
let unscaled = window_available >> self.window_scale;
if unscaled == 0 {
debug!("cannot estimate window scale: unscaled window size is 0");
return false;
}
let mut new_limit = self.highest_acked + (unscaled << try_scale);
loop {
if try_scale >= 14 {
debug!("cannot estimate window scale: scale is too large");
return false;
}
if new_limit < fit_end_offset {
try_scale += 1;
new_limit = self.highest_acked + (unscaled << try_scale);
} else {
debug!("estimating window scale to be {try_scale}");
self.window_scale = try_scale;
self.state.set_limit(new_limit);
return true;
}
}
}
pub fn set_isn(&mut self, isn: u32, window_size: u16) {
self.initial_sequence_number = isn;
self.seq_offset = SeqOffset::Initial(isn);
self.seq_window_start = isn;
self.seq_window_end = self.seq_window_start.wrapping_add(SEQ_WINDOW_SIZE);
let window_size = (window_size as u64) << self.window_scale as u64;
if window_size < MAX_ALLOWED_BUFFER_SIZE {
trace!("got initial window size from handshake: {window_size}");
self.state.set_limit(window_size);
} else {
warn!("received window size in handshake is too large: {window_size}");
self.state.set_limit(MAX_ALLOWED_BUFFER_SIZE);
}
}
pub fn update_offset(&mut self, number: u32, should_advance: bool) -> Option<u64> {
if self.seq_window_start < self.seq_window_end {
if !(number >= self.seq_window_start && number < self.seq_window_end) {
None
} else {
if should_advance && number - self.seq_window_start > SEQ_WINDOW_ADVANCE_THRESHOLD {
let old_start = self.seq_window_start;
self.seq_window_start = number - SEQ_WINDOW_ADVANCE_BY;
self.seq_window_end = self.seq_window_start.wrapping_add(SEQ_WINDOW_SIZE);
trace!(
"advance seq_window {} -> {} (received seq {})",
old_start,
self.seq_window_start,
number
);
}
Some(self.seq_offset.compute_absolute(number))
}
} else if number < self.seq_window_start && number >= self.seq_window_end {
None
} else if number >= self.seq_window_start {
if should_advance && number - self.seq_window_start > SEQ_WINDOW_ADVANCE_THRESHOLD {
let old_start = self.seq_window_start;
self.seq_window_start = number - SEQ_WINDOW_ADVANCE_BY;
self.seq_window_end = self.seq_window_start.wrapping_add(SEQ_WINDOW_SIZE);
trace!(
"advance seq_window {} -> {} (received seq {})",
old_start,
self.seq_window_start,
number
);
}
Some(self.seq_offset.compute_absolute(number))
} else {
let bytes_from_start = number.wrapping_sub(self.seq_window_start);
let rollover_offset = match self.seq_offset {
SeqOffset::Initial(isn) => SeqOffset::Subsequent((1 << 32) - isn as u64),
SeqOffset::Subsequent(off) => SeqOffset::Subsequent(off + (1 << 32)),
};
if should_advance && bytes_from_start > SEQ_WINDOW_ADVANCE_THRESHOLD {
let old_start = self.seq_window_start;
self.seq_window_start = number.wrapping_sub(SEQ_WINDOW_ADVANCE_BY);
self.seq_window_end = self.seq_window_start.wrapping_add(SEQ_WINDOW_SIZE);
trace!(
"advance seq_window {} -> {} (received seq {})",
old_start,
self.seq_window_start,
number
);
if self.seq_window_start < self.seq_window_end {
self.seq_offset = rollover_offset.clone();
trace!("seq_window rollover over, advance seq_offset");
}
}
let offset = rollover_offset.compute_absolute(number);
Some(offset)
}
}
pub fn handle_data_packet(
&mut self,
sequence_number: u32,
mut data: &[u8],
extra: &PacketExtra,
) -> bool {
let Some(offset) = self.update_offset(sequence_number, true) else {
warn!(
"received seq number {} outside of window ({} - {})",
sequence_number, self.seq_window_start, self.seq_window_end
);
return false;
};
let packet_end_offset = offset + data.len() as u64;
if packet_end_offset > self.state.window_limit {
debug!(
"got packet exceeding the original receiver's window limit: \
seq: {}, offset: {}, len: {}, original window limit: {}",
sequence_number,
offset,
data.len(),
self.state.window_limit
);
if packet_end_offset - self.state.buffer_offset < MAX_ALLOWED_BUFFER_SIZE {
if !self.got_window_scale {
if self.estimate_window_scale(packet_end_offset) {
debug_assert!(self.state.window_limit >= packet_end_offset);
} else {
self.state.set_limit(packet_end_offset);
}
} else {
trace!("extending window limit due to out-of-window packet");
self.state.set_limit(packet_end_offset);
}
} else {
let max_offset = self.state.buffer_offset + MAX_ALLOWED_BUFFER_SIZE;
let max_len = max_offset.saturating_sub(offset) as usize;
if max_len > 0 {
warn!(
"packet exceeds max buffer, dropping {} bytes",
data.len() - max_len
);
data = &data[..max_len];
} else {
warn!("packet exceeds max buffer, dropping packet");
return false;
}
}
}
let mut is_retransmit = false;
match self.state.receive_segment(offset, data) {
ReceiveSegmentResult::Duplicate => {
self.retransmit_count += 1;
is_retransmit = true;
trace!(
"handle_data_packet: got retransmit of {} bytes at seq {}, offset {}",
data.len(),
sequence_number,
offset
);
}
ReceiveSegmentResult::ExceedsWindow => {
unreachable!();
}
ReceiveSegmentResult::Received => {
trace!(
"handle_data_packet: got {} bytes at seq {}, offset {}",
data.len(),
sequence_number,
offset
);
}
}
self.add_segment_info(SegmentInfo {
offset,
reverse_acked: self.reverse_acked,
extra: extra.clone(),
data: SegmentType::Data {
len: data.len(),
is_retransmit,
},
});
true
}
pub fn handle_ack_packet(
&mut self,
acknowledgment_number: u32,
window_size: u16,
extra: &PacketExtra,
) -> bool {
let Some(offset) = self.update_offset(acknowledgment_number, true) else {
warn!(
"received ack number {} outside of window ({} - {})",
acknowledgment_number, self.seq_window_start, self.seq_window_end
);
return false;
};
if offset > self.highest_acked {
self.highest_acked = offset;
trace!("handle_ack_packet: highest ack is {offset}");
}
if let Some(final_seq) = self.state.final_offset {
if self.highest_acked > final_seq {
self.has_ended = true;
debug!("handle_ack_packet: fin (offset {final_seq}) got ack (offset {offset})");
}
}
let real_window = (window_size as u32) << (self.window_scale as u32);
let limit = offset + real_window as u64;
trace!(
"handle_ack_packet: ack: {}, offset {}, win {}",
acknowledgment_number,
offset,
real_window
);
if limit > self.state.window_limit {
let new_buffer_size = limit - self.state.buffer_offset;
if new_buffer_size > MAX_ALLOWED_BUFFER_SIZE {
warn!(
"received ack packet which would result in a buffer size \
exceeding the maximum allowed buffer size: \
ack: {}, win: {}, win scale: {}, absolute window limit: {}",
acknowledgment_number, window_size, self.window_scale, limit
);
self.state
.set_limit(self.state.buffer_offset + MAX_ALLOWED_BUFFER_SIZE);
} else {
trace!(
"received window increase: {} -> {} ({} bytes)",
offset,
limit,
real_window
);
self.state.set_limit(limit);
}
}
self.add_segment_info(SegmentInfo {
offset,
reverse_acked: self.reverse_acked,
extra: extra.clone(),
data: SegmentType::Ack {
window: real_window as usize,
},
});
true
}
pub fn handle_fin_packet(
&mut self,
sequence_number: u32,
data_len: usize,
extra: &PacketExtra,
) -> bool {
let Some(offset) = self.update_offset(sequence_number, true) else {
warn!(
"received fin with seq number {} outside of window ({} - {})",
sequence_number, self.seq_window_start, self.seq_window_end
);
return false;
};
let fin_offset = offset + data_len as u64;
match self.state.final_offset {
None => {
self.state.set_final_offset(fin_offset);
debug!(
"handle_fin_packet: seq: {}, len: {}, final offset: {}",
sequence_number,
data_len,
fin_offset
);
}
Some(prev_fin) => {
if fin_offset != prev_fin {
warn!(
"received duplicate FIN different from previous: prev: {}, now: {}",
prev_fin, fin_offset
);
}
trace!("handle_fin_packet: detected retransmitted FIN");
}
}
self.add_segment_info(SegmentInfo {
offset,
reverse_acked: self.reverse_acked,
extra: extra.clone(),
data: SegmentType::Fin {
end_offset: fin_offset,
},
});
true
}
pub fn handle_rst_packet(&mut self, sequence_number: u32, extra: &PacketExtra) -> bool {
let Some(offset) = self.update_offset(sequence_number, false) else {
warn!(
"received reset with seq number {} outside of window ({} - {})",
sequence_number, self.seq_window_start, self.seq_window_end
);
return false;
};
if offset >= self.highest_acked.saturating_sub(RESET_MAX_LOOKBEHIND as u64)
&& offset < self.highest_acked.saturating_add(RESET_MAX_LOOKAHEAD as u64)
{
debug!("handle_rst_packet: got reset at offset {offset}");
self.add_segment_info(SegmentInfo {
offset,
reverse_acked: self.reverse_acked,
extra: extra.clone(),
data: SegmentType::Rst,
});
true
} else {
warn!(
"got likely invalid reset packet at offset {} (highest acked {}, seq {})",
offset, self.highest_acked, sequence_number
);
false
}
}
pub fn add_segment_info(&mut self, info: SegmentInfo) -> bool {
if self.segments_info.len() < MAX_SEGMENTS_INFO_COUNT {
self.segments_info.push(info);
true
} else {
self.segments_info_dropped += 1;
false
}
}
pub fn read_segments_until(&mut self, end_offset: Option<u64>, in_segments: &mut Vec<SegmentInfo>) {
loop {
let Some(info_peek) = self.segments_info.peek() else {
break;
};
if let Some(end_offset) = end_offset {
if info_peek.offset >= end_offset {
break;
}
}
in_segments.push(self.segments_info.pop().unwrap());
}
}
pub fn read_gaps(&mut self, range: Range<u64>, in_gaps: &mut Vec<Range<u64>>) {
for gap in self.state.received.range_complement(range) {
trace!("read_gaps: gap: {} .. {}", gap.start, gap.end);
in_gaps.push(gap.clone());
self.gaps_length += gap.end - gap.start;
}
}
pub fn read_next<T>(
&mut self,
end_offset: u64,
in_segments: &mut Vec<SegmentInfo>,
in_gaps: &mut Vec<Range<u64>>,
read_fn: impl FnOnce(RingBufSlice<'_, u8>) -> T,
) -> Option<T> {
let start_offset = self.state.buffer_offset;
if end_offset < start_offset {
warn!("requested read of range that no longer exists");
return None;
}
if end_offset == start_offset {
return None;
}
if (end_offset - start_offset) as usize > self.state.buffer.len() {
warn!("requested read of range past end of buffer");
return None;
}
self.read_segments_until(Some(end_offset), in_segments);
self.read_gaps(start_offset..end_offset, in_gaps);
self.state.received.insert_range(start_offset..end_offset);
let Some(slice) = self.state.read_segment(start_offset..end_offset) else {
panic!("InboundStreamState says range is not available");
};
let ret = read_fn(slice);
self.state.advance_buffer(end_offset);
Some(ret)
}
}
impl Default for Stream {
fn default() -> Self {
Self::new()
}
}
pub fn in_range_wrapping(base: u32, before: u32, after: u32, value: u32) -> bool {
let (begin, begin_wrap) = base.overflowing_sub(before);
let (end, end_wrap) = base.overflowing_add(after);
if begin_wrap && end_wrap {
panic!("requested range too large");
}
if begin <= end {
begin <= value && value <= end
} else {
begin <= value || value <= end
}
}
#[derive(Clone)]
pub struct SegmentInfo {
pub offset: u64,
pub reverse_acked: u64,
pub extra: PacketExtra,
pub data: SegmentType,
}
#[derive(Clone)]
pub enum SegmentType {
Data { len: usize, is_retransmit: bool },
Ack { window: usize },
Fin { end_offset: u64 },
Rst,
}
impl Ord for SegmentInfo {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
use std::cmp::Ordering;
match self.offset.cmp(&other.offset) {
Ordering::Less => Ordering::Greater,
Ordering::Equal => match self.reverse_acked.cmp(&other.reverse_acked) {
Ordering::Less => Ordering::Greater,
Ordering::Equal => Ordering::Equal,
Ordering::Greater => Ordering::Less,
},
Ordering::Greater => Ordering::Less,
}
}
}
impl PartialOrd for SegmentInfo {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl PartialEq for SegmentInfo {
fn eq(&self, other: &Self) -> bool {
self.offset == other.offset && self.reverse_acked == other.reverse_acked
}
}
impl Eq for SegmentInfo {}
#[derive(Clone)]
pub enum SeqOffset {
Initial(u32),
Subsequent(u64),
}
impl SeqOffset {
pub fn compute_absolute(&self, number: u32) -> u64 {
match self {
SeqOffset::Initial(isn) => {
debug_assert!(number >= *isn);
(number - isn) as u64
}
SeqOffset::Subsequent(offset) => number as u64 + offset,
}
}
}