use super::{
consts::*, E2eeExt, ExtensionTag, Extensions, FrameMarker, Handle, HandleError, Header, Packet,
Timestamp, UserTimestampExt,
};
use bytes::{Buf, Bytes};
use thiserror::Error;
#[derive(Error, Debug)]
pub enum DeserializeError {
#[error("too short to contain a valid header")]
TooShort,
#[error("header exceeds total packet length")]
HeaderOverrun,
#[error("extension word indicator is missing")]
MissingExtWords,
#[error("unsupported version {0}")]
UnsupportedVersion(u8),
#[error("invalid track handle: {0}")]
InvalidHandle(#[from] HandleError),
#[error("extension with tag {0} is malformed")]
MalformedExt(ExtensionTag),
}
impl Packet {
pub fn deserialize(mut raw: Bytes) -> Result<Self, DeserializeError> {
let header = Header::deserialize(&mut raw)?;
let payload_len = raw.remaining();
let payload = raw.copy_to_bytes(payload_len);
Ok(Self { header, payload })
}
}
impl Header {
fn deserialize(raw: &mut impl Buf) -> Result<Self, DeserializeError> {
if raw.remaining() < BASE_HEADER_LEN {
Err(DeserializeError::TooShort)?
}
let initial = raw.get_u8();
let version = initial >> VERSION_SHIFT & VERSION_MASK;
if version > SUPPORTED_VERSION {
Err(DeserializeError::UnsupportedVersion(version))?
}
let marker = match initial >> FRAME_MARKER_SHIFT & FRAME_MARKER_MASK {
FRAME_MARKER_START => FrameMarker::Start,
FRAME_MARKER_FINAL => FrameMarker::Final,
FRAME_MARKER_SINGLE => FrameMarker::Single,
_ => FrameMarker::Inter,
};
let ext_flag = (initial >> EXT_FLAG_SHIFT & EXT_FLAG_MASK) > 0;
raw.advance(1);
let track_handle: Handle = raw.get_u16().try_into()?;
let sequence = raw.get_u16();
let frame_number = raw.get_u16();
let timestamp = Timestamp::from_ticks(raw.get_u32());
let mut extensions = Extensions::default();
if ext_flag {
if raw.remaining() < 2 {
Err(DeserializeError::MissingExtWords)?;
}
let ext_words = raw.get_u16();
let ext_len = 4 * (ext_words as usize + 1) - EXT_WORDS_INDICATOR_SIZE;
if ext_len > raw.remaining() {
Err(DeserializeError::HeaderOverrun)?
}
let ext_block = raw.copy_to_bytes(ext_len);
extensions = Extensions::deserialize(ext_block)?;
}
let header = Header { marker, track_handle, sequence, frame_number, timestamp, extensions };
Ok(header)
}
}
macro_rules! deserialize_ext {
($ext_type:ty, $raw:expr, $len:expr) => {{
if $raw.remaining() < $len {
Err(DeserializeError::MalformedExt(<$ext_type>::TAG))?
}
let mut buf = [0u8; <$ext_type>::LEN];
$raw.copy_to_slice(&mut buf);
let extra_bytes = $len - <$ext_type>::LEN;
if extra_bytes > 0 {
$raw.advance(extra_bytes);
}
Some(<$ext_type>::deserialize(buf))
}};
}
impl Extensions {
fn deserialize(mut raw: impl Buf) -> Result<Self, DeserializeError> {
let mut extensions = Self::default();
while raw.remaining() >= 2 * size_of::<u8>() {
let tag = raw.get_u8();
let len = raw.get_u8() as usize;
match tag {
EXT_TAG_PADDING => {} E2eeExt::TAG if len >= E2eeExt::LEN => {
extensions.e2ee = deserialize_ext!(E2eeExt, raw, len);
}
UserTimestampExt::TAG if len >= UserTimestampExt::LEN => {
extensions.user_timestamp = deserialize_ext!(UserTimestampExt, raw, len);
}
_ => {
if raw.remaining() < len {
Err(DeserializeError::MalformedExt(tag))?
}
raw.advance(len);
continue;
}
}
}
Ok(extensions)
}
}
impl UserTimestampExt {
fn deserialize(raw: [u8; Self::LEN]) -> Self {
let timestamp = u64::from_be_bytes(raw);
Self(timestamp)
}
}
impl E2eeExt {
fn deserialize(raw: [u8; Self::LEN]) -> Self {
let key_index = raw[0];
let mut iv = [0u8; 12];
iv.copy_from_slice(&raw[1..13]);
Self { key_index, iv }
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::{BufMut, BytesMut};
use test_case::test_matrix;
fn valid_packet() -> BytesMut {
let mut raw = BytesMut::zeroed(12); raw[3] = 1; raw
}
#[test]
fn test_short_buffer() {
let mut raw = valid_packet();
raw.truncate(11);
let packet = Packet::deserialize(raw.freeze());
assert!(matches!(packet, Err(DeserializeError::TooShort)));
}
#[test]
fn test_missing_ext_words() {
let mut raw = valid_packet();
raw[0] |= 1 << EXT_FLAG_SHIFT;
let packet = Packet::deserialize(raw.freeze());
assert!(matches!(packet, Err(DeserializeError::MissingExtWords)));
}
#[test]
fn test_header_overrun() {
let mut raw = valid_packet();
raw[0] |= 1 << EXT_FLAG_SHIFT; raw.put_u16(1);
let packet = Packet::deserialize(raw.freeze());
assert!(matches!(packet, Err(DeserializeError::HeaderOverrun)));
}
#[test]
fn test_unsupported_version() {
let mut raw = valid_packet();
raw[0] = 0x20;
let packet = Packet::deserialize(raw.freeze());
assert!(matches!(packet, Err(DeserializeError::UnsupportedVersion(1))));
}
#[test]
fn test_base_header() {
let mut raw = BytesMut::new();
raw.put_u8(0x8); raw.put_u8(0x0); raw.put_slice(&[0x88, 0x11]); raw.put_slice(&[0x44, 0x22]); raw.put_slice(&[0x44, 0x11]); raw.put_slice(&[0x44, 0x22, 0x11, 0x88]);
let packet = Packet::deserialize(raw.freeze()).unwrap();
assert_eq!(packet.header.marker, FrameMarker::Final);
assert_eq!(packet.header.track_handle, 0x8811u32.try_into().unwrap());
assert_eq!(packet.header.sequence, 0x4422);
assert_eq!(packet.header.frame_number, 0x4411);
assert_eq!(packet.header.timestamp, Timestamp::from_ticks(0x44221188));
assert_eq!(packet.header.extensions.user_timestamp, None);
assert_eq!(packet.header.extensions.e2ee, None);
}
#[test_matrix([0, 1, 24])]
fn test_ext_skips_padding(ext_words: usize) {
let mut raw = valid_packet();
raw[0] |= 1 << EXT_FLAG_SHIFT;
raw.put_u16(ext_words as u16); let data_len = (ext_words + 1) * 4 - EXT_WORDS_INDICATOR_SIZE;
raw.put_bytes(0, data_len);
let packet = Packet::deserialize(raw.freeze()).unwrap();
assert_eq!(packet.payload.len(), 0);
}
#[test]
fn test_ext_e2ee() {
let mut raw = valid_packet();
raw[0] |= 1 << EXT_FLAG_SHIFT; raw.put_u16(4);
raw.put_u8(1); raw.put_u8(13); raw.put_u8(0xFA); raw.put_bytes(0x3C, 12); raw.put_bytes(0, 3);
let packet = Packet::deserialize(raw.freeze()).unwrap();
let e2ee = packet.header.extensions.e2ee.unwrap();
assert_eq!(e2ee.key_index, 0xFA);
assert_eq!(e2ee.iv, [0x3C; 12]);
}
#[test]
fn test_ext_user_timestamp() {
let mut raw = valid_packet();
raw[0] |= 1 << EXT_FLAG_SHIFT; raw.put_u16(2);
raw.put_u8(2);
raw.put_u8(8); raw.put_slice(&[0x44, 0x11, 0x22, 0x11, 0x11, 0x11, 0x88, 0x11]);
let packet = Packet::deserialize(raw.freeze()).unwrap();
assert_eq!(
packet.header.extensions.user_timestamp,
UserTimestampExt(0x4411221111118811).into()
);
}
#[test]
fn test_ext_forward_compat_longer_length() {
let mut raw = valid_packet();
raw[0] |= 1 << EXT_FLAG_SHIFT; raw.put_u16(3);
raw.put_u8(2); raw.put_u8(12); raw.put_slice(&[0x44, 0x11, 0x22, 0x11, 0x11, 0x11, 0x88, 0x11]); raw.put_bytes(0xFF, 4);
let packet = Packet::deserialize(raw.freeze()).unwrap();
assert_eq!(
packet.header.extensions.user_timestamp,
UserTimestampExt(0x4411221111118811).into()
);
}
#[test]
fn test_ext_shorter_than_known_length_skipped() {
let mut raw = valid_packet();
raw[0] |= 1 << EXT_FLAG_SHIFT; raw.put_u16(1);
raw.put_u8(2); raw.put_u8(4); raw.put_bytes(0x3C, 4);
let packet = Packet::deserialize(raw.freeze()).unwrap();
assert!(packet.header.extensions.user_timestamp.is_none());
}
#[test]
fn test_ext_unknown() {
let mut raw = valid_packet();
raw[0] |= 1 << EXT_FLAG_SHIFT; raw.put_u16(1);
raw.put_u8(8); raw.put_u8(0); raw.put_bytes(0, 4);
Packet::deserialize(raw.freeze()).expect("Should skip unknown extension");
}
#[test]
fn test_ext_required_word_alignment() {
let mut raw = valid_packet();
raw[0] |= 1 << EXT_FLAG_SHIFT; raw.put_u16(0); raw.put_bytes(0, 1);
assert!(Packet::deserialize(raw.freeze()).is_err());
}
}