#![cfg(any(target_os = "linux", target_os = "macos"))]
use std::io;
use std::net::SocketAddr;
use std::os::unix::io::{AsRawFd, RawFd};
#[derive(Debug)]
pub(crate) struct ConnectedPeerSocket {
fd: RawFd,
peer_addr: SocketAddr,
local_addr: SocketAddr,
}
impl ConnectedPeerSocket {
pub fn open(
local_addr: SocketAddr,
peer_addr: SocketAddr,
recv_buf: usize,
send_buf: usize,
) -> io::Result<Self> {
if local_addr.is_ipv4() != peer_addr.is_ipv4() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"ConnectedPeerSocket: local + peer address families differ",
));
}
let domain = if local_addr.is_ipv4() {
libc::AF_INET
} else {
libc::AF_INET6
};
#[cfg(target_os = "linux")]
let typ = libc::SOCK_DGRAM | libc::SOCK_NONBLOCK | libc::SOCK_CLOEXEC;
#[cfg(not(target_os = "linux"))]
let typ = libc::SOCK_DGRAM;
let fd = unsafe { libc::socket(domain, typ, libc::IPPROTO_UDP) };
if fd < 0 {
return Err(io::Error::last_os_error());
}
let sock = ConnectedPeerSocket {
fd,
peer_addr,
local_addr,
};
#[cfg(not(target_os = "linux"))]
sock.set_nonblocking_cloexec()?;
sock.set_sockopt_int(libc::SOL_SOCKET, libc::SO_REUSEADDR, 1)?;
sock.set_sockopt_int(libc::SOL_SOCKET, libc::SO_REUSEPORT, 1)?;
#[cfg(target_os = "macos")]
crate::transport::udp::darwin_sockopts::apply_udp_socket_tuning(
sock.fd,
"connected-udp-peer",
);
#[cfg(target_os = "linux")]
{
sock.set_buf_size(libc::SO_RCVBUFFORCE, libc::SO_RCVBUF, recv_buf);
sock.set_buf_size(libc::SO_SNDBUFFORCE, libc::SO_SNDBUF, send_buf);
}
#[cfg(not(target_os = "linux"))]
{
sock.set_buf_size(libc::SO_RCVBUF, recv_buf);
sock.set_buf_size(libc::SO_SNDBUF, send_buf);
}
let local_sa: socket2::SockAddr = local_addr.into();
let bind_r = unsafe {
libc::bind(
sock.fd,
local_sa.as_ptr() as *const libc::sockaddr,
local_sa.len(),
)
};
if bind_r < 0 {
return Err(io::Error::last_os_error());
}
let peer_sa: socket2::SockAddr = peer_addr.into();
let conn_r = unsafe {
libc::connect(
sock.fd,
peer_sa.as_ptr() as *const libc::sockaddr,
peer_sa.len(),
)
};
if conn_r < 0 {
return Err(io::Error::last_os_error());
}
Ok(sock)
}
#[cfg(not(target_os = "linux"))]
fn set_nonblocking_cloexec(&self) -> io::Result<()> {
let flags = unsafe { libc::fcntl(self.fd, libc::F_GETFL) };
if flags < 0 {
return Err(io::Error::last_os_error());
}
if unsafe { libc::fcntl(self.fd, libc::F_SETFL, flags | libc::O_NONBLOCK) } < 0 {
return Err(io::Error::last_os_error());
}
let fd_flags = unsafe { libc::fcntl(self.fd, libc::F_GETFD) };
if fd_flags < 0 {
return Err(io::Error::last_os_error());
}
if unsafe { libc::fcntl(self.fd, libc::F_SETFD, fd_flags | libc::FD_CLOEXEC) } < 0 {
return Err(io::Error::last_os_error());
}
Ok(())
}
fn set_sockopt_int(
&self,
level: libc::c_int,
name: libc::c_int,
value: libc::c_int,
) -> io::Result<()> {
let r = unsafe {
libc::setsockopt(
self.fd,
level,
name,
&value as *const _ as *const libc::c_void,
std::mem::size_of::<libc::c_int>() as libc::socklen_t,
)
};
if r < 0 {
Err(io::Error::last_os_error())
} else {
Ok(())
}
}
#[cfg(target_os = "linux")]
fn set_buf_size(&self, force_name: libc::c_int, normal_name: libc::c_int, size: usize) {
let value: libc::c_int = size as libc::c_int;
let r = unsafe {
libc::setsockopt(
self.fd,
libc::SOL_SOCKET,
force_name,
&value as *const _ as *const libc::c_void,
std::mem::size_of::<libc::c_int>() as libc::socklen_t,
)
};
if r < 0 {
let _ = unsafe {
libc::setsockopt(
self.fd,
libc::SOL_SOCKET,
normal_name,
&value as *const _ as *const libc::c_void,
std::mem::size_of::<libc::c_int>() as libc::socklen_t,
)
};
}
}
#[cfg(not(target_os = "linux"))]
fn set_buf_size(&self, normal_name: libc::c_int, size: usize) {
let value: libc::c_int = size as libc::c_int;
let _ = unsafe {
libc::setsockopt(
self.fd,
libc::SOL_SOCKET,
normal_name,
&value as *const _ as *const libc::c_void,
std::mem::size_of::<libc::c_int>() as libc::socklen_t,
)
};
}
pub fn peer_addr(&self) -> SocketAddr {
self.peer_addr
}
#[allow(dead_code)] pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
}
impl AsRawFd for ConnectedPeerSocket {
fn as_raw_fd(&self) -> RawFd {
self.fd
}
}
impl Drop for ConnectedPeerSocket {
fn drop(&mut self) {
unsafe {
libc::close(self.fd);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::UdpSocket;
#[test]
fn open_send_recv_loopback() {
let peer = UdpSocket::bind("127.0.0.1:0").expect("bind peer");
let peer_addr = peer.local_addr().expect("peer local_addr");
peer.set_read_timeout(Some(std::time::Duration::from_millis(500)))
.expect("set_read_timeout");
let local_addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
let sock = ConnectedPeerSocket::open(
local_addr,
peer_addr,
1 << 20,
1 << 20,
)
.expect("ConnectedPeerSocket::open");
let payload = b"hello-from-connected-socket";
let r = unsafe {
libc::send(
sock.as_raw_fd(),
payload.as_ptr() as *const libc::c_void,
payload.len(),
0,
)
};
assert!(r >= 0, "send failed: {}", std::io::Error::last_os_error());
assert_eq!(r as usize, payload.len());
let mut recv_buf = [0u8; 64];
let (len, from) = peer.recv_from(&mut recv_buf).expect("peer recv");
assert_eq!(len, payload.len());
assert_eq!(&recv_buf[..len], payload);
let reply = b"hello-back";
peer.send_to(reply, from).expect("peer send_to");
let deadline = std::time::Instant::now() + std::time::Duration::from_millis(500);
loop {
let mut buf = [0u8; 64];
let r = unsafe {
libc::recv(
sock.as_raw_fd(),
buf.as_mut_ptr() as *mut libc::c_void,
buf.len(),
0,
)
};
if r >= 0 {
assert_eq!(r as usize, reply.len());
assert_eq!(&buf[..r as usize], reply);
break;
}
let err = std::io::Error::last_os_error();
if err.kind() == std::io::ErrorKind::WouldBlock {
if std::time::Instant::now() >= deadline {
panic!("connected socket never received reply");
}
std::thread::sleep(std::time::Duration::from_millis(2));
continue;
}
panic!("recv failed: {err}");
}
}
#[test]
fn two_connected_sockets_share_listen_port() {
let peer_a = UdpSocket::bind("127.0.0.1:0").expect("bind peer_a");
let peer_b = UdpSocket::bind("127.0.0.1:0").expect("bind peer_b");
let peer_a_addr = peer_a.local_addr().expect("peer_a local_addr");
let peer_b_addr = peer_b.local_addr().expect("peer_b local_addr");
let anchor = UdpSocket::bind("127.0.0.1:0").expect("bind anchor");
let shared_port = anchor.local_addr().expect("anchor local_addr").port();
let shared_local: SocketAddr = format!("127.0.0.1:{shared_port}").parse().unwrap();
drop(anchor);
let sock_a = ConnectedPeerSocket::open(shared_local, peer_a_addr, 1 << 20, 1 << 20)
.expect("open sock_a");
let sock_b = ConnectedPeerSocket::open(shared_local, peer_b_addr, 1 << 20, 1 << 20)
.expect("open sock_b");
assert_eq!(sock_a.peer_addr(), peer_a_addr);
assert_eq!(sock_b.peer_addr(), peer_b_addr);
}
#[test]
fn connected_socket_shares_live_listener_port() {
let peer = UdpSocket::bind("127.0.0.1:0").expect("bind peer");
let peer_addr = peer.local_addr().expect("peer local_addr");
let listener = socket2::Socket::new(
socket2::Domain::IPV4,
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)
.expect("create listener");
listener
.set_reuse_address(true)
.expect("listener reuseaddr");
listener.set_reuse_port(true).expect("listener reuseport");
listener
.bind(&"0.0.0.0:0".parse::<SocketAddr>().unwrap().into())
.expect("bind listener");
let local = listener
.local_addr()
.expect("listener local addr")
.as_socket()
.expect("ip socket");
let sock = ConnectedPeerSocket::open(local, peer_addr, 1 << 20, 1 << 20)
.expect("open connected sibling");
assert_eq!(sock.local_addr(), local);
assert_eq!(sock.peer_addr(), peer_addr);
}
}