use async_trait::async_trait;
use std::io;
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::Instant;
use tokio::net::UdpSocket;
use tokio::sync::mpsc;
use super::addr::{TransportAddr, TransportType};
use super::capabilities::TransportCapabilities;
use super::provider::{
InboundDatagram, LinkQuality, TransportError, TransportProvider, TransportStats,
};
pub struct UdpTransport {
socket: Arc<UdpSocket>,
capabilities: TransportCapabilities,
local_addr: SocketAddr,
online: AtomicBool,
delegated_to_quinn: AtomicBool,
stats: UdpTransportStats,
inbound_tx: mpsc::Sender<InboundDatagram>,
shutdown_tx: mpsc::Sender<()>,
}
struct UdpTransportStats {
datagrams_sent: AtomicU64,
datagrams_received: AtomicU64,
bytes_sent: AtomicU64,
bytes_received: AtomicU64,
send_errors: AtomicU64,
receive_errors: AtomicU64,
}
impl Default for UdpTransportStats {
fn default() -> Self {
Self {
datagrams_sent: AtomicU64::new(0),
datagrams_received: AtomicU64::new(0),
bytes_sent: AtomicU64::new(0),
bytes_received: AtomicU64::new(0),
send_errors: AtomicU64::new(0),
receive_errors: AtomicU64::new(0),
}
}
}
impl UdpTransport {
pub async fn bind(addr: SocketAddr) -> io::Result<Self> {
let socket = UdpSocket::bind(addr).await?;
let local_addr = socket.local_addr()?;
let socket = Arc::new(socket);
let (inbound_tx, _) = mpsc::channel(1024);
let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
let transport = Self {
socket: socket.clone(),
capabilities: TransportCapabilities::broadband(),
local_addr,
online: AtomicBool::new(true),
delegated_to_quinn: AtomicBool::new(false),
stats: UdpTransportStats::default(),
inbound_tx,
shutdown_tx,
};
transport.spawn_recv_loop(socket, shutdown_rx);
Ok(transport)
}
pub async fn bind_for_quinn(addr: SocketAddr) -> io::Result<(Self, std::net::UdpSocket)> {
let std_socket = Self::create_socket_for_quinn(addr)?;
let local_addr = std_socket.local_addr()?;
let std_socket_for_transport = std_socket.try_clone()?;
let tokio_socket = UdpSocket::from_std(std_socket_for_transport)?;
let socket_arc = Arc::new(tokio_socket);
let (inbound_tx, _) = mpsc::channel(1024);
let (shutdown_tx, _shutdown_rx) = mpsc::channel(1);
let transport = Self {
socket: socket_arc,
capabilities: TransportCapabilities::broadband(),
local_addr,
online: AtomicBool::new(true),
delegated_to_quinn: AtomicBool::new(true), stats: UdpTransportStats::default(),
inbound_tx,
shutdown_tx,
};
Ok((transport, std_socket))
}
pub async fn bind_dual_stack_for_endpoint(
port: u16,
) -> io::Result<(
Self,
std::sync::Arc<crate::high_level::runtime::dual_stack::DualStackSocket>,
)> {
use crate::high_level::runtime::dual_stack;
let (v4_std, v6_std) = dual_stack::create_dual_stack_sockets(port)?;
let registry_socket = v4_std
.as_ref()
.or(v6_std.as_ref())
.ok_or_else(|| io::Error::other("no sockets created"))?;
let transport_clone = registry_socket.try_clone()?;
transport_clone.set_nonblocking(true)?;
let tokio_socket = UdpSocket::from_std(transport_clone)?;
let local_addr = tokio_socket.local_addr()?;
let (inbound_tx, _) = mpsc::channel(1024);
let (shutdown_tx, _) = mpsc::channel(1);
let transport = Self {
socket: Arc::new(tokio_socket),
capabilities: TransportCapabilities::broadband(),
local_addr,
online: AtomicBool::new(true),
delegated_to_quinn: AtomicBool::new(true),
stats: UdpTransportStats::default(),
inbound_tx,
shutdown_tx,
};
let dual = dual_stack::wrap_dual_stack(v4_std, v6_std)?;
Ok((transport, std::sync::Arc::new(dual)))
}
#[cfg(feature = "network-discovery")]
fn create_socket_for_quinn(addr: SocketAddr) -> io::Result<std::net::UdpSocket> {
use socket2::{Domain, Protocol, Socket, Type};
let domain = if addr.is_ipv6() {
Domain::IPV6
} else {
Domain::IPV4
};
let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?;
if addr.is_ipv6() {
if let Err(e) = socket.set_only_v6(false) {
tracing::debug!(%e, "unable to make socket dual-stack, IPv6-only mode");
}
}
socket.set_nonblocking(true)?;
let buffer_size = crate::config::buffer_defaults::PLATFORM_DEFAULT;
if let Err(e) = socket.set_send_buffer_size(buffer_size) {
tracing::debug!(%e, "unable to set send buffer size to {}", buffer_size);
}
if let Err(e) = socket.set_recv_buffer_size(buffer_size) {
tracing::debug!(%e, "unable to set recv buffer size to {}", buffer_size);
}
socket.bind(&addr.into())?;
Ok(socket.into())
}
#[cfg(not(feature = "network-discovery"))]
fn create_socket_for_quinn(addr: SocketAddr) -> io::Result<std::net::UdpSocket> {
let socket = std::net::UdpSocket::bind(addr)?;
socket.set_nonblocking(true)?;
Ok(socket)
}
pub fn from_socket(socket: Arc<UdpSocket>, local_addr: SocketAddr) -> Self {
let (inbound_tx, _) = mpsc::channel(1024);
let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
let transport = Self {
socket: socket.clone(),
capabilities: TransportCapabilities::broadband(),
local_addr,
online: AtomicBool::new(true),
delegated_to_quinn: AtomicBool::new(false),
stats: UdpTransportStats::default(),
inbound_tx,
shutdown_tx,
};
transport.spawn_recv_loop(socket, shutdown_rx);
transport
}
pub fn is_delegated_to_quinn(&self) -> bool {
self.delegated_to_quinn.load(Ordering::SeqCst)
}
fn spawn_recv_loop(&self, socket: Arc<UdpSocket>, mut shutdown_rx: mpsc::Receiver<()>) {
let inbound_tx = self.inbound_tx.clone();
let online = self.online.load(Ordering::SeqCst);
if !online {
return;
}
tokio::spawn(async move {
let mut buf = vec![0u8; 65535];
loop {
tokio::select! {
result = socket.recv_from(&mut buf) => {
match result {
Ok((len, source)) => {
let datagram = InboundDatagram {
data: buf[..len].to_vec(),
source: TransportAddr::Udp(source),
received_at: Instant::now(),
link_quality: None,
};
let _ = inbound_tx.try_send(datagram);
}
Err(_) => {
continue;
}
}
}
_ = shutdown_rx.recv() => {
break;
}
}
}
});
}
pub fn socket(&self) -> &Arc<UdpSocket> {
&self.socket
}
pub fn local_address(&self) -> SocketAddr {
self.local_addr
}
}
#[async_trait]
impl TransportProvider for UdpTransport {
fn name(&self) -> &str {
"UDP"
}
fn transport_type(&self) -> TransportType {
TransportType::Udp
}
fn capabilities(&self) -> &TransportCapabilities {
&self.capabilities
}
fn local_addr(&self) -> Option<TransportAddr> {
Some(TransportAddr::Udp(self.local_addr))
}
async fn send(&self, data: &[u8], dest: &TransportAddr) -> Result<(), TransportError> {
if !self.online.load(Ordering::SeqCst) {
return Err(TransportError::Offline);
}
let socket_addr = match dest {
TransportAddr::Udp(addr) => *addr,
_ => {
return Err(TransportError::AddressMismatch {
expected: TransportType::Udp,
actual: dest.transport_type(),
});
}
};
if data.len() > self.capabilities.mtu {
return Err(TransportError::MessageTooLarge {
size: data.len(),
mtu: self.capabilities.mtu,
});
}
match self.socket.send_to(data, socket_addr).await {
Ok(sent) => {
self.stats.datagrams_sent.fetch_add(1, Ordering::Relaxed);
self.stats
.bytes_sent
.fetch_add(sent as u64, Ordering::Relaxed);
Ok(())
}
Err(e) => {
self.stats.send_errors.fetch_add(1, Ordering::Relaxed);
Err(TransportError::SendFailed {
reason: e.to_string(),
})
}
}
}
fn inbound(&self) -> mpsc::Receiver<InboundDatagram> {
let (_, rx) = mpsc::channel(1024);
rx
}
fn is_online(&self) -> bool {
self.online.load(Ordering::SeqCst)
}
async fn shutdown(&self) -> Result<(), TransportError> {
self.online.store(false, Ordering::SeqCst);
let _ = self.shutdown_tx.send(()).await;
Ok(())
}
async fn broadcast(&self, data: &[u8]) -> Result<(), TransportError> {
if !self.capabilities.broadcast {
return Err(TransportError::BroadcastNotSupported);
}
let broadcast_addr = SocketAddr::new(
std::net::IpAddr::V4(std::net::Ipv4Addr::BROADCAST),
self.local_addr.port(),
);
self.send(data, &TransportAddr::Udp(broadcast_addr)).await
}
async fn link_quality(&self, _peer: &TransportAddr) -> Option<LinkQuality> {
None
}
fn stats(&self) -> TransportStats {
TransportStats {
datagrams_sent: self.stats.datagrams_sent.load(Ordering::Relaxed),
datagrams_received: self.stats.datagrams_received.load(Ordering::Relaxed),
bytes_sent: self.stats.bytes_sent.load(Ordering::Relaxed),
bytes_received: self.stats.bytes_received.load(Ordering::Relaxed),
send_errors: self.stats.send_errors.load(Ordering::Relaxed),
receive_errors: self.stats.receive_errors.load(Ordering::Relaxed),
current_rtt: None,
}
}
fn socket(&self) -> Option<&Arc<UdpSocket>> {
Some(&self.socket)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_udp_transport_bind() {
let transport = UdpTransport::bind("127.0.0.1:0".parse().unwrap())
.await
.unwrap();
assert!(transport.is_online());
assert_eq!(transport.transport_type(), TransportType::Udp);
assert!(transport.capabilities().supports_full_quic());
let local_addr = transport.local_addr();
assert!(local_addr.is_some());
if let Some(TransportAddr::Udp(addr)) = local_addr {
assert_eq!(
addr.ip(),
std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)
);
assert_ne!(addr.port(), 0);
}
}
#[tokio::test]
async fn test_udp_transport_send() {
let transport1 = UdpTransport::bind("127.0.0.1:0".parse().unwrap())
.await
.unwrap();
let transport2 = UdpTransport::bind("127.0.0.1:0".parse().unwrap())
.await
.unwrap();
let dest = transport2.local_addr().unwrap();
let result = transport1.send(b"hello", &dest).await;
assert!(result.is_ok());
let stats = transport1.stats();
assert_eq!(stats.datagrams_sent, 1);
assert_eq!(stats.bytes_sent, 5);
}
#[tokio::test]
async fn test_udp_transport_address_mismatch() {
let transport = UdpTransport::bind("127.0.0.1:0".parse().unwrap())
.await
.unwrap();
let ble_addr = TransportAddr::ble([0x00, 0x11, 0x22, 0x33, 0x44, 0x55], None);
let result = transport.send(b"hello", &ble_addr).await;
match result {
Err(TransportError::AddressMismatch { expected, actual }) => {
assert_eq!(expected, TransportType::Udp);
assert_eq!(actual, TransportType::Ble);
}
_ => panic!("expected AddressMismatch error"),
}
}
#[tokio::test]
async fn test_udp_transport_shutdown() {
let transport = UdpTransport::bind("127.0.0.1:0".parse().unwrap())
.await
.unwrap();
assert!(transport.is_online());
transport.shutdown().await.unwrap();
assert!(!transport.is_online());
let dest = TransportAddr::Udp("127.0.0.1:9000".parse().unwrap());
let result = transport.send(b"hello", &dest).await;
assert!(matches!(result, Err(TransportError::Offline)));
}
#[test]
fn test_udp_capabilities() {
let caps = TransportCapabilities::broadband();
assert!(caps.supports_full_quic());
assert!(!caps.half_duplex);
assert!(caps.broadcast);
assert!(!caps.metered);
assert!(!caps.power_constrained);
}
#[tokio::test]
async fn test_udp_transport_socket_accessor() {
let transport = UdpTransport::bind("127.0.0.1:0".parse().unwrap())
.await
.unwrap();
let socket_ref = transport.socket();
assert!(socket_ref.local_addr().is_ok());
let provider: &dyn TransportProvider = &transport;
let socket_opt = provider.socket();
assert!(socket_opt.is_some());
assert!(socket_opt.unwrap().local_addr().is_ok());
}
#[tokio::test]
async fn test_bind_for_quinn_ipv6_dual_stack() {
let addr: SocketAddr = "[::]:0".parse().unwrap();
let (transport, quinn_socket) = UdpTransport::bind_for_quinn(addr).await.unwrap();
let local = quinn_socket.local_addr().unwrap();
assert!(local.is_ipv6(), "expected IPv6 address, got {local}");
assert_ne!(local.port(), 0, "port should be assigned by OS");
let transport_addr = transport.local_address();
assert!(transport_addr.is_ipv6());
assert_eq!(transport_addr.port(), local.port());
}
#[tokio::test]
async fn test_bind_for_quinn_ipv4_explicit() {
let addr: SocketAddr = "0.0.0.0:0".parse().unwrap();
let (transport, quinn_socket) = UdpTransport::bind_for_quinn(addr).await.unwrap();
let local = quinn_socket.local_addr().unwrap();
assert!(local.is_ipv4(), "expected IPv4 address, got {local}");
assert_ne!(local.port(), 0);
assert_eq!(transport.local_address().port(), local.port());
}
#[cfg(not(target_os = "windows"))]
#[tokio::test]
async fn test_dual_stack_socket_can_send_to_ipv4_mapped() {
let receiver = std::net::UdpSocket::bind("[::]:0").unwrap();
receiver.set_nonblocking(true).unwrap();
let recv_port = receiver.local_addr().unwrap().port();
let addr: SocketAddr = "[::]:0".parse().unwrap();
let (transport, _quinn_socket) = UdpTransport::bind_for_quinn(addr).await.unwrap();
let ipv4_mapped: SocketAddr = format!("[::ffff:127.0.0.1]:{recv_port}").parse().unwrap();
let dest = TransportAddr::Udp(ipv4_mapped);
transport.send(b"dual-stack-test", &dest).await.unwrap();
let mut buf = [0u8; 64];
let mut received = false;
for _ in 0..50 {
match receiver.recv_from(&mut buf) {
Ok((len, _src)) => {
assert_eq!(&buf[..len], b"dual-stack-test");
received = true;
break;
}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
std::thread::sleep(std::time::Duration::from_millis(10));
}
Err(e) => panic!("unexpected recv error: {e}"),
}
}
assert!(received, "receiver did not get the dual-stack datagram");
}
#[tokio::test]
async fn test_dual_stack_socket_can_receive_from_ipv4_sender() {
let addr: SocketAddr = "[::]:0".parse().unwrap();
let (_transport, quinn_socket) = UdpTransport::bind_for_quinn(addr).await.unwrap();
let recv_port = quinn_socket.local_addr().unwrap().port();
quinn_socket.set_nonblocking(true).unwrap();
let sender = std::net::UdpSocket::bind("127.0.0.1:0").unwrap();
let dest: SocketAddr = format!("127.0.0.1:{recv_port}").parse().unwrap();
sender.send_to(b"from-ipv4", dest).unwrap();
let mut buf = [0u8; 64];
let mut received = false;
for _ in 0..50 {
match quinn_socket.recv_from(&mut buf) {
Ok((len, src)) => {
assert_eq!(&buf[..len], b"from-ipv4");
let src_ip = src.ip();
let is_loopback = match src_ip {
std::net::IpAddr::V4(v4) => v4.is_loopback(),
std::net::IpAddr::V6(v6) => {
v6.to_ipv4_mapped()
.map(|v4| v4.is_loopback())
.unwrap_or(false)
}
};
assert!(is_loopback, "source should be loopback, got {src_ip}");
received = true;
break;
}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
std::thread::sleep(std::time::Duration::from_millis(10));
}
Err(e) => panic!("unexpected recv error: {e}"),
}
}
assert!(received, "dual-stack socket did not receive IPv4 datagram");
}
#[tokio::test]
async fn test_dual_stack_socket_can_communicate_ipv6() {
let addr: SocketAddr = "[::]:0".parse().unwrap();
let (_transport, quinn_socket) = UdpTransport::bind_for_quinn(addr).await.unwrap();
let recv_port = quinn_socket.local_addr().unwrap().port();
quinn_socket.set_nonblocking(true).unwrap();
let sender = std::net::UdpSocket::bind("[::1]:0").unwrap();
let dest: SocketAddr = format!("[::1]:{recv_port}").parse().unwrap();
sender.send_to(b"from-ipv6", dest).unwrap();
let mut buf = [0u8; 64];
let mut received = false;
for _ in 0..50 {
match quinn_socket.recv_from(&mut buf) {
Ok((len, src)) => {
assert_eq!(&buf[..len], b"from-ipv6");
let is_v6_loopback = match src.ip() {
std::net::IpAddr::V6(v6) => v6 == std::net::Ipv6Addr::LOCALHOST,
_ => false,
};
assert!(is_v6_loopback, "source should be ::1, got {}", src.ip());
received = true;
break;
}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
std::thread::sleep(std::time::Duration::from_millis(10));
}
Err(e) => panic!("unexpected recv error: {e}"),
}
}
assert!(received, "dual-stack socket did not receive IPv6 datagram");
}
#[tokio::test]
async fn test_bind_for_quinn_with_specific_port() {
let addr: SocketAddr = "[::]:0".parse().unwrap();
let (_, socket1) = UdpTransport::bind_for_quinn(addr).await.unwrap();
let port = socket1.local_addr().unwrap().port();
assert!(port > 0);
let specific: SocketAddr = format!("[::]:{port}").parse().unwrap();
let result = UdpTransport::bind_for_quinn(specific).await;
assert!(result.is_err(), "binding to same port should fail");
}
#[cfg(all(feature = "network-discovery", not(target_os = "windows")))]
#[test]
fn test_create_socket_for_quinn_dual_stack_flag() {
use socket2::Socket;
let addr: SocketAddr = "[::]:0".parse().unwrap();
let std_socket = UdpTransport::create_socket_for_quinn(addr).unwrap();
let socket2_sock = Socket::from(std_socket);
let only_v6 = socket2_sock.only_v6().unwrap();
assert!(!only_v6, "IPV6_V6ONLY should be false (dual-stack enabled)");
}
}