#![allow(dead_code)]
use alloc::collections::BTreeMap;
use crate::rng::RngCore;
use crate::tls::Error;
const MAX_CID_LEN: usize = 20;
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
pub(crate) struct ConnectionId {
bytes: [u8; MAX_CID_LEN],
len: u8,
}
impl ConnectionId {
pub(crate) fn from_slice(bytes: &[u8]) -> Option<Self> {
if bytes.len() > MAX_CID_LEN {
return None;
}
let mut storage = [0u8; MAX_CID_LEN];
storage[..bytes.len()].copy_from_slice(bytes);
Some(Self {
bytes: storage,
len: bytes.len() as u8,
})
}
pub(crate) const fn empty() -> Self {
Self {
bytes: [0; MAX_CID_LEN],
len: 0,
}
}
pub(crate) fn random<R: RngCore>(rng: &mut R, len: usize) -> Self {
debug_assert!((1..=MAX_CID_LEN).contains(&len));
let mut storage = [0u8; MAX_CID_LEN];
rng.fill_bytes(&mut storage[..len]);
Self {
bytes: storage,
len: len as u8,
}
}
#[inline]
pub(crate) fn as_slice(&self) -> &[u8] {
&self.bytes[..self.len as usize]
}
#[inline]
pub(crate) fn len(&self) -> usize {
self.len as usize
}
#[inline]
pub(crate) fn is_empty(&self) -> bool {
self.len == 0
}
}
impl core::fmt::Debug for ConnectionId {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "ConnectionId(")?;
for b in self.as_slice() {
write!(f, "{b:02x}")?;
}
write!(f, ")")
}
}
#[derive(Clone, Debug)]
pub(crate) struct CidPair {
pub peer: ConnectionId,
pub local: ConnectionId,
}
impl CidPair {
pub(crate) fn new(peer: ConnectionId, local: ConnectionId) -> Self {
Self { peer, local }
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub(crate) struct CidEntry {
pub(crate) cid: ConnectionId,
pub(crate) sequence: u64,
pub(crate) reset_token: Option<[u8; 16]>,
}
pub(crate) struct CidPool {
pub(crate) entries: BTreeMap<u64, CidEntry>,
pub(crate) active_seq: u64,
pub(crate) retire_prior_to: u64,
pub(crate) limit: u64,
pub(crate) pending_retire: alloc::vec::Vec<u64>,
}
impl CidPool {
pub(crate) fn new(initial: ConnectionId, initial_reset_token: Option<[u8; 16]>) -> Self {
let mut entries = BTreeMap::new();
entries.insert(
0,
CidEntry {
cid: initial,
sequence: 0,
reset_token: initial_reset_token,
},
);
Self {
entries,
active_seq: 0,
retire_prior_to: 0,
limit: 2,
pending_retire: alloc::vec::Vec::new(),
}
}
pub(crate) fn set_limit(&mut self, limit: u64) {
self.limit = limit.max(2);
}
pub(crate) fn add(&mut self, entry: CidEntry) -> Result<(), Error> {
if let Some(existing) = self.entries.get(&entry.sequence) {
if existing != &entry {
return Err(Error::IllegalParameter);
}
return Ok(());
}
if entry.sequence < self.retire_prior_to {
self.pending_retire.push(entry.sequence);
return Ok(());
}
let live = self
.entries
.iter()
.filter(|(seq, _)| **seq >= self.retire_prior_to)
.count() as u64;
if live >= self.limit {
return Err(Error::IllegalParameter);
}
self.entries.insert(entry.sequence, entry);
Ok(())
}
pub(crate) fn retire(&mut self, sequence: u64) -> Result<Option<CidEntry>, Error> {
if sequence == self.active_seq && self.entries.contains_key(&sequence) {
return Err(Error::IllegalParameter);
}
Ok(self.entries.remove(&sequence))
}
pub(crate) fn note_retire_prior_to(&mut self, new: u64) {
if new <= self.retire_prior_to {
return;
}
self.retire_prior_to = new;
let dropped: alloc::vec::Vec<u64> =
self.entries.keys().copied().filter(|s| *s < new).collect();
for s in dropped {
self.entries.remove(&s);
self.pending_retire.push(s);
}
}
pub(crate) fn active(&self) -> Option<&CidEntry> {
self.entries.get(&self.active_seq)
}
pub(crate) fn how_many_to_issue(&self) -> u64 {
let live = self
.entries
.iter()
.filter(|(seq, _)| **seq >= self.retire_prior_to)
.count() as u64;
self.limit.saturating_sub(live)
}
pub(crate) fn pop_pending_retire(&mut self) -> Option<u64> {
if self.pending_retire.is_empty() {
None
} else {
Some(self.pending_retire.remove(0))
}
}
pub(crate) fn max_sequence(&self) -> u64 {
self.entries.keys().next_back().copied().unwrap_or(0)
}
pub(crate) fn set_token(&mut self, sequence: u64, token: [u8; 16]) -> bool {
if let Some(e) = self.entries.get_mut(&sequence) {
e.reset_token = Some(token);
true
} else {
false
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hash::Sha256;
use crate::rng::HmacDrbg;
#[test]
fn from_slice_caps_at_20() {
assert!(ConnectionId::from_slice(&[0u8; 21]).is_none());
let cid = ConnectionId::from_slice(&[1, 2, 3, 4]).unwrap();
assert_eq!(cid.as_slice(), &[1, 2, 3, 4]);
assert_eq!(cid.len(), 4);
assert!(!cid.is_empty());
}
#[test]
fn empty_is_empty() {
let e = ConnectionId::empty();
assert!(e.is_empty());
assert_eq!(e.len(), 0);
assert_eq!(e.as_slice(), &[] as &[u8]);
}
#[test]
fn random_has_right_length() {
let mut rng = HmacDrbg::<Sha256>::new(b"cid-test", b"nonce", &[]);
let cid = ConnectionId::random(&mut rng, 8);
assert_eq!(cid.len(), 8);
assert_ne!(cid.as_slice(), &[0u8; 8]);
}
#[test]
fn debug_is_hex() {
let cid = ConnectionId::from_slice(&[0x83, 0x94]).unwrap();
let s = alloc::format!("{cid:?}");
assert!(s.contains("8394"));
}
fn cid_n(n: u8) -> ConnectionId {
ConnectionId::from_slice(&[n; 8]).expect("8-byte cid")
}
#[test]
fn cidpool_seeded_with_handshake_entry() {
let pool = CidPool::new(cid_n(0), Some([0u8; 16]));
assert_eq!(pool.active_seq, 0);
assert!(pool.active().is_some());
assert_eq!(pool.active().unwrap().cid, cid_n(0));
assert_eq!(pool.limit, 2);
assert_eq!(pool.how_many_to_issue(), 1);
}
#[test]
fn cidpool_add_respects_limit() {
let mut pool = CidPool::new(cid_n(0), None);
pool.set_limit(2);
let e1 = CidEntry {
cid: cid_n(1),
sequence: 1,
reset_token: Some([1u8; 16]),
};
assert!(pool.add(e1).is_ok());
let e2 = CidEntry {
cid: cid_n(2),
sequence: 2,
reset_token: Some([2u8; 16]),
};
assert!(matches!(pool.add(e2), Err(Error::IllegalParameter)));
pool.set_limit(3);
let e2 = CidEntry {
cid: cid_n(2),
sequence: 2,
reset_token: Some([2u8; 16]),
};
assert!(pool.add(e2).is_ok());
assert_eq!(pool.max_sequence(), 2);
}
#[test]
fn cidpool_add_rejects_inconsistent_duplicate() {
let mut pool = CidPool::new(cid_n(0), None);
let e1 = CidEntry {
cid: cid_n(1),
sequence: 1,
reset_token: Some([7u8; 16]),
};
assert!(pool.add(e1.clone()).is_ok());
assert!(pool.add(e1.clone()).is_ok());
let e1_bad = CidEntry {
cid: cid_n(0xff),
sequence: 1,
reset_token: Some([7u8; 16]),
};
assert!(matches!(pool.add(e1_bad), Err(Error::IllegalParameter)));
}
#[test]
fn cidpool_retire_prior_to_pulls_retires_and_queues() {
let mut pool = CidPool::new(cid_n(0), None);
pool.set_limit(4);
for s in 1..=3 {
pool.add(CidEntry {
cid: cid_n(s as u8),
sequence: s,
reset_token: Some([s as u8; 16]),
})
.unwrap();
}
pool.active_seq = 2;
pool.note_retire_prior_to(2);
assert_eq!(pool.retire_prior_to, 2);
assert!(!pool.entries.contains_key(&0));
assert!(!pool.entries.contains_key(&1));
let mut got = alloc::vec::Vec::new();
while let Some(s) = pool.pop_pending_retire() {
got.push(s);
}
got.sort();
assert_eq!(got, alloc::vec![0u64, 1]);
}
#[test]
fn cidpool_retire_active_is_protocol_error() {
let mut pool = CidPool::new(cid_n(0), None);
assert!(matches!(pool.retire(0), Err(Error::IllegalParameter)));
}
#[test]
fn cidpool_retire_unknown_sequence_returns_none() {
let mut pool = CidPool::new(cid_n(0), None);
let r = pool.retire(42).expect("ok");
assert!(r.is_none());
}
#[test]
fn cidpool_add_below_retire_prior_to_immediately_retires() {
let mut pool = CidPool::new(cid_n(0), None);
pool.set_limit(4);
pool.active_seq = 5;
pool.note_retire_prior_to(3);
let e = CidEntry {
cid: cid_n(0xab),
sequence: 2,
reset_token: None,
};
assert!(pool.add(e).is_ok());
assert!(!pool.entries.contains_key(&2));
let mut got = alloc::vec::Vec::new();
while let Some(s) = pool.pop_pending_retire() {
got.push(s);
}
assert!(got.contains(&2));
}
}