use std::{
hash::{Hash, Hasher},
ops::Deref,
};
use nom::{IResult, bytes::streaming::take, number::streaming::be_u8};
use rand::RngExt;
pub const MAX_CID_SIZE: usize = 20;
#[derive(Clone, Copy, Eq, Default)]
pub struct ConnectionId {
pub(crate) len: u8,
pub(crate) bytes: [u8; MAX_CID_SIZE],
}
impl core::fmt::LowerHex for ConnectionId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
for &b in self.as_ref() {
write!(f, "{b:02x}")?;
}
Ok(())
}
}
impl core::fmt::Display for ConnectionId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
core::fmt::LowerHex::fmt(self, f)
}
}
impl core::fmt::Debug for ConnectionId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
core::fmt::LowerHex::fmt(self, f)
}
}
pub fn be_connection_id(input: &[u8]) -> IResult<&[u8], ConnectionId> {
let (remain, len) = be_u8(input)?;
be_connection_id_with_len(remain, len as usize)
}
pub fn be_connection_id_with_len(input: &[u8], len: usize) -> IResult<&[u8], ConnectionId> {
if len > MAX_CID_SIZE {
return Err(nom::Err::Error(nom::error::make_error(
input,
nom::error::ErrorKind::TooLarge,
)));
}
let (remain, bytes) = take(len)(input)?;
Ok((remain, ConnectionId::from_slice(bytes)))
}
pub trait WriteConnectionId: bytes::BufMut {
fn put_connection_id(&mut self, cid: &ConnectionId);
}
impl<T: bytes::BufMut> WriteConnectionId for T {
fn put_connection_id(&mut self, cid: &ConnectionId) {
self.put_u8(cid.len);
self.put_slice(cid);
}
}
impl ConnectionId {
pub fn from_slice(bytes: &[u8]) -> Self {
debug_assert!(bytes.len() <= MAX_CID_SIZE);
let mut res = Self {
len: bytes.len() as u8,
bytes: [0; MAX_CID_SIZE],
};
res.bytes[..bytes.len()].copy_from_slice(bytes);
res
}
pub fn random_gen(len: usize) -> Self {
debug_assert!(len <= MAX_CID_SIZE);
let mut bytes = [0; MAX_CID_SIZE];
rand::rng().fill(&mut bytes[..len]);
Self {
len: len as u8,
bytes,
}
}
pub fn random_gen_with_mark(len: usize, mark: u8, mask: u8) -> Self {
debug_assert!(len > 0 && len <= MAX_CID_SIZE);
let mut bytes = [0; MAX_CID_SIZE];
rand::rng().fill(&mut bytes[..len]);
bytes[0] = (bytes[0] & mask) | mark;
Self {
len: len as u8,
bytes,
}
}
pub fn encoding_size(&self) -> usize {
1 + self.len as usize
}
}
impl Deref for ConnectionId {
type Target = [u8];
fn deref(&self) -> &Self::Target {
&self.bytes[0..self.len as usize]
}
}
impl PartialEq<ConnectionId> for ConnectionId {
fn eq(&self, other: &ConnectionId) -> bool {
self.len == other.len && self.bytes[..self.len as usize] == other.bytes[..self.len as usize]
}
}
impl Hash for ConnectionId {
fn hash<H: Hasher>(&self, state: &mut H) {
self.len.hash(state);
self.bytes[..self.len as usize].hash(state);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_read_connection_id() {
let buf = vec![0x04, 0x01, 0x02, 0x03, 0x04];
let (remain, cid) = be_connection_id(&buf).unwrap();
assert!(remain.is_empty());
assert_eq!(*cid, [0x01, 0x02, 0x03, 0x04],);
let buf = vec![21, 0x01, 0x02, 0x03, 0x04];
assert_eq!(
be_connection_id(&buf),
Err(nom::Err::Error(nom::error::make_error(
&buf[1..],
nom::error::ErrorKind::TooLarge
)))
);
}
#[test]
#[should_panic]
fn test_cid_from_large_slice() {
ConnectionId::from_slice(&[0; MAX_CID_SIZE + 1]);
}
#[test]
fn test_write_connection_id() {
use bytes::{Bytes, BytesMut};
let mut buf = BytesMut::new();
let cid = ConnectionId::from_slice(&[0x01, 0x02, 0x03, 0x04]);
buf.put_connection_id(&cid);
assert_eq!(
buf.freeze(),
Bytes::from_static(&[0x04, 0x01, 0x02, 0x03, 0x04])
);
}
}