use super::{
DiscoveredPeer, PacketTx, ReceivedPacket, Transport, TransportAddr, TransportError,
TransportId, TransportState, TransportType,
};
#[cfg(any(target_os = "linux", target_os = "macos"))]
pub(crate) mod connected_peer;
#[cfg(target_os = "macos")]
pub(crate) mod darwin_sockopts;
#[cfg(any(target_os = "linux", target_os = "macos"))]
pub(crate) mod peer_drain;
pub(crate) mod socket;
mod stats;
use super::resolve_socket_addr;
use crate::config::UdpConfig;
use crate::discovery::is_punch_packet;
use socket::{AsyncUdpSocket, UdpRawSocket};
use stats::UdpStats;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::{Arc, Mutex as StdMutex};
use std::time::{Duration, Instant};
use tokio::task::JoinHandle;
use tracing::{debug, info, trace, warn};
const DNS_CACHE_TTL: Duration = Duration::from_secs(60);
pub struct UdpTransport {
transport_id: TransportId,
name: Option<String>,
config: UdpConfig,
state: TransportState,
socket: Option<AsyncUdpSocket>,
packet_tx: PacketTx,
recv_task: Option<JoinHandle<()>>,
local_addr: Option<SocketAddr>,
stats: Arc<UdpStats>,
dns_cache: StdMutex<HashMap<TransportAddr, (SocketAddr, Instant)>>,
}
impl UdpTransport {
pub fn new(
transport_id: TransportId,
name: Option<String>,
config: UdpConfig,
packet_tx: PacketTx,
) -> Self {
Self {
transport_id,
name,
config,
state: TransportState::Configured,
socket: None,
packet_tx,
recv_task: None,
local_addr: None,
stats: Arc::new(UdpStats::new()),
dns_cache: StdMutex::new(HashMap::new()),
}
}
pub fn name(&self) -> Option<&str> {
self.name.as_deref()
}
pub fn local_addr(&self) -> Option<SocketAddr> {
self.local_addr
}
pub fn recv_buf_size(&self) -> usize {
self.config.recv_buf_size()
}
pub fn send_buf_size(&self) -> usize {
self.config.send_buf_size()
}
pub fn clone_packet_tx(&self) -> PacketTx {
self.packet_tx.clone()
}
pub fn stats(&self) -> &Arc<UdpStats> {
&self.stats
}
pub async fn resolve_for_off_task(
&self,
addr: &TransportAddr,
) -> Result<SocketAddr, TransportError> {
self.resolve_cached(addr).await
}
pub fn async_socket(&self) -> Option<AsyncUdpSocket> {
self.socket.clone()
}
async fn resolve_cached(&self, addr: &TransportAddr) -> Result<SocketAddr, TransportError> {
if let Some(s) = addr.as_str()
&& let Ok(sock_addr) = s.parse::<SocketAddr>()
{
return Ok(sock_addr);
}
{
let cache = self.dns_cache.lock().unwrap();
if let Some((resolved, cached_at)) = cache.get(addr)
&& cached_at.elapsed() < DNS_CACHE_TTL
{
return Ok(*resolved);
}
}
let resolved = resolve_socket_addr(addr).await?;
{
let mut cache = self.dns_cache.lock().unwrap();
cache.insert(addr.clone(), (resolved, Instant::now()));
}
Ok(resolved)
}
pub fn congestion(&self) -> super::TransportCongestion {
super::TransportCongestion {
recv_drops: Some(
self.stats
.kernel_drops
.load(std::sync::atomic::Ordering::Relaxed),
),
}
}
pub async fn start_async(&mut self) -> Result<(), TransportError> {
if !self.state.can_start() {
return Err(TransportError::AlreadyStarted);
}
self.state = TransportState::Starting;
if self.config.outbound_only() && self.config.bind_addr.is_some() {
warn!(
configured_bind_addr = ?self.config.bind_addr,
"udp.outbound_only = true; configured bind_addr is ignored, binding to 0.0.0.0:0"
);
}
let bind_addr: SocketAddr = self
.config
.bind_addr()
.parse()
.map_err(|e| TransportError::StartFailed(format!("invalid bind address: {}", e)))?;
let raw_socket = UdpRawSocket::open(
bind_addr,
self.config.recv_buf_size(),
self.config.send_buf_size(),
)?;
let actual_recv = raw_socket.recv_buffer_size()?;
let actual_send = raw_socket.send_buffer_size()?;
self.local_addr = Some(raw_socket.local_addr());
let async_socket = raw_socket.into_async()?;
self.socket = Some(async_socket.clone());
let transport_id = self.transport_id;
let packet_tx = self.packet_tx.clone();
let mtu = self.config.mtu();
let stats = self.stats.clone();
let recv_task = tokio::spawn(async move {
udp_receive_loop(async_socket, transport_id, packet_tx, mtu, stats).await;
});
self.recv_task = Some(recv_task);
self.state = TransportState::Up;
if let Some(ref name) = self.name {
info!(
name = %name,
local_addr = %self.local_addr.unwrap(),
recv_buf = actual_recv,
send_buf = actual_send,
"UDP transport started"
);
} else {
info!(
local_addr = %self.local_addr.unwrap(),
recv_buf = actual_recv,
send_buf = actual_send,
"UDP transport started"
);
}
Ok(())
}
pub async fn adopt_socket_async(
&mut self,
socket: std::net::UdpSocket,
) -> Result<(), TransportError> {
if !self.state.can_start() {
return Err(TransportError::AlreadyStarted);
}
self.state = TransportState::Starting;
let raw_socket = UdpRawSocket::adopt(
socket,
self.config.recv_buf_size(),
self.config.send_buf_size(),
)?;
let actual_recv = raw_socket.recv_buffer_size()?;
let actual_send = raw_socket.send_buffer_size()?;
self.local_addr = Some(raw_socket.local_addr());
let async_socket = raw_socket.into_async()?;
self.socket = Some(async_socket.clone());
let transport_id = self.transport_id;
let packet_tx = self.packet_tx.clone();
let mtu = self.config.mtu();
let stats = self.stats.clone();
let recv_task = tokio::spawn(async move {
udp_receive_loop(async_socket, transport_id, packet_tx, mtu, stats).await;
});
self.recv_task = Some(recv_task);
self.state = TransportState::Up;
if let Some(ref name) = self.name {
info!(
name = %name,
local_addr = %self.local_addr.unwrap(),
recv_buf = actual_recv,
send_buf = actual_send,
"UDP transport adopted existing socket"
);
} else {
info!(
local_addr = %self.local_addr.unwrap(),
recv_buf = actual_recv,
send_buf = actual_send,
"UDP transport adopted existing socket"
);
}
Ok(())
}
pub async fn stop_async(&mut self) -> Result<(), TransportError> {
if !self.state.is_operational() {
return Err(TransportError::NotStarted);
}
if let Some(task) = self.recv_task.take() {
task.abort();
let _ = task.await; }
self.socket.take();
self.local_addr = None;
self.state = TransportState::Down;
info!(
transport_id = %self.transport_id,
"UDP transport stopped"
);
Ok(())
}
pub async fn send_async(
&self,
addr: &TransportAddr,
data: &[u8],
) -> Result<usize, TransportError> {
if !self.state.is_operational() {
return Err(TransportError::NotStarted);
}
if data.len() > self.config.mtu() as usize {
self.stats.record_mtu_exceeded();
return Err(TransportError::MtuExceeded {
packet_size: data.len(),
mtu: self.config.mtu(),
});
}
let socket_addr = self.resolve_cached(addr).await?;
let socket = self.socket.as_ref().ok_or(TransportError::NotStarted)?;
match socket.send_to(data, &socket_addr).await {
Ok(bytes_sent) => {
self.stats.record_send(bytes_sent);
trace!(
transport_id = %self.transport_id,
remote_addr = %socket_addr,
bytes = bytes_sent,
"UDP packet sent"
);
Ok(bytes_sent)
}
Err(e) => {
self.stats.record_send_error();
Err(e)
}
}
}
pub async fn flush_pending_send(&self) {}
}
impl Transport for UdpTransport {
fn transport_id(&self) -> TransportId {
self.transport_id
}
fn transport_type(&self) -> &TransportType {
&TransportType::UDP
}
fn state(&self) -> TransportState {
self.state
}
fn mtu(&self) -> u16 {
self.config.mtu()
}
fn start(&mut self) -> Result<(), TransportError> {
Err(TransportError::NotSupported(
"use start_async() for UDP transport".into(),
))
}
fn stop(&mut self) -> Result<(), TransportError> {
Err(TransportError::NotSupported(
"use stop_async() for UDP transport".into(),
))
}
fn send(&self, _addr: &TransportAddr, _data: &[u8]) -> Result<(), TransportError> {
Err(TransportError::NotSupported(
"use send_async() for UDP transport".into(),
))
}
fn discover(&self) -> Result<Vec<DiscoveredPeer>, TransportError> {
Ok(Vec::new())
}
fn accept_connections(&self) -> bool {
if self.config.outbound_only() {
false
} else {
self.config.accept_connections()
}
}
}
impl Drop for UdpTransport {
fn drop(&mut self) {
let had_task = self.recv_task.is_some();
let had_socket = self.socket.is_some();
if had_task || had_socket {
debug!(
transport_id = %self.transport_id,
state = ?self.state,
had_recv_task = had_task,
had_socket = had_socket,
"UdpTransport dropped without stop_async(); cleaning up",
);
}
if let Some(task) = self.recv_task.take() {
task.abort();
}
self.socket.take();
self.local_addr = None;
}
}
async fn udp_receive_loop(
socket: AsyncUdpSocket,
transport_id: TransportId,
packet_tx: PacketTx,
mtu: u16,
stats: Arc<UdpStats>,
) {
debug!(transport_id = %transport_id, "UDP receive loop starting");
#[cfg(target_os = "linux")]
{
const BATCH: usize = 32;
let buf_size = mtu as usize + 100;
let mut backing: Vec<Vec<u8>> = (0..BATCH).map(|_| vec![0u8; buf_size]).collect();
let mut addrs: [Option<std::net::SocketAddr>; BATCH] = std::array::from_fn(|_| None);
let mut lens: [usize; BATCH] = [0; BATCH];
loop {
let mut bufs: [&mut [u8]; BATCH] = {
let mut iter = backing.iter_mut();
std::array::from_fn(|_| iter.next().unwrap().as_mut_slice())
};
let recv_result = {
let _t = crate::perf_profile::Timer::start(crate::perf_profile::Stage::UdpRecv);
socket.recv_batch(&mut bufs, &mut addrs, &mut lens).await
};
match recv_result {
Ok((count, kernel_drops)) => {
stats.set_kernel_drops(kernel_drops as u64);
for i in 0..count {
let len = lens[i];
let Some(remote_addr) = addrs[i] else {
continue;
};
stats.record_recv(len);
if is_punch_packet(&backing[i][..len]) {
trace!(
transport_id = %transport_id,
remote_addr = %remote_addr,
bytes = len,
"Dropping stray punch probe/ack on UDP transport"
);
continue;
}
let mut data = std::mem::replace(&mut backing[i], vec![0u8; buf_size]);
data.truncate(len);
let addr = TransportAddr::from_socket_addr(remote_addr);
let packet = ReceivedPacket::new(transport_id, addr, data);
trace!(
transport_id = %transport_id,
remote_addr = %remote_addr,
bytes = len,
"UDP packet received"
);
if packet_tx.send(packet).is_err() {
debug!(
transport_id = %transport_id,
"Packet channel closed, stopping receive loop"
);
return;
}
}
}
Err(e) => {
stats.record_recv_error();
warn!(
transport_id = %transport_id,
error = %e,
"UDP receive error"
);
}
}
}
}
#[cfg(not(target_os = "linux"))]
{
let mut buf = vec![0u8; mtu as usize + 100];
loop {
match socket.recv_from(&mut buf).await {
Ok((len, remote_addr, kernel_drops)) => {
stats.record_recv(len);
stats.set_kernel_drops(kernel_drops as u64);
if is_punch_packet(&buf[..len]) {
trace!(
transport_id = %transport_id,
remote_addr = %remote_addr,
bytes = len,
"Dropping stray punch probe/ack on UDP transport"
);
continue;
}
let data = buf[..len].to_vec();
let addr = TransportAddr::from_socket_addr(remote_addr);
let packet = ReceivedPacket::new(transport_id, addr, data);
trace!(
transport_id = %transport_id,
remote_addr = %remote_addr,
bytes = len,
"UDP packet received"
);
if packet_tx.send(packet).is_err() {
debug!(
transport_id = %transport_id,
"Packet channel closed, stopping receive loop"
);
break;
}
}
Err(e) => {
stats.record_recv_error();
warn!(
transport_id = %transport_id,
error = %e,
"UDP receive error"
);
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transport::packet_channel;
use tokio::time::{Duration, timeout};
fn make_config(port: u16) -> UdpConfig {
UdpConfig {
bind_addr: Some(format!("127.0.0.1:{}", port)),
mtu: Some(1280),
..Default::default()
}
}
#[tokio::test]
async fn test_start_stop() {
let (tx, _rx) = packet_channel(100);
let mut transport = UdpTransport::new(TransportId::new(1), None, make_config(0), tx);
assert_eq!(transport.state(), TransportState::Configured);
transport.start_async().await.unwrap();
assert_eq!(transport.state(), TransportState::Up);
assert!(transport.local_addr().is_some());
transport.stop_async().await.unwrap();
assert_eq!(transport.state(), TransportState::Down);
}
#[tokio::test]
async fn test_double_start_fails() {
let (tx, _rx) = packet_channel(100);
let mut transport = UdpTransport::new(TransportId::new(1), None, make_config(0), tx);
transport.start_async().await.unwrap();
let result = transport.start_async().await;
assert!(matches!(result, Err(TransportError::AlreadyStarted)));
transport.stop_async().await.unwrap();
}
#[tokio::test]
async fn test_stop_not_started_fails() {
let (tx, _rx) = packet_channel(100);
let mut transport = UdpTransport::new(TransportId::new(1), None, make_config(0), tx);
let result = transport.stop_async().await;
assert!(matches!(result, Err(TransportError::NotStarted)));
}
#[tokio::test]
async fn test_send_recv() {
let (tx1, _rx1) = packet_channel(100);
let (tx2, mut rx2) = packet_channel(100);
let mut t1 = UdpTransport::new(TransportId::new(1), None, make_config(0), tx1);
let mut t2 = UdpTransport::new(TransportId::new(2), None, make_config(0), tx2);
t1.start_async().await.unwrap();
t2.start_async().await.unwrap();
let addr1 = t1.local_addr().unwrap();
let addr2 = t2.local_addr().unwrap();
let data = b"hello world";
let bytes_sent = t1
.send_async(&TransportAddr::from_string(&addr2.to_string()), data)
.await
.unwrap();
assert_eq!(bytes_sent, data.len());
let packet = timeout(Duration::from_secs(1), rx2.recv())
.await
.expect("timeout")
.expect("channel closed");
assert_eq!(packet.data, data);
assert_eq!(
packet.remote_addr.as_str(),
Some(addr1.to_string().as_str())
);
t1.stop_async().await.unwrap();
t2.stop_async().await.unwrap();
}
#[tokio::test]
async fn test_bidirectional() {
let (tx1, mut rx1) = packet_channel(100);
let (tx2, mut rx2) = packet_channel(100);
let mut t1 = UdpTransport::new(TransportId::new(1), None, make_config(0), tx1);
let mut t2 = UdpTransport::new(TransportId::new(2), None, make_config(0), tx2);
t1.start_async().await.unwrap();
t2.start_async().await.unwrap();
let addr1 = TransportAddr::from_string(&t1.local_addr().unwrap().to_string());
let addr2 = TransportAddr::from_string(&t2.local_addr().unwrap().to_string());
t1.send_async(&addr2, b"ping").await.unwrap();
let packet = timeout(Duration::from_secs(1), rx2.recv())
.await
.expect("timeout")
.expect("channel closed");
assert_eq!(packet.data, b"ping");
t2.send_async(&addr1, b"pong").await.unwrap();
let packet = timeout(Duration::from_secs(1), rx1.recv())
.await
.expect("timeout")
.expect("channel closed");
assert_eq!(packet.data, b"pong");
t1.stop_async().await.unwrap();
t2.stop_async().await.unwrap();
}
#[tokio::test]
async fn test_mtu_exceeded() {
let (tx, _rx) = packet_channel(100);
let mut transport = UdpTransport::new(
TransportId::new(1),
None,
UdpConfig {
mtu: Some(100),
..make_config(0)
},
tx,
);
transport.start_async().await.unwrap();
let oversized = vec![0u8; 200];
let result = transport
.send_async(&TransportAddr::from_string("127.0.0.1:9999"), &oversized)
.await;
assert!(matches!(result, Err(TransportError::MtuExceeded { .. })));
transport.stop_async().await.unwrap();
}
#[tokio::test]
async fn test_send_not_started() {
let (tx, _rx) = packet_channel(100);
let transport = UdpTransport::new(TransportId::new(1), None, make_config(0), tx);
let result = transport
.send_async(&TransportAddr::from_string("127.0.0.1:9999"), b"test")
.await;
assert!(matches!(result, Err(TransportError::NotStarted)));
}
#[tokio::test]
async fn test_discover_returns_empty() {
let (tx, _rx) = packet_channel(100);
let transport = UdpTransport::new(TransportId::new(1), None, make_config(0), tx);
let peers = transport.discover().unwrap();
assert!(peers.is_empty());
}
#[test]
fn test_transport_type() {
let (tx, _rx) = packet_channel(100);
let transport = UdpTransport::new(TransportId::new(1), None, make_config(0), tx);
assert_eq!(transport.transport_type().name, "udp");
assert!(!transport.transport_type().connection_oriented);
assert!(!transport.transport_type().reliable);
}
#[test]
fn test_sync_methods_return_not_supported() {
let (tx, _rx) = packet_channel(100);
let mut transport = UdpTransport::new(TransportId::new(1), None, make_config(0), tx);
assert!(matches!(
transport.start(),
Err(TransportError::NotSupported(_))
));
assert!(matches!(
transport.stop(),
Err(TransportError::NotSupported(_))
));
assert!(matches!(
transport.send(&TransportAddr::from_string("test"), b"data"),
Err(TransportError::NotSupported(_))
));
}
#[tokio::test]
async fn test_resolve_socket_addr_ip() {
let addr = TransportAddr::from_string("192.168.1.1:2121");
let result = resolve_socket_addr(&addr).await.unwrap();
assert_eq!(result.to_string(), "192.168.1.1:2121");
}
#[tokio::test]
async fn test_resolve_socket_addr_invalid() {
let invalid = TransportAddr::from_string("nonexistent.invalid:2121");
assert!(resolve_socket_addr(&invalid).await.is_err());
let binary = TransportAddr::new(vec![0xff, 0x80]);
assert!(resolve_socket_addr(&binary).await.is_err());
}
#[tokio::test]
async fn test_resolve_socket_addr_hostname() {
let addr = TransportAddr::from_string("localhost:2121");
let result = resolve_socket_addr(&addr).await.unwrap();
assert!(result.ip().is_loopback());
assert_eq!(result.port(), 2121);
}
#[tokio::test]
async fn test_congestion_reports_kernel_drops() {
let (tx, _rx) = packet_channel(100);
let transport = UdpTransport::new(TransportId::new(1), None, make_config(0), tx);
let cong = transport.congestion();
assert_eq!(cong.recv_drops, Some(0));
}
#[test]
fn test_accept_connections_default_true() {
let (tx, _rx) = packet_channel(100);
let transport = UdpTransport::new(TransportId::new(1), None, make_config(0), tx);
assert!(transport.accept_connections());
}
#[test]
fn test_accept_connections_false_when_configured() {
let (tx, _rx) = packet_channel(100);
let transport = UdpTransport::new(
TransportId::new(1),
None,
UdpConfig {
bind_addr: Some("127.0.0.1:0".to_string()),
accept_connections: Some(false),
..Default::default()
},
tx,
);
assert!(!transport.accept_connections());
}
#[test]
fn test_accept_connections_forced_false_in_outbound_only() {
let (tx, _rx) = packet_channel(100);
let transport = UdpTransport::new(
TransportId::new(1),
None,
UdpConfig {
outbound_only: Some(true),
accept_connections: Some(true), ..Default::default()
},
tx,
);
assert!(!transport.accept_connections());
}
#[tokio::test]
async fn test_outbound_only_binds_ephemeral() {
let (tx, _rx) = packet_channel(100);
let mut transport = UdpTransport::new(
TransportId::new(1),
None,
UdpConfig {
bind_addr: Some("127.0.0.1:65535".to_string()),
outbound_only: Some(true),
..Default::default()
},
tx,
);
transport.start_async().await.unwrap();
let local = transport.local_addr().unwrap();
assert_ne!(local.port(), 65535);
assert!(local.port() > 0);
assert!(local.ip().is_unspecified());
transport.stop_async().await.unwrap();
}
#[tokio::test]
async fn test_punch_probe_dropped() {
let (tx_recv, mut rx_recv) = packet_channel(100);
let (tx_send, _rx_send) = packet_channel(100);
let mut t_recv = UdpTransport::new(TransportId::new(1), None, make_config(0), tx_recv);
let mut t_send = UdpTransport::new(TransportId::new(2), None, make_config(0), tx_send);
t_recv.start_async().await.unwrap();
t_send.start_async().await.unwrap();
let recv_addr = t_recv.local_addr().unwrap();
let recv_addr_str = TransportAddr::from_string(&recv_addr.to_string());
let mut probe = vec![0u8; 16];
probe[..4].copy_from_slice(&0x4E505443u32.to_be_bytes());
t_send.send_async(&recv_addr_str, &probe).await.unwrap();
let mut ack = vec![0u8; 16];
ack[..4].copy_from_slice(&0x4E505441u32.to_be_bytes());
t_send.send_async(&recv_addr_str, &ack).await.unwrap();
let real = b"valid-fmp-frame";
t_send.send_async(&recv_addr_str, real).await.unwrap();
let packet = timeout(Duration::from_secs(1), rx_recv.recv())
.await
.expect("timeout waiting for real packet")
.expect("channel closed");
assert_eq!(packet.data, real);
let no_more = timeout(Duration::from_millis(200), rx_recv.recv()).await;
assert!(no_more.is_err(), "punch probe/ack leaked through filter");
t_recv.stop_async().await.unwrap();
t_send.stop_async().await.unwrap();
}
#[test]
fn test_is_punch_packet_helper() {
use crate::discovery::is_punch_packet;
assert!(is_punch_packet(&[0x4E, 0x50, 0x54, 0x43, 0xAA, 0xBB]));
assert!(is_punch_packet(&[0x4E, 0x50, 0x54, 0x41]));
assert!(!is_punch_packet(&[0x01, 0x02, 0x03, 0x04]));
assert!(!is_punch_packet(&[0x4E, 0x50, 0x54]));
assert!(!is_punch_packet(&[]));
}
#[tokio::test]
async fn test_send_recv_ip_string() {
let (tx1, _rx1) = packet_channel(100);
let (tx2, mut rx2) = packet_channel(100);
let mut t1 = UdpTransport::new(TransportId::new(1), None, make_config(0), tx1);
let mut t2 = UdpTransport::new(TransportId::new(2), None, make_config(0), tx2);
t1.start_async().await.unwrap();
t2.start_async().await.unwrap();
let port2 = t2.local_addr().unwrap().port();
let data = b"hello via ip string";
let bytes_sent = t1
.send_async(
&TransportAddr::from_string(&format!("127.0.0.1:{}", port2)),
data,
)
.await
.unwrap();
assert_eq!(bytes_sent, data.len());
let packet = timeout(Duration::from_secs(1), rx2.recv())
.await
.expect("timeout")
.expect("channel closed");
assert_eq!(packet.data, data);
t1.stop_async().await.unwrap();
t2.stop_async().await.unwrap();
}
}