use crate::error::Error;
use util::marshal::{Marshal, MarshalSize, Unmarshal};
use anyhow::Result;
use bytes::{Buf, BufMut};
#[derive(Debug, Copy, Clone, PartialEq)]
#[repr(u8)]
pub enum PacketType {
Unsupported = 0,
SenderReport = 200, ReceiverReport = 201, SourceDescription = 202, Goodbye = 203, ApplicationDefined = 204, TransportSpecificFeedback = 205, PayloadSpecificFeedback = 206, }
impl Default for PacketType {
fn default() -> Self {
PacketType::Unsupported
}
}
pub const FORMAT_SLI: u8 = 2;
pub const FORMAT_PLI: u8 = 1;
pub const FORMAT_FIR: u8 = 4;
pub const FORMAT_TLN: u8 = 1;
pub const FORMAT_RRR: u8 = 5;
pub const FORMAT_REMB: u8 = 15;
pub const FORMAT_TCC: u8 = 15;
impl std::fmt::Display for PacketType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = match self {
PacketType::Unsupported => "Unsupported",
PacketType::SenderReport => "SR",
PacketType::ReceiverReport => "RR",
PacketType::SourceDescription => "SDES",
PacketType::Goodbye => "BYE",
PacketType::ApplicationDefined => "APP",
PacketType::TransportSpecificFeedback => "TSFB",
PacketType::PayloadSpecificFeedback => "PSFB",
};
write!(f, "{}", s)
}
}
impl From<u8> for PacketType {
fn from(b: u8) -> Self {
match b {
200 => PacketType::SenderReport, 201 => PacketType::ReceiverReport, 202 => PacketType::SourceDescription, 203 => PacketType::Goodbye, 204 => PacketType::ApplicationDefined, 205 => PacketType::TransportSpecificFeedback, 206 => PacketType::PayloadSpecificFeedback, _ => PacketType::Unsupported,
}
}
}
pub const RTP_VERSION: u8 = 2;
pub const VERSION_SHIFT: u8 = 6;
pub const VERSION_MASK: u8 = 0x3;
pub const PADDING_SHIFT: u8 = 5;
pub const PADDING_MASK: u8 = 0x1;
pub const COUNT_SHIFT: u8 = 0;
pub const COUNT_MASK: u8 = 0x1f;
pub const HEADER_LENGTH: usize = 4;
pub const COUNT_MAX: usize = (1 << 5) - 1;
pub const SSRC_LENGTH: usize = 4;
pub const SDES_MAX_OCTET_COUNT: usize = (1 << 8) - 1;
#[derive(Debug, PartialEq, Default, Clone)]
pub struct Header {
pub padding: bool,
pub count: u8,
pub packet_type: PacketType,
pub length: u16,
}
impl MarshalSize for Header {
fn marshal_size(&self) -> usize {
HEADER_LENGTH
}
}
impl Marshal for Header {
fn marshal_to(&self, mut buf: &mut [u8]) -> Result<usize> {
if self.count > 31 {
return Err(Error::InvalidHeader.into());
}
if buf.remaining_mut() < HEADER_LENGTH {
return Err(Error::BufferTooShort.into());
}
let b0 = (RTP_VERSION << VERSION_SHIFT)
| ((self.padding as u8) << PADDING_SHIFT)
| (self.count << COUNT_SHIFT);
buf.put_u8(b0);
buf.put_u8(self.packet_type as u8);
buf.put_u16(self.length);
Ok(HEADER_LENGTH)
}
}
impl Unmarshal for Header {
fn unmarshal<B>(raw_packet: &mut B) -> Result<Self>
where
Self: Sized,
B: Buf,
{
if raw_packet.remaining() < HEADER_LENGTH {
return Err(Error::PacketTooShort.into());
}
let b0 = raw_packet.get_u8();
let version = (b0 >> VERSION_SHIFT) & VERSION_MASK;
if version != RTP_VERSION {
return Err(Error::BadVersion.into());
}
let padding = ((b0 >> PADDING_SHIFT) & PADDING_MASK) > 0;
let count = (b0 >> COUNT_SHIFT) & COUNT_MASK;
let packet_type = PacketType::from(raw_packet.get_u8());
let length = raw_packet.get_u16();
Ok(Header {
padding,
count,
packet_type,
length,
})
}
}
#[cfg(test)]
mod test {
use super::*;
use bytes::Bytes;
#[test]
fn test_header_unmarshal() {
let tests = vec![
(
"valid",
Bytes::from_static(&[
0x81u8, 0xc9, 0x00, 0x07,
]),
Header {
padding: false,
count: 1,
packet_type: PacketType::ReceiverReport,
length: 7,
},
None,
),
(
"also valid",
Bytes::from_static(&[
0xa1, 0xcc, 0x00, 0x07,
]),
Header {
padding: true,
count: 1,
packet_type: PacketType::ApplicationDefined,
length: 7,
},
None,
),
(
"bad version",
Bytes::from_static(&[
0x00, 0xc9, 0x00, 0x04,
]),
Header {
padding: false,
count: 0,
packet_type: PacketType::Unsupported,
length: 0,
},
Some(Error::BadVersion),
),
];
for (name, data, want, want_error) in tests {
let buf = &mut data.clone();
let got = Header::unmarshal(buf);
assert_eq!(
got.is_err(),
want_error.is_some(),
"Unmarshal {}: err = {:?}, want {:?}",
name,
got,
want_error
);
if let Some(err) = want_error {
let got_err = got.err().unwrap();
assert!(
err.equal(&got_err),
"Unmarshal {}: err = {:?}, want {:?}",
name,
got_err,
err,
);
} else {
let actual = got.unwrap();
assert_eq!(
actual, want,
"Unmarshal {}: got {:?}, want {:?}",
name, actual, want
);
}
}
}
#[test]
fn test_header_roundtrip() {
let tests = vec![
(
"valid",
Header {
padding: true,
count: 31,
packet_type: PacketType::SenderReport,
length: 4,
},
None,
),
(
"also valid",
Header {
padding: false,
count: 28,
packet_type: PacketType::ReceiverReport,
length: 65535,
},
None,
),
(
"invalid count",
Header {
padding: false,
count: 40,
packet_type: PacketType::Unsupported,
length: 0,
},
Some(Error::InvalidHeader),
),
];
for (name, want, want_error) in tests {
let got = want.marshal();
assert_eq!(
got.is_ok(),
want_error.is_none(),
"Marshal {}: err = {:?}, want {:?}",
name,
got,
want_error
);
if let Some(err) = want_error {
let got_err = got.err().unwrap();
assert!(
err.equal(&got_err),
"Unmarshal {} rr: err = {:?}, want {:?}",
name,
got_err,
err,
);
} else {
let data = got.ok().unwrap();
let buf = &mut data.clone();
let actual = Header::unmarshal(buf).expect(format!("Unmarshal {}", name).as_str());
assert_eq!(
actual, want,
"{} round trip: got {:?}, want {:?}",
name, actual, want
)
}
}
}
}