use crate::error::{NetError, NetResult};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use std::collections::HashMap;
use std::time::{Duration, Instant};
pub const RTP_VERSION: u8 = 2;
pub const MAX_RTP_PAYLOAD: usize = 1400;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RtpHeader {
pub padding: bool,
pub extension: bool,
pub csrc_count: u8,
pub marker: bool,
pub payload_type: u8,
pub sequence_number: u16,
pub timestamp: u32,
pub ssrc: u32,
pub csrcs: Vec<u32>,
pub extension_data: Option<RtpHeaderExtension>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RtpHeaderExtension {
pub profile: u16,
pub data: Bytes,
}
#[derive(Debug, Clone)]
pub struct RtpPacket {
pub header: RtpHeader,
pub payload: Bytes,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RtcpPacketType {
SenderReport = 200,
ReceiverReport = 201,
SourceDescription = 202,
Goodbye = 203,
ApplicationDefined = 204,
}
#[derive(Debug, Clone)]
pub struct RtcpSenderReport {
pub ssrc: u32,
pub ntp_timestamp_msw: u32,
pub ntp_timestamp_lsw: u32,
pub rtp_timestamp: u32,
pub sender_packet_count: u32,
pub sender_octet_count: u32,
pub report_blocks: Vec<RtcpReportBlock>,
}
#[derive(Debug, Clone)]
pub struct RtcpReceiverReport {
pub ssrc: u32,
pub report_blocks: Vec<RtcpReportBlock>,
}
#[derive(Debug, Clone)]
pub struct RtcpReportBlock {
pub ssrc: u32,
pub fraction_lost: u8,
pub cumulative_packets_lost: i32,
pub extended_highest_sequence: u32,
pub interarrival_jitter: u32,
pub last_sr_timestamp: u32,
pub delay_since_last_sr: u32,
}
#[derive(Debug, Clone)]
pub struct SequenceNumberTracker {
highest_seq: u16,
cycles: u32,
base_seq: u16,
received: u64,
expected_prior: u64,
received_prior: u64,
}
impl SequenceNumberTracker {
#[must_use]
pub fn new(initial_seq: u16) -> Self {
Self {
highest_seq: initial_seq,
cycles: 0,
base_seq: initial_seq,
received: 0,
expected_prior: 0,
received_prior: 0,
}
}
pub fn update(&mut self, seq: u16) {
let delta = seq.wrapping_sub(self.highest_seq);
if delta < 0x8000 {
if seq < self.highest_seq {
self.cycles += 1;
}
self.highest_seq = seq;
} else {
}
self.received += 1;
}
#[must_use]
pub const fn extended_sequence(&self) -> u32 {
((self.cycles as u32) << 16) | (self.highest_seq as u32)
}
#[must_use]
pub fn packets_lost(&self) -> i64 {
let expected = i64::from(self.extended_sequence() - u32::from(self.base_seq)) + 1;
expected - self.received as i64
}
#[must_use]
pub fn fraction_lost(&mut self) -> u8 {
let expected = self.extended_sequence() as u64 - u64::from(self.base_seq);
let expected_interval = expected.saturating_sub(self.expected_prior);
let received_interval = self.received.saturating_sub(self.received_prior);
self.expected_prior = expected;
self.received_prior = self.received;
if expected_interval == 0 || expected_interval <= received_interval {
0
} else {
let lost = expected_interval - received_interval;
((lost << 8) / expected_interval).min(255) as u8
}
}
}
#[derive(Debug, Clone)]
pub struct JitterCalculator {
jitter: f64,
prev_timestamp: Option<u32>,
prev_arrival: Option<Instant>,
clock_rate: u32,
}
impl JitterCalculator {
#[must_use]
pub const fn new(clock_rate: u32) -> Self {
Self {
jitter: 0.0,
prev_timestamp: None,
prev_arrival: None,
clock_rate,
}
}
pub fn update(&mut self, timestamp: u32, arrival_time: Instant) {
if let (Some(prev_ts), Some(prev_arr)) = (self.prev_timestamp, self.prev_arrival) {
let ts_diff = timestamp.wrapping_sub(prev_ts) as i64;
let arr_diff = arrival_time.duration_since(prev_arr).as_secs_f64();
let arr_diff_units = (arr_diff * f64::from(self.clock_rate)) as i64;
let d = (arr_diff_units - ts_diff).abs() as f64;
self.jitter += (d - self.jitter) / 16.0;
}
self.prev_timestamp = Some(timestamp);
self.prev_arrival = Some(arrival_time);
}
#[must_use]
pub const fn jitter(&self) -> u32 {
self.jitter as u32
}
#[must_use]
pub fn jitter_duration(&self) -> Duration {
Duration::from_secs_f64(self.jitter / f64::from(self.clock_rate))
}
}
#[derive(Debug, Clone)]
pub struct RtpStatistics {
pub seq_tracker: SequenceNumberTracker,
pub jitter: JitterCalculator,
pub packets_received: u64,
pub bytes_received: u64,
pub last_timestamp: u32,
pub last_arrival: Instant,
}
impl RtpStatistics {
#[must_use]
pub fn new(initial_seq: u16, clock_rate: u32) -> Self {
Self {
seq_tracker: SequenceNumberTracker::new(initial_seq),
jitter: JitterCalculator::new(clock_rate),
packets_received: 0,
bytes_received: 0,
last_timestamp: 0,
last_arrival: Instant::now(),
}
}
pub fn update(&mut self, packet: &RtpPacket) {
self.seq_tracker.update(packet.header.sequence_number);
self.jitter.update(packet.header.timestamp, Instant::now());
self.packets_received += 1;
self.bytes_received += packet.payload.len() as u64;
self.last_timestamp = packet.header.timestamp;
self.last_arrival = Instant::now();
}
}
#[derive(Debug)]
pub struct RtpSession {
pub ssrc: u32,
pub remote_sources: HashMap<u32, RtpStatistics>,
pub next_seq: u16,
pub packets_sent: u64,
pub bytes_sent: u64,
pub clock_rate: u32,
}
impl RtpSession {
#[must_use]
pub fn new(ssrc: u32, clock_rate: u32) -> Self {
Self {
ssrc,
remote_sources: HashMap::new(),
next_seq: rand::random(),
packets_sent: 0,
bytes_sent: 0,
clock_rate,
}
}
pub fn process_packet(&mut self, packet: &RtpPacket) {
let ssrc = packet.header.ssrc;
let stats = self
.remote_sources
.entry(ssrc)
.or_insert_with(|| RtpStatistics::new(packet.header.sequence_number, self.clock_rate));
stats.update(packet);
}
pub fn next_sequence(&mut self) -> u16 {
let seq = self.next_seq;
self.next_seq = self.next_seq.wrapping_add(1);
seq
}
pub fn record_sent(&mut self, payload_size: usize) {
self.packets_sent += 1;
self.bytes_sent += payload_size as u64;
}
}
impl RtpHeader {
pub const MIN_SIZE: usize = 12;
pub fn parse(data: &[u8]) -> NetResult<(Self, usize)> {
if data.len() < Self::MIN_SIZE {
return Err(NetError::parse(0, "RTP header too short"));
}
let mut cursor = &data[..];
let byte0 = cursor.get_u8();
let version = (byte0 >> 6) & 0x03;
if version != RTP_VERSION {
return Err(NetError::protocol(format!(
"Invalid RTP version: {version}"
)));
}
let padding = (byte0 & 0x20) != 0;
let extension = (byte0 & 0x10) != 0;
let csrc_count = byte0 & 0x0F;
let byte1 = cursor.get_u8();
let marker = (byte1 & 0x80) != 0;
let payload_type = byte1 & 0x7F;
let sequence_number = cursor.get_u16();
let timestamp = cursor.get_u32();
let ssrc = cursor.get_u32();
let mut csrcs = Vec::with_capacity(csrc_count as usize);
for _ in 0..csrc_count {
if cursor.len() < 4 {
return Err(NetError::parse(0, "Not enough data for CSRC"));
}
csrcs.push(cursor.get_u32());
}
let mut header_size = Self::MIN_SIZE + (csrc_count as usize * 4);
let extension_data = if extension {
if cursor.len() < 4 {
return Err(NetError::parse(0, "Not enough data for extension"));
}
let profile = cursor.get_u16();
let length = cursor.get_u16() as usize * 4;
if cursor.len() < length {
return Err(NetError::parse(0, "Not enough data for extension data"));
}
let ext_data = cursor.copy_to_bytes(length);
header_size += 4 + length;
Some(RtpHeaderExtension {
profile,
data: ext_data,
})
} else {
None
};
Ok((
Self {
padding,
extension,
csrc_count,
marker,
payload_type,
sequence_number,
timestamp,
ssrc,
csrcs,
extension_data,
},
header_size,
))
}
pub fn serialize(&self, buf: &mut BytesMut) {
let byte0 = (RTP_VERSION << 6)
| (u8::from(self.padding) << 5)
| (u8::from(self.extension) << 4)
| (self.csrc_count & 0x0F);
buf.put_u8(byte0);
let byte1 = (u8::from(self.marker) << 7) | (self.payload_type & 0x7F);
buf.put_u8(byte1);
buf.put_u16(self.sequence_number);
buf.put_u32(self.timestamp);
buf.put_u32(self.ssrc);
for csrc in &self.csrcs {
buf.put_u32(*csrc);
}
if let Some(ext) = &self.extension_data {
buf.put_u16(ext.profile);
buf.put_u16((ext.data.len() / 4) as u16);
buf.put_slice(&ext.data);
}
}
#[must_use]
pub fn size(&self) -> usize {
let mut size = Self::MIN_SIZE + (self.csrc_count as usize * 4);
if let Some(ext) = &self.extension_data {
size += 4 + ext.data.len();
}
size
}
}
impl RtpPacket {
pub fn parse(data: Bytes) -> NetResult<Self> {
let (header, header_size) = RtpHeader::parse(&data)?;
if data.len() < header_size {
return Err(NetError::parse(0, "Data shorter than header"));
}
let payload = data.slice(header_size..);
Ok(Self { header, payload })
}
pub fn serialize(&self) -> Bytes {
let mut buf = BytesMut::with_capacity(self.header.size() + self.payload.len());
self.header.serialize(&mut buf);
buf.put_slice(&self.payload);
buf.freeze()
}
}
impl RtcpSenderReport {
pub fn parse(data: &[u8]) -> NetResult<Self> {
if data.len() < 28 {
return Err(NetError::parse(0, "RTCP SR too short"));
}
let mut cursor = &data[..];
let ssrc = cursor.get_u32();
let ntp_timestamp_msw = cursor.get_u32();
let ntp_timestamp_lsw = cursor.get_u32();
let rtp_timestamp = cursor.get_u32();
let sender_packet_count = cursor.get_u32();
let sender_octet_count = cursor.get_u32();
let mut report_blocks = Vec::new();
while cursor.len() >= 24 {
report_blocks.push(RtcpReportBlock::parse(&mut cursor)?);
}
Ok(Self {
ssrc,
ntp_timestamp_msw,
ntp_timestamp_lsw,
rtp_timestamp,
sender_packet_count,
sender_octet_count,
report_blocks,
})
}
pub fn serialize(&self, buf: &mut BytesMut) {
buf.put_u32(self.ssrc);
buf.put_u32(self.ntp_timestamp_msw);
buf.put_u32(self.ntp_timestamp_lsw);
buf.put_u32(self.rtp_timestamp);
buf.put_u32(self.sender_packet_count);
buf.put_u32(self.sender_octet_count);
for block in &self.report_blocks {
block.serialize(buf);
}
}
}
impl RtcpReportBlock {
fn parse(cursor: &mut &[u8]) -> NetResult<Self> {
if cursor.len() < 24 {
return Err(NetError::parse(0, "RTCP report block too short"));
}
let ssrc = cursor.get_u32();
let byte = cursor.get_u8();
let fraction_lost = byte;
let cumulative_bytes = {
let b1 = cursor.get_u8() as u32;
let b2 = cursor.get_u8() as u32;
let b3 = cursor.get_u8() as u32;
(b1 << 16) | (b2 << 8) | b3
};
let cumulative_packets_lost =
cumulative_bytes as i32 - if (byte & 0x80) != 0 { 0x1000000 } else { 0 };
let extended_highest_sequence = cursor.get_u32();
let interarrival_jitter = cursor.get_u32();
let last_sr_timestamp = cursor.get_u32();
let delay_since_last_sr = cursor.get_u32();
Ok(Self {
ssrc,
fraction_lost,
cumulative_packets_lost,
extended_highest_sequence,
interarrival_jitter,
last_sr_timestamp,
delay_since_last_sr,
})
}
fn serialize(&self, buf: &mut BytesMut) {
buf.put_u32(self.ssrc);
buf.put_u8(self.fraction_lost);
let cumulative = if self.cumulative_packets_lost < 0 {
(self.cumulative_packets_lost + 0x1000000) as u32
} else {
self.cumulative_packets_lost as u32
};
buf.put_u8(((cumulative >> 16) & 0xFF) as u8);
buf.put_u8(((cumulative >> 8) & 0xFF) as u8);
buf.put_u8((cumulative & 0xFF) as u8);
buf.put_u32(self.extended_highest_sequence);
buf.put_u32(self.interarrival_jitter);
buf.put_u32(self.last_sr_timestamp);
buf.put_u32(self.delay_since_last_sr);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rtp_header_parse_basic() {
let data = [
0x80, 0x60, 0x00, 0x01, 0x00, 0x00, 0x00, 0x64, 0x12, 0x34, 0x56, 0x78, ];
let (header, size) = RtpHeader::parse(&data).expect("should succeed in test");
assert_eq!(size, 12);
assert!(!header.padding);
assert!(!header.extension);
assert_eq!(header.csrc_count, 0);
assert!(!header.marker);
assert_eq!(header.payload_type, 96);
assert_eq!(header.sequence_number, 1);
assert_eq!(header.timestamp, 100);
assert_eq!(header.ssrc, 0x12345678);
}
#[test]
fn test_sequence_tracker() {
let mut tracker = SequenceNumberTracker::new(100);
tracker.update(101);
tracker.update(102);
assert_eq!(tracker.extended_sequence(), 102);
tracker.highest_seq = 65535;
tracker.update(0);
assert_eq!(tracker.cycles, 1);
}
#[test]
fn test_jitter_calculator() {
let mut jitter = JitterCalculator::new(90000);
let start = Instant::now();
jitter.update(1000, start);
jitter.update(2000, start + Duration::from_millis(15));
jitter.update(3000, start + Duration::from_millis(25));
assert!(jitter.jitter() > 0);
}
}