use std::hash::Hasher;
use rand::{Rng, RngCore};
use crate::Duration;
use crate::MAX_CID_SIZE;
use crate::shared::ConnectionId;
pub trait ConnectionIdGenerator: Send + Sync {
fn generate_cid(&mut self) -> ConnectionId;
fn validate(&self, _cid: &ConnectionId) -> Result<(), InvalidCid> {
Ok(())
}
fn cid_len(&self) -> usize;
fn cid_lifetime(&self) -> Option<Duration>;
}
#[derive(Debug, Copy, Clone)]
pub struct InvalidCid;
#[derive(Debug, Clone, Copy)]
pub struct RandomConnectionIdGenerator {
cid_len: usize,
lifetime: Option<Duration>,
}
impl Default for RandomConnectionIdGenerator {
fn default() -> Self {
Self {
cid_len: 8,
lifetime: None,
}
}
}
impl RandomConnectionIdGenerator {
pub fn new(cid_len: usize) -> Self {
debug_assert!(cid_len <= MAX_CID_SIZE);
Self {
cid_len,
..Self::default()
}
}
pub fn set_lifetime(&mut self, d: Duration) -> &mut Self {
self.lifetime = Some(d);
self
}
}
impl ConnectionIdGenerator for RandomConnectionIdGenerator {
fn generate_cid(&mut self) -> ConnectionId {
let mut bytes_arr = [0; MAX_CID_SIZE];
rand::thread_rng().fill_bytes(&mut bytes_arr[..self.cid_len]);
ConnectionId::new(&bytes_arr[..self.cid_len])
}
fn cid_len(&self) -> usize {
self.cid_len
}
fn cid_lifetime(&self) -> Option<Duration> {
self.lifetime
}
}
pub struct HashedConnectionIdGenerator {
key: u64,
lifetime: Option<Duration>,
}
impl HashedConnectionIdGenerator {
pub fn new() -> Self {
Self::from_key(rand::thread_rng().r#gen())
}
pub fn from_key(key: u64) -> Self {
Self {
key,
lifetime: None,
}
}
pub fn set_lifetime(&mut self, d: Duration) -> &mut Self {
self.lifetime = Some(d);
self
}
}
impl Default for HashedConnectionIdGenerator {
fn default() -> Self {
Self::new()
}
}
impl ConnectionIdGenerator for HashedConnectionIdGenerator {
fn generate_cid(&mut self) -> ConnectionId {
let mut bytes_arr = [0; NONCE_LEN + SIGNATURE_LEN];
rand::thread_rng().fill_bytes(&mut bytes_arr[..NONCE_LEN]);
let mut hasher = rustc_hash::FxHasher::default();
hasher.write_u64(self.key);
hasher.write(&bytes_arr[..NONCE_LEN]);
bytes_arr[NONCE_LEN..].copy_from_slice(&hasher.finish().to_le_bytes()[..SIGNATURE_LEN]);
ConnectionId::new(&bytes_arr)
}
fn validate(&self, cid: &ConnectionId) -> Result<(), InvalidCid> {
let (nonce, signature) = cid.split_at(NONCE_LEN);
let mut hasher = rustc_hash::FxHasher::default();
hasher.write_u64(self.key);
hasher.write(nonce);
let expected = hasher.finish().to_le_bytes();
match expected[..SIGNATURE_LEN] == signature[..] {
true => Ok(()),
false => Err(InvalidCid),
}
}
fn cid_len(&self) -> usize {
NONCE_LEN + SIGNATURE_LEN
}
fn cid_lifetime(&self) -> Option<Duration> {
self.lifetime
}
}
const NONCE_LEN: usize = 3; const SIGNATURE_LEN: usize = 8 - NONCE_LEN;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn random_default_len() {
let mut g = RandomConnectionIdGenerator::default();
assert_eq!(g.cid_len(), 8);
let cid = g.generate_cid();
assert_eq!(cid.len(), 8);
}
#[test]
fn random_custom_len() {
let mut g = RandomConnectionIdGenerator::new(4);
assert_eq!(g.cid_len(), 4);
let cid = g.generate_cid();
assert_eq!(cid.len(), 4);
}
#[test]
fn random_max_len() {
let mut g = RandomConnectionIdGenerator::new(MAX_CID_SIZE);
assert_eq!(g.cid_len(), MAX_CID_SIZE);
let cid = g.generate_cid();
assert_eq!(cid.len(), MAX_CID_SIZE);
}
#[test]
fn random_min_zero_len() {
let mut g = RandomConnectionIdGenerator::new(0);
assert_eq!(g.cid_len(), 0);
let cid = g.generate_cid();
assert!(cid.is_empty());
}
#[test]
fn random_cids_differ() {
let mut g = RandomConnectionIdGenerator::new(8);
let a = g.generate_cid();
let b = g.generate_cid();
assert_ne!(a, b, "random CIDs should almost never collide");
}
#[test]
fn random_no_lifetime_by_default() {
let g = RandomConnectionIdGenerator::default();
assert!(g.cid_lifetime().is_none());
}
#[test]
fn random_with_lifetime() {
let mut g = RandomConnectionIdGenerator::new(8);
g.set_lifetime(Duration::from_secs(60));
assert_eq!(g.cid_lifetime(), Some(Duration::from_secs(60)));
}
#[test]
fn random_not_all_zeros() {
let mut g = RandomConnectionIdGenerator::new(8);
let cid = g.generate_cid();
assert!(cid.iter().any(|&b| b != 0), "CID should not be all zeros");
}
#[test]
fn random_validate_always_ok() {
let g = RandomConnectionIdGenerator::default();
let arbitrary = ConnectionId::new(&[0xAB, 0xCD]);
assert!(g.validate(&arbitrary).is_ok());
}
#[test]
fn hashed_default_len() {
let mut g = HashedConnectionIdGenerator::new();
assert_eq!(g.cid_len(), 8);
let cid = g.generate_cid();
assert_eq!(cid.len(), 8);
}
#[test]
fn hashed_validate_own_cid() {
let mut g = HashedConnectionIdGenerator::new();
let cid = g.generate_cid();
assert!(g.validate(&cid).is_ok());
}
#[test]
fn hashed_validates_from_key() {
let mut g = HashedConnectionIdGenerator::from_key(0xDEAD_BEEF_CAFE_BABE);
let cid = g.generate_cid();
assert!(g.validate(&cid).is_ok());
}
#[test]
fn hashed_rejects_other_key() {
let mut g1 = HashedConnectionIdGenerator::from_key(1);
let cid_from_1 = g1.generate_cid();
let g2 = HashedConnectionIdGenerator::from_key(2);
assert!(g2.validate(&cid_from_1).is_err());
}
#[test]
fn hashed_deterministic_key_still_random_nonce() {
let mut g = HashedConnectionIdGenerator::from_key(42);
let cid1 = g.generate_cid();
let cid2 = g.generate_cid();
assert_ne!(cid1, cid2);
assert!(g.validate(&cid1).is_ok());
assert!(g.validate(&cid2).is_ok());
}
#[test]
fn hashed_rejects_random_cid() {
let g = HashedConnectionIdGenerator::from_key(42);
let random_cid = ConnectionId::new(&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]);
assert!(g.validate(&random_cid).is_err());
}
#[test]
fn hashed_rejects_wrong_length() {
let g = HashedConnectionIdGenerator::from_key(42);
let short_cid = ConnectionId::new(&[0x01, 0x02, 0x03]);
assert!(g.validate(&short_cid).is_err());
}
#[test]
fn hashed_lifetime() {
let mut g = HashedConnectionIdGenerator::new();
assert!(g.cid_lifetime().is_none());
g.set_lifetime(Duration::from_secs(300));
assert_eq!(g.cid_lifetime(), Some(Duration::from_secs(300)));
}
#[test]
fn hashed_validate_twice() {
let mut g = HashedConnectionIdGenerator::from_key(123);
let cid = g.generate_cid();
assert!(g.validate(&cid).is_ok());
assert!(g.validate(&cid).is_ok());
}
#[test]
fn hashed_cid_always_8_bytes() {
let mut g = HashedConnectionIdGenerator::new();
for _ in 0..10 {
let cid = g.generate_cid();
assert_eq!(cid.len(), 8);
}
}
#[test]
fn hashed_default_is_new() {
let g = HashedConnectionIdGenerator::default();
assert_eq!(g.cid_len(), 8);
}
#[test]
fn invalid_cid_debug_and_copy() {
let a = InvalidCid;
let b = a;
assert_eq!(format!("{a:?}"), format!("{b:?}"));
}
#[test]
fn trait_random_generator() {
let mut g: Box<dyn ConnectionIdGenerator> = Box::new(RandomConnectionIdGenerator::new(4));
let cid = g.generate_cid();
assert_eq!(cid.len(), 4);
assert_eq!(g.cid_len(), 4);
}
#[test]
fn trait_hashed_generator() {
let mut g: Box<dyn ConnectionIdGenerator> = Box::new(HashedConnectionIdGenerator::new());
let cid = g.generate_cid();
assert_eq!(cid.len(), 8);
assert!(g.validate(&cid).is_ok());
}
}