use std::collections::HashMap;
use parsely_rs::*;
#[derive(Debug, PartialEq)]
pub struct OneByteHeaderExtension {
id: u4,
data: Bits,
}
impl OneByteHeaderExtension {
pub const TYPE: u16 = 0xBEDE;
pub fn type_matches(ext_type: u16) -> bool {
ext_type == Self::TYPE
}
pub fn new(id: u4, data: Bits) -> Self {
Self { id, data }
}
pub fn id(&self) -> u4 {
self.id
}
pub fn data(&self) -> &[u8] {
self.data.chunk_bytes()
}
}
impl From<OneByteHeaderExtension> for SomeHeaderExtension {
fn from(value: OneByteHeaderExtension) -> Self {
SomeHeaderExtension::OneByteHeaderExtension(value)
}
}
impl<B: BitBuf> ParselyRead<B> for OneByteHeaderExtension {
type Ctx = ();
fn read<T: ByteOrder>(buf: &mut B, _ctx: Self::Ctx) -> ParselyResult<Self> {
let id = buf.get_u4().context("id")?;
let data_length_bytes = match id {
i if i == 0 => {
let _ = buf.get_u4();
0
}
_ => {
let length: usize = buf.get_u4().context("length")?.into();
length + 1
}
};
if buf.remaining_bytes() < data_length_bytes {
bail!(
"Header extension length was {data_length_bytes} but buffer only has {} bytes remaining",
buf.remaining_bytes()
);
}
let data = Bits::copy_from_bytes(&buf.chunk_bytes()[..data_length_bytes]);
buf.advance_bytes(data_length_bytes);
Ok(OneByteHeaderExtension { id, data })
}
}
impl<B: BitBufMut> ParselyWrite<B> for OneByteHeaderExtension {
type Ctx = ();
fn write<T: ByteOrder>(&self, buf: &mut B, _ctx: Self::Ctx) -> ParselyResult<()> {
buf.put_u4(self.id).context("Writing field 'id'")?;
let data_length_bytes = self.data.len_bytes();
let length_field = u4::try_from(data_length_bytes - 1).context("fitting length in u4")?;
buf.put_u4(length_field).context("Writing field 'length'")?;
buf.try_put_slice_bytes(self.data())
.context("Writing field 'data'")?;
Ok(())
}
}
impl_stateless_sync!(OneByteHeaderExtension);
#[derive(Debug, PartialEq)]
pub struct TwoByteHeaderExtension {
id: u8,
data: Bits,
}
impl TwoByteHeaderExtension {
const TYPE_MASK: u16 = 0xFFF0;
pub const TYPE: u16 = 0x1000;
pub fn type_matches(ext_type: u16) -> bool {
(ext_type & Self::TYPE_MASK) == Self::TYPE
}
pub fn id(&self) -> u8 {
self.id
}
pub fn data(&self) -> &[u8] {
self.data.chunk_bytes()
}
}
impl From<TwoByteHeaderExtension> for SomeHeaderExtension {
fn from(value: TwoByteHeaderExtension) -> Self {
SomeHeaderExtension::TwoByteHeaderExtension(value)
}
}
impl<B: BitBuf> ParselyRead<B> for TwoByteHeaderExtension {
type Ctx = ();
fn read<T: ByteOrder>(buf: &mut B, _ctx: Self::Ctx) -> ParselyResult<Self> {
let id = buf.get_u8().context("id")?;
let data_length_bytes = match id {
0 => 0,
_ => buf.get_u8().context("length")?,
} as usize;
if buf.remaining_bytes() < data_length_bytes {
bail!(
"Header extension length was {data_length_bytes} but buffer only has {} bytes remaining",
buf.remaining_bytes()
);
}
let data = Bits::copy_from_bytes(&buf.chunk_bytes()[..data_length_bytes]);
buf.advance_bytes(data_length_bytes);
Ok(TwoByteHeaderExtension { id, data })
}
}
impl<B: BitBufMut> ParselyWrite<B> for TwoByteHeaderExtension {
type Ctx = ();
fn write<T: ByteOrder>(&self, buf: &mut B, _ctx: Self::Ctx) -> ParselyResult<()> {
buf.put_u8(self.id()).context("Writing field 'id'")?;
let data_length_bytes = self.data().len();
buf.put_u8(data_length_bytes as u8)
.context("Writing field 'length'")?;
buf.try_put_slice_bytes(self.data())
.context("Writing field 'data'")?;
Ok(())
}
}
impl_stateless_sync!(TwoByteHeaderExtension);
#[derive(Debug, PartialEq)]
pub enum SomeHeaderExtension {
OneByteHeaderExtension(OneByteHeaderExtension),
TwoByteHeaderExtension(TwoByteHeaderExtension),
}
impl SomeHeaderExtension {
pub fn id(&self) -> u8 {
match self {
SomeHeaderExtension::OneByteHeaderExtension(e) => e.id().into(),
SomeHeaderExtension::TwoByteHeaderExtension(e) => e.id(),
}
}
pub fn data(&self) -> &[u8] {
match self {
SomeHeaderExtension::OneByteHeaderExtension(e) => e.data(),
SomeHeaderExtension::TwoByteHeaderExtension(e) => e.data(),
}
}
}
impl<B: BitBufMut> ParselyWrite<B> for SomeHeaderExtension {
type Ctx = ();
fn write<T: ByteOrder>(&self, buf: &mut B, _ctx: Self::Ctx) -> ParselyResult<()> {
match self {
SomeHeaderExtension::OneByteHeaderExtension(he) => he.write::<T>(buf, ()),
SomeHeaderExtension::TwoByteHeaderExtension(he) => he.write::<T>(buf, ()),
}
}
}
impl_stateless_sync!(SomeHeaderExtension);
#[derive(Debug, Default, PartialEq)]
pub struct HeaderExtensions(HashMap<u8, SomeHeaderExtension>);
impl HeaderExtensions {
pub fn len(&self) -> usize {
self.0.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn has_one_byte(&self) -> bool {
self.0
.iter()
.any(|(_, he)| matches!(he, SomeHeaderExtension::OneByteHeaderExtension(_)))
}
pub fn has_two_byte(&self) -> bool {
self.0
.iter()
.any(|(_, he)| matches!(he, SomeHeaderExtension::TwoByteHeaderExtension(_)))
}
pub fn add_extension<T: Into<SomeHeaderExtension>>(
&mut self,
ext: T,
) -> Option<SomeHeaderExtension> {
let ext: SomeHeaderExtension = ext.into();
self.0.insert(ext.id(), ext)
}
pub fn remove_extension_by_id(&mut self, id: u8) -> Option<SomeHeaderExtension> {
self.0.remove(&id)
}
pub fn get_by_id(&self, id: u8) -> Option<&SomeHeaderExtension> {
self.0.get(&id)
}
}
impl<'a> IntoIterator for &'a HeaderExtensions {
type Item = (&'a u8, &'a SomeHeaderExtension);
type IntoIter = std::collections::hash_map::Iter<'a, u8, SomeHeaderExtension>;
fn into_iter(self) -> Self::IntoIter {
self.0.iter()
}
}
impl<B: BitBuf> ParselyRead<B> for HeaderExtensions {
type Ctx = ();
fn read<T: ByteOrder>(buf: &mut B, _ctx: Self::Ctx) -> ParselyResult<Self> {
let mut header_extensions = HashMap::new();
let ext_type = buf
.get_u16::<NetworkOrder>()
.context("Reading header extensions profile")?;
let ext_length = buf
.get_u16::<NetworkOrder>()
.context("Reading header extensions length")?;
let ext_length_bytes = (ext_length * 4) as usize;
let mut extensions_buf = buf.take_bytes(ext_length_bytes);
while extensions_buf.has_remaining_bytes() {
let extension = if OneByteHeaderExtension::type_matches(ext_type) {
let id = (&extensions_buf.chunk_bits()[..4]).as_u4();
if id == 0xF {
let _ = extensions_buf.get_u8();
let he = TwoByteHeaderExtension::read::<T>(&mut extensions_buf, ())
.context("One-byte header extension")?;
SomeHeaderExtension::TwoByteHeaderExtension(he)
} else {
let he = OneByteHeaderExtension::read::<T>(&mut extensions_buf, ())
.context("One-byte header extension")?;
SomeHeaderExtension::OneByteHeaderExtension(he)
}
} else if TwoByteHeaderExtension::type_matches(ext_type) {
let he = TwoByteHeaderExtension::read::<T>(&mut extensions_buf, ())
.context("One-byte header extension")?;
SomeHeaderExtension::TwoByteHeaderExtension(he)
} else {
bail!("Encountered invalid header extension block type: {ext_type:x}");
};
if extension.id() != 0 {
header_extensions.insert(extension.id(), extension);
}
}
Ok(HeaderExtensions(header_extensions))
}
}
impl<B: BitBufMut> ParselyWrite<B> for HeaderExtensions {
type Ctx = ();
fn write<T: ByteOrder>(&self, buf: &mut B, _ctx: Self::Ctx) -> ParselyResult<()> {
let len_start = buf.remaining_mut_bytes();
self.0
.values()
.map(|he| he.write::<T>(buf, ()))
.collect::<ParselyResult<Vec<_>>>()
.context("Writing header extensions")?;
while (len_start - buf.remaining_mut_bytes()) % 4 != 0 {
buf.put_u8(0).context("Padding")?;
}
Ok(())
}
}
impl_stateless_sync!(HeaderExtensions);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_one_byte_header_extension_parse() {
#[rustfmt::skip]
let mut bits = Bits::from_static_bytes(&[
0x10, 0xFF, 0x00, 0x00
]);
let he = OneByteHeaderExtension::read::<NetworkOrder>(&mut bits, ()).unwrap();
assert_eq!(he.id(), 1);
assert_eq!(he.data(), &[0xFF]);
}
#[test]
fn test_two_byte_header_extension_parse() {
#[rustfmt::skip]
let mut bits = Bits::from_static_bytes(&[
0x01, 0x01, 0xFF, 0x00, 0x00
]);
let he = TwoByteHeaderExtension::read::<NetworkOrder>(&mut bits, ()).unwrap();
assert_eq!(he.id(), 1);
assert_eq!(he.data(), &[0xFF]);
}
#[test]
fn test_header_extensions_parse_all_one_byte() {
#[rustfmt::skip]
let mut bits = Bits::from_static_bytes(&[
0xBE, 0xDE, 0x00, 0x02,
0x10, 0xFF, 0x00, 0x00,
0x21, 0xDE, 0xAD, 0x00
]);
let exts = HeaderExtensions::read::<NetworkOrder>(&mut bits, ()).unwrap();
assert_eq!(exts.len(), 2);
let ext1 = exts.get_by_id(1).unwrap();
assert_eq!(ext1.data(), &[0xFF]);
let ext2 = exts.get_by_id(2).unwrap();
assert_eq!(ext2.data(), &[0xDE, 0xAD]);
}
#[test]
fn test_header_extensions_parse_all_two_byte() {
#[rustfmt::skip]
let mut bits = Bits::from_static_bytes(&[
0x10, 0x00, 0x00, 0x03,
0x07, 0x04, 0xDE, 0xAD,
0xBE, 0xEF, 0x04, 0x01,
0x42, 0x00, 0x00, 0x00,
]);
let exts = HeaderExtensions::read::<NetworkOrder>(&mut bits, ()).unwrap();
assert_eq!(exts.len(), 2);
let ext7 = exts.get_by_id(7).unwrap();
assert_eq!(ext7.data(), &[0xDE, 0xAD, 0xBE, 0xEF]);
let ext4 = exts.get_by_id(4).unwrap();
assert_eq!(ext4.data(), &[0x42]);
}
#[test]
fn test_header_extensions_parse_mixed() {
#[rustfmt::skip]
let mut bits = Bits::from_static_bytes(&[
0xBE, 0xDE, 0x00, 0x04,
0x10, 0xFF, 0x00, 0x00,
0xF0, 0x07, 0x04, 0xDE,
0xAD, 0xBE, 0xEF, 0xF0,
0x04, 0x01, 0x42, 0x00,
]);
let exts = HeaderExtensions::read::<NetworkOrder>(&mut bits, ()).unwrap();
assert_eq!(exts.len(), 3);
let ext = exts.get_by_id(1).unwrap();
assert_eq!(ext.data(), &[0xFF]);
let ext = exts.get_by_id(7).unwrap();
assert_eq!(ext.data(), &[0xDE, 0xAD, 0xBE, 0xEF]);
let ext = exts.get_by_id(4).unwrap();
assert_eq!(ext.data(), &[0x42]);
}
}