use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tokio::net::UdpSocket;
use tokio::sync::watch;
use tokio::time::interval;
use tracing::{debug, warn};
use crate::protocol::UdpStats;
use crate::stats::StreamStats;
pub const UDP_PAYLOAD_SIZE: usize = 1400; const UDP_HEADER_SIZE: usize = 16; const UDP_INACTIVITY_TIMEOUT: Duration = Duration::from_secs(30);
#[derive(Debug, Clone, Copy)]
pub struct UdpPacketHeader {
pub sequence: u64,
pub timestamp_us: u64,
}
impl UdpPacketHeader {
pub fn encode(&self, buffer: &mut [u8]) -> bool {
if buffer.len() < UDP_HEADER_SIZE {
return false;
}
buffer[0..8].copy_from_slice(&self.sequence.to_be_bytes());
buffer[8..16].copy_from_slice(&self.timestamp_us.to_be_bytes());
true
}
pub fn decode(buffer: &[u8]) -> Option<Self> {
if buffer.len() < UDP_HEADER_SIZE {
return None;
}
let sequence = u64::from_be_bytes(buffer[0..8].try_into().ok()?);
let timestamp_us = u64::from_be_bytes(buffer[8..16].try_into().ok()?);
Some(Self {
sequence,
timestamp_us,
})
}
}
#[derive(Debug, Clone)]
pub struct UdpSendStats {
pub packets_sent: u64,
pub bytes_sent: u64,
}
pub struct JitterCalculator {
last_send_time: Option<u64>,
last_recv_time: Option<Instant>,
jitter: f64,
}
impl JitterCalculator {
pub fn new() -> Self {
Self {
last_send_time: None,
last_recv_time: None,
jitter: 0.0,
}
}
pub fn update(&mut self, send_time_us: u64, recv_time: Instant) -> f64 {
if let (Some(last_send), Some(last_recv)) = (self.last_send_time, self.last_recv_time) {
let recv_diff = recv_time.duration_since(last_recv).as_micros() as i64;
let send_diff = (send_time_us as i64) - (last_send as i64);
let d = (recv_diff - send_diff).abs() as f64;
self.jitter += (d - self.jitter) / 16.0;
}
self.last_send_time = Some(send_time_us);
self.last_recv_time = Some(recv_time);
self.jitter
}
pub fn jitter_ms(&self) -> f64 {
self.jitter / 1000.0
}
}
impl Default for JitterCalculator {
fn default() -> Self {
Self::new()
}
}
pub struct PacketTracker {
expected_sequence: u64,
received: u64,
lost: AtomicU64,
out_of_order: AtomicU64,
highest_seen: u64,
}
impl PacketTracker {
pub fn new() -> Self {
Self {
expected_sequence: 0,
received: 0,
lost: AtomicU64::new(0),
out_of_order: AtomicU64::new(0),
highest_seen: 0,
}
}
pub fn record(&mut self, sequence: u64) {
self.received += 1;
if sequence < self.expected_sequence {
self.out_of_order.fetch_add(1, Ordering::Relaxed);
} else if sequence > self.expected_sequence {
let gap = sequence - self.expected_sequence;
self.lost.fetch_add(gap, Ordering::Relaxed);
self.expected_sequence = sequence + 1;
} else {
self.expected_sequence = sequence + 1;
}
self.highest_seen = self.highest_seen.max(sequence);
}
pub fn stats(&self, packets_sent: u64) -> (u64, u64, f64) {
let lost = self.lost.load(Ordering::Relaxed);
let ooo = self.out_of_order.load(Ordering::Relaxed);
let loss_percent = if packets_sent > 0 {
(lost as f64 / packets_sent as f64) * 100.0
} else {
0.0
};
(lost, ooo, loss_percent)
}
}
impl Default for PacketTracker {
fn default() -> Self {
Self::new()
}
}
const HIGH_PPS_THRESHOLD: f64 = 100_000.0;
const BURST_SIZE: u64 = 100;
pub async fn send_udp_paced(
socket: Arc<UdpSocket>,
target: Option<SocketAddr>,
target_bitrate: u64,
duration: Duration,
stats: Arc<StreamStats>,
mut cancel: watch::Receiver<bool>,
mut pause: watch::Receiver<bool>,
) -> anyhow::Result<UdpSendStats> {
let packet_size = UDP_PAYLOAD_SIZE;
if target_bitrate == 0 {
return send_udp_unlimited(socket, target, duration, stats, cancel, pause).await;
}
let bits_per_packet = (packet_size * 8) as u64;
let packets_per_sec_f64 = target_bitrate as f64 / bits_per_packet as f64;
let (pacing_interval, packets_per_tick) = if packets_per_sec_f64 > HIGH_PPS_THRESHOLD {
let interval = Duration::from_secs_f64(BURST_SIZE as f64 / packets_per_sec_f64);
(interval, BURST_SIZE)
} else {
let interval = Duration::from_secs_f64(1.0 / packets_per_sec_f64);
(interval, 1)
};
debug!(
"UDP pacing: {:.0} packets/sec, interval {:?}, {} packets/tick",
packets_per_sec_f64, pacing_interval, packets_per_tick
);
let mut sequence: u64 = 0;
let mut ticker = interval(pacing_interval);
let start = Instant::now();
let deadline = start + duration;
let is_infinite = duration == Duration::ZERO;
let mut packet = vec![0u8; packet_size];
loop {
if *cancel.borrow() {
debug!("UDP send cancelled");
break;
}
if *pause.borrow() {
if crate::pause::wait_while_paused(&mut pause, &mut cancel).await {
break;
}
continue;
}
tokio::select! {
biased;
_ = cancel.changed() => {
if *cancel.borrow() { break; }
continue;
}
_ = pause.changed() => { continue; } _ = ticker.tick() => {}
}
if !is_infinite && Instant::now() >= deadline {
break;
}
for _ in 0..packets_per_tick {
if *cancel.borrow() || *pause.borrow() {
break;
}
if !is_infinite && Instant::now() >= deadline {
break;
}
let now_us = start.elapsed().as_micros() as u64;
let header = UdpPacketHeader {
sequence,
timestamp_us: now_us,
};
header.encode(&mut packet);
let result = match target {
Some(addr) => socket.send_to(&packet, addr).await,
None => socket.send(&packet).await,
};
match result {
Ok(n) => {
stats.add_bytes_sent(n as u64);
sequence += 1;
}
Err(e) => {
warn!("UDP send error: {}", e);
}
}
}
}
Ok(UdpSendStats {
packets_sent: sequence,
bytes_sent: sequence * packet_size as u64,
})
}
async fn send_udp_unlimited(
socket: Arc<UdpSocket>,
target: Option<SocketAddr>,
duration: Duration,
stats: Arc<StreamStats>,
mut cancel: watch::Receiver<bool>,
mut pause: watch::Receiver<bool>,
) -> anyhow::Result<UdpSendStats> {
let packet_size = UDP_PAYLOAD_SIZE;
let mut sequence: u64 = 0;
let start = Instant::now();
let deadline = start + duration;
let is_infinite = duration == Duration::ZERO;
let mut packet = vec![0u8; packet_size];
debug!("UDP unlimited mode: sending as fast as possible");
loop {
if *cancel.borrow() {
debug!("UDP send cancelled");
break;
}
if *pause.borrow() {
if crate::pause::wait_while_paused(&mut pause, &mut cancel).await {
break;
}
continue;
}
if !is_infinite && Instant::now() >= deadline {
break;
}
for _ in 0..BURST_SIZE {
if *cancel.borrow() || *pause.borrow() {
break;
}
if !is_infinite && Instant::now() >= deadline {
break;
}
let now_us = start.elapsed().as_micros() as u64;
let header = UdpPacketHeader {
sequence,
timestamp_us: now_us,
};
header.encode(&mut packet);
let result = match target {
Some(addr) => socket.send_to(&packet, addr).await,
None => socket.send(&packet).await,
};
match result {
Ok(n) => {
stats.add_bytes_sent(n as u64);
sequence += 1;
}
Err(e) => {
warn!("UDP send error: {}", e);
}
}
}
tokio::task::yield_now().await;
}
Ok(UdpSendStats {
packets_sent: sequence,
bytes_sent: sequence * packet_size as u64,
})
}
pub async fn receive_udp(
socket: Arc<UdpSocket>,
stats: Arc<StreamStats>,
mut cancel: watch::Receiver<bool>,
mut pause: watch::Receiver<bool>,
) -> anyhow::Result<(UdpStats, u64)> {
let mut buffer = vec![0u8; UDP_PAYLOAD_SIZE + 100];
let mut jitter_calc = JitterCalculator::new();
let mut packet_tracker = PacketTracker::new();
let mut packets_received: u64 = 0;
let mut last_recv = Instant::now();
loop {
if *cancel.borrow() {
debug!("UDP receive cancelled");
break;
}
if *pause.borrow() {
if crate::pause::wait_while_paused(&mut pause, &mut cancel).await {
break;
}
last_recv = Instant::now();
continue;
}
if last_recv.elapsed() > UDP_INACTIVITY_TIMEOUT {
debug!(
"UDP receive timeout: no packets for {:?}",
UDP_INACTIVITY_TIMEOUT
);
break;
}
let recv_future = socket.recv_from(&mut buffer);
let timeout_future = tokio::time::sleep(Duration::from_millis(100));
tokio::select! {
result = recv_future => {
match result {
Ok((n, _addr)) => {
let recv_time = Instant::now();
last_recv = recv_time;
stats.add_bytes_received(n as u64);
packets_received += 1;
if let Some(header) = UdpPacketHeader::decode(&buffer[..n]) {
let old_lost = packet_tracker.lost.load(Ordering::Relaxed);
packet_tracker.record(header.sequence);
let new_lost = packet_tracker.lost.load(Ordering::Relaxed);
let jitter_us = jitter_calc.update(header.timestamp_us, recv_time);
stats.set_udp_jitter_us(jitter_us as u64);
if new_lost > old_lost {
stats.add_udp_lost(new_lost - old_lost);
}
}
}
Err(e) => {
warn!("UDP receive error: {}", e);
}
}
}
_ = timeout_future => {
}
}
}
let (lost, out_of_order, _) =
packet_tracker.stats(packets_received + packet_tracker.lost.load(Ordering::Relaxed));
let packets_sent = packets_received + lost;
let loss_percent = if packets_sent > 0 {
(lost as f64 / packets_sent as f64) * 100.0
} else {
0.0
};
Ok((
UdpStats {
packets_sent,
packets_received,
lost,
lost_percent: loss_percent,
jitter_ms: jitter_calc.jitter_ms(),
out_of_order,
},
packets_sent,
))
}
pub async fn wait_for_client(socket: &UdpSocket, timeout: Duration) -> anyhow::Result<SocketAddr> {
let mut buffer = [0u8; 64];
tokio::select! {
result = socket.recv_from(&mut buffer) => {
match result {
Ok((_, addr)) => {
debug!("UDP client connected from {}", addr);
Ok(addr)
}
Err(e) => Err(anyhow::anyhow!("Failed to receive from client: {}", e)),
}
}
_ = tokio::time::sleep(timeout) => {
Err(anyhow::anyhow!("Timeout waiting for UDP client"))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_packet_header_roundtrip() {
let header = UdpPacketHeader {
sequence: 12345,
timestamp_us: 67890,
};
let mut buffer = [0u8; 16];
assert!(header.encode(&mut buffer));
let decoded = UdpPacketHeader::decode(&buffer).unwrap();
assert_eq!(decoded.sequence, 12345);
assert_eq!(decoded.timestamp_us, 67890);
}
#[test]
fn test_jitter_calculator() {
let mut calc = JitterCalculator::new();
let start = Instant::now();
calc.update(0, start);
assert_eq!(calc.jitter_ms(), 0.0);
calc.update(1000, start + Duration::from_micros(1000));
assert!(calc.jitter_ms() < 1.0);
}
#[test]
fn test_packet_tracker() {
let mut tracker = PacketTracker::new();
tracker.record(0);
tracker.record(1);
tracker.record(2);
assert_eq!(tracker.lost.load(Ordering::Relaxed), 0);
tracker.record(4);
assert_eq!(tracker.lost.load(Ordering::Relaxed), 1);
tracker.record(3);
assert_eq!(tracker.out_of_order.load(Ordering::Relaxed), 1);
}
}