1use std::collections::BinaryHeap;
2use std::ops::Range;
3
4use kinesin_rdt::common::ring_buffer::RingBufSlice;
5use kinesin_rdt::stream::inbound::{ReceiveSegmentResult, StreamInboundState};
6use tracing::{debug, trace, warn};
7
8use crate::PacketExtra;
9
10pub const SEQ_WINDOW_SIZE: u32 = 1024 << 20; pub const SEQ_WINDOW_ADVANCE_THRESHOLD: u32 = 512 << 20;
14pub const SEQ_WINDOW_ADVANCE_BY: u32 = 256 << 20;
16pub const MAX_ALLOWED_BUFFER_SIZE: u64 = 128 << 20;
18pub const MAX_SEGMENTS_INFO_COUNT: usize = 128 << 10;
20pub const RESET_MAX_LOOKAHEAD: u32 = 16 << 20;
22pub const RESET_MAX_LOOKBEHIND: u32 = 256 << 10;
24
25pub struct Stream {
28 pub initial_sequence_number: u32,
30 pub seq_offset: SeqOffset,
32 pub window_scale: u8,
34 pub got_window_scale: bool,
36 pub state: StreamInboundState,
38 pub seq_window_start: u32,
40 pub seq_window_end: u32,
42 pub highest_acked: u64,
44 pub reverse_acked: u64,
46
47 pub had_reset: bool,
49 pub has_ended: bool,
51
52 pub gaps_length: u64,
54 pub retransmit_count: usize,
56 pub segments_info: BinaryHeap<SegmentInfo>,
58 pub segments_info_dropped: usize,
60}
61
62impl Stream {
63 pub fn new() -> Self {
65 Stream {
66 initial_sequence_number: 0,
67 seq_offset: SeqOffset::Initial(0),
68 window_scale: 0,
69 got_window_scale: false,
70 state: StreamInboundState::new(0, true),
71 seq_window_start: 0,
72 seq_window_end: 0,
73 highest_acked: 0,
74 reverse_acked: 0,
75 had_reset: false,
76 has_ended: false,
77 gaps_length: 0,
78 retransmit_count: 0,
79 segments_info: BinaryHeap::new(),
80 segments_info_dropped: 0,
81 }
82 }
83
84 pub fn readable_buffered_length(&self) -> usize {
86 if let Some(highest_readable) = self.state.max_contiguous_offset() {
87 (highest_readable - self.state.buffer_offset) as usize
88 } else {
89 0
90 }
91 }
92
93 pub fn total_buffered_length(&self) -> usize {
96 self.state.buffer.len()
97 }
98
99 pub fn buffer_start(&self) -> u64 {
101 self.state.buffer_offset
102 }
103
104 pub fn set_window_scale(&mut self, window_scale: u8) -> bool {
106 if window_scale > 14 {
107 warn!("rejected oversized window_scale value: {window_scale}");
109 false
110 } else {
111 self.window_scale = window_scale;
112 self.got_window_scale = true;
113 true
114 }
115 }
116
117 pub fn estimate_window_scale(&mut self, fit_end_offset: u64) -> bool {
119 debug_assert!(fit_end_offset > self.state.window_limit);
120 let window_available = self.state.window_limit - self.highest_acked;
121 trace!("available window: {window_available}");
122 if window_available < 8 {
123 debug!("cannot estimate window scale (available window: {window_available})");
125 return false;
126 }
127 let mut try_scale = self.window_scale;
128 let unscaled = window_available >> self.window_scale;
129 if unscaled == 0 {
130 debug!("cannot estimate window scale: unscaled window size is 0");
131 return false;
132 }
133 let mut new_limit = self.highest_acked + (unscaled << try_scale);
134 loop {
135 if try_scale >= 14 {
136 debug!("cannot estimate window scale: scale is too large");
137 return false;
138 }
139 if new_limit < fit_end_offset {
140 try_scale += 1;
141 new_limit = self.highest_acked + (unscaled << try_scale);
142 } else {
143 debug!("estimating window scale to be {try_scale}");
144 self.window_scale = try_scale;
145 self.state.set_limit(new_limit);
146 return true;
147 }
148 }
149 }
150
151 pub fn set_isn(&mut self, isn: u32, window_size: u16) {
153 self.initial_sequence_number = isn;
154 self.seq_offset = SeqOffset::Initial(isn);
155 self.seq_window_start = isn;
157 self.seq_window_end = self.seq_window_start.wrapping_add(SEQ_WINDOW_SIZE);
158 let window_size = (window_size as u64) << self.window_scale as u64;
160 if window_size < MAX_ALLOWED_BUFFER_SIZE {
161 trace!("got initial window size from handshake: {window_size}");
162 self.state.set_limit(window_size);
163 } else {
164 warn!("received window size in handshake is too large: {window_size}");
165 self.state.set_limit(MAX_ALLOWED_BUFFER_SIZE);
166 }
167 }
168
169 pub fn update_offset(&mut self, number: u32, should_advance: bool) -> Option<u64> {
172 if self.seq_window_start < self.seq_window_end {
174 if !(number >= self.seq_window_start && number < self.seq_window_end) {
176 None
177 } else {
178 if should_advance && number - self.seq_window_start > SEQ_WINDOW_ADVANCE_THRESHOLD {
179 let old_start = self.seq_window_start;
181 self.seq_window_start = number - SEQ_WINDOW_ADVANCE_BY;
182 self.seq_window_end = self.seq_window_start.wrapping_add(SEQ_WINDOW_SIZE);
183 trace!(
184 "advance seq_window {} -> {} (received seq {})",
185 old_start,
186 self.seq_window_start,
187 number
188 );
189 }
190 Some(self.seq_offset.compute_absolute(number))
191 }
192 } else if number < self.seq_window_start && number >= self.seq_window_end {
193 None
195 } else if number >= self.seq_window_start {
196 if should_advance && number - self.seq_window_start > SEQ_WINDOW_ADVANCE_THRESHOLD {
198 let old_start = self.seq_window_start;
200 self.seq_window_start = number - SEQ_WINDOW_ADVANCE_BY;
201 self.seq_window_end = self.seq_window_start.wrapping_add(SEQ_WINDOW_SIZE);
202 trace!(
203 "advance seq_window {} -> {} (received seq {})",
204 old_start,
205 self.seq_window_start,
206 number
207 );
208 }
209 Some(self.seq_offset.compute_absolute(number))
210 } else {
211 let bytes_from_start = number.wrapping_sub(self.seq_window_start);
213 let rollover_offset = match self.seq_offset {
215 SeqOffset::Initial(isn) => SeqOffset::Subsequent((1 << 32) - isn as u64),
216 SeqOffset::Subsequent(off) => SeqOffset::Subsequent(off + (1 << 32)),
217 };
218 if should_advance && bytes_from_start > SEQ_WINDOW_ADVANCE_THRESHOLD {
219 let old_start = self.seq_window_start;
221 self.seq_window_start = number.wrapping_sub(SEQ_WINDOW_ADVANCE_BY);
222 self.seq_window_end = self.seq_window_start.wrapping_add(SEQ_WINDOW_SIZE);
223 trace!(
224 "advance seq_window {} -> {} (received seq {})",
225 old_start,
226 self.seq_window_start,
227 number
228 );
229
230 if self.seq_window_start < self.seq_window_end {
231 self.seq_offset = rollover_offset.clone();
233 trace!("seq_window rollover over, advance seq_offset");
234 }
235 }
236 let offset = rollover_offset.compute_absolute(number);
237 Some(offset)
238 }
239 }
240
241 pub fn handle_data_packet(
243 &mut self,
244 sequence_number: u32,
245 mut data: &[u8],
246 extra: &PacketExtra,
247 ) -> bool {
248 let Some(offset) = self.update_offset(sequence_number, true) else {
249 warn!(
250 "received seq number {} outside of window ({} - {})",
251 sequence_number, self.seq_window_start, self.seq_window_end
252 );
253 return false;
254 };
255
256 let packet_end_offset = offset + data.len() as u64;
257 if packet_end_offset > self.state.window_limit {
258 debug!(
260 "got packet exceeding the original receiver's window limit: \
261 seq: {}, offset: {}, len: {}, original window limit: {}",
262 sequence_number,
263 offset,
264 data.len(),
265 self.state.window_limit
266 );
267 if packet_end_offset - self.state.buffer_offset < MAX_ALLOWED_BUFFER_SIZE {
269 if !self.got_window_scale {
270 if self.estimate_window_scale(packet_end_offset) {
271 debug_assert!(self.state.window_limit >= packet_end_offset);
272 } else {
273 self.state.set_limit(packet_end_offset);
274 }
275 } else {
276 trace!("extending window limit due to out-of-window packet");
277 self.state.set_limit(packet_end_offset);
278 }
279 } else {
280 let max_offset = self.state.buffer_offset + MAX_ALLOWED_BUFFER_SIZE;
281 let max_len = max_offset.saturating_sub(offset) as usize;
282 if max_len > 0 {
283 warn!(
284 "packet exceeds max buffer, dropping {} bytes",
285 data.len() - max_len
286 );
287 data = &data[..max_len];
288 } else {
289 warn!("packet exceeds max buffer, dropping packet");
290 return false;
291 }
292 }
293 }
294
295 let mut is_retransmit = false;
297 match self.state.receive_segment(offset, data) {
298 ReceiveSegmentResult::Duplicate => {
299 self.retransmit_count += 1;
301 is_retransmit = true;
302 trace!(
303 "handle_data_packet: got retransmit of {} bytes at seq {}, offset {}",
304 data.len(),
305 sequence_number,
306 offset
307 );
308 }
309 ReceiveSegmentResult::ExceedsWindow => {
310 unreachable!();
312 }
313 ReceiveSegmentResult::Received => {
314 trace!(
316 "handle_data_packet: got {} bytes at seq {}, offset {}",
317 data.len(),
318 sequence_number,
319 offset
320 );
321 }
322 }
323
324 self.add_segment_info(SegmentInfo {
325 offset,
326 reverse_acked: self.reverse_acked,
327 extra: extra.clone(),
328 data: SegmentType::Data {
329 len: data.len(),
330 is_retransmit,
331 },
332 });
333
334 true
335 }
336
337 pub fn handle_ack_packet(
339 &mut self,
340 acknowledgment_number: u32,
341 window_size: u16,
342 extra: &PacketExtra,
343 ) -> bool {
344 let Some(offset) = self.update_offset(acknowledgment_number, true) else {
345 warn!(
346 "received ack number {} outside of window ({} - {})",
347 acknowledgment_number, self.seq_window_start, self.seq_window_end
348 );
349 return false;
350 };
351
352 if offset > self.highest_acked {
353 self.highest_acked = offset;
354 trace!("handle_ack_packet: highest ack is {offset}");
355 }
356
357 if let Some(final_seq) = self.state.final_offset {
358 if self.highest_acked > final_seq {
360 self.has_ended = true;
361 debug!("handle_ack_packet: fin (offset {final_seq}) got ack (offset {offset})");
362 }
363 }
364
365 let real_window = (window_size as u32) << (self.window_scale as u32);
367 let limit = offset + real_window as u64;
368 trace!(
369 "handle_ack_packet: ack: {}, offset {}, win {}",
370 acknowledgment_number,
371 offset,
372 real_window
373 );
374
375 if limit > self.state.window_limit {
376 let new_buffer_size = limit - self.state.buffer_offset;
377 if new_buffer_size > MAX_ALLOWED_BUFFER_SIZE {
378 warn!(
381 "received ack packet which would result in a buffer size \
382 exceeding the maximum allowed buffer size: \
383 ack: {}, win: {}, win scale: {}, absolute window limit: {}",
384 acknowledgment_number, window_size, self.window_scale, limit
385 );
386 self.state
387 .set_limit(self.state.buffer_offset + MAX_ALLOWED_BUFFER_SIZE);
388 } else {
389 trace!(
390 "received window increase: {} -> {} ({} bytes)",
391 offset,
392 limit,
393 real_window
394 );
395 self.state.set_limit(limit);
396 }
397 }
398
399 self.add_segment_info(SegmentInfo {
400 offset,
401 reverse_acked: self.reverse_acked,
402 extra: extra.clone(),
403 data: SegmentType::Ack {
404 window: real_window as usize,
405 },
406 });
407
408 true
409 }
410
411 pub fn handle_fin_packet(
413 &mut self,
414 sequence_number: u32,
415 data_len: usize,
416 extra: &PacketExtra,
417 ) -> bool {
418 let Some(offset) = self.update_offset(sequence_number, true) else {
419 warn!(
420 "received fin with seq number {} outside of window ({} - {})",
421 sequence_number, self.seq_window_start, self.seq_window_end
422 );
423 return false;
424 };
425 let fin_offset = offset + data_len as u64;
426
427 match self.state.final_offset {
428 None => {
429 self.state.set_final_offset(fin_offset);
430 debug!(
431 "handle_fin_packet: seq: {}, len: {}, final offset: {}",
432 sequence_number,
433 data_len,
434 fin_offset
435 );
436 }
437 Some(prev_fin) => {
438 if fin_offset != prev_fin {
439 warn!(
440 "received duplicate FIN different from previous: prev: {}, now: {}",
441 prev_fin, fin_offset
442 );
443 }
444 trace!("handle_fin_packet: detected retransmitted FIN");
445 }
447 }
448
449 self.add_segment_info(SegmentInfo {
450 offset,
451 reverse_acked: self.reverse_acked,
452 extra: extra.clone(),
453 data: SegmentType::Fin {
454 end_offset: fin_offset,
455 },
456 });
457 true
458 }
459
460 pub fn handle_rst_packet(&mut self, sequence_number: u32, extra: &PacketExtra) -> bool {
462 let Some(offset) = self.update_offset(sequence_number, false) else {
468 warn!(
469 "received reset with seq number {} outside of window ({} - {})",
470 sequence_number, self.seq_window_start, self.seq_window_end
471 );
472 return false;
473 };
474
475 if offset >= self.highest_acked.saturating_sub(RESET_MAX_LOOKBEHIND as u64)
476 && offset < self.highest_acked.saturating_add(RESET_MAX_LOOKAHEAD as u64)
477 {
478 debug!("handle_rst_packet: got reset at offset {offset}");
479 self.add_segment_info(SegmentInfo {
480 offset,
481 reverse_acked: self.reverse_acked,
482 extra: extra.clone(),
483 data: SegmentType::Rst,
484 });
485 true
486 } else {
487 warn!(
488 "got likely invalid reset packet at offset {} (highest acked {}, seq {})",
489 offset, self.highest_acked, sequence_number
490 );
491 false
492 }
493 }
494
495 pub fn add_segment_info(&mut self, info: SegmentInfo) -> bool {
497 if self.segments_info.len() < MAX_SEGMENTS_INFO_COUNT {
498 self.segments_info.push(info);
499 true
500 } else {
501 self.segments_info_dropped += 1;
502 false
503 }
504 }
505
506 pub fn read_segments_until(&mut self, end_offset: Option<u64>, in_segments: &mut Vec<SegmentInfo>) {
509 loop {
510 let Some(info_peek) = self.segments_info.peek() else {
511 break;
512 };
513 if let Some(end_offset) = end_offset {
514 if info_peek.offset >= end_offset {
515 break;
516 }
517 }
518
519 in_segments.push(self.segments_info.pop().unwrap());
520 }
521 }
522
523 pub fn read_gaps(&mut self, range: Range<u64>, in_gaps: &mut Vec<Range<u64>>) {
525 for gap in self.state.received.range_complement(range) {
526 trace!("read_gaps: gap: {} .. {}", gap.start, gap.end);
527 in_gaps.push(gap.clone());
528 self.gaps_length += gap.end - gap.start;
529 }
530 }
531
532 pub fn read_next<T>(
534 &mut self,
535 end_offset: u64,
536 in_segments: &mut Vec<SegmentInfo>,
537 in_gaps: &mut Vec<Range<u64>>,
538 read_fn: impl FnOnce(RingBufSlice<'_, u8>) -> T,
539 ) -> Option<T> {
540 let start_offset = self.state.buffer_offset;
541 if end_offset < start_offset {
542 warn!("requested read of range that no longer exists");
543 return None;
544 }
545 if end_offset == start_offset {
546 return None;
548 }
549 if (end_offset - start_offset) as usize > self.state.buffer.len() {
550 warn!("requested read of range past end of buffer");
551 return None;
552 }
553 self.read_segments_until(Some(end_offset), in_segments);
554 self.read_gaps(start_offset..end_offset, in_gaps);
555 self.state.received.insert_range(start_offset..end_offset);
557 let Some(slice) = self.state.read_segment(start_offset..end_offset) else {
559 panic!("InboundStreamState says range is not available");
560 };
561 let ret = read_fn(slice);
562 self.state.advance_buffer(end_offset);
564 Some(ret)
565 }
566}
567
568impl Default for Stream {
569 fn default() -> Self {
570 Self::new()
571 }
572}
573
574pub fn in_range_wrapping(base: u32, before: u32, after: u32, value: u32) -> bool {
576 let (begin, begin_wrap) = base.overflowing_sub(before);
577 let (end, end_wrap) = base.overflowing_add(after);
578 if begin_wrap && end_wrap {
579 panic!("requested range too large");
580 }
581
582 if begin <= end {
583 begin <= value && value <= end
584 } else {
585 begin <= value || value <= end
586 }
587}
588
589#[derive(Clone)]
591pub struct SegmentInfo {
592 pub offset: u64,
594 pub reverse_acked: u64,
596 pub extra: PacketExtra,
598 pub data: SegmentType,
600}
601
602#[derive(Clone)]
604pub enum SegmentType {
605 Data { len: usize, is_retransmit: bool },
606 Ack { window: usize },
607 Fin { end_offset: u64 },
608 Rst,
609}
610
611impl Ord for SegmentInfo {
612 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
614 use std::cmp::Ordering;
615 match self.offset.cmp(&other.offset) {
616 Ordering::Less => Ordering::Greater,
617 Ordering::Equal => match self.reverse_acked.cmp(&other.reverse_acked) {
618 Ordering::Less => Ordering::Greater,
620 Ordering::Equal => Ordering::Equal,
621 Ordering::Greater => Ordering::Less,
622 },
623 Ordering::Greater => Ordering::Less,
624 }
625 }
626}
627
628impl PartialOrd for SegmentInfo {
629 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
630 Some(self.cmp(other))
631 }
632}
633
634impl PartialEq for SegmentInfo {
635 fn eq(&self, other: &Self) -> bool {
636 self.offset == other.offset && self.reverse_acked == other.reverse_acked
637 }
638}
639
640impl Eq for SegmentInfo {}
641
642#[derive(Clone)]
644pub enum SeqOffset {
645 Initial(u32),
647 Subsequent(u64),
649}
650
651impl SeqOffset {
652 pub fn compute_absolute(&self, number: u32) -> u64 {
653 match self {
654 SeqOffset::Initial(isn) => {
655 debug_assert!(number >= *isn);
656 (number - isn) as u64
657 }
658 SeqOffset::Subsequent(offset) => number as u64 + offset,
659 }
660 }
661}