#![cfg(any(target_os = "linux", target_os = "macos"))]
use super::super::{ReceivedPacket, TransportAddr, TransportId};
use super::PacketTx;
use super::connected_peer::ConnectedPeerSocket;
use std::io;
use std::net::SocketAddr;
use std::os::unix::io::{AsRawFd, RawFd};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use tracing::{debug, trace, warn};
#[derive(Debug)]
pub(crate) struct PeerRecvDrain {
stop_pipe_tx: RawFd,
stop: Arc<AtomicBool>,
join: Option<std::thread::JoinHandle<()>>,
}
impl PeerRecvDrain {
pub fn spawn(
socket: Arc<ConnectedPeerSocket>,
transport_id: TransportId,
peer_addr: SocketAddr,
packet_tx: PacketTx,
) -> io::Result<Self> {
let (pipe_rx, pipe_tx) = make_pipe()?;
let stop = Arc::new(AtomicBool::new(false));
let stop_clone = stop.clone();
let socket_clone = socket.clone();
let thread = std::thread::Builder::new()
.name(format!("fips-peer-drain-{}", socket.peer_addr()))
.spawn(move || {
drain_loop(
socket_clone,
transport_id,
peer_addr,
packet_tx,
pipe_rx,
stop_clone,
);
unsafe { libc::close(pipe_rx) };
});
match thread {
Ok(join) => Ok(Self {
stop_pipe_tx: pipe_tx,
stop,
join: Some(join),
}),
Err(e) => {
unsafe {
libc::close(pipe_rx);
libc::close(pipe_tx);
}
Err(io::Error::other(format!(
"failed to spawn peer drain thread: {e}"
)))
}
}
}
}
impl Drop for PeerRecvDrain {
fn drop(&mut self) {
self.stop.store(true, Ordering::Release);
let byte = 1u8;
let _ = unsafe { libc::write(self.stop_pipe_tx, &byte as *const _ as *const _, 1) };
if let Some(j) = self.join.take() {
let _ = j.join();
}
unsafe { libc::close(self.stop_pipe_tx) };
}
}
fn drain_loop(
socket: Arc<ConnectedPeerSocket>,
transport_id: TransportId,
peer_addr: SocketAddr,
packet_tx: PacketTx,
stop_pipe_rx: RawFd,
stop: Arc<AtomicBool>,
) {
let socket_fd = socket.as_raw_fd();
trace!(
transport_id = %transport_id,
peer_addr = %peer_addr,
"fips-peer-drain: starting"
);
const BATCH: usize = 32;
const BUF_SIZE: usize = 1600; let mut backing: Vec<Vec<u8>> = (0..BATCH).map(|_| vec![0u8; BUF_SIZE]).collect();
let mut lens: [usize; BATCH] = [0; BATCH];
let packet_addr = TransportAddr::from_socket_addr(peer_addr);
loop {
if stop.load(Ordering::Acquire) {
break;
}
let mut pfds = [
libc::pollfd {
fd: socket_fd,
events: libc::POLLIN,
revents: 0,
},
libc::pollfd {
fd: stop_pipe_rx,
events: libc::POLLIN,
revents: 0,
},
];
let r = unsafe { libc::poll(pfds.as_mut_ptr(), 2, -1) };
if r < 0 {
let err = io::Error::last_os_error();
if err.kind() == io::ErrorKind::Interrupted {
continue;
}
warn!(error = %err, "fips-peer-drain: poll failed; exiting");
break;
}
if pfds[1].revents != 0 {
if stop.load(Ordering::Acquire) {
break;
}
}
if pfds[0].revents & libc::POLLIN == 0 {
continue;
}
let n = drain_packets(socket_fd, &mut backing, &mut lens);
let count = match n {
Ok(c) => c,
Err(err) if err.kind() == io::ErrorKind::WouldBlock => continue,
Err(err) => {
debug!(error = %err, "fips-peer-drain: recv failed; exiting");
break;
}
};
for i in 0..count {
let len = lens[i];
if len == 0 {
continue;
}
let mut data = std::mem::replace(&mut backing[i], vec![0u8; BUF_SIZE]);
data.truncate(len);
let packet = ReceivedPacket::new(transport_id, packet_addr.clone(), data);
if packet_tx.send(packet).is_err() {
trace!("fips-peer-drain: packet channel closed; exiting");
return;
}
}
}
trace!(
transport_id = %transport_id,
peer_addr = %peer_addr,
"fips-peer-drain: stopped"
);
}
fn make_pipe() -> io::Result<(RawFd, RawFd)> {
let mut pipe_fds = [0i32; 2];
#[cfg(target_os = "linux")]
{
let r = unsafe { libc::pipe2(pipe_fds.as_mut_ptr(), libc::O_CLOEXEC | libc::O_NONBLOCK) };
if r < 0 {
return Err(io::Error::last_os_error());
}
}
#[cfg(not(target_os = "linux"))]
{
let r = unsafe { libc::pipe(pipe_fds.as_mut_ptr()) };
if r < 0 {
return Err(io::Error::last_os_error());
}
if let Err(err) = set_nonblocking_cloexec(pipe_fds[0]) {
unsafe {
libc::close(pipe_fds[0]);
libc::close(pipe_fds[1]);
}
return Err(err);
}
if let Err(err) = set_nonblocking_cloexec(pipe_fds[1]) {
unsafe {
libc::close(pipe_fds[0]);
libc::close(pipe_fds[1]);
}
return Err(err);
}
}
Ok((pipe_fds[0], pipe_fds[1]))
}
#[cfg(not(target_os = "linux"))]
fn set_nonblocking_cloexec(fd: RawFd) -> io::Result<()> {
let flags = unsafe { libc::fcntl(fd, libc::F_GETFL) };
if flags < 0 {
return Err(io::Error::last_os_error());
}
if unsafe { libc::fcntl(fd, libc::F_SETFL, flags | libc::O_NONBLOCK) } < 0 {
return Err(io::Error::last_os_error());
}
let fd_flags = unsafe { libc::fcntl(fd, libc::F_GETFD) };
if fd_flags < 0 {
return Err(io::Error::last_os_error());
}
if unsafe { libc::fcntl(fd, libc::F_SETFD, fd_flags | libc::FD_CLOEXEC) } < 0 {
return Err(io::Error::last_os_error());
}
Ok(())
}
#[cfg(target_os = "linux")]
fn drain_packets(fd: RawFd, backing: &mut [Vec<u8>], lens: &mut [usize]) -> io::Result<usize> {
recvmmsg_drain(fd, backing, lens)
}
#[cfg(not(target_os = "linux"))]
fn drain_packets(fd: RawFd, backing: &mut [Vec<u8>], lens: &mut [usize]) -> io::Result<usize> {
recv_drain(fd, backing, lens)
}
#[cfg(target_os = "linux")]
fn recvmmsg_drain(fd: RawFd, backing: &mut [Vec<u8>], lens: &mut [usize]) -> io::Result<usize> {
const BATCH: usize = 32;
let n = backing.len().min(lens.len()).min(BATCH);
if n == 0 {
return Ok(0);
}
let mut iovs: [libc::iovec; BATCH] = unsafe { std::mem::zeroed() };
let mut storages: [libc::sockaddr_storage; BATCH] = unsafe { std::mem::zeroed() };
let mut msgs: [libc::mmsghdr; BATCH] = unsafe { std::mem::zeroed() };
for i in 0..n {
iovs[i].iov_base = backing[i].as_mut_ptr() as *mut libc::c_void;
iovs[i].iov_len = backing[i].len();
msgs[i].msg_hdr.msg_name = &mut storages[i] as *mut _ as *mut libc::c_void;
msgs[i].msg_hdr.msg_namelen =
std::mem::size_of::<libc::sockaddr_storage>() as libc::socklen_t;
msgs[i].msg_hdr.msg_iov = &mut iovs[i];
msgs[i].msg_hdr.msg_iovlen = 1 as _;
msgs[i].msg_len = 0;
}
let r = unsafe {
libc::recvmmsg(
fd,
msgs.as_mut_ptr(),
n as libc::c_uint,
libc::MSG_DONTWAIT as _,
std::ptr::null_mut(),
)
};
if r < 0 {
return Err(io::Error::last_os_error());
}
let count = r as usize;
for i in 0..count {
lens[i] = msgs[i].msg_len as usize;
}
Ok(count)
}
#[cfg(not(target_os = "linux"))]
fn recv_drain(fd: RawFd, backing: &mut [Vec<u8>], lens: &mut [usize]) -> io::Result<usize> {
let n = backing.len().min(lens.len());
if n == 0 {
return Ok(0);
}
let mut count = 0usize;
while count < n {
let r = unsafe {
libc::recv(
fd,
backing[count].as_mut_ptr() as *mut libc::c_void,
backing[count].len(),
0,
)
};
if r < 0 {
let err = io::Error::last_os_error();
if err.kind() == io::ErrorKind::Interrupted {
continue;
}
if err.kind() == io::ErrorKind::WouldBlock && count > 0 {
return Ok(count);
}
return Err(err);
}
lens[count] = r as usize;
count += 1;
}
Ok(count)
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::UdpSocket;
use std::time::Duration;
use tokio::sync::mpsc;
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn drain_delivers_packets_to_packet_tx() {
let peer = UdpSocket::bind("127.0.0.1:0").expect("bind peer");
let peer_addr = peer.local_addr().expect("peer local_addr");
let local_addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
let socket = Arc::new(
ConnectedPeerSocket::open(local_addr, peer_addr, 1 << 20, 1 << 20)
.expect("ConnectedPeerSocket::open"),
);
let (tx, mut rx) = mpsc::unbounded_channel::<ReceivedPacket>();
let transport_id = TransportId::new(42);
let our_local_addr: SocketAddr = {
let mut storage: libc::sockaddr_storage = unsafe { std::mem::zeroed() };
let mut len = std::mem::size_of::<libc::sockaddr_storage>() as libc::socklen_t;
let r = unsafe {
libc::getsockname(
socket.as_raw_fd(),
&mut storage as *mut _ as *mut libc::sockaddr,
&mut len,
)
};
assert!(r >= 0, "getsockname failed");
assert_eq!(
storage.ss_family as i32,
libc::AF_INET,
"test assumes IPv4 loopback"
);
let sin: &libc::sockaddr_in =
unsafe { &*(&storage as *const _ as *const libc::sockaddr_in) };
let port = u16::from_be(sin.sin_port);
let ip = std::net::Ipv4Addr::from(u32::from_be(sin.sin_addr.s_addr));
SocketAddr::from((ip, port))
};
let _drain = PeerRecvDrain::spawn(socket.clone(), transport_id, peer_addr, tx)
.expect("PeerRecvDrain::spawn");
for i in 0u8..5 {
let payload = [i, 0xAA, 0xBB, 0xCC];
peer.send_to(&payload, our_local_addr).expect("peer sendto");
}
for i in 0u8..5 {
let pkt = tokio::time::timeout(Duration::from_millis(500), rx.recv())
.await
.unwrap_or_else(|_| panic!("timeout waiting for packet {i}"))
.expect("packet channel closed");
assert_eq!(pkt.transport_id, transport_id);
assert_eq!(pkt.data.len(), 4);
assert_eq!(pkt.data[0], i, "packet {i} payload mismatch");
}
}
}