use crate::packet::number::{PacketNumber, PacketNumberRange, PacketNumberSpace};
use alloc::{boxed::Box, vec::Vec};
use core::fmt;
#[derive(Clone)]
pub struct Map<V> {
values: Box<[Option<V>]>,
start: PacketNumber,
end: PacketNumber,
index: usize,
}
impl<V: fmt::Debug> fmt::Debug for Map<V> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_map().entries(self.iter()).finish()
}
}
const DEFAULT_CAPACITY: usize = 8;
impl<V> Default for Map<V> {
fn default() -> Self {
let base = PacketNumberSpace::Initial.new_packet_number(0u8.into());
let mut values = Vec::with_capacity(DEFAULT_CAPACITY);
while values.len() != values.capacity() {
values.push(None);
}
let values = values.into_boxed_slice();
let index = values.len();
Self {
values,
start: base,
end: base,
index,
}
}
}
impl<V> Map<V> {
pub fn insert(&mut self, packet_number: PacketNumber, value: V) {
if self.is_empty() {
self.start = packet_number;
self.end = packet_number;
self.values[0] = Some(value);
self.index = 0;
self.invariants();
return;
}
debug_assert!(
packet_number > self.start && packet_number > self.end,
"packet numbers should be monotonic: {:?} > {:?} && {:?}",
packet_number,
self.start,
self.end
);
let distance = (packet_number.as_u64() - self.start.as_u64()) as usize;
let index = if distance >= self.values.len() {
self.resize(distance);
distance
} else {
(self.index + distance) % self.values.len()
};
self.values[index] = Some(value);
self.end = packet_number;
self.invariants();
}
pub fn insert_or_update<F: FnOnce(&mut V)>(
&mut self,
packet_number: PacketNumber,
value: V,
update: F,
) {
if self.is_empty() {
self.start = packet_number;
self.end = packet_number;
self.values[0] = Some(value);
self.index = 0;
self.invariants();
return;
}
debug_assert!(
packet_number >= self.start,
"packet numbers should be monotonic: {:?} > {:?}",
packet_number,
self.start,
);
let distance = (packet_number.as_u64() - self.start.as_u64()) as usize;
let index = if distance >= self.values.len() {
self.resize(distance);
distance
} else {
(self.index + distance) % self.values.len()
};
let entry = &mut self.values[index];
if let Some(prev) = entry.as_mut() {
update(prev);
} else {
*entry = Some(value);
}
self.end = self.end.max(packet_number);
self.invariants();
}
#[inline]
pub fn get(&self, packet_number: PacketNumber) -> Option<&V> {
let index = self.pn_index(packet_number)?;
self.values[index].as_ref()
}
pub fn remove(&mut self, packet_number: PacketNumber) -> Option<V> {
let index = self.pn_index(packet_number)?;
let info = self.values[index].take()?;
match (self.start == packet_number, self.end == packet_number) {
(true, true) => {
self.logical_clear();
}
(true, false) => {
self.set_start(packet_number.next().unwrap());
}
(false, true) => {
self.set_end(packet_number.prev().unwrap());
}
(false, false) => {
}
}
self.invariants();
Some(info)
}
#[inline]
pub fn remove_range(&mut self, range: PacketNumberRange) -> RemoveIter<'_, V> {
RemoveIter::new(self, range)
}
#[inline]
pub fn get_range(&self) -> PacketNumberRange {
PacketNumberRange::new(self.start, self.end)
}
#[inline]
pub fn iter(&self) -> Iter<'_, V> {
Iter::new(self)
}
#[inline]
pub fn iter_mut(&mut self) -> IterMut<'_, V> {
IterMut::new(self)
}
#[inline]
pub fn is_empty(&self) -> bool {
self.index == self.values.len()
}
#[inline]
pub fn clear(&mut self) {
if self.is_empty() {
return;
}
for pn in PacketNumberRange::new(self.start, self.end) {
if let Some(index) = self.pn_index(pn) {
self.values[index] = None;
}
}
self.logical_clear();
self.invariants();
}
fn logical_clear(&mut self) {
self.index = self.values.len();
}
#[cfg(not(test))]
#[inline(always)]
fn invariants(&self) {}
#[cfg(test)]
fn invariants(&self) {
if self.is_empty() {
for (idx, slot) in self.values.iter().enumerate() {
assert!(
slot.is_none(),
"map is_empty() but slot {} has a value; index={}, len={}",
idx,
self.index,
self.values.len()
);
}
}
assert!(
self.index <= self.values.len(),
"index out of bounds: index={}, len={}",
self.index,
self.values.len()
);
if !self.is_empty() {
assert!(
self.values[self.index].is_some(),
"map not empty but index slot is None; index={}, start={:?}, end={:?}",
self.index,
self.start,
self.end
);
}
}
#[inline]
fn pn_index(&self, packet_number: PacketNumber) -> Option<usize> {
if self.is_empty() {
return None;
}
if packet_number > self.end {
return None;
}
let offset = packet_number.checked_distance(self.start)?;
let index = self.index.checked_add(offset as usize)?;
let index = index % self.values.len();
Some(index)
}
#[inline]
fn set_start(&mut self, packet_number: PacketNumber) {
debug_assert!(!self.is_empty());
debug_assert!(packet_number >= self.start);
debug_assert!(packet_number <= self.end);
for packet_number in PacketNumberRange::new(packet_number, self.end) {
if self.get(packet_number).is_some() {
let index = self
.pn_index(packet_number)
.expect("packet should be in bounds");
self.index = index;
self.start = packet_number;
debug_assert!(self.start <= self.end);
debug_assert_eq!(self.pn_index(packet_number), Some(index));
return;
}
}
unreachable!("could not find an occupied entry; map should be empty");
}
#[inline]
fn set_end(&mut self, packet_number: PacketNumber) {
debug_assert!(!self.is_empty());
debug_assert!(packet_number >= self.start);
debug_assert!(packet_number <= self.end);
for packet_number in PacketNumberRange::new(self.start, packet_number).rev() {
if self.get(packet_number).is_some() {
self.end = packet_number;
debug_assert!(self.start <= self.end);
return;
}
}
unreachable!("could not find an occupied entry; map should be empty");
}
fn resize(&mut self, len: usize) {
let mut new_len = self.values.len();
loop {
new_len *= 2;
if len < new_len {
break;
}
}
let mut values = Vec::with_capacity(new_len);
values.extend(self.values[self.index..].iter_mut().map(|v| v.take()));
values.extend(self.values[..self.index].iter_mut().map(|v| v.take()));
while values.len() != values.capacity() {
values.push(None);
}
self.index = 0;
self.values = values.into_boxed_slice();
}
}
macro_rules! impl_iter {
($name:ident, [$($lt:tt)*], $split:ident) => {
#[derive(Debug)]
pub struct $name<'a, V> {
iter: core::iter::Chain<core::slice::$name<'a, Option<V>>, core::slice::$name<'a, Option<V>>>,
packet_number: Option<PacketNumber>,
remaining: usize,
}
impl<'a, V> $name<'a, V> {
#[inline]
fn new(packets: $($lt)* Map<V>) -> Self {
let start = packets.start;
let end = packets.end;
let index = packets.index;
let capacity = packets.values.len();
let is_empty = packets.is_empty();
let (tail, head) = packets.values.$split(index);
let iter = head.into_iter().chain(tail);
let mut iter = Self {
iter,
packet_number: Some(start),
remaining: 0,
};
if is_empty {
return iter;
}
iter.remaining = (end.as_u64() - start.as_u64()) as usize;
iter.remaining += 1;
debug_assert!(iter.remaining <= capacity);
iter
}
}
impl<'a, V> Iterator for $name<'a, V> {
type Item = (PacketNumber, $($lt)* V);
#[inline]
fn next(&mut self) -> Option<Self::Item> {
while self.remaining > 0 {
self.remaining -= 1;
let packet_number = self.packet_number?;
self.packet_number = packet_number.next();
if let Some(Some(info)) = self.iter.next() {
return Some((packet_number, info));
}
}
None
}
}
};
}
impl_iter!(Iter, [&'a], split_at);
impl_iter!(IterMut, [&'a mut], split_at_mut);
#[derive(Debug)]
pub struct RemoveIter<'a, V> {
packets: &'a mut Map<V>,
packet_number: Option<PacketNumber>,
index: usize,
remaining: usize,
}
impl<'a, V> RemoveIter<'a, V> {
#[inline]
fn new(packets: &'a mut Map<V>, range: PacketNumberRange) -> Self {
let mut start = packets.start;
let mut end = packets.end;
let index = packets.index;
let mut iter = Self {
packets,
packet_number: None,
index,
remaining: 0,
};
if iter.packets.is_empty() {
return iter;
}
if range.end() < start || range.start() > end {
return iter;
}
use core::cmp::Ordering::*;
match (range.start().cmp(&start), range.end().cmp(&end)) {
(Less, Equal) | (Less, Greater) | (Equal, Greater) | (Equal, Equal) => {
iter.packets.logical_clear();
}
(Less, Less) | (Equal, Less) => {
end = range.end();
iter.packets.set_start(end.next().unwrap());
}
(Greater, Greater) | (Greater, Equal) => {
start = range.start();
iter.index = iter
.packets
.pn_index(start)
.expect("packet number bounds have already been checked");
iter.packets.set_end(start.prev().unwrap());
}
(Greater, Less) => {
start = range.start();
end = range.end();
iter.index = iter
.packets
.pn_index(start)
.expect("packet number bounds have already been checked");
}
}
iter.packet_number = Some(start);
iter.remaining = (end.as_u64() - start.as_u64()) as usize;
iter.remaining += 1;
debug_assert!(iter.remaining <= iter.packets.values.len());
iter
}
}
impl<V> Iterator for RemoveIter<'_, V> {
type Item = (PacketNumber, V);
#[inline]
fn next(&mut self) -> Option<Self::Item> {
while self.remaining > 0 {
self.remaining -= 1;
let packet_number = self.packet_number?;
self.packet_number = packet_number.next();
let index = self.index;
self.index = (index + 1) % self.packets.values.len();
if let Some(info) = self.packets.values[index].take() {
return Some((packet_number, info));
}
}
None
}
}
impl<V> Drop for RemoveIter<'_, V> {
fn drop(&mut self) {
while self.next().is_some() {}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
packet::number::{PacketNumber, PacketNumberRange, PacketNumberSpace},
varint::VarInt,
};
use alloc::collections::BTreeMap;
use bolero::{check, generator::*};
type TestMap = Map<u64>;
#[test]
fn insert_get_range() {
let mut sent_packets = TestMap::default();
let packet_number_1 = PacketNumberSpace::Initial.new_packet_number(VarInt::from_u8(1));
let packet_number_2 = packet_number_1.next().unwrap();
let packet_number_3 = packet_number_2.next().unwrap();
sent_packets.insert(packet_number_1, 1);
sent_packets.insert(packet_number_2, 2);
assert!(sent_packets.get(packet_number_1).is_some());
assert!(sent_packets.get(packet_number_2).is_some());
assert!(sent_packets.get(packet_number_3).is_none());
assert_eq!(sent_packets.get(packet_number_1).unwrap(), &1);
assert_eq!(sent_packets.get(packet_number_2).unwrap(), &2);
sent_packets.insert(packet_number_3, 3);
assert!(sent_packets.get(packet_number_3).is_some());
assert_eq!(sent_packets.get(packet_number_3).unwrap(), &3);
for (packet_number, sent_packet_info) in sent_packets.iter() {
assert_eq!(sent_packets.get(packet_number).unwrap(), sent_packet_info);
}
}
#[test]
fn remove() {
let mut sent_packets = TestMap::default();
let packet_number = PacketNumberSpace::Initial.new_packet_number(VarInt::from_u8(1));
sent_packets.insert(packet_number, 1);
assert!(sent_packets.get(packet_number).is_some());
assert_eq!(sent_packets.get(packet_number).unwrap(), &1);
assert_eq!(Some(1), sent_packets.remove(packet_number));
assert!(sent_packets.get(packet_number).is_none());
assert_eq!(None, sent_packets.remove(packet_number));
}
#[test]
fn empty() {
let mut sent_packets = TestMap::default();
assert!(sent_packets.is_empty());
let packet_number = PacketNumberSpace::Initial.new_packet_number(VarInt::from_u8(1));
sent_packets.insert(packet_number, 1);
assert!(!sent_packets.is_empty());
}
#[test]
#[should_panic]
fn wrong_packet_space_on_insert() {
let mut sent_packets = new_sent_packets(PacketNumberSpace::Initial);
let packet_number =
PacketNumberSpace::ApplicationData.new_packet_number(VarInt::from_u8(1));
sent_packets.insert(packet_number, 1);
}
#[test]
#[should_panic]
fn wrong_packet_space_on_get() {
let sent_packets = new_sent_packets(PacketNumberSpace::Initial);
let packet_number =
PacketNumberSpace::ApplicationData.new_packet_number(VarInt::from_u8(1));
sent_packets.get(packet_number);
}
#[test]
#[should_panic]
fn wrong_packet_space_on_remove_range() {
let mut sent_packets = new_sent_packets(PacketNumberSpace::Initial);
let packet_number_start =
PacketNumberSpace::ApplicationData.new_packet_number(VarInt::from_u8(1));
let packet_number_end =
PacketNumberSpace::ApplicationData.new_packet_number(VarInt::from_u8(2));
sent_packets
.remove_range(PacketNumberRange::new(
packet_number_start,
packet_number_end,
))
.for_each(|_| ());
}
#[test]
#[should_panic]
fn wrong_packet_space_on_remove() {
let mut sent_packets = new_sent_packets(PacketNumberSpace::Initial);
let packet_number =
PacketNumberSpace::ApplicationData.new_packet_number(VarInt::from_u8(1));
sent_packets.remove(packet_number);
}
fn new_sent_packets(space: PacketNumberSpace) -> TestMap {
let mut sent_packets = TestMap::default();
let packet_number = space.new_packet_number(VarInt::from_u8(0));
sent_packets.insert(packet_number, 0);
sent_packets
}
#[derive(Clone, Copy, Debug, TypeGenerator)]
enum Operation {
Insert,
Skip,
Remove(VarInt),
RemoveRange(VarInt, VarInt),
Clear,
}
fn model(ops: &[Operation]) {
let mut current = PacketNumberSpace::ApplicationData.new_packet_number(VarInt::from_u8(0));
#[derive(Debug, Default)]
struct Model {
subject: TestMap,
oracle: BTreeMap<PacketNumber, u64>,
}
impl Model {
pub fn insert(&mut self, packet_number: PacketNumber) {
let value = packet_number.as_u64();
self.subject.insert(packet_number, value);
self.oracle.insert(packet_number, value);
self.check_consistency();
}
pub fn remove(&mut self, packet_number: PacketNumber) {
assert_eq!(
self.subject.remove(packet_number),
self.oracle.remove(&packet_number)
);
self.check_consistency();
}
pub fn remove_range(&mut self, range: PacketNumberRange) {
let range = if self.subject.is_empty() {
PacketNumberRange::new(range.start(), range.start())
} else {
let start = range.start().max(self.subject.start);
let end = range.end().min(self.subject.end);
if start > end {
PacketNumberRange::new(start, start)
} else {
PacketNumberRange::new(start, end)
}
};
let actual: Vec<_> = self.subject.remove_range(range).collect();
let mut expected = vec![];
for pn in range {
if let Some(value) = self.oracle.remove(&pn) {
expected.push((pn, value));
}
}
assert_eq!(expected, actual);
self.check_consistency();
}
pub fn clear(&mut self) {
self.subject.clear();
self.oracle.clear();
self.check_consistency();
}
fn check_consistency(&self) {
let mut subject = self.subject.iter();
let mut oracle = self.oracle.iter();
loop {
match (subject.next(), oracle.next()) {
(Some(actual), Some((expected_pn, expected_info))) => {
assert_eq!((*expected_pn, expected_info), actual);
}
(None, None) => break,
(actual, expected) => {
panic!("expected: {expected:?}, actual: {actual:?}");
}
}
}
let actual_stored_count =
self.subject.values.iter().filter(|v| v.is_some()).count();
let expected_count = self.oracle.len();
assert_eq!(
actual_stored_count,
expected_count,
"Memory leak detected: {} values stored but only {} in oracle (leaked {} values)",
actual_stored_count,
expected_count,
actual_stored_count.saturating_sub(expected_count)
);
}
}
let mut model = Model::default();
for op in ops.iter().copied() {
match op {
Operation::Insert => {
model.insert(current);
current = current.next().unwrap();
}
Operation::Skip => {
current = current.next().unwrap();
}
Operation::Remove(pn) => {
let pn = PacketNumberSpace::ApplicationData.new_packet_number(pn);
model.remove(pn);
}
Operation::RemoveRange(start, end) => {
let (start, end) = if start > end {
(end, start)
} else {
(start, end)
};
let start = PacketNumberSpace::ApplicationData.new_packet_number(start);
let end = PacketNumberSpace::ApplicationData.new_packet_number(end);
let range = PacketNumberRange::new(start, end);
model.remove_range(range);
}
Operation::Clear => {
model.clear();
}
}
}
}
#[test]
fn differential_test() {
check!()
.with_type::<Vec<Operation>>()
.for_each(|ops| model(ops))
}
#[test]
#[cfg_attr(kani, kani::proof, kani::unwind(9), kani::solver(kissat))]
fn insert_value() {
check!().with_type().cloned().for_each(|pn| {
let space = PacketNumberSpace::ApplicationData;
let mut map = Map::default();
assert!(map.is_empty());
let pn = space.new_packet_number(pn);
map.insert(pn, ());
assert!(map.get(pn).is_some());
assert!(!map.is_empty());
});
}
#[test]
fn clear_actually_clears_values() {
let space = PacketNumberSpace::ApplicationData;
let mut map = Map::default();
for i in 0u8..5u8 {
let pn = space.new_packet_number(i.into());
map.insert(pn, i);
}
assert!(!map.is_empty());
assert_eq!(map.iter().count(), 5);
map.clear();
assert!(map.is_empty());
assert_eq!(map.iter().count(), 0);
for (idx, slot) in map.values.iter().enumerate() {
assert!(
slot.is_none(),
"After clear(), slot {} should be None but contains a value",
idx
);
}
let pn = space.new_packet_number(100u8.into());
map.insert(pn, 100);
assert!(!map.is_empty());
assert_eq!(map.get(pn), Some(&100));
}
}