use std::ops::Range;
use crate::{frame::NewConnectionId, ConnectionId, ResetToken};
type CidData = (ConnectionId, Option<ResetToken>);
#[derive(Debug)]
pub struct CidQueue {
buffer: [Option<CidData>; Self::LEN],
cursor: usize,
offset: u64,
}
impl CidQueue {
pub fn new(cid: ConnectionId) -> Self {
let mut buffer = [None; Self::LEN];
buffer[0] = Some((cid, None));
Self {
buffer,
cursor: 0,
offset: 0,
}
}
pub fn insert(
&mut self,
cid: NewConnectionId,
) -> Result<Option<(Range<u64>, ResetToken)>, InsertError> {
let index = match cid.sequence.checked_sub(self.offset) {
None => return Err(InsertError::Retired),
Some(x) => x,
};
let retired_count = cid.retire_prior_to.saturating_sub(self.offset);
if index >= Self::LEN as u64 + retired_count {
return Err(InsertError::ExceedsLimit);
}
if index == 0 && self.buffer[self.cursor].is_some() {
return Ok(None);
}
for i in 0..(retired_count.min(Self::LEN as u64) as usize) {
self.buffer[(self.cursor + i) % Self::LEN] = None;
}
let index = ((self.cursor as u64 + index) % Self::LEN as u64) as usize;
self.buffer[index] = Some((cid.id, Some(cid.reset_token)));
if retired_count == 0 {
return Ok(None);
}
self.cursor = ((self.cursor as u64 + retired_count) % Self::LEN as u64) as usize;
let (i, (_, token)) = self
.iter()
.next()
.expect("it is impossible to retire a CID without supplying a new one");
self.cursor = (self.cursor + i) % Self::LEN;
let orig_offset = self.offset;
self.offset = cid.retire_prior_to + i as u64;
Ok(Some((
orig_offset..self.offset.min(orig_offset + Self::LEN as u64),
token.expect("non-initial CID missing reset token"),
)))
}
pub fn next(&mut self) -> Option<(ResetToken, Range<u64>)> {
let (i, cid_data) = self.iter().nth(1)?;
self.buffer[self.cursor] = None;
let orig_offset = self.offset;
self.offset += i as u64;
self.cursor = (self.cursor + i) % Self::LEN;
let sequence = orig_offset + i as u64;
Some((cid_data.1.unwrap(), orig_offset..sequence))
}
fn iter(&self) -> impl Iterator<Item = (usize, CidData)> + '_ {
(0..Self::LEN).filter_map(move |step| {
let index = (self.cursor + step) % Self::LEN;
self.buffer[index].map(|cid_data| (step, cid_data))
})
}
pub fn update_initial_cid(&mut self, cid: ConnectionId) {
debug_assert_eq!(self.offset, 0);
self.buffer[self.cursor] = Some((cid, None));
}
pub fn active(&self) -> ConnectionId {
self.buffer[self.cursor].unwrap().0
}
pub fn active_seq(&self) -> u64 {
self.offset
}
pub const LEN: usize = 5;
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum InsertError {
Retired,
ExceedsLimit,
}
#[cfg(test)]
mod tests {
use super::*;
fn cid(sequence: u64, retire_prior_to: u64) -> NewConnectionId {
NewConnectionId {
sequence,
id: ConnectionId::new(&[0xAB; 8]),
reset_token: ResetToken::from([0xCD; crate::RESET_TOKEN_SIZE]),
retire_prior_to,
}
}
fn initial_cid() -> ConnectionId {
ConnectionId::new(&[0xFF; 8])
}
#[test]
fn next_dense() {
let mut q = CidQueue::new(initial_cid());
assert!(q.next().is_none());
assert!(q.next().is_none());
for i in 1..CidQueue::LEN as u64 {
q.insert(cid(i, 0)).unwrap();
}
for i in 1..CidQueue::LEN as u64 {
let (_, retire) = q.next().unwrap();
assert_eq!(q.active_seq(), i);
assert_eq!(retire.end - retire.start, 1);
}
assert!(q.next().is_none());
}
#[test]
fn next_sparse() {
let mut q = CidQueue::new(initial_cid());
let seqs = (1..CidQueue::LEN as u64).filter(|x| x % 2 == 0);
for i in seqs.clone() {
q.insert(cid(i, 0)).unwrap();
}
for i in seqs {
let (_, retire) = q.next().unwrap();
dbg!(&retire);
assert_eq!(q.active_seq(), i);
assert_eq!(retire, (q.active_seq().saturating_sub(2))..q.active_seq());
}
assert!(q.next().is_none());
}
#[test]
fn wrap() {
let mut q = CidQueue::new(initial_cid());
for i in 1..CidQueue::LEN as u64 {
q.insert(cid(i, 0)).unwrap();
}
for _ in 1..(CidQueue::LEN as u64 - 1) {
q.next().unwrap();
}
for i in CidQueue::LEN as u64..(CidQueue::LEN as u64 + 3) {
q.insert(cid(i, 0)).unwrap();
}
for i in (CidQueue::LEN as u64 - 1)..(CidQueue::LEN as u64 + 3) {
q.next().unwrap();
assert_eq!(q.active_seq(), i);
}
assert!(q.next().is_none());
}
#[test]
fn retire_dense() {
let mut q = CidQueue::new(initial_cid());
for i in 1..CidQueue::LEN as u64 {
q.insert(cid(i, 0)).unwrap();
}
assert_eq!(q.active_seq(), 0);
assert_eq!(q.insert(cid(4, 2)).unwrap().unwrap().0, 0..2);
assert_eq!(q.active_seq(), 2);
assert_eq!(q.insert(cid(4, 2)), Ok(None));
for i in 2..(CidQueue::LEN as u64 - 1) {
let _ = q.next().unwrap();
assert_eq!(q.active_seq(), i + 1);
assert_eq!(q.insert(cid(i + 1, i + 1)), Ok(None));
}
assert!(q.next().is_none());
}
#[test]
fn retire_sparse() {
let mut q = CidQueue::new(initial_cid());
q.insert(cid(2, 0)).unwrap();
assert_eq!(q.insert(cid(3, 1)).unwrap().unwrap().0, 0..2,);
assert_eq!(q.active_seq(), 2);
}
#[test]
fn retire_many() {
let mut q = CidQueue::new(initial_cid());
q.insert(cid(2, 0)).unwrap();
assert_eq!(
q.insert(cid(1_000_000, 1_000_000)).unwrap().unwrap().0,
0..CidQueue::LEN as u64,
);
assert_eq!(q.active_seq(), 1_000_000);
}
#[test]
fn insert_limit() {
let mut q = CidQueue::new(initial_cid());
assert_eq!(q.insert(cid(CidQueue::LEN as u64 - 1, 0)), Ok(None));
assert_eq!(
q.insert(cid(CidQueue::LEN as u64, 0)),
Err(InsertError::ExceedsLimit)
);
}
#[test]
fn insert_duplicate() {
let mut q = CidQueue::new(initial_cid());
q.insert(cid(0, 0)).unwrap();
q.insert(cid(0, 0)).unwrap();
}
#[test]
fn insert_retired() {
let mut q = CidQueue::new(initial_cid());
assert_eq!(
q.insert(cid(0, 0)),
Ok(None),
"reinserting active CID succeeds"
);
assert!(q.next().is_none(), "active CID isn't requeued");
q.insert(cid(1, 0)).unwrap();
q.next().unwrap();
assert_eq!(
q.insert(cid(0, 0)),
Err(InsertError::Retired),
"previous active CID is already retired"
);
}
#[test]
fn retire_then_insert_next() {
let mut q = CidQueue::new(initial_cid());
for i in 1..CidQueue::LEN as u64 {
q.insert(cid(i, 0)).unwrap();
}
q.next().unwrap();
q.insert(cid(CidQueue::LEN as u64, 0)).unwrap();
assert_eq!(
q.insert(cid(CidQueue::LEN as u64 + 1, 0)),
Err(InsertError::ExceedsLimit)
);
}
#[test]
fn always_valid() {
let mut q = CidQueue::new(initial_cid());
assert!(q.next().is_none());
assert_eq!(q.active(), initial_cid());
assert_eq!(q.active_seq(), 0);
}
}