use std::sync::{Arc, Mutex};
use super::{ConnectionId, GenUniqueCid};
use crate::{
error::{Error, ErrorKind},
frame::{
BeFrame, FrameType, NewConnectionIdFrame, ReceiveFrame, RetireConnectionIdFrame, SendFrame,
},
token::ResetToken,
util::IndexDeque,
varint::{VarInt, VARINT_MAX},
};
#[derive(Debug)]
struct LocalCids<ISSUED>
where
ISSUED: GenUniqueCid + SendFrame<NewConnectionIdFrame>,
{
cid_deque: IndexDeque<Option<(ConnectionId, ResetToken)>, VARINT_MAX>,
issued_cids: ISSUED,
active_cid_limit: Option<u64>,
}
impl<ISSUED> LocalCids<ISSUED>
where
ISSUED: GenUniqueCid + SendFrame<NewConnectionIdFrame>,
{
fn new(scid: ConnectionId, issued_cids: ISSUED) -> Self {
let mut cid_deque = IndexDeque::default();
cid_deque
.push_back(Some((scid, ResetToken::default())))
.unwrap();
let new_cid = issued_cids.gen_unique_cid();
let new_cid_frame =
NewConnectionIdFrame::new(new_cid, VarInt::from_u32(1), VarInt::from_u32(0));
issued_cids.send_frame([new_cid_frame]);
cid_deque
.push_back(Some((new_cid_frame.id, new_cid_frame.reset_token)))
.unwrap();
Self {
cid_deque,
issued_cids,
active_cid_limit: None,
}
}
fn initial_scid(&self) -> Option<ConnectionId> {
self.cid_deque.get(0)?.map(|(cid, _)| cid)
}
fn set_limit(&mut self, active_cid_limit: u64) -> Result<(), Error> {
debug_assert!(self.active_cid_limit.is_none());
if active_cid_limit < 2 {
return Err(Error::new(
ErrorKind::TransportParameter,
FrameType::Crypto,
format!("{} < 2", active_cid_limit),
));
}
for _ in self.cid_deque.largest()..active_cid_limit {
self.issue_new_cid();
}
self.active_cid_limit = Some(active_cid_limit);
Ok(())
}
fn issue_new_cid(&mut self) {
let seq = VarInt::from_u64(self.cid_deque.largest()).unwrap();
let retire_prior_to = VarInt::from_u64(self.cid_deque.offset()).unwrap();
let new_cid = self.issued_cids.gen_unique_cid();
let new_cid_frame = NewConnectionIdFrame::new(new_cid, seq, retire_prior_to);
self.issued_cids.send_frame([new_cid_frame]);
self.cid_deque.push_back(Some((new_cid_frame.id, new_cid_frame.reset_token)))
.expect("it's very very hard to issue a new connection ID whose sequence excceeds VARINT_MAX");
}
fn recv_retire_cid_frame(
&mut self,
frame: &RetireConnectionIdFrame,
) -> Result<Option<ConnectionId>, Error> {
let seq = frame.sequence.into_inner();
if seq >= self.cid_deque.largest() {
return Err(Error::new(
ErrorKind::ConnectionIdLimit,
frame.frame_type(),
format!(
"Sequence({seq}) in RetireConnectionIdFrame exceeds the largest one({}) issued by us",
self.cid_deque.largest().saturating_sub(1)
),
));
}
if let Some(value) = self.cid_deque.get_mut(seq) {
if let Some((cid, _)) = value.take() {
let n = self.cid_deque.iter().take_while(|v| v.is_none()).count();
self.cid_deque.advance(n);
self.issue_new_cid();
return Ok(Some(cid));
}
}
Ok(None)
}
}
#[derive(Debug, Clone)]
pub struct ArcLocalCids<ISSUED>(Arc<Mutex<LocalCids<ISSUED>>>)
where
ISSUED: GenUniqueCid + SendFrame<NewConnectionIdFrame>;
impl<ISSUED> ArcLocalCids<ISSUED>
where
ISSUED: GenUniqueCid + SendFrame<NewConnectionIdFrame>,
{
pub fn new(scid: ConnectionId, issued_cids: ISSUED) -> Self {
let raw_local_cids = LocalCids::new(scid, issued_cids);
Self(Arc::new(Mutex::new(raw_local_cids)))
}
pub fn initial_scid(&self) -> Option<ConnectionId> {
self.0.lock().unwrap().initial_scid()
}
pub fn active_cids(&self) -> Vec<ConnectionId> {
self.0
.lock()
.unwrap()
.cid_deque
.iter()
.filter_map(|v| v.map(|(cid, _)| cid))
.collect()
}
pub fn set_limit(&self, active_cid_limit: u64) -> Result<(), Error> {
self.0.lock().unwrap().set_limit(active_cid_limit)
}
}
impl<ISSUED> ReceiveFrame<RetireConnectionIdFrame> for ArcLocalCids<ISSUED>
where
ISSUED: GenUniqueCid + SendFrame<NewConnectionIdFrame>,
{
type Output = Option<ConnectionId>;
fn recv_frame(
&self,
frame: &RetireConnectionIdFrame,
) -> Result<Self::Output, crate::error::Error> {
self.0.lock().unwrap().recv_retire_cid_frame(frame)
}
}
#[cfg(test)]
mod tests {
use deref_derive::Deref;
use super::*;
#[derive(Debug, Deref, Default)]
struct IssuedCids(Arc<Mutex<Vec<NewConnectionIdFrame>>>);
impl IssuedCids {
fn lock_guard(&self) -> std::sync::MutexGuard<'_, Vec<NewConnectionIdFrame>> {
self.0.lock().unwrap()
}
}
impl GenUniqueCid for IssuedCids {
fn gen_unique_cid(&self) -> ConnectionId {
ConnectionId::random_gen_with_mark(8, 0x80, 0x7F)
}
}
impl SendFrame<NewConnectionIdFrame> for IssuedCids {
fn send_frame<I: IntoIterator<Item = NewConnectionIdFrame>>(&self, iter: I) {
self.0.lock().unwrap().extend(iter);
}
}
#[test]
fn test_issue_cid() {
let initial_scid = ConnectionId::random_gen(8);
let local_cids = ArcLocalCids::new(initial_scid, IssuedCids::default());
let mut guard = local_cids.0.lock().unwrap();
assert_eq!(guard.cid_deque.len(), 2);
guard.set_limit(3).unwrap();
assert_eq!(guard.cid_deque.len(), 3);
}
#[test]
fn test_recv_retire_cid_frame() {
let initial_scid = ConnectionId::random_gen(8);
let mut local_cids = LocalCids::new(initial_scid, IssuedCids::default());
assert_eq!(local_cids.cid_deque.len(), 2);
assert_eq!(local_cids.issued_cids.lock_guard().len(), 1);
let issued_cid2 = local_cids.issued_cids.lock_guard()[0].id;
let retire_frame = RetireConnectionIdFrame {
sequence: VarInt::from_u32(1),
};
let cid2 = local_cids.recv_retire_cid_frame(&retire_frame);
assert!(cid2.is_ok());
assert_eq!(cid2, Ok(Some(issued_cid2)));
assert_eq!(local_cids.cid_deque.get(1), Some(&None));
assert_eq!(local_cids.cid_deque.len(), 3);
assert_eq!(local_cids.issued_cids.lock_guard().len(), 2);
let retire_frame = RetireConnectionIdFrame {
sequence: VarInt::from_u32(0),
};
let cid1 = local_cids.recv_retire_cid_frame(&retire_frame);
assert!(cid1.is_ok());
assert_eq!(cid1, Ok(Some(initial_scid)));
assert_eq!(local_cids.cid_deque.get(0), None);
assert_eq!(local_cids.cid_deque.len(), 2);
assert_eq!(local_cids.issued_cids.lock_guard().len(), 3);
let retire_frame = RetireConnectionIdFrame {
sequence: VarInt::from_u32(2),
};
let cid3 = local_cids.recv_retire_cid_frame(&retire_frame);
assert!(cid3.is_ok());
}
}