use std::{collections::HashMap, io::Cursor};
use bytes::{Buf, Bytes};
#[derive(Debug)]
pub struct OneByteHeaderExtension(Bytes);
impl OneByteHeaderExtension {
pub const TYPE: u16 = 0xBEDE;
pub fn type_matches(ext_type: u16) -> bool {
ext_type == Self::TYPE
}
pub fn id(&self) -> u8 {
(self.0[0] & 0xF0) >> 4
}
pub fn data(&self) -> Bytes {
self.0.slice(1..)
}
}
pub fn read_one_byte_header_extension(buf: &mut Bytes) -> OneByteHeaderExtension {
let id = (buf[0] & 0xF0) >> 4;
let length_bytes = match id {
0 => buf.len() - 1,
_ => ((buf[0] & 0xF) + 1) as usize,
};
let he = buf.split_to(1 + length_bytes);
OneByteHeaderExtension(he)
}
#[derive(Debug)]
pub struct TwoByteHeaderExtension(Bytes);
impl TwoByteHeaderExtension {
const TYPE_MASK: u16 = 0xFFF0;
pub const TYPE: u16 = 0x1000;
pub fn type_matches(ext_type: u16) -> bool {
(ext_type & Self::TYPE_MASK) == Self::TYPE
}
pub fn id(&self) -> u8 {
self.0[0]
}
pub fn data(&self) -> Bytes {
self.0.slice(2..)
}
}
pub fn read_two_byte_header_extension(buf: &mut Bytes) -> TwoByteHeaderExtension {
let id = buf[0];
let length_bytes = match id {
0 => 0,
_ => buf[1] + 1,
};
let he = buf.split_to(2 + length_bytes as usize);
TwoByteHeaderExtension(he)
}
#[derive(Debug)]
pub enum SomeHeaderExtension {
OneByteHeaderExtension(OneByteHeaderExtension),
TwoByteHeaderExtension(TwoByteHeaderExtension),
}
impl SomeHeaderExtension {
pub fn id(&self) -> u8 {
match self {
SomeHeaderExtension::OneByteHeaderExtension(e) => e.id(),
SomeHeaderExtension::TwoByteHeaderExtension(e) => e.id(),
}
}
pub fn data(&self) -> Bytes {
match self {
SomeHeaderExtension::OneByteHeaderExtension(e) => e.data(),
SomeHeaderExtension::TwoByteHeaderExtension(e) => e.data(),
}
}
}
pub fn read_header_extensions(buf: Bytes) -> HashMap<u8, SomeHeaderExtension> {
let mut cursor = Cursor::new(buf);
let ext_type = cursor.get_u16();
let length_bytes = cursor.get_u16() * 4;
let buf = cursor.into_inner();
let mut header_extensions_bytes = buf.slice(4..).slice(..length_bytes as usize);
let mut header_extensions: HashMap<u8, SomeHeaderExtension> = HashMap::new();
while !header_extensions_bytes.is_empty() {
let ext = if TwoByteHeaderExtension::type_matches(ext_type) {
SomeHeaderExtension::TwoByteHeaderExtension(read_two_byte_header_extension(
&mut header_extensions_bytes,
))
} else if OneByteHeaderExtension::type_matches(ext_type) {
SomeHeaderExtension::OneByteHeaderExtension(read_one_byte_header_extension(
&mut header_extensions_bytes,
))
} else {
panic!("Invalid header extension type: {ext_type:x?}");
};
header_extensions.insert(ext.id(), ext);
}
header_extensions
}
#[cfg(test)]
mod test {
use bytes::Bytes;
use super::read_header_extensions;
#[test]
fn test_one_byte_header_extensions() {
#[rustfmt::skip]
let data: Vec<u8> = vec![
0xBE, 0xDE, 0x00, 0x01,
0x10, 0xFF, 0x00, 0x00
];
let bytes = Bytes::from(data);
let he = read_header_extensions(bytes);
assert_eq!(he.len(), 2);
let ext_one = he
.get(&1)
.expect("should contain a header extension with ID 1");
assert_eq!(ext_one.data(), Bytes::from_static(&[0xFF]));
}
#[test]
fn test_one_byte_header_extensions_one_byte_padding() {
#[rustfmt::skip]
let data: Vec<u8> = vec![
0xBE, 0xDE, 0x00, 0x01,
0x51, 0x00, 0x01, 0x00
];
let bytes = Bytes::from(data);
let he = read_header_extensions(bytes);
assert_eq!(he.len(), 2);
let ext_one = he
.get(&5)
.expect("should contain a header extension with ID 1");
assert_eq!(ext_one.data(), Bytes::from_static(&[0x00, 0x01]));
}
#[test]
fn test_one_byte_header_extensions_two_bytes_padding() {
#[rustfmt::skip]
let data: Vec<u8> = vec![
0xBE, 0xDE, 0x00, 0x01,
0x10, 0xFF, 0x00, 0x00
];
let bytes = Bytes::from(data);
let he = read_header_extensions(bytes);
assert_eq!(he.len(), 2);
let ext_one = he
.get(&1)
.expect("should contain a header extension with ID 1");
assert_eq!(ext_one.data(), Bytes::from_static(&[0xFF]));
}
}