use crate::{packet::number::PacketNumber, varint::VarInt};
use core::mem;
#[derive(Clone, Default, Debug)]
pub struct SlidingWindow {
window: Window,
right_edge: Option<PacketNumber>,
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum SlidingWindowError {
Duplicate,
TooOld,
}
type Window = u128;
const WINDOW_WIDTH: u64 = 1 + mem::size_of::<Window>() as u64 * 8;
#[derive(Debug, PartialEq, Eq)]
enum WindowPosition {
Left,
Right(u64),
RightEdge,
Within(u64),
Empty,
}
#[derive(Default, Clone)]
pub struct EvictedSet {
window: Window,
right_edge: PacketNumber,
}
impl core::fmt::Debug for EvictedSet {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_set().entries(self.clone()).finish()
}
}
impl PartialEq for EvictedSet {
fn eq(&self, other: &Self) -> bool {
self.clone().eq(other.clone())
}
}
impl Iterator for EvictedSet {
type Item = PacketNumber;
fn next(&mut self) -> Option<PacketNumber> {
loop {
if self.window == 0 {
return None;
}
let shift = self.window.leading_zeros() + 1;
self.right_edge = PacketNumber::from_varint(
PacketNumber::as_varint(self.right_edge) + VarInt::from_u32(shift),
self.right_edge.space(),
);
if shift == Window::BITS {
self.window = 0;
} else {
self.window <<= shift;
}
if let Some(left_edge) = PacketNumber::as_varint(self.right_edge)
.checked_sub(VarInt::from_u32(WINDOW_WIDTH as u32))
{
return Some(PacketNumber::from_varint(
left_edge,
self.right_edge.space(),
));
} else {
continue;
}
}
}
}
impl SlidingWindow {
pub fn insert(&mut self, packet_number: PacketNumber) -> Result<(), SlidingWindowError> {
self.insert_with_evicted(packet_number).map(|_| ())
}
pub fn insert_with_evicted(
&mut self,
packet_number: PacketNumber,
) -> Result<EvictedSet, SlidingWindowError> {
#[cfg(debug_assertions)]
let initial = self.clone();
let res = self.insert_with_evicted_inner(packet_number);
#[cfg(debug_assertions)]
self.check_insert_result(packet_number, initial, &res);
res
}
fn insert_with_evicted_inner(
&mut self,
packet_number: PacketNumber,
) -> Result<EvictedSet, SlidingWindowError> {
match self.window_position(packet_number) {
WindowPosition::Left => Err(SlidingWindowError::TooOld),
WindowPosition::RightEdge => Err(SlidingWindowError::Duplicate),
WindowPosition::Right(delta) => {
let removed = if delta < WINDOW_WIDTH {
let removed_mask = if delta == 128 {
u128::MAX
} else {
!u128::MAX.wrapping_shr(delta as u32)
};
let removed = !self.window & removed_mask;
self.window = self.window.checked_shl(delta as u32).unwrap_or(0);
self.window |= 1 << (delta - 1);
removed
} else {
let removed = self.window;
self.window = 0;
!removed
};
if let Some(prev_right_edge) = self.right_edge.replace(packet_number) {
Ok(EvictedSet {
window: removed,
right_edge: prev_right_edge,
})
} else {
assert!(removed == 0);
Ok(EvictedSet::default())
}
}
WindowPosition::Within(delta) => {
let mask = 1 << (delta - 1); let duplicate = self.window & mask != 0;
self.window |= mask;
if duplicate {
Err(SlidingWindowError::Duplicate)
} else {
Ok(EvictedSet::default())
}
}
WindowPosition::Empty => {
self.right_edge = Some(packet_number);
Ok(EvictedSet::default())
}
}
}
#[cfg_attr(not(debug_assertions), allow(dead_code))]
fn check_insert_result(
&self,
packet_number: PacketNumber,
initial: Self,
res: &Result<EvictedSet, SlidingWindowError>,
) {
let evicted = match res {
Ok(evicted) => evicted,
Err(_) => {
assert_eq!(self.window, initial.window);
assert_eq!(self.right_edge, initial.right_edge);
return;
}
};
{
for pn in evicted.clone() {
assert_eq!(initial.check(pn), Ok(()), "{pn:?}");
assert_eq!(self.check(pn), Err(SlidingWindowError::TooOld), "{pn:?}");
}
}
for pn in initial
.right_edge
.map_or(0, |e| e.as_u64())
.saturating_sub(WINDOW_WIDTH)
..initial.right_edge.map_or(WINDOW_WIDTH, |e| e.as_u64())
{
let pn = PacketNumber::from_varint(VarInt::new(pn).unwrap(), packet_number.space());
match self.check(pn) {
Ok(()) => assert!(evicted.clone().all(|e| e != pn)),
Err(SlidingWindowError::TooOld) => {
if initial.check(pn).is_ok()
&& initial.window_position(pn) != WindowPosition::Empty
{
assert!(
evicted.clone().any(|e| e == pn),
"{pn:?} expected in evicted after insert of {packet_number:?}"
);
}
}
Err(SlidingWindowError::Duplicate) => {
assert!(evicted.clone().all(|e| e != pn))
}
}
}
}
pub fn check(&self, packet_number: PacketNumber) -> Result<(), SlidingWindowError> {
match self.window_position(packet_number) {
WindowPosition::Left => Err(SlidingWindowError::TooOld),
WindowPosition::RightEdge => Err(SlidingWindowError::Duplicate),
WindowPosition::Right(_) | WindowPosition::Empty => Ok(()),
WindowPosition::Within(delta) => {
let mask = 1 << (delta - 1); if self.window & mask != 0 {
Err(SlidingWindowError::Duplicate)
} else {
Ok(())
}
}
}
}
fn window_position(&self, packet_number: PacketNumber) -> WindowPosition {
if let Some(right_edge) = self.right_edge {
match right_edge.checked_distance(packet_number) {
Some(0) => WindowPosition::RightEdge,
Some(delta) if delta >= WINDOW_WIDTH => WindowPosition::Left,
Some(delta) => WindowPosition::Within(delta),
None => WindowPosition::Right(
packet_number
.checked_distance(right_edge)
.expect("packet_number must be greater than right_edge"),
),
}
} else {
WindowPosition::Empty
}
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::{packet::number::PacketNumberSpace, varint::VarInt};
use bolero::{check, generator::*};
use SlidingWindowError::*;
macro_rules! assert_window {
(
$window:expr, $to_insert:expr, $duplicate:expr, $expected_window:expr, $right_edge:expr
) => {{
assert_eq!($window.check($to_insert), $duplicate);
assert_eq!($window.insert($to_insert), $duplicate);
assert_eq!(
$window.window, $expected_window,
"Expected: {:b}, Actual: {:b}",
$expected_window, $window.window
);
assert_eq!($window.right_edge.unwrap(), $right_edge);
}};
}
#[test]
#[allow(clippy::cognitive_complexity)] fn insert() {
let space = PacketNumberSpace::ApplicationData;
let mut window = SlidingWindow::default();
let zero = space.new_packet_number(VarInt::from_u8(0));
let one = space.new_packet_number(VarInt::from_u8(1));
let two = space.new_packet_number(VarInt::from_u8(2));
let three = space.new_packet_number(VarInt::from_u8(3));
let four = space.new_packet_number(VarInt::from_u8(4));
let five = space.new_packet_number(VarInt::from_u8(5));
let six = space.new_packet_number(VarInt::from_u8(6));
let seven = space.new_packet_number(VarInt::from_u8(7));
let eight = space.new_packet_number(VarInt::from_u8(8));
let large = space.new_packet_number(VarInt::MAX);
assert_eq!(window.window, Window::default());
assert_eq!(window.right_edge, None);
assert_window!(window, zero, Ok(()), Window::default(), zero);
assert_window!(window, zero, Err(Duplicate), Window::default(), zero);
assert_window!(window, one, Ok(()), 0b1, one);
assert_window!(window, one, Err(Duplicate), 0b1, one);
assert_window!(window, two, Ok(()), 0b11, two);
assert_window!(window, five, Ok(()), 0b11100, five);
assert_window!(window, eight, Ok(()), 0b1110_0100, eight);
assert_window!(window, seven, Ok(()), 0b1110_0101, eight);
assert_window!(window, three, Ok(()), 0b1111_0101, eight);
assert_window!(window, six, Ok(()), 0b1111_0111, eight);
assert_window!(window, four, Ok(()), 0b1111_1111, eight);
assert_window!(window, seven, Err(Duplicate), 0b1111_1111, eight);
assert_window!(window, two, Err(Duplicate), 0b1111_1111, eight);
assert_window!(window, eight, Err(Duplicate), 0b1111_1111, eight);
assert_window!(window, large, Ok(()), Window::default(), large);
assert_window!(window, five, Err(TooOld), Window::default(), large);
}
#[test]
#[cfg_attr(miri, ignore)] fn incremental_insert() {
let mut window = SlidingWindow::default();
let space = PacketNumberSpace::ApplicationData;
for right_edge in 0..1000 {
let pn = space.new_packet_number(VarInt::from_u32(right_edge));
assert_eq!(window.check(pn), Ok(()));
assert_eq!(window.insert(pn), Ok(()));
assert_eq!(window.right_edge.unwrap(), pn);
for dup in 0..=right_edge {
let expected_error = if right_edge - dup < WINDOW_WIDTH as u32 {
Err(Duplicate)
} else {
Err(TooOld)
};
let dup_pn = space.new_packet_number(VarInt::from_u32(dup));
assert_eq!(window.check(dup_pn), expected_error);
assert_eq!(window.insert(dup_pn), expected_error);
}
}
}
#[test]
#[allow(clippy::cognitive_complexity)] fn insert_at_edge() {
let mut window = SlidingWindow::default();
let space = PacketNumberSpace::ApplicationData;
let zero = space.new_packet_number(VarInt::from_u8(0));
let window_width_minus_1 = space.new_packet_number(VarInt::new(WINDOW_WIDTH - 1).unwrap());
let window_width = window_width_minus_1.next().unwrap();
assert_window!(window, zero, Ok(()), Window::default(), zero);
assert_window!(
window,
window_width_minus_1,
Ok(()),
(1_u128) << 127,
window_width_minus_1
);
assert_window!(
window,
window_width_minus_1,
Err(Duplicate),
(1_u128) << 127,
window_width_minus_1
);
assert_window!(window, window_width, Ok(()), 0b1, window_width);
window = SlidingWindow::default();
assert_window!(window, zero, Ok(()), Window::default(), zero);
assert_window!(
window,
window_width,
Ok(()),
Window::default(),
window_width
);
assert_window!(
window,
window_width,
Err(Duplicate),
Window::default(),
window_width
);
}
#[test]
fn delta_larger_than_32_bits() {
let mut window = SlidingWindow::default();
let space = PacketNumberSpace::ApplicationData;
let zero = space.new_packet_number(VarInt::from_u8(0));
let large = space.new_packet_number(VarInt::new((1 << 32) + 1).unwrap());
assert_eq!(window.check(zero), Ok(()));
assert_eq!(window.insert(zero), Ok(()));
assert_eq!(window.check(large), Ok(()));
assert_eq!(window.insert(large), Ok(()));
assert_eq!(window.check(large), Err(Duplicate));
assert_eq!(window.insert(large), Err(Duplicate));
assert_eq!(window.window, 0b0);
}
#[test]
fn insert_into_empty() {
let pn = VarInt::from_u32(256);
let mut window = SlidingWindow::default();
let space = PacketNumberSpace::ApplicationData;
let packet_number = space.new_packet_number(pn);
assert!(window.insert(packet_number).is_ok());
}
#[test]
#[cfg_attr(kani, kani::proof, kani::unwind(130), kani::solver(kissat))]
#[cfg_attr(miri, ignore)] fn insert_test() {
let gen = produce::<(VarInt, VarInt)>().filter_gen(|(a, b)| a != b);
check!()
.with_generator(gen)
.cloned()
.for_each(|(pn, other_pn)| {
let mut window = SlidingWindow::default();
let space = PacketNumberSpace::ApplicationData;
let packet_number = space.new_packet_number(pn);
let other_packet_number = space.new_packet_number(other_pn);
assert!(window.insert(packet_number).is_ok());
assert_eq!(Err(Duplicate), window.check(packet_number));
assert_ne!(Err(Duplicate), window.check(other_packet_number));
});
}
#[test]
#[allow(clippy::cognitive_complexity)] fn insert_evicted() {
let space = PacketNumberSpace::ApplicationData;
let mut window = SlidingWindow::default();
let zero = space.new_packet_number(VarInt::from_u8(0));
let one = space.new_packet_number(VarInt::from_u8(1));
let two = space.new_packet_number(VarInt::from_u8(2));
let three = space.new_packet_number(VarInt::from_u8(3));
let four = space.new_packet_number(VarInt::from_u8(4));
let five = space.new_packet_number(VarInt::from_u8(5));
let six = space.new_packet_number(VarInt::from_u8(6));
let seven = space.new_packet_number(VarInt::from_u8(7));
let eight = space.new_packet_number(VarInt::from_u8(8));
let large = space.new_packet_number(VarInt::MAX);
assert!(window.insert(zero).is_ok());
assert!(window.insert(two).is_ok());
assert!(window.insert(four).is_ok());
assert!(window.insert(five).is_ok());
assert!(window.insert(seven).is_ok());
assert!(window.insert(eight).is_ok());
let mut evicted = window.insert_with_evicted(large).unwrap();
assert_eq!(evicted.next(), Some(one));
assert_eq!(evicted.next(), Some(three));
assert_eq!(evicted.next(), Some(six));
assert_eq!(evicted.next(), None);
}
}