use std::sync::{Arc, Mutex};
use super::{ConnectionId, GenUniqueCid, RetireCid};
use crate::{
error::{Error, ErrorKind, QuicError},
frame::{
FrameType, GetFrameType, NewConnectionIdFrame, RetireConnectionIdFrame,
io::{ReceiveFrame, SendFrame},
},
token::ResetToken,
util::IndexDeque,
varint::{VARINT_MAX, VarInt},
};
#[derive(Debug)]
struct LocalCids<ISSUED>
where
ISSUED: GenUniqueCid + RetireCid + SendFrame<NewConnectionIdFrame>,
{
cid_deque: IndexDeque<Option<(ConnectionId, ResetToken)>, VARINT_MAX>,
issued_cids: ISSUED,
active_cid_limit: Option<u64>,
}
impl<ISSUED> LocalCids<ISSUED>
where
ISSUED: GenUniqueCid + RetireCid + SendFrame<NewConnectionIdFrame>,
{
fn new(scid: ConnectionId, issued_cids: ISSUED) -> Self {
let mut cid_deque = IndexDeque::default();
cid_deque
.push_back(Some((scid, ResetToken::default())))
.unwrap();
let new_cid = issued_cids.gen_unique_cid();
let new_cid_frame =
NewConnectionIdFrame::new(new_cid, VarInt::from_u32(1), VarInt::from_u32(0));
issued_cids.send_frame([new_cid_frame]);
cid_deque
.push_back(Some((
*new_cid_frame.connection_id(),
*new_cid_frame.reset_token(),
)))
.unwrap();
Self {
cid_deque,
issued_cids,
active_cid_limit: None,
}
}
fn initial_scid(&self) -> Option<ConnectionId> {
self.cid_deque.get(0)?.map(|(cid, _)| cid)
}
fn set_limit(&mut self, active_cid_limit: u64) -> Result<(), Error> {
debug_assert!(self.active_cid_limit.is_none());
if active_cid_limit < 2 {
return Err(QuicError::new(
ErrorKind::TransportParameter,
FrameType::Crypto.into(),
format!("active connection id limit {active_cid_limit} < 2"),
)
.into());
}
for _ in self.cid_deque.largest()..active_cid_limit {
self.issue_new_cid();
}
self.active_cid_limit = Some(active_cid_limit);
Ok(())
}
fn issue_new_cid(&mut self) {
let seq = VarInt::from_u64(self.cid_deque.largest()).unwrap();
let retire_prior_to = VarInt::from_u64(self.cid_deque.offset()).unwrap();
let new_cid = self.issued_cids.gen_unique_cid();
let new_cid_frame = NewConnectionIdFrame::new(new_cid, seq, retire_prior_to);
self.issued_cids.send_frame([new_cid_frame]);
self.cid_deque.push_back(Some((*new_cid_frame.connection_id(), *new_cid_frame.reset_token())))
.expect("it's very very hard to issue a new connection ID whose sequence excceeds VARINT_MAX");
}
fn recv_retire_cid_frame(&mut self, frame: &RetireConnectionIdFrame) -> Result<(), Error> {
let seq = frame.sequence();
if seq >= self.cid_deque.largest() {
return Err(QuicError::new(
ErrorKind::ConnectionIdLimit,
frame.frame_type().into(),
format!(
"Sequence({seq}) in RetireConnectionIdFrame exceeds the largest one({}) issued by us",
self.cid_deque.largest().saturating_sub(1)
),
).into());
}
if let Some(value) = self.cid_deque.get_mut(seq) {
if let Some((cid, _)) = value.take() {
let n = self.cid_deque.iter().take_while(|v| v.is_none()).count();
self.cid_deque.advance(n);
self.issue_new_cid();
self.issued_cids.retire_cid(cid);
}
}
Ok(())
}
fn clear(&mut self) {
for (cid, _reset_token) in self.cid_deque.drain_to(self.cid_deque.largest()).flatten() {
self.issued_cids.retire_cid(cid);
}
}
}
impl<ISSUED> Drop for LocalCids<ISSUED>
where
ISSUED: GenUniqueCid + RetireCid + SendFrame<NewConnectionIdFrame>,
{
fn drop(&mut self) {
self.clear();
}
}
#[derive(Debug, Clone)]
pub struct ArcLocalCids<ISSUED>(Arc<Mutex<LocalCids<ISSUED>>>)
where
ISSUED: GenUniqueCid + RetireCid + SendFrame<NewConnectionIdFrame>;
impl<ISSUED> ArcLocalCids<ISSUED>
where
ISSUED: GenUniqueCid + RetireCid + SendFrame<NewConnectionIdFrame>,
{
pub fn new(scid: ConnectionId, issued_cids: ISSUED) -> Self {
let raw_local_cids = LocalCids::new(scid, issued_cids);
Self(Arc::new(Mutex::new(raw_local_cids)))
}
pub fn initial_scid(&self) -> Option<ConnectionId> {
self.0.lock().unwrap().initial_scid()
}
pub fn clear(&self) {
self.0.lock().unwrap().clear();
}
pub fn set_limit(&self, active_cid_limit: u64) -> Result<(), Error> {
self.0.lock().unwrap().set_limit(active_cid_limit)
}
}
impl<ISSUED> ReceiveFrame<RetireConnectionIdFrame> for ArcLocalCids<ISSUED>
where
ISSUED: GenUniqueCid + RetireCid + SendFrame<NewConnectionIdFrame>,
{
type Output = ();
fn recv_frame(
&self,
frame: &RetireConnectionIdFrame,
) -> Result<Self::Output, crate::error::Error> {
self.0.lock().unwrap().recv_retire_cid_frame(frame)
}
}
#[cfg(test)]
mod tests {
use std::{collections::HashMap, sync::MutexGuard};
use super::*;
#[derive(Default)]
struct IssuedCids {
frames: Mutex<Vec<NewConnectionIdFrame>>,
active_cids: Mutex<HashMap<ConnectionId, ResetToken>>,
}
impl IssuedCids {
fn frames(&self) -> MutexGuard<'_, Vec<NewConnectionIdFrame>> {
self.frames.lock().unwrap()
}
fn active_cids(&self) -> MutexGuard<'_, HashMap<ConnectionId, ResetToken>> {
self.active_cids.lock().unwrap()
}
}
impl GenUniqueCid for IssuedCids {
fn gen_unique_cid(&self) -> ConnectionId {
let mut local_cids = self.active_cids.lock().unwrap();
let unique_cid =
core::iter::from_fn(|| Some(ConnectionId::random_gen_with_mark(8, 0x80, 0x7F)))
.find(|cid| !local_cids.contains_key(cid))
.unwrap();
local_cids.insert(unique_cid, ResetToken::default());
unique_cid
}
}
impl RetireCid for IssuedCids {
fn retire_cid(&self, cid: ConnectionId) {
self.active_cids.lock().unwrap().remove(&cid);
}
}
impl SendFrame<NewConnectionIdFrame> for IssuedCids {
fn send_frame<I: IntoIterator<Item = NewConnectionIdFrame>>(&self, iter: I) {
self.frames.lock().unwrap().extend(iter);
}
}
#[test]
fn test_issue_cid() {
let initial_scid = ConnectionId::random_gen(8);
let local_cids = ArcLocalCids::new(initial_scid, IssuedCids::default());
let mut local_cids = local_cids.0.lock().unwrap();
assert_eq!(local_cids.cid_deque.len(), 2);
local_cids.set_limit(3).unwrap();
assert_eq!(local_cids.cid_deque.len(), 3);
}
#[test]
fn test_recv_retire_cid_frame() {
let initial_scid = ConnectionId::random_gen(8);
let mut local_cids = LocalCids::new(initial_scid, IssuedCids::default());
assert_eq!(local_cids.cid_deque.len(), 2);
assert_eq!(local_cids.issued_cids.frames().len(), 1);
let issued_cid2 = *local_cids.issued_cids.frames()[0].connection_id();
let retire_frame = RetireConnectionIdFrame::new(VarInt::from_u32(1));
let cid2 = local_cids.recv_retire_cid_frame(&retire_frame);
assert!(cid2.is_ok());
assert!(
!local_cids
.issued_cids
.active_cids()
.contains_key(&issued_cid2)
);
assert_eq!(local_cids.cid_deque.get(1), Some(&None));
assert_eq!(local_cids.cid_deque.len(), 3);
assert_eq!(local_cids.issued_cids.frames().len(), 2);
let retire_frame = RetireConnectionIdFrame::new(VarInt::from_u32(0));
let cid1 = local_cids.recv_retire_cid_frame(&retire_frame);
assert!(cid1.is_ok());
assert!(
!local_cids
.issued_cids
.active_cids()
.contains_key(&initial_scid)
);
assert_eq!(local_cids.cid_deque.get(0), None);
assert_eq!(local_cids.cid_deque.len(), 2);
assert_eq!(local_cids.issued_cids.frames().len(), 3);
let retire_frame = RetireConnectionIdFrame::new(VarInt::from_u32(2));
let cid3 = local_cids.recv_retire_cid_frame(&retire_frame);
assert!(cid3.is_ok());
}
}