use deref_derive::Deref;
use super::{error::Error, KeyPhaseBit, PacketNumber};
pub mod long;
pub mod short;
const HEADER_FORM_MASK: u8 = 0x80;
const FIXED_BIT: u8 = 0x40;
pub const LONG_RESERVED_MASK: u8 = 0x0C;
pub const SHORT_RESERVED_MASK: u8 = 0x18;
#[derive(Debug, Clone, Copy, Deref)]
pub struct SpecificBits<const R: u8>(pub(super) u8);
pub type LongSpecificBits = SpecificBits<LONG_RESERVED_MASK>;
pub type ShortSpecificBits = SpecificBits<SHORT_RESERVED_MASK>;
impl<const R: u8> SpecificBits<R> {
pub fn from_pn(pn: &PacketNumber) -> Self {
Self(pn.size() as u8 - 1)
}
pub fn with_pn_len(pn_size: usize) -> Self {
debug_assert!(pn_size <= 4 && pn_size > 0);
Self(pn_size as u8 - 1)
}
}
impl ShortSpecificBits {
pub fn set_key_phase(&mut self, key_phase_bit: KeyPhaseBit) {
key_phase_bit.imply(&mut self.0);
}
pub fn key_phase(&self) -> KeyPhaseBit {
KeyPhaseBit::from(self.0)
}
}
impl<const R: u8> From<u8> for SpecificBits<R> {
fn from(byte: u8) -> Self {
Self(byte)
}
}
pub trait GetPacketNumberLength {
const PN_LEN_MASK: u8 = 0x03;
fn pn_len(&self) -> Result<u8, Error>;
}
impl<const R: u8> GetPacketNumberLength for SpecificBits<R> {
fn pn_len(&self) -> Result<u8, Error> {
let reserved_bit = self.0 & R;
if reserved_bit == 0 {
Ok((self.0 & Self::PN_LEN_MASK) + 1)
} else {
Err(Error::InvalidReservedBits(reserved_bit, R))
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Type {
Long(long::Type),
Short(short::OneRtt),
}
impl Type {
#[inline]
pub fn encoding_size(&self) -> usize {
match self {
Type::Short(_) => 1,
Type::Long(_) => 5,
}
}
}
pub mod io {
use bytes::BufMut;
use super::{long::io::WriteLongType, short::WriteShortType, *};
pub fn be_packet_type(input: &[u8]) -> nom::IResult<&[u8], Type, Error> {
let (remain, ty) = nom::number::streaming::be_u8(input)?;
if ty & HEADER_FORM_MASK == 0 {
Ok((remain, Type::Short(short::OneRtt::from(ty))))
} else {
let (remain, ty) = long::io::parse_long_type(ty)(remain)?;
Ok((remain, Type::Long(ty)))
}
}
pub trait WritePacketType: BufMut {
fn put_packet_type(&mut self, ty: &Type);
}
impl<B: BufMut> WritePacketType for B {
fn put_packet_type(&mut self, ty: &Type) {
match ty {
Type::Short(one_rtt) => self.put_short_type(one_rtt),
Type::Long(long_type) => self.put_long_type(long_type),
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_long_clear_bits() {
let specific_bits = SpecificBits::<0x0C>(0x0C);
assert_eq!(
specific_bits.pn_len(),
Err(Error::InvalidReservedBits(0x0C, 0x0C))
);
let specific_bits = SpecificBits::<0x0C>(0x04);
assert_eq!(
specific_bits.pn_len(),
Err(Error::InvalidReservedBits(0x04, 0x0C))
);
let specific_bits = SpecificBits::<0x0C>(0x08);
assert_eq!(
specific_bits.pn_len(),
Err(Error::InvalidReservedBits(0x08, 0x0C))
);
let specific_bits = LongSpecificBits::with_pn_len(4);
assert_eq!(specific_bits.pn_len().unwrap(), 4);
let specific_bits = LongSpecificBits::with_pn_len(3);
assert_eq!(specific_bits.pn_len().unwrap(), 3);
let specific_bits = LongSpecificBits::with_pn_len(2);
assert_eq!(specific_bits.pn_len().unwrap(), 2);
let specific_bits = LongSpecificBits::with_pn_len(1);
assert_eq!(specific_bits.pn_len().unwrap(), 1);
}
#[test]
fn test_short_specific_bits() {
let specific_bits = SpecificBits::<0x18>(0x18);
assert_eq!(
specific_bits.pn_len(),
Err(Error::InvalidReservedBits(0x18, 0x18))
);
let specific_bits = SpecificBits::<0x18>(0x11);
assert_eq!(
specific_bits.pn_len(),
Err(Error::InvalidReservedBits(0x10, 0x18))
);
let specific_bits = SpecificBits::<0x18>(0x0A);
assert_eq!(
specific_bits.pn_len(),
Err(Error::InvalidReservedBits(0x08, 0x18))
);
let specific_bits = ShortSpecificBits::with_pn_len(4);
assert_eq!(specific_bits.pn_len().unwrap(), 4);
let specific_bits = ShortSpecificBits::with_pn_len(3);
assert_eq!(specific_bits.pn_len().unwrap(), 3);
let specific_bits = ShortSpecificBits::with_pn_len(2);
assert_eq!(specific_bits.pn_len().unwrap(), 2);
let specific_bits = ShortSpecificBits::with_pn_len(1);
assert_eq!(specific_bits.pn_len().unwrap(), 1);
}
#[test]
fn test_set_key_phase_bit() {
let mut specific_bits = ShortSpecificBits::with_pn_len(4);
assert_eq!(specific_bits.0, 0x03);
specific_bits.set_key_phase(KeyPhaseBit::One);
assert_eq!(specific_bits.0, 0x07);
assert_eq!(specific_bits.key_phase(), KeyPhaseBit::One);
specific_bits.set_key_phase(KeyPhaseBit::Zero);
assert_eq!(specific_bits.0, 0x03);
assert_eq!(specific_bits.key_phase(), KeyPhaseBit::Zero);
}
}