use std::{
cmp,
collections::{BTreeMap, BTreeSet, VecDeque},
mem,
net::IpAddr,
ops::{Bound, Index, IndexMut},
};
use rand::{CryptoRng, RngExt};
use rustc_hash::{FxHashMap, FxHashSet};
use sorted_index_buffer::SortedIndexBuffer;
use tracing::trace;
use super::PathId;
use crate::{
Dir, Duration, FourTuple, Instant, StreamId, TransportError, TransportErrorCode, VarInt,
connection::StreamsState,
frame::{self, AddAddress, RemoveAddress},
packet::SpaceId,
range_set::ArrayRangeSet,
shared::IssuedCid,
};
pub(super) struct PacketSpace {
pub(super) pending: Retransmits,
pub(super) number_spaces: BTreeMap<PathId, PacketNumberSpace>,
}
impl PacketSpace {
pub(super) fn new(now: Instant, space: SpaceId, rng: &mut (impl CryptoRng + ?Sized)) -> Self {
let number_space_0 = PacketNumberSpace::new(now, space, rng);
Self {
pending: Retransmits::default(),
number_spaces: BTreeMap::from([(PathId::ZERO, number_space_0)]),
}
}
#[cfg(test)]
pub(super) fn new_deterministic(now: Instant, space: SpaceId) -> Self {
let number_space_0 = PacketNumberSpace::new_deterministic(now, space);
Self {
pending: Retransmits::default(),
number_spaces: BTreeMap::from([(PathId::ZERO, number_space_0)]),
}
}
pub(super) fn path_space(&self, path_id: PathId) -> Option<&PacketNumberSpace> {
self.number_spaces.get(&path_id)
}
pub(super) fn path_space_mut(&mut self, path_id: PathId) -> Option<&mut PacketNumberSpace> {
self.number_spaces.get_mut(&path_id)
}
pub(super) fn for_path(&mut self, path: PathId) -> &mut PacketNumberSpace {
self.number_spaces
.get_mut(&path)
.unwrap_or_else(|| panic!("PacketNumberSpace missing for {path}"))
}
pub(super) fn iter_paths_mut(&mut self) -> impl Iterator<Item = &mut PacketNumberSpace> {
self.number_spaces.values_mut()
}
pub(super) fn queue_tail_loss_probe(
&mut self,
path_id: PathId,
request_immediate_ack: bool,
streams: &StreamsState,
) {
if request_immediate_ack {
self.for_path(path_id).immediate_ack_pending = true;
}
if !self.pending.is_empty(streams) {
return;
}
for packet in self
.number_spaces
.values_mut()
.flat_map(|s| s.sent_packets.values_mut())
{
if !packet.retransmits.is_empty(streams) {
self.pending |= mem::take(&mut packet.retransmits);
return;
}
}
if !self.for_path(path_id).immediate_ack_pending {
self.for_path(path_id).ping_pending = true;
}
}
pub(super) fn can_send(&self, path_id: PathId, streams: &StreamsState) -> SendableFrames {
let acks = self
.number_spaces
.values()
.any(|pns| pns.pending_acks.can_send());
let space_specific = self
.number_spaces
.get(&path_id)
.is_some_and(|s| s.ping_pending || s.immediate_ack_pending);
let other = !self.pending.is_empty(streams);
SendableFrames {
acks,
close: false,
space_specific,
other,
}
}
}
impl Index<SpaceId> for [PacketSpace; 3] {
type Output = PacketSpace;
fn index(&self, space: SpaceId) -> &PacketSpace {
&self.as_ref()[space as usize]
}
}
impl IndexMut<SpaceId> for [PacketSpace; 3] {
fn index_mut(&mut self, space: SpaceId) -> &mut PacketSpace {
&mut self.as_mut()[space as usize]
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub(crate) enum SpaceKind {
Initial = 0,
Handshake = 1,
Data = 2,
}
impl SpaceKind {
pub(crate) fn encryption_level(self) -> super::EncryptionLevel {
match self {
Self::Initial => super::EncryptionLevel::Initial,
Self::Handshake => super::EncryptionLevel::Handshake,
Self::Data => super::EncryptionLevel::OneRtt,
}
}
}
impl Index<SpaceKind> for [PacketSpace; 3] {
type Output = PacketSpace;
fn index(&self, space: SpaceKind) -> &PacketSpace {
&self.as_ref()[space as usize]
}
}
impl IndexMut<SpaceKind> for [PacketSpace; 3] {
fn index_mut(&mut self, space: SpaceKind) -> &mut PacketSpace {
&mut self.as_mut()[space as usize]
}
}
pub(super) struct PacketNumberSpace {
pub(super) largest_received_packet_number: Option<u64>,
pub(super) next_packet_number: u64,
pub(super) largest_acked_packet: Option<u64>,
pub(super) largest_acked_packet_sent: Instant,
pub(super) largest_ack_eliciting_sent: u64,
pub(super) unacked_non_ack_eliciting_tail: u64,
pub(super) sent_packets: SortedIndexBuffer<SentPacket>,
pub(super) lost_packets: SortedIndexBuffer<LostPacket>,
pub(super) ecn_counters: frame::EcnCounts,
pub(super) ecn_feedback: frame::EcnCounts,
pub(super) ping_pending: bool,
pub(super) immediate_ack_pending: bool,
pub(super) dedup: Dedup,
pub(super) pending_acks: PendingAcks,
pub(super) time_of_last_ack_eliciting_packet: Option<Instant>,
pub(super) loss_time: Option<Instant>,
pub(super) loss_probes: u32,
pn_filter: Option<PacketNumberFilter>,
}
impl PacketNumberSpace {
pub(super) fn new(now: Instant, space: SpaceId, rng: &mut (impl CryptoRng + ?Sized)) -> Self {
let pn_filter = match space {
SpaceId::Initial | SpaceId::Handshake => None,
SpaceId::Data => Some(PacketNumberFilter::new(rng)),
};
Self {
largest_received_packet_number: None,
next_packet_number: 0,
largest_acked_packet: None,
largest_acked_packet_sent: now,
largest_ack_eliciting_sent: 0,
unacked_non_ack_eliciting_tail: 0,
sent_packets: SortedIndexBuffer::new(),
lost_packets: SortedIndexBuffer::new(),
ecn_counters: frame::EcnCounts::ZERO,
ecn_feedback: frame::EcnCounts::ZERO,
ping_pending: false,
immediate_ack_pending: false,
dedup: Default::default(),
pending_acks: PendingAcks::new(),
time_of_last_ack_eliciting_packet: None,
loss_time: None,
loss_probes: 0,
pn_filter,
}
}
#[cfg(test)]
fn new_deterministic(now: Instant, space: SpaceId) -> Self {
let pn_filter = match space {
SpaceId::Initial | SpaceId::Handshake => None,
SpaceId::Data => Some(PacketNumberFilter::disabled()),
};
Self {
largest_received_packet_number: None,
next_packet_number: 0,
largest_acked_packet: None,
largest_acked_packet_sent: now,
largest_ack_eliciting_sent: 0,
unacked_non_ack_eliciting_tail: 0,
sent_packets: SortedIndexBuffer::new(),
lost_packets: SortedIndexBuffer::new(),
ecn_counters: frame::EcnCounts::ZERO,
ecn_feedback: frame::EcnCounts::ZERO,
ping_pending: false,
immediate_ack_pending: false,
dedup: Default::default(),
pending_acks: PendingAcks::new(),
time_of_last_ack_eliciting_packet: None,
loss_time: None,
loss_probes: 0,
pn_filter,
}
}
pub(super) fn get_tx_number(&mut self, rng: &mut (impl CryptoRng + ?Sized)) -> u64 {
assert!(self.next_packet_number < 2u64.pow(62));
let mut pn = self.next_packet_number;
self.next_packet_number += 1;
if let Some(ref mut filter) = self.pn_filter
&& filter.skip_pn(pn, rng)
{
pn = self.next_packet_number;
self.next_packet_number += 1;
}
pn
}
pub(super) fn peek_tx_number(&mut self) -> u64 {
let pn = self.next_packet_number;
if let Some(ref filter) = self.pn_filter
&& pn == filter.next_skipped_packet_number
{
return pn + 1;
}
pn
}
pub(super) fn check_ack(&self, range: std::ops::Range<u64>) -> Result<(), TransportError> {
if let Some(ref filter) = self.pn_filter
&& filter
.prev_skipped_packet_number
.is_some_and(|pn| range.contains(&pn))
{
return Err(TransportError::PROTOCOL_VIOLATION("unsent packet acked"));
}
Ok(())
}
pub(super) fn detect_ecn(
&mut self,
newly_acked: u64,
ecn: frame::EcnCounts,
) -> Result<bool, &'static str> {
let ect0_increase = ecn
.ect0
.checked_sub(self.ecn_feedback.ect0)
.ok_or("peer ECT(0) count regression")?;
let ect1_increase = ecn
.ect1
.checked_sub(self.ecn_feedback.ect1)
.ok_or("peer ECT(1) count regression")?;
let ce_increase = ecn
.ce
.checked_sub(self.ecn_feedback.ce)
.ok_or("peer CE count regression")?;
let total_increase = ect0_increase + ect1_increase + ce_increase;
if total_increase < newly_acked {
return Err("ECN bleaching");
}
if (ect0_increase + ce_increase) < newly_acked || ect1_increase != 0 {
return Err("ECN corruption");
}
self.ecn_feedback = ecn;
Ok(ce_increase != 0)
}
pub(super) fn take(&mut self, number: u64) -> Option<SentPacket> {
let packet = self.sent_packets.remove(number)?;
if !packet.ack_eliciting && number > self.largest_ack_eliciting_sent {
self.unacked_non_ack_eliciting_tail =
self.unacked_non_ack_eliciting_tail.checked_sub(1).unwrap();
}
Some(packet)
}
pub(super) fn sent(&mut self, number: u64, packet: SentPacket) -> Option<SentPacket> {
const MAX_UNACKED_NON_ACK_ELICTING_TAIL: u64 = 1_000;
let mut forgotten = None;
if packet.ack_eliciting {
self.unacked_non_ack_eliciting_tail = 0;
self.largest_ack_eliciting_sent = number;
} else if self.unacked_non_ack_eliciting_tail > MAX_UNACKED_NON_ACK_ELICTING_TAIL {
let oldest_after_ack_eliciting = self
.sent_packets
.keys_range((
Bound::Excluded(self.largest_ack_eliciting_sent),
Bound::Unbounded,
))
.next()
.unwrap();
let packet = self
.sent_packets
.remove(oldest_after_ack_eliciting)
.unwrap();
debug_assert!(!packet.ack_eliciting);
forgotten = Some(packet);
} else {
self.unacked_non_ack_eliciting_tail += 1;
}
self.sent_packets.insert(number, packet);
forgotten
}
pub(super) fn has_in_flight(&self) -> bool {
self.sent_packets.values().any(|x| x.size != 0)
}
}
#[derive(Debug, Clone)]
pub(super) struct SentPacket {
pub(super) path_generation: u64,
pub(super) time_sent: Instant,
pub(super) size: u16,
pub(super) ack_eliciting: bool,
pub(super) largest_acked: FxHashMap<PathId, u64>,
pub(super) retransmits: ThinRetransmits,
pub(super) stream_frames: frame::StreamMetaVec,
}
#[derive(Debug)]
pub(super) struct LostPacket {
pub(super) time_sent: Instant,
}
#[allow(unreachable_pub)] #[derive(Debug, Default, Clone)]
pub struct Retransmits {
pub(super) max_data: bool,
pub(super) max_stream_id: [bool; 2],
pub(super) reset_stream: Vec<(StreamId, VarInt)>,
pub(super) stop_sending: Vec<frame::StopSending>,
pub(super) max_stream_data: FxHashSet<StreamId>,
pub(super) crypto: VecDeque<frame::Crypto>,
pub(super) new_cids: PendingNewCids,
pub(super) retire_cids: Vec<(PathId, u64)>,
pub(super) ack_frequency: bool,
pub(super) handshake_done: bool,
pub(super) observed_addr: bool,
pub(super) max_path_id: bool,
pub(super) paths_blocked: bool,
pub(super) new_tokens: Vec<FourTuple>,
pub(super) path_abandon: BTreeMap<PathId, TransportErrorCode>,
pub(super) path_status: BTreeSet<PathId>,
pub(super) path_cids_blocked: BTreeSet<PathId>,
pub(super) add_address: BTreeSet<AddAddress>,
pub(super) remove_address: BTreeSet<RemoveAddress>,
pub(super) reach_out: Option<(VarInt, FxHashSet<(IpAddr, u16)>)>,
}
impl Retransmits {
pub(super) fn is_empty(&self, streams: &StreamsState) -> bool {
let Self {
max_data,
max_stream_id,
reset_stream,
stop_sending,
max_stream_data,
crypto,
new_cids,
retire_cids,
ack_frequency,
handshake_done,
observed_addr,
max_path_id,
paths_blocked,
new_tokens,
path_abandon,
path_status,
path_cids_blocked,
add_address,
remove_address,
reach_out,
} = &self;
!max_data
&& !max_stream_id.iter().any(|x| *x)
&& reset_stream.is_empty()
&& stop_sending.is_empty()
&& max_stream_data
.iter()
.all(|&id| !streams.can_send_flow_control(id))
&& crypto.is_empty()
&& new_cids.is_empty()
&& retire_cids.is_empty()
&& !ack_frequency
&& !handshake_done
&& !observed_addr
&& !max_path_id
&& !paths_blocked
&& new_tokens.is_empty()
&& path_abandon.is_empty()
&& path_status.is_empty()
&& path_cids_blocked.is_empty()
&& add_address.is_empty()
&& remove_address.is_empty()
&& reach_out.is_none()
}
}
impl ::std::ops::BitOrAssign for Retransmits {
fn bitor_assign(&mut self, rhs: Self) {
let Self {
max_data,
max_stream_id,
reset_stream,
stop_sending,
max_stream_data,
crypto,
new_cids,
retire_cids,
ack_frequency,
handshake_done,
observed_addr,
max_path_id,
paths_blocked,
new_tokens,
mut path_abandon,
mut path_status,
mut path_cids_blocked,
add_address,
remove_address,
reach_out,
} = rhs;
self.max_data |= max_data;
for dir in Dir::iter() {
self.max_stream_id[dir as usize] |= max_stream_id[dir as usize];
}
self.reset_stream.extend_from_slice(&reset_stream);
self.stop_sending.extend_from_slice(&stop_sending);
self.max_stream_data.extend(&max_stream_data);
for crypto in crypto.into_iter().rev() {
self.crypto.push_front(crypto);
}
self.new_cids.extend(&new_cids);
self.retire_cids.extend(retire_cids);
self.ack_frequency |= ack_frequency;
self.handshake_done |= handshake_done;
self.observed_addr |= observed_addr;
self.max_path_id |= max_path_id;
self.paths_blocked |= paths_blocked;
self.new_tokens.extend_from_slice(&new_tokens);
self.path_abandon.append(&mut path_abandon);
self.path_status.append(&mut path_status);
self.path_cids_blocked.append(&mut path_cids_blocked);
self.add_address.extend(add_address.iter().copied());
self.remove_address.extend(remove_address.iter().copied());
if let Some((rhs_round, rhs_addrs)) = reach_out {
match self.reach_out.as_mut() {
None => self.reach_out = Some((rhs_round, rhs_addrs)),
Some((lhs_round, _lhs_addrs)) if rhs_round > *lhs_round => {
self.reach_out = Some((rhs_round, rhs_addrs));
}
Some((lhs_round, lhs_addrs)) if rhs_round == *lhs_round => {
lhs_addrs.extend(rhs_addrs);
}
Some(_) => {}
}
}
}
}
impl ::std::ops::BitOrAssign<ThinRetransmits> for Retransmits {
fn bitor_assign(&mut self, rhs: ThinRetransmits) {
let ThinRetransmits { retransmits } = rhs;
if let Some(retransmits) = retransmits {
self.bitor_assign(*retransmits)
}
}
}
impl ::std::iter::FromIterator<Self> for Retransmits {
fn from_iter<T>(iter: T) -> Self
where
T: IntoIterator<Item = Self>,
{
let mut result = Self::default();
for packet in iter {
result |= packet;
}
result
}
}
#[derive(Clone, Debug, Default)]
pub(super) struct PendingNewCids {
cids: Vec<IssuedCid>,
sorted: bool,
}
impl PendingNewCids {
pub(super) fn push(&mut self, cid: IssuedCid) {
self.cids.push(cid);
self.sorted = false;
}
pub(super) fn pop(&mut self) -> Option<IssuedCid> {
if !std::mem::replace(&mut self.sorted, true) {
self.cids
.sort_by_key(|cid| cmp::Reverse((cid.path_id, cid.sequence)));
}
self.cids.pop()
}
pub(super) fn is_empty(&self) -> bool {
self.cids.is_empty()
}
pub(super) fn extend(&mut self, other: &Self) {
self.cids.extend(&other.cids);
self.sorted = false;
}
pub(super) fn retain<F>(&mut self, f: F)
where
F: FnMut(&IssuedCid) -> bool,
{
self.cids.retain(f);
}
}
#[derive(Debug, Default, Clone)]
pub(super) struct ThinRetransmits {
retransmits: Option<Box<Retransmits>>,
}
impl ThinRetransmits {
pub(super) fn is_empty(&self, streams: &StreamsState) -> bool {
match &self.retransmits {
Some(retransmits) => retransmits.is_empty(streams),
None => true,
}
}
pub(super) fn get(&self) -> Option<&Retransmits> {
self.retransmits.as_deref()
}
pub(super) fn get_mut(&mut self) -> Option<&mut Retransmits> {
self.retransmits.as_deref_mut()
}
pub(super) fn get_or_create(&mut self) -> &mut Retransmits {
if self.retransmits.is_none() {
self.retransmits = Some(Box::default());
}
self.retransmits.as_deref_mut().unwrap()
}
}
#[derive(Debug, Default)]
pub(super) struct Dedup {
window: Window,
next: u64,
}
type Window = u128;
const WINDOW_SIZE: u64 = 1 + mem::size_of::<Window>() as u64 * 8;
impl Dedup {
#[cfg(test)]
pub(super) fn new() -> Self {
Self { window: 0, next: 0 }
}
fn highest(&self) -> u64 {
self.next - 1
}
pub(super) fn insert(&mut self, packet: u64) -> bool {
if let Some(diff) = packet.checked_sub(self.next) {
self.window = ((self.window << 1) | 1)
.checked_shl(cmp::min(diff, u64::from(u32::MAX)) as u32)
.unwrap_or(0);
self.next = packet + 1;
false
} else if self.highest() - packet < WINDOW_SIZE {
if let Some(bit) = (self.highest() - packet).checked_sub(1) {
let mask = 1 << bit;
let duplicate = self.window & mask != 0;
self.window |= mask;
duplicate
} else {
true
}
} else {
true
}
}
fn smallest_missing_in_interval(&self, lower_bound: u64, upper_bound: u64) -> Option<u64> {
debug_assert!(lower_bound <= upper_bound);
debug_assert!(upper_bound <= self.highest());
const BITFIELD_SIZE: u64 = (mem::size_of::<Window>() * 8) as u64;
let lower_bound = lower_bound + 1;
let upper_bound = upper_bound.saturating_sub(1);
let start_offset = (self.highest() - upper_bound).max(1) - 1;
if start_offset >= BITFIELD_SIZE {
return None;
}
let end_offset_exclusive = self.highest().saturating_sub(lower_bound);
let range_len = end_offset_exclusive
.saturating_sub(start_offset)
.min(BITFIELD_SIZE);
if range_len == 0 {
return None;
}
let mask = if range_len == BITFIELD_SIZE {
u128::MAX
} else {
((1u128 << range_len) - 1) << start_offset
};
let gaps = !self.window & mask;
let smallest_missing_offset = 128 - gaps.leading_zeros() as u64;
let smallest_missing_packet = self.highest() - smallest_missing_offset;
if smallest_missing_packet <= upper_bound {
Some(smallest_missing_packet)
} else {
None
}
}
fn missing_in_interval(&self, lower_bound: u64, upper_bound: u64) -> bool {
self.smallest_missing_in_interval(lower_bound, upper_bound)
.is_some()
}
}
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub(super) struct SendableFrames {
pub(super) acks: bool,
pub(super) close: bool,
pub(super) space_specific: bool,
pub(super) other: bool,
}
impl SendableFrames {
pub(super) fn empty() -> Self {
Self {
acks: false,
close: false,
space_specific: false,
other: false,
}
}
pub(super) fn is_ack_eliciting(&self) -> bool {
let Self {
acks: _,
close,
space_specific,
other,
} = *self;
if close {
return false;
}
space_specific || other
}
pub(super) fn is_empty(&self) -> bool {
let Self {
acks,
close,
space_specific,
other,
} = *self;
!acks && !close && !space_specific && !other
}
}
impl ::std::ops::BitOrAssign for SendableFrames {
fn bitor_assign(&mut self, rhs: Self) {
let Self {
acks,
close,
space_specific,
other,
} = rhs;
self.acks |= acks;
self.close |= close;
self.space_specific |= space_specific;
self.other |= other;
}
}
#[derive(Debug)]
pub(super) struct PendingAcks {
immediate_ack_required: bool,
ack_eliciting_since_last_ack_sent: u64,
non_ack_eliciting_since_last_ack_sent: u64,
ack_eliciting_threshold: u64,
reordering_threshold: u64,
earliest_ack_eliciting_since_last_ack_sent: Option<Instant>,
ranges: ArrayRangeSet,
largest_packet: Option<(u64, Instant)>,
largest_ack_eliciting_packet: Option<u64>,
largest_acked: Option<u64>,
}
impl PendingAcks {
fn new() -> Self {
Self {
immediate_ack_required: false,
ack_eliciting_since_last_ack_sent: 0,
non_ack_eliciting_since_last_ack_sent: 0,
ack_eliciting_threshold: 1,
reordering_threshold: 1,
earliest_ack_eliciting_since_last_ack_sent: None,
ranges: Default::default(),
largest_packet: Default::default(),
largest_ack_eliciting_packet: Default::default(),
largest_acked: Default::default(),
}
}
pub(super) fn set_ack_frequency_params(&mut self, frame: &frame::AckFrequency) {
self.ack_eliciting_threshold = frame.ack_eliciting_threshold.into_inner();
self.reordering_threshold = frame.reordering_threshold.into_inner();
}
pub(super) fn set_immediate_ack_required(&mut self) {
self.immediate_ack_required = true;
}
pub(super) fn on_max_ack_delay_timeout(&mut self) {
self.immediate_ack_required = self.ack_eliciting_since_last_ack_sent > 0;
}
pub(super) fn max_ack_delay_timeout(&self, max_ack_delay: Duration) -> Option<Instant> {
self.earliest_ack_eliciting_since_last_ack_sent
.map(|earliest_unacked| earliest_unacked + max_ack_delay)
}
pub(super) fn can_send(&self) -> bool {
self.immediate_ack_required && !self.ranges.is_empty()
}
pub(super) fn ack_delay(&self, now: Instant) -> Duration {
self.largest_packet
.map_or_else(Duration::default, |(_, received)| now - received)
}
pub(super) fn packet_received(
&mut self,
now: Instant,
packet_number: u64,
ack_eliciting: bool,
dedup: &Dedup,
) -> bool {
if !ack_eliciting {
self.non_ack_eliciting_since_last_ack_sent += 1;
return false;
}
let prev_largest_ack_eliciting = self.largest_ack_eliciting_packet.unwrap_or(0);
self.largest_ack_eliciting_packet = self
.largest_ack_eliciting_packet
.map(|pn| pn.max(packet_number))
.or(Some(packet_number));
self.ack_eliciting_since_last_ack_sent += 1;
self.immediate_ack_required |=
self.ack_eliciting_since_last_ack_sent > self.ack_eliciting_threshold;
self.immediate_ack_required |=
self.is_out_of_order(packet_number, prev_largest_ack_eliciting, dedup);
if self.earliest_ack_eliciting_since_last_ack_sent.is_none() && !self.can_send() {
self.earliest_ack_eliciting_since_last_ack_sent = Some(now);
return true;
}
false
}
fn is_out_of_order(
&self,
packet_number: u64,
prev_largest_ack_eliciting: u64,
dedup: &Dedup,
) -> bool {
match self.reordering_threshold {
0 => false,
1 => {
packet_number < prev_largest_ack_eliciting
|| dedup.missing_in_interval(prev_largest_ack_eliciting, packet_number)
}
_ => {
let Some((largest_acked, largest_unacked)) =
self.largest_acked.zip(self.largest_ack_eliciting_packet)
else {
return false;
};
if self.reordering_threshold > largest_acked {
return false;
}
let largest_reported = largest_acked - self.reordering_threshold + 1;
let Some(smallest_missing_unreported) =
dedup.smallest_missing_in_interval(largest_reported, largest_unacked)
else {
return false;
};
largest_unacked - smallest_missing_unreported >= self.reordering_threshold
}
}
}
pub(super) fn acks_sent(&mut self) {
self.immediate_ack_required = false;
self.ack_eliciting_since_last_ack_sent = 0;
self.non_ack_eliciting_since_last_ack_sent = 0;
self.earliest_ack_eliciting_since_last_ack_sent = None;
self.largest_acked = self.largest_ack_eliciting_packet;
}
pub(super) fn insert_one(&mut self, packet: u64, now: Instant) {
self.ranges.insert_one(packet);
if self.largest_packet.is_none_or(|(pn, _)| packet > pn) {
self.largest_packet = Some((packet, now));
}
if self.ranges.len() > MAX_ACK_BLOCKS {
self.ranges.pop_min();
}
}
pub(super) fn subtract_below(&mut self, max: u64) {
self.ranges.remove(0..(max + 1));
}
pub(super) fn ranges(&self) -> &ArrayRangeSet {
&self.ranges
}
pub(super) fn maybe_ack_non_eliciting(&mut self) {
const LAZY_ACK_THRESHOLD: u64 = 10;
if self.non_ack_eliciting_since_last_ack_sent > LAZY_ACK_THRESHOLD {
self.immediate_ack_required = true;
}
}
}
pub(super) struct PacketNumberFilter {
next_skipped_packet_number: u64,
prev_skipped_packet_number: Option<u64>,
exponent: u32,
}
impl PacketNumberFilter {
pub(super) fn new(rng: &mut (impl CryptoRng + ?Sized)) -> Self {
let exponent = 6;
Self {
next_skipped_packet_number: rng.random_range(0..2u64.saturating_pow(exponent)),
prev_skipped_packet_number: None,
exponent,
}
}
#[cfg(test)]
pub(super) fn disabled() -> Self {
Self {
next_skipped_packet_number: u64::MAX,
prev_skipped_packet_number: None,
exponent: u32::MAX,
}
}
pub(super) fn skip_pn(&mut self, n: u64, rng: &mut (impl CryptoRng + ?Sized)) -> bool {
if n != self.next_skipped_packet_number {
return false;
}
trace!("skipping pn {n}");
self.prev_skipped_packet_number = Some(self.next_skipped_packet_number);
let next_exponent = self.exponent.saturating_add(1);
self.next_skipped_packet_number = rng
.random_range(2u64.saturating_pow(self.exponent)..2u64.saturating_pow(next_exponent));
self.exponent = next_exponent;
true
}
}
const MAX_ACK_BLOCKS: usize = 64;
#[cfg(test)]
mod test {
use rand::Rng;
use rand::seq::SliceRandom;
use crate::token::ResetToken;
use crate::{ConnectionIdGenerator, RandomConnectionIdGenerator};
use super::*;
#[test]
fn sanity() {
let mut dedup = Dedup::new();
assert!(!dedup.insert(0));
assert_eq!(dedup.next, 1);
assert_eq!(dedup.window, 0b1);
assert!(dedup.insert(0));
assert_eq!(dedup.next, 1);
assert_eq!(dedup.window, 0b1);
assert!(!dedup.insert(1));
assert_eq!(dedup.next, 2);
assert_eq!(dedup.window, 0b11);
assert!(!dedup.insert(2));
assert_eq!(dedup.next, 3);
assert_eq!(dedup.window, 0b111);
assert!(!dedup.insert(4));
assert_eq!(dedup.next, 5);
assert_eq!(dedup.window, 0b11110);
assert!(!dedup.insert(7));
assert_eq!(dedup.next, 8);
assert_eq!(dedup.window, 0b1111_0100);
assert!(dedup.insert(4));
assert!(!dedup.insert(3));
assert_eq!(dedup.next, 8);
assert_eq!(dedup.window, 0b1111_1100);
assert!(!dedup.insert(6));
assert_eq!(dedup.next, 8);
assert_eq!(dedup.window, 0b1111_1101);
assert!(!dedup.insert(5));
assert_eq!(dedup.next, 8);
assert_eq!(dedup.window, 0b1111_1111);
}
#[test]
fn happypath() {
let mut dedup = Dedup::new();
for i in 0..(2 * WINDOW_SIZE) {
assert!(!dedup.insert(i));
for j in 0..=i {
assert!(dedup.insert(j));
}
}
}
#[test]
fn jump() {
let mut dedup = Dedup::new();
dedup.insert(2 * WINDOW_SIZE);
assert!(dedup.insert(WINDOW_SIZE));
assert_eq!(dedup.next, 2 * WINDOW_SIZE + 1);
assert_eq!(dedup.window, 0);
assert!(!dedup.insert(WINDOW_SIZE + 1));
assert_eq!(dedup.next, 2 * WINDOW_SIZE + 1);
assert_eq!(dedup.window, 1 << (WINDOW_SIZE - 2));
}
#[test]
fn dedup_has_missing() {
let mut dedup = Dedup::new();
dedup.insert(0);
assert!(!dedup.missing_in_interval(0, 0));
dedup.insert(1);
assert!(!dedup.missing_in_interval(0, 1));
dedup.insert(3);
assert!(dedup.missing_in_interval(1, 3));
dedup.insert(4);
assert!(!dedup.missing_in_interval(3, 4));
assert!(dedup.missing_in_interval(0, 4));
dedup.insert(2);
assert!(!dedup.missing_in_interval(0, 4));
}
#[test]
fn dedup_outside_of_window_has_missing() {
let mut dedup = Dedup::new();
for i in 0..140 {
dedup.insert(i);
}
assert!(!dedup.missing_in_interval(0, 4));
dedup.insert(160);
assert!(!dedup.missing_in_interval(0, 4));
assert!(!dedup.missing_in_interval(0, 140));
assert!(dedup.missing_in_interval(0, 160));
}
#[test]
fn dedup_smallest_missing() {
let mut dedup = Dedup::new();
dedup.insert(0);
assert_eq!(dedup.smallest_missing_in_interval(0, 0), None);
dedup.insert(1);
assert_eq!(dedup.smallest_missing_in_interval(0, 1), None);
dedup.insert(5);
dedup.insert(7);
assert_eq!(dedup.smallest_missing_in_interval(0, 7), Some(2));
assert_eq!(dedup.smallest_missing_in_interval(5, 7), Some(6));
dedup.insert(2);
assert_eq!(dedup.smallest_missing_in_interval(1, 7), Some(3));
dedup.insert(170);
dedup.insert(172);
dedup.insert(300);
assert_eq!(dedup.smallest_missing_in_interval(170, 172), None);
dedup.insert(500);
assert_eq!(dedup.smallest_missing_in_interval(0, 500), Some(372));
assert_eq!(dedup.smallest_missing_in_interval(0, 373), Some(372));
assert_eq!(dedup.smallest_missing_in_interval(0, 372), None);
}
#[test]
fn pending_acks_first_packet_is_not_considered_reordered() {
let mut acks = PendingAcks::new();
let mut dedup = Dedup::new();
dedup.insert(0);
acks.packet_received(Instant::now(), 0, true, &dedup);
assert!(!acks.immediate_ack_required);
}
#[test]
fn pending_acks_after_immediate_ack_set() {
let mut acks = PendingAcks::new();
let mut dedup = Dedup::new();
dedup.insert(0);
let now = Instant::now();
acks.insert_one(0, now);
acks.packet_received(now, 0, true, &dedup);
assert!(!acks.ranges.is_empty());
assert!(!acks.can_send());
acks.set_immediate_ack_required();
assert!(acks.can_send());
}
#[test]
fn pending_acks_ack_delay() {
let mut acks = PendingAcks::new();
let mut dedup = Dedup::new();
let t1 = Instant::now();
let t2 = t1 + Duration::from_millis(2);
let t3 = t2 + Duration::from_millis(5);
assert_eq!(acks.ack_delay(t1), Duration::from_millis(0));
assert_eq!(acks.ack_delay(t2), Duration::from_millis(0));
assert_eq!(acks.ack_delay(t3), Duration::from_millis(0));
dedup.insert(0);
acks.insert_one(0, t1);
acks.packet_received(t1, 0, true, &dedup);
assert_eq!(acks.ack_delay(t1), Duration::from_millis(0));
assert_eq!(acks.ack_delay(t2), Duration::from_millis(2));
assert_eq!(acks.ack_delay(t3), Duration::from_millis(7));
dedup.insert(3);
acks.insert_one(3, t2);
acks.packet_received(t2, 3, true, &dedup);
assert_eq!(acks.ack_delay(t2), Duration::from_millis(0));
assert_eq!(acks.ack_delay(t3), Duration::from_millis(5));
dedup.insert(2);
acks.insert_one(2, t3);
acks.packet_received(t3, 2, true, &dedup);
assert_eq!(acks.ack_delay(t3), Duration::from_millis(5));
}
#[test]
fn sent_packet_size() {
assert!(std::mem::size_of::<SentPacket>() <= 128);
}
#[test]
fn pending_new_cids() {
#[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))]
use aws_lc_rs::hmac;
#[cfg(feature = "ring")]
use ring::hmac;
let mut cid_generator = RandomConnectionIdGenerator::new(8);
let mut reset_key = [0; 64];
rand::rng().fill_bytes(&mut reset_key);
let hmac = hmac::Key::new(hmac::HMAC_SHA256, &reset_key);
let cid_a = cid_generator.generate_cid();
let a = IssuedCid {
path_id: PathId::ZERO,
sequence: 1,
id: cid_a,
reset_token: ResetToken::new(&hmac, cid_a),
};
let cid_b = cid_generator.generate_cid();
let b = IssuedCid {
path_id: PathId::ZERO,
sequence: 2,
id: cid_b,
reset_token: ResetToken::new(&hmac, cid_b),
};
let cid_c = cid_generator.generate_cid();
let c = IssuedCid {
path_id: PathId(1),
sequence: 1,
id: cid_c,
reset_token: ResetToken::new(&hmac, cid_c),
};
let mut pending_cids = PendingNewCids::default();
for _ in 0..9 {
let mut input = vec![a, b, c];
input.shuffle(&mut rand::rng());
for cid in input {
pending_cids.push(cid);
}
assert_eq!(pending_cids.pop().map(|i| i.id), Some(a.id));
assert_eq!(pending_cids.pop().map(|i| i.id), Some(b.id));
assert_eq!(pending_cids.pop().map(|i| i.id), Some(c.id));
assert!(pending_cids.pop().is_none());
}
}
}