use shared::{
error::{Error, Result},
marshal::{Marshal, MarshalSize, Unmarshal},
};
use bytes::{Buf, BufMut, Bytes};
pub const HEADER_LENGTH: usize = 4;
pub const VERSION_SHIFT: u8 = 6;
pub const VERSION_MASK: u8 = 0x3;
pub const PADDING_SHIFT: u8 = 5;
pub const PADDING_MASK: u8 = 0x1;
pub const EXTENSION_SHIFT: u8 = 4;
pub const EXTENSION_MASK: u8 = 0x1;
pub const EXTENSION_PROFILE_ONE_BYTE: u16 = 0xBEDE;
pub const EXTENSION_PROFILE_TWO_BYTE: u16 = 0x1000;
pub const EXTENSION_ID_RESERVED: u8 = 0xF;
pub const CC_MASK: u8 = 0xF;
pub const MARKER_SHIFT: u8 = 7;
pub const MARKER_MASK: u8 = 0x1;
pub const PT_MASK: u8 = 0x7F;
pub const SEQ_NUM_OFFSET: usize = 2;
pub const SEQ_NUM_LENGTH: usize = 2;
pub const TIMESTAMP_OFFSET: usize = 4;
pub const TIMESTAMP_LENGTH: usize = 4;
pub const SSRC_OFFSET: usize = 8;
pub const SSRC_LENGTH: usize = 4;
pub const CSRC_OFFSET: usize = 12;
pub const CSRC_LENGTH: usize = 4;
#[derive(Debug, Eq, PartialEq, Default, Clone)]
pub struct Extension {
pub id: u8,
pub payload: Bytes,
}
#[derive(Debug, Eq, PartialEq, Default, Clone)]
pub struct Header {
pub version: u8,
pub padding: bool,
pub extension: bool,
pub marker: bool,
pub payload_type: u8,
pub sequence_number: u16,
pub timestamp: u32,
pub ssrc: u32,
pub csrc: Vec<u32>,
pub extension_profile: u16,
pub extensions: Vec<Extension>,
}
impl Unmarshal for Header {
fn unmarshal<B>(raw_packet: &mut B) -> Result<Self>
where
Self: Sized,
B: Buf,
{
let raw_packet_len = raw_packet.remaining();
if raw_packet_len < HEADER_LENGTH {
return Err(Error::ErrHeaderSizeInsufficient);
}
let b0 = raw_packet.get_u8();
let version = b0 >> VERSION_SHIFT & VERSION_MASK;
let padding = (b0 >> PADDING_SHIFT & PADDING_MASK) > 0;
let extension = (b0 >> EXTENSION_SHIFT & EXTENSION_MASK) > 0;
let cc = (b0 & CC_MASK) as usize;
let mut curr_offset = CSRC_OFFSET + (cc * CSRC_LENGTH);
if raw_packet_len < curr_offset {
return Err(Error::ErrHeaderSizeInsufficient);
}
let b1 = raw_packet.get_u8();
let marker = (b1 >> MARKER_SHIFT & MARKER_MASK) > 0;
let payload_type = b1 & PT_MASK;
let sequence_number = raw_packet.get_u16();
let timestamp = raw_packet.get_u32();
let ssrc = raw_packet.get_u32();
let mut csrc = Vec::with_capacity(cc);
for _ in 0..cc {
csrc.push(raw_packet.get_u32());
}
let (extension_profile, extensions) = if extension {
let expected = curr_offset + 4;
if raw_packet_len < expected {
return Err(Error::ErrHeaderSizeInsufficientForExtension);
}
let extension_profile = raw_packet.get_u16();
curr_offset += 2;
let extension_length = raw_packet.get_u16() as usize * 4;
curr_offset += 2;
let expected = curr_offset + extension_length;
if raw_packet_len < expected {
return Err(Error::ErrHeaderSizeInsufficientForExtension);
}
let mut extensions = vec![];
match extension_profile {
EXTENSION_PROFILE_ONE_BYTE => {
let end = curr_offset + extension_length;
while curr_offset < end {
let b = raw_packet.get_u8();
if b == 0x00 {
curr_offset += 1;
continue;
}
let extid = b >> 4;
let len = ((b & (0xFF ^ 0xF0)) + 1) as usize;
curr_offset += 1;
if extid == EXTENSION_ID_RESERVED {
break;
}
extensions.push(Extension {
id: extid,
payload: raw_packet.copy_to_bytes(len),
});
curr_offset += len;
}
}
EXTENSION_PROFILE_TWO_BYTE => {
let end = curr_offset + extension_length;
while curr_offset < end {
let b = raw_packet.get_u8();
if b == 0x00 {
curr_offset += 1;
continue;
}
let extid = b;
curr_offset += 1;
let len = raw_packet.get_u8() as usize;
curr_offset += 1;
extensions.push(Extension {
id: extid,
payload: raw_packet.copy_to_bytes(len),
});
curr_offset += len;
}
}
_ => {
if raw_packet_len < curr_offset + extension_length {
return Err(Error::ErrHeaderSizeInsufficientForExtension);
}
extensions.push(Extension {
id: 0,
payload: raw_packet.copy_to_bytes(extension_length),
});
}
};
(extension_profile, extensions)
} else {
(0, vec![])
};
Ok(Header {
version,
padding,
extension,
marker,
payload_type,
sequence_number,
timestamp,
ssrc,
csrc,
extension_profile,
extensions,
})
}
}
impl MarshalSize for Header {
fn marshal_size(&self) -> usize {
let mut head_size = 12 + (self.csrc.len() * CSRC_LENGTH);
if self.extension {
let extension_payload_len = self.get_extension_payload_len();
let extension_payload_size = (extension_payload_len + 3) / 4;
head_size += 4 + extension_payload_size * 4;
}
head_size
}
}
impl Marshal for Header {
fn marshal_to(&self, mut buf: &mut [u8]) -> Result<usize> {
let remaining_before = buf.remaining_mut();
if remaining_before < self.marshal_size() {
return Err(Error::ErrBufferTooSmall);
}
let mut b0 = (self.version << VERSION_SHIFT) | self.csrc.len() as u8;
if self.padding {
b0 |= 1 << PADDING_SHIFT;
}
if self.extension {
b0 |= 1 << EXTENSION_SHIFT;
}
buf.put_u8(b0);
let mut b1 = self.payload_type;
if self.marker {
b1 |= 1 << MARKER_SHIFT;
}
buf.put_u8(b1);
buf.put_u16(self.sequence_number);
buf.put_u32(self.timestamp);
buf.put_u32(self.ssrc);
for csrc in &self.csrc {
buf.put_u32(*csrc);
}
if self.extension {
buf.put_u16(self.extension_profile);
let extension_payload_len = self.get_extension_payload_len();
if self.extension_profile != EXTENSION_PROFILE_ONE_BYTE
&& self.extension_profile != EXTENSION_PROFILE_TWO_BYTE
&& extension_payload_len % 4 != 0
{
return Err(Error::HeaderExtensionPayloadNot32BitWords);
}
let extension_payload_size = (extension_payload_len as u16 + 3) / 4;
buf.put_u16(extension_payload_size);
match self.extension_profile {
EXTENSION_PROFILE_ONE_BYTE => {
for extension in &self.extensions {
buf.put_u8((extension.id << 4) | (extension.payload.len() as u8 - 1));
buf.put(&*extension.payload);
}
}
EXTENSION_PROFILE_TWO_BYTE => {
for extension in &self.extensions {
buf.put_u8(extension.id);
buf.put_u8(extension.payload.len() as u8);
buf.put(&*extension.payload);
}
}
_ => {
if self.extensions.len() != 1 {
return Err(Error::ErrRfc3550headerIdrange);
}
if let Some(extension) = self.extensions.first() {
let ext_len = extension.payload.len();
if ext_len % 4 != 0 {
return Err(Error::HeaderExtensionPayloadNot32BitWords);
}
buf.put(&*extension.payload);
}
}
};
for _ in extension_payload_len..extension_payload_size as usize * 4 {
buf.put_u8(0);
}
}
let remaining_after = buf.remaining_mut();
Ok(remaining_before - remaining_after)
}
}
impl Header {
pub fn get_extension_payload_len(&self) -> usize {
let payload_len: usize = self
.extensions
.iter()
.map(|extension| extension.payload.len())
.sum();
let profile_len = self.extensions.len()
* match self.extension_profile {
EXTENSION_PROFILE_ONE_BYTE => 1,
EXTENSION_PROFILE_TWO_BYTE => 2,
_ => 0,
};
payload_len + profile_len
}
pub fn set_extension(&mut self, id: u8, payload: Bytes) -> Result<()> {
if self.extension {
match self.extension_profile {
EXTENSION_PROFILE_ONE_BYTE => {
if !(1..=14).contains(&id) {
return Err(Error::ErrRfc8285oneByteHeaderIdrange);
}
if payload.len() > 16 {
return Err(Error::ErrRfc8285oneByteHeaderSize);
}
}
EXTENSION_PROFILE_TWO_BYTE => {
if id < 1 {
return Err(Error::ErrRfc8285twoByteHeaderIdrange);
}
if payload.len() > 255 {
return Err(Error::ErrRfc8285twoByteHeaderSize);
}
}
_ => {
if id != 0 {
return Err(Error::ErrRfc3550headerIdrange);
}
}
};
if let Some(extension) = self
.extensions
.iter_mut()
.find(|extension| extension.id == id)
{
extension.payload = payload;
} else {
self.extensions.push(Extension { id, payload });
}
} else {
self.extension = true;
self.extension_profile = match payload.len() {
0..=16 => EXTENSION_PROFILE_ONE_BYTE,
17..=255 => EXTENSION_PROFILE_TWO_BYTE,
_ => self.extension_profile,
};
self.extensions.push(Extension { id, payload });
}
Ok(())
}
pub fn get_extension_ids(&self) -> Vec<u8> {
if self.extension {
self.extensions.iter().map(|e| e.id).collect()
} else {
vec![]
}
}
pub fn get_extension(&self, id: u8) -> Option<Bytes> {
if self.extension {
self.extensions
.iter()
.find(|extension| extension.id == id)
.map(|extension| extension.payload.clone())
} else {
None
}
}
pub fn del_extension(&mut self, id: u8) -> Result<()> {
if self.extension {
if let Some(index) = self
.extensions
.iter()
.position(|extension| extension.id == id)
{
self.extensions.remove(index);
Ok(())
} else {
Err(Error::ErrHeaderExtensionNotFound)
}
} else {
Err(Error::ErrHeaderExtensionsNotEnabled)
}
}
}