use std::{collections::VecDeque, ops::Range};
use bytes::{Buf, Bytes};
use crate::{range_set::RangeSet, VarInt};
#[derive(Default, Debug)]
pub struct SendBuffer {
unacked_segments: VecDeque<Bytes>,
unacked_len: usize,
offset: u64,
unsent: u64,
acks: RangeSet,
retransmits: RangeSet,
}
impl SendBuffer {
pub fn new() -> Self {
Self::default()
}
pub fn write(&mut self, data: Bytes) {
self.unacked_len += data.len();
self.offset += data.len() as u64;
self.unacked_segments.push_back(data);
}
pub fn ack(&mut self, mut range: Range<u64>) {
let base_offset = self.offset - self.unacked_len as u64;
range.start = base_offset.max(range.start);
range.end = base_offset.max(range.end);
self.acks.insert(range);
while self.acks.min() == Some(self.offset - self.unacked_len as u64) {
let prefix = self.acks.pop_min().unwrap();
let mut to_advance = (prefix.end - prefix.start) as usize;
self.unacked_len -= to_advance;
while to_advance > 0 {
let front = self
.unacked_segments
.front_mut()
.expect("Expected buffered data");
if front.len() <= to_advance {
to_advance -= front.len();
self.unacked_segments.pop_front();
if self.unacked_segments.len() * 4 < self.unacked_segments.capacity() {
self.unacked_segments.shrink_to_fit();
}
} else {
front.advance(to_advance);
to_advance = 0;
}
}
}
}
pub fn poll_transmit(&mut self, mut max_len: usize) -> (Range<u64>, bool) {
debug_assert!(max_len >= 8 + 8);
let mut encode_length = false;
if let Some(range) = self.retransmits.pop_min() {
if range.start != 0 {
max_len -= VarInt::size(unsafe { VarInt::from_u64_unchecked(range.start) });
}
if range.end - range.start < max_len as u64 {
encode_length = true;
max_len -= 8;
}
let end = range.end.min((max_len as u64).saturating_add(range.start));
if end != range.end {
self.retransmits.insert(end..range.end);
}
return (range.start..end, encode_length);
}
if self.unsent != 0 {
max_len -= VarInt::size(unsafe { VarInt::from_u64_unchecked(self.unsent) });
}
if self.offset - self.unsent < max_len as u64 {
encode_length = true;
max_len -= 8;
}
let end = self
.offset
.min((max_len as u64).saturating_add(self.unsent));
let result = self.unsent..end;
self.unsent = end;
(result, encode_length)
}
pub fn get(&self, offsets: Range<u64>) -> &[u8] {
let base_offset = self.offset - self.unacked_len as u64;
let mut segment_offset = base_offset;
for segment in self.unacked_segments.iter() {
if offsets.start >= segment_offset
&& offsets.start < segment_offset + segment.len() as u64
{
let start = (offsets.start - segment_offset) as usize;
let end = (offsets.end - segment_offset) as usize;
return &segment[start..end.min(segment.len())];
}
segment_offset += segment.len() as u64;
}
&[]
}
pub fn retransmit(&mut self, range: Range<u64>) {
debug_assert!(range.end <= self.unsent, "unsent data can't be lost");
self.retransmits.insert(range);
}
pub fn retransmit_all_for_0rtt(&mut self) {
debug_assert_eq!(self.offset, self.unacked_len as u64);
self.unsent = 0;
}
pub fn offset(&self) -> u64 {
self.offset
}
pub fn is_fully_acked(&self) -> bool {
self.unacked_len == 0
}
pub fn has_unsent_data(&self) -> bool {
self.unsent != self.offset || !self.retransmits.is_empty()
}
pub fn unacked(&self) -> u64 {
self.unacked_len as u64 - self.acks.iter().map(|x| x.end - x.start).sum::<u64>()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn fragment_with_length() {
let mut buf = SendBuffer::new();
const MSG: &[u8] = b"Hello, world!";
buf.write(MSG.into());
assert_eq!(buf.poll_transmit(19), (0..11, true));
assert_eq!(
buf.poll_transmit(MSG.len() + 16 - 11),
(11..MSG.len() as u64, true)
);
assert_eq!(
buf.poll_transmit(58),
(MSG.len() as u64..MSG.len() as u64, true)
);
}
#[test]
fn fragment_without_length() {
let mut buf = SendBuffer::new();
const MSG: &[u8] = b"Hello, world with some extra data!";
buf.write(MSG.into());
assert_eq!(buf.poll_transmit(19), (0..19, false));
assert_eq!(
buf.poll_transmit(MSG.len() - 19 + 1),
(19..MSG.len() as u64, false)
);
assert_eq!(
buf.poll_transmit(58),
(MSG.len() as u64..MSG.len() as u64, true)
);
}
#[test]
fn reserves_encoded_offset() {
let mut buf = SendBuffer::new();
let chunk: Bytes = Bytes::from_static(&[0; 1024 * 1024]);
for _ in 0..1025 {
buf.write(chunk.clone());
}
const SIZE1: u64 = 64;
const SIZE2: u64 = 16 * 1024;
const SIZE3: u64 = 1024 * 1024 * 1024;
assert_eq!(buf.poll_transmit(16), (0..16, false));
buf.retransmit(0..16);
assert_eq!(buf.poll_transmit(16), (0..16, false));
let mut transmitted = 16u64;
assert_eq!(
buf.poll_transmit((SIZE1 - transmitted + 1) as usize),
(transmitted as u64..SIZE1, false)
);
buf.retransmit(transmitted as u64..SIZE1);
assert_eq!(
buf.poll_transmit((SIZE1 - transmitted + 1) as usize),
(transmitted as u64..SIZE1, false)
);
transmitted = SIZE1;
assert_eq!(
buf.poll_transmit((SIZE2 - transmitted + 2) as usize),
(transmitted as u64..SIZE2, false)
);
buf.retransmit(transmitted as u64..SIZE2);
assert_eq!(
buf.poll_transmit((SIZE2 - transmitted + 2) as usize),
(transmitted as u64..SIZE2, false)
);
transmitted = SIZE2;
assert_eq!(
buf.poll_transmit((SIZE3 - transmitted + 4) as usize),
(transmitted as u64..SIZE3, false)
);
buf.retransmit(transmitted as u64..SIZE3);
assert_eq!(
buf.poll_transmit((SIZE3 - transmitted + 4) as usize),
(transmitted as u64..SIZE3, false)
);
transmitted = SIZE3;
assert_eq!(
buf.poll_transmit(chunk.len() + 8),
(transmitted as u64..transmitted + chunk.len() as u64, false)
);
buf.retransmit(transmitted as u64..transmitted + chunk.len() as u64);
assert_eq!(
buf.poll_transmit(chunk.len() + 8),
(transmitted as u64..transmitted + chunk.len() as u64, false)
);
}
#[test]
fn multiple_segments() {
let mut buf = SendBuffer::new();
const MSG: &[u8] = b"Hello, world!";
const MSG_LEN: u64 = MSG.len() as u64;
const SEG1: &[u8] = b"He";
buf.write(SEG1.into());
const SEG2: &[u8] = b"llo,";
buf.write(SEG2.into());
const SEG3: &[u8] = b" w";
buf.write(SEG3.into());
const SEG4: &[u8] = b"o";
buf.write(SEG4.into());
const SEG5: &[u8] = b"rld!";
buf.write(SEG5.into());
assert_eq!(aggregate_unacked(&buf), MSG);
assert_eq!(buf.poll_transmit(16), (0..8, true));
assert_eq!(buf.get(0..5), SEG1);
assert_eq!(buf.get(2..8), SEG2);
assert_eq!(buf.get(6..8), SEG3);
assert_eq!(buf.poll_transmit(16), (8..MSG_LEN, true));
assert_eq!(buf.get(8..MSG_LEN), SEG4);
assert_eq!(buf.get(9..MSG_LEN), SEG5);
assert_eq!(buf.poll_transmit(42), (MSG_LEN..MSG_LEN, true));
buf.ack(0..1);
assert_eq!(aggregate_unacked(&buf), &MSG[1..]);
buf.ack(0..3);
assert_eq!(aggregate_unacked(&buf), &MSG[3..]);
buf.ack(3..5);
assert_eq!(aggregate_unacked(&buf), &MSG[5..]);
buf.ack(7..9);
assert_eq!(aggregate_unacked(&buf), &MSG[5..]);
buf.ack(4..7);
assert_eq!(aggregate_unacked(&buf), &MSG[9..]);
buf.ack(0..MSG_LEN);
assert_eq!(aggregate_unacked(&buf), &[]);
}
#[test]
fn retransmit() {
let mut buf = SendBuffer::new();
const MSG: &[u8] = b"Hello, world with extra data!";
buf.write(MSG.into());
assert_eq!(buf.poll_transmit(16), (0..16, false));
assert_eq!(buf.poll_transmit(16), (16..23, true));
buf.retransmit(0..16);
assert_eq!(buf.poll_transmit(16), (0..16, false));
assert_eq!(buf.poll_transmit(16), (23..MSG.len() as u64, true));
buf.retransmit(16..23);
assert_eq!(buf.poll_transmit(16), (16..23, true));
}
#[test]
fn ack() {
let mut buf = SendBuffer::new();
const MSG: &[u8] = b"Hello, world!";
buf.write(MSG.into());
assert_eq!(buf.poll_transmit(16), (0..8, true));
buf.ack(0..8);
assert_eq!(aggregate_unacked(&buf), &MSG[8..]);
}
#[test]
fn reordered_ack() {
let mut buf = SendBuffer::new();
const MSG: &[u8] = b"Hello, world with extra data!";
buf.write(MSG.into());
assert_eq!(buf.poll_transmit(16), (0..16, false));
assert_eq!(buf.poll_transmit(16), (16..23, true));
buf.ack(16..23);
assert_eq!(aggregate_unacked(&buf), MSG);
buf.ack(0..16);
assert_eq!(aggregate_unacked(&buf), &MSG[23..]);
assert!(buf.acks.is_empty());
}
fn aggregate_unacked(buf: &SendBuffer) -> Vec<u8> {
let mut result = Vec::new();
for segment in buf.unacked_segments.iter() {
result.extend_from_slice(&segment[..]);
}
result
}
}