use std::{
collections::VecDeque,
ops::Deref,
sync::{Arc, Mutex},
task::{Context, Poll, Waker},
};
use super::ConnectionId;
use crate::{
error::Error,
frame::{BeFrame, NewConnectionIdFrame, ReceiveFrame, RetireConnectionIdFrame, SendFrame},
token::ResetToken,
util::IndexDeque,
varint::{VarInt, VARINT_MAX},
};
#[derive(Debug)]
struct RemoteCids<RETIRED>
where
RETIRED: SendFrame<RetireConnectionIdFrame> + Clone,
{
cid_deque: IndexDeque<Option<(u64, ConnectionId, ResetToken)>, VARINT_MAX>,
ready_cells: IndexDeque<ArcCidCell<RETIRED>, VARINT_MAX>,
pending_cells: VecDeque<ArcCidCell<RETIRED>>,
active_cid_limit: u64,
cursor: u64,
retired_cids: RETIRED,
}
impl<RETIRED> RemoteCids<RETIRED>
where
RETIRED: SendFrame<RetireConnectionIdFrame> + Clone,
{
fn new(initial_dcid: ConnectionId, active_cid_limit: u64, retired_cids: RETIRED) -> Self {
let mut cid_deque = IndexDeque::default();
let reset_token = ResetToken::default();
cid_deque
.push_back(Some((0, initial_dcid, reset_token)))
.unwrap();
Self {
active_cid_limit,
cid_deque,
ready_cells: Default::default(),
pending_cells: Default::default(),
cursor: 0,
retired_cids,
}
}
fn revise_initial_dcid(&mut self, initial_dcid: ConnectionId) {
let first_dcid = self.cid_deque.get_mut(0).unwrap();
*first_dcid = Some((0, initial_dcid, ResetToken::default()));
if let Some(apply) = self.ready_cells.get_mut(0) {
apply.revise(initial_dcid);
}
}
fn recv_new_cid_frame(
&mut self,
frame: &NewConnectionIdFrame,
) -> Result<Option<ResetToken>, Error> {
let seq = frame.sequence.into_inner();
let retire_prior_to = frame.retire_prior_to.into_inner();
let active_len = seq.saturating_sub(retire_prior_to);
if active_len > self.active_cid_limit {
return Err(Error::new(
crate::error::ErrorKind::ConnectionIdLimit,
frame.frame_type(),
format!(
"{active_len} exceed active_cid_limit {}",
self.active_cid_limit
),
));
}
if frame.sequence < self.cid_deque.offset() {
return Ok(None);
}
let id = frame.id;
let token = frame.reset_token;
self.cid_deque.insert(seq, Some((seq, id, token))).unwrap();
self.retire_prior_to(retire_prior_to);
self.arrange_idle_cid();
Ok(Some(token))
}
#[doc(hidden)]
fn arrange_idle_cid(&mut self) {
loop {
let next_unalloced_cell = self.pending_cells.front();
if next_unalloced_cell.is_none() {
break;
}
let next_unalloced_cell = next_unalloced_cell.unwrap();
let mut guard = next_unalloced_cell.0.lock().unwrap();
if guard.is_retired {
drop(guard);
self.pending_cells.pop_front();
continue;
}
let next_unused_cid = self.cid_deque.get(self.cursor);
if let Some(Some((seq, cid, _))) = next_unused_cid {
guard.assign(*seq, *cid);
drop(guard);
let apply = self.pending_cells.pop_front().unwrap();
self.ready_cells
.push_back(apply)
.expect("Sequence of new connection ID should never exceed the limit");
self.cursor += 1;
} else {
break;
}
}
}
#[doc(hidden)]
fn retire_prior_to(&mut self, tomb_seq: u64) {
if tomb_seq <= self.ready_cells.offset() {
return;
}
_ = self.cid_deque.drain_to(tomb_seq);
self.cursor = self.cursor.max(tomb_seq);
if self.ready_cells.is_empty() {
self.retired_cids
.send_frame((self.ready_cells.offset()..tomb_seq).map(|seq| {
RetireConnectionIdFrame {
sequence: VarInt::from_u64(seq)
.expect("Sequence of connection id is very hard to exceed VARINT_MAX"),
}
}));
self.ready_cells.reset_offset(tomb_seq);
} else {
let actual_applied = self.ready_cells.largest();
let need_reassigned = actual_applied.min(tomb_seq);
for _ in self.ready_cells.offset()..need_reassigned {
let (_, cell) = self.ready_cells.pop_front().unwrap();
if cell.is_retired() {
continue;
}
self.pending_cells.push_back(cell);
}
if actual_applied < tomb_seq {
self.ready_cells.reset_offset(tomb_seq);
self.retired_cids
.send_frame(
(actual_applied..tomb_seq).map(|seq| RetireConnectionIdFrame {
sequence: VarInt::from_u64(seq).expect(
"Sequence of connection id is very hard to exceed VARINT_MAX",
),
}),
);
}
}
}
fn apply_dcid(&mut self) -> ArcCidCell<RETIRED> {
let cell = ArcCidCell::new(self.retired_cids.clone());
self.pending_cells.push_back(cell.clone());
self.arrange_idle_cid();
cell
}
}
#[derive(Debug, Clone)]
pub struct ArcRemoteCids<RETIRED>(Arc<Mutex<RemoteCids<RETIRED>>>)
where
RETIRED: SendFrame<RetireConnectionIdFrame> + Clone;
impl<RETIRED> ArcRemoteCids<RETIRED>
where
RETIRED: SendFrame<RetireConnectionIdFrame> + Clone,
{
pub fn new(initial_dcid: ConnectionId, active_cid_limit: u64, retired_cids: RETIRED) -> Self {
Self(Arc::new(Mutex::new(RemoteCids::new(
initial_dcid,
active_cid_limit,
retired_cids,
))))
}
pub fn revise_initial_dcid(&self, initial_dcid: ConnectionId) {
self.0.lock().unwrap().revise_initial_dcid(initial_dcid);
}
pub fn apply_dcid(&self) -> ArcCidCell<RETIRED> {
self.0.lock().unwrap().apply_dcid()
}
pub fn latest_dcid(&self) -> Option<ConnectionId> {
self.0
.lock()
.unwrap()
.cid_deque
.iter()
.rev()
.flatten()
.next()
.map(|(_, cid, _)| *cid)
}
}
impl<RETIRED> ReceiveFrame<NewConnectionIdFrame> for ArcRemoteCids<RETIRED>
where
RETIRED: SendFrame<RetireConnectionIdFrame> + Clone,
{
type Output = Option<ResetToken>;
fn recv_frame(&self, frame: &NewConnectionIdFrame) -> Result<Self::Output, Error> {
self.0.lock().unwrap().recv_new_cid_frame(frame)
}
}
#[derive(Debug)]
struct CidCell<RETIRED>
where
RETIRED: SendFrame<RetireConnectionIdFrame>,
{
retired_cids: RETIRED,
allocated_cids: VecDeque<(u64, ConnectionId)>,
waker: Option<Waker>,
is_retired: bool,
is_using: bool,
}
impl<RETIRED> CidCell<RETIRED>
where
RETIRED: SendFrame<RetireConnectionIdFrame> + Clone,
{
fn assign(&mut self, seq: u64, cid: ConnectionId) {
assert!(!self.is_retired);
self.allocated_cids.push_front((seq, cid));
if !self.is_using {
while self.allocated_cids.len() > 1 {
let (seq, _) = self.allocated_cids.pop_back().unwrap();
let sequence = VarInt::try_from(seq)
.expect("Sequence of connection id is very hard to exceed VARINT_MAX");
self.retired_cids
.send_frame([RetireConnectionIdFrame { sequence }]);
}
}
if let Some(waker) = self.waker.take() {
waker.wake();
}
}
fn revise(&mut self, dcid: ConnectionId) {
assert!(!self.is_retired);
assert!(!self.allocated_cids.is_empty());
self.allocated_cids[0].1 = dcid;
}
fn poll_borrow_cid(&mut self, cx: &mut Context<'_>) -> Poll<Option<ConnectionId>> {
if self.is_retired {
return Poll::Ready(None);
}
if self.allocated_cids.is_empty() {
self.waker = Some(cx.waker().clone());
Poll::Pending
} else {
let cid = self.allocated_cids[0].1;
self.is_using = true;
Poll::Ready(Some(cid))
}
}
fn renew(&mut self) {
assert!(self.is_using);
self.is_using = false;
while self.allocated_cids.len() > 1 {
let (seq, _) = self.allocated_cids.pop_back().unwrap();
let sequence = VarInt::try_from(seq)
.expect("Sequence of connection id is very hard to exceed VARINT_MAX");
self.retired_cids
.send_frame([RetireConnectionIdFrame { sequence }]);
}
}
fn retire(&mut self) {
if !self.is_retired {
self.is_retired = true;
while let Some((seq, _)) = self.allocated_cids.pop_front() {
let sequence = VarInt::try_from(seq)
.expect("Sequence of connection id is very hard to exceed VARINT_MAX");
self.retired_cids
.send_frame([RetireConnectionIdFrame { sequence }]);
}
if let Some(waker) = self.waker.take() {
waker.wake();
}
}
}
}
#[derive(Debug, Clone)]
pub struct ArcCidCell<RETIRED>(Arc<Mutex<CidCell<RETIRED>>>)
where
RETIRED: SendFrame<RetireConnectionIdFrame> + Clone;
impl<RETIRED> ArcCidCell<RETIRED>
where
RETIRED: SendFrame<RetireConnectionIdFrame> + Clone,
{
#[doc(hidden)]
fn new(retired_cids: RETIRED) -> Self {
Self(Arc::new(Mutex::new(CidCell {
retired_cids,
allocated_cids: VecDeque::with_capacity(2),
waker: None,
is_retired: false,
is_using: false,
})))
}
fn is_retired(&self) -> bool {
self.0.lock().unwrap().is_retired
}
fn revise(&self, dcid: ConnectionId) {
self.0.lock().unwrap().revise(dcid);
}
pub fn poll_borrow_cid(&self, cx: &mut Context<'_>) -> Poll<Option<BorrowedCid<RETIRED>>> {
self.0.lock().unwrap().poll_borrow_cid(cx).map(|opt| {
opt.map(|cid| BorrowedCid {
cid_cell: &self.0,
cid,
})
})
}
pub fn retire(&self) {
self.0.lock().unwrap().retire();
}
}
pub struct BorrowedCid<'a, RETIRED>
where
RETIRED: SendFrame<RetireConnectionIdFrame> + Clone,
{
cid: ConnectionId,
cid_cell: &'a Mutex<CidCell<RETIRED>>,
}
impl<RETIRED> Deref for BorrowedCid<'_, RETIRED>
where
RETIRED: SendFrame<RetireConnectionIdFrame> + Clone,
{
type Target = ConnectionId;
fn deref(&self) -> &Self::Target {
&self.cid
}
}
impl<RETIRED> Drop for BorrowedCid<'_, RETIRED>
where
RETIRED: SendFrame<RetireConnectionIdFrame> + Clone,
{
fn drop(&mut self) {
self.cid_cell.lock().unwrap().renew();
}
}
#[cfg(test)]
mod tests {
use deref_derive::Deref;
use super::*;
#[derive(Debug, Clone, Default, Deref)]
struct RetiredCids(Arc<Mutex<Vec<RetireConnectionIdFrame>>>);
impl SendFrame<RetireConnectionIdFrame> for RetiredCids {
fn send_frame<I: IntoIterator<Item = RetireConnectionIdFrame>>(&self, iter: I) {
self.0.lock().unwrap().extend(iter);
}
}
#[test]
fn test_remote_cids() {
let waker = futures::task::noop_waker();
let mut cx = std::task::Context::from_waker(&waker);
let initial_dcid = ConnectionId::random_gen(8);
let retired_cids = RetiredCids::default();
let mut remote_cids = RemoteCids::new(initial_dcid, 8, retired_cids);
let cid_apply0 = remote_cids.apply_dcid();
assert!(matches!(
cid_apply0.poll_borrow_cid(&mut cx),
Poll::Ready(Some(cid)) if *cid == initial_dcid
));
let cid_apply1 = remote_cids.apply_dcid();
assert!(cid_apply1.poll_borrow_cid(&mut cx).is_pending());
let cid = ConnectionId::random_gen(8);
let frame = NewConnectionIdFrame {
sequence: VarInt::from_u32(1),
retire_prior_to: VarInt::from_u32(0),
id: cid,
reset_token: ResetToken::random_gen(),
};
assert!(remote_cids.recv_new_cid_frame(&frame).is_ok());
assert_eq!(remote_cids.cid_deque.len(), 2);
assert!(matches!(
cid_apply1.poll_borrow_cid(&mut cx),
Poll::Ready(Some(r#ref)) if *r#ref == cid
));
let cid_apply2 = remote_cids.apply_dcid();
assert!(cid_apply2.poll_borrow_cid(&mut cx).is_pending());
}
#[test]
fn test_retire_in_remote_cids() {
let waker = futures::task::noop_waker();
let mut cx = std::task::Context::from_waker(&waker);
let initial_dcid = ConnectionId::random_gen(8);
let retired_cids = RetiredCids::default();
let remote_cids = ArcRemoteCids::new(initial_dcid, 8, retired_cids);
let mut guard = remote_cids.0.lock().unwrap();
let mut cids = vec![initial_dcid];
for seq in 1..8 {
let cid = ConnectionId::random_gen(8);
cids.push(cid);
let frame = NewConnectionIdFrame {
sequence: VarInt::from_u32(seq),
retire_prior_to: VarInt::from_u32(0),
id: cid,
reset_token: ResetToken::random_gen(),
};
_ = guard.recv_new_cid_frame(&frame);
}
let cid_apply1 = guard.apply_dcid();
let cid_apply2 = guard.apply_dcid();
assert_eq!(cid_apply1.0.lock().unwrap().allocated_cids[0].0, 0);
assert_eq!(cid_apply2.0.lock().unwrap().allocated_cids[0].0, 1);
assert!(matches!(
cid_apply1.poll_borrow_cid(&mut cx),
Poll::Ready(Some(r#ref)) if *r#ref == cids[0]
));
assert!(matches!(
cid_apply2.poll_borrow_cid(&mut cx),
Poll::Ready(Some(r#ref)) if *r#ref == cids[1]
));
guard.retire_prior_to(4);
assert_eq!(guard.cid_deque.offset(), 4);
assert_eq!(guard.ready_cells.offset(), 4);
assert_eq!(guard.retired_cids.0.lock().unwrap().len(), 2);
assert_eq!(cid_apply1.0.lock().unwrap().allocated_cids[0].0, 0);
assert_eq!(cid_apply2.0.lock().unwrap().allocated_cids[0].0, 1);
assert!(matches!(
cid_apply1.poll_borrow_cid(&mut cx),
Poll::Ready(Some(r#ref)) if *r#ref == cids[0]
));
assert!(matches!(
cid_apply2.poll_borrow_cid(&mut cx),
Poll::Ready(Some(r#ref)) if *r#ref == cids[1]
));
guard.arrange_idle_cid();
assert_eq!(guard.retired_cids.0.lock().unwrap().len(), 4);
let retired_cids = [1, 0, 3, 2];
for seq in retired_cids {
assert_eq!(
guard.retired_cids.0.lock().unwrap().pop(),
Some(RetireConnectionIdFrame {
sequence: VarInt::from_u32(seq),
})
);
}
assert!(matches!(
cid_apply1.poll_borrow_cid(&mut cx),
Poll::Ready(Some(r#ref)) if *r#ref == cids[4]
));
assert!(matches!(
cid_apply2.poll_borrow_cid(&mut cx),
Poll::Ready(Some(r#ref)) if *r#ref == cids[5]
));
cid_apply2.retire();
assert_eq!(guard.retired_cids.lock().unwrap().len(), 1);
assert_eq!(
guard.retired_cids.0.lock().unwrap().pop(),
Some(RetireConnectionIdFrame {
sequence: VarInt::from_u32(5),
})
);
}
#[test]
fn test_retire_without_apply() {
let waker = futures::task::noop_waker();
let mut cx = std::task::Context::from_waker(&waker);
let initial_dcid = ConnectionId::random_gen(8);
let retired_cids = RetiredCids::default();
let remote_cids = ArcRemoteCids::new(initial_dcid, 8, retired_cids);
let mut guard = remote_cids.0.lock().unwrap();
let mut cids = vec![initial_dcid];
for seq in 1..8 {
let cid = ConnectionId::random_gen(8);
cids.push(cid);
let frame = NewConnectionIdFrame::new(cid, VarInt::from_u32(seq), VarInt::from_u32(0));
_ = guard.recv_new_cid_frame(&frame);
}
guard.retire_prior_to(4);
assert_eq!(guard.cid_deque.offset(), 4);
assert_eq!(guard.ready_cells.offset(), 4);
assert_eq!(guard.retired_cids.0.lock().unwrap().len(), 4);
let cid_apply1 = guard.apply_dcid();
let cid_apply2 = guard.apply_dcid();
assert_eq!(cid_apply1.0.lock().unwrap().allocated_cids[0].0, 4);
assert_eq!(cid_apply2.0.lock().unwrap().allocated_cids[0].0, 5);
assert!(matches!(
cid_apply1.poll_borrow_cid(&mut cx),
Poll::Ready(Some(r#ref)) if *r#ref == cids[4]
));
assert!(matches!(
cid_apply2.poll_borrow_cid(&mut cx),
Poll::Ready(Some(r#ref)) if *r#ref == cids[5]
));
}
}