#![expect(
clippy::undocumented_unsafe_blocks,
reason = "module-wide libc-syscall + POD-zero-init contract documented in the # Safety section above"
)]
#![expect(
clippy::multiple_unsafe_ops_per_block,
reason = "Linux syscall wrappers compose pointer arithmetic + libc calls in single semantic operations (one batched sendmmsg / recvmmsg / configure_socket call)"
)]
use bytes::{Bytes, BytesMut};
use std::io;
use std::net::SocketAddr;
use std::os::unix::io::RawFd;
use super::protocol::MAX_PACKET_SIZE;
pub const MAX_BATCH_SIZE: usize = 64;
pub struct BatchedTransport {
socket_fd: RawFd,
iovecs: Vec<libc::iovec>,
msgs: Vec<libc::mmsghdr>,
addrs: Vec<libc::sockaddr_in>,
recv_buffers: Vec<BytesMut>,
}
unsafe impl Send for BatchedTransport {}
const _: fn() = || {
fn assert_send<T: Send>() {}
assert_send::<BatchedTransport>();
};
impl BatchedTransport {
pub fn new(socket_fd: RawFd) -> Self {
Self::new_inner(socket_fd, true)
}
pub fn new_send_only(socket_fd: RawFd) -> Self {
Self::new_inner(socket_fd, false)
}
fn new_inner(socket_fd: RawFd, with_recv_buffers: bool) -> Self {
let mut iovecs = Vec::with_capacity(MAX_BATCH_SIZE);
let mut msgs = Vec::with_capacity(MAX_BATCH_SIZE);
let mut addrs = Vec::with_capacity(MAX_BATCH_SIZE);
let mut recv_buffers = if with_recv_buffers {
Vec::with_capacity(MAX_BATCH_SIZE)
} else {
Vec::new()
};
for _ in 0..MAX_BATCH_SIZE {
iovecs.push(libc::iovec {
iov_base: std::ptr::null_mut(),
iov_len: 0,
});
addrs.push(unsafe { std::mem::zeroed() });
msgs.push(unsafe { std::mem::zeroed() });
if with_recv_buffers {
recv_buffers.push(BytesMut::with_capacity(MAX_PACKET_SIZE));
}
}
Self {
socket_fd,
iovecs,
msgs,
addrs,
recv_buffers,
}
}
pub fn send_batch(&mut self, packets: &[Bytes], target: SocketAddr) -> io::Result<usize> {
if packets.is_empty() {
return Ok(0);
}
let target_addr = match target {
SocketAddr::V4(addr) => {
let mut sockaddr: libc::sockaddr_in = unsafe { std::mem::zeroed() };
sockaddr.sin_family = libc::AF_INET as u16;
sockaddr.sin_port = addr.port().to_be();
sockaddr.sin_addr.s_addr = u32::from_ne_bytes(addr.ip().octets());
sockaddr
}
SocketAddr::V6(_) => {
return Err(io::Error::new(
io::ErrorKind::Unsupported,
"IPv6 not yet supported for batched I/O",
));
}
};
let mut total_sent: usize = 0;
for chunk_start in (0..packets.len()).step_by(MAX_BATCH_SIZE) {
let chunk_end = (chunk_start + MAX_BATCH_SIZE).min(packets.len());
let chunk_len = chunk_end - chunk_start;
let chunk_sent =
self.send_batch_chunk(&packets[chunk_start..chunk_end], &target_addr)?;
total_sent += chunk_sent;
if chunk_sent < chunk_len {
return Ok(total_sent);
}
}
Ok(total_sent)
}
fn send_batch_chunk(
&mut self,
packets: &[Bytes],
target_addr: &libc::sockaddr_in,
) -> io::Result<usize> {
debug_assert!(packets.len() <= MAX_BATCH_SIZE);
let total = packets.len();
if total == 0 {
return Ok(0);
}
for (i, packet) in packets.iter().enumerate() {
self.iovecs[i] = libc::iovec {
iov_base: packet.as_ptr() as *mut _,
iov_len: packet.len(),
};
self.addrs[i] = *target_addr;
self.msgs[i].msg_hdr = unsafe { std::mem::zeroed() };
self.msgs[i].msg_hdr.msg_name = &mut self.addrs[i] as *mut _ as *mut _;
self.msgs[i].msg_hdr.msg_namelen = std::mem::size_of::<libc::sockaddr_in>() as u32;
self.msgs[i].msg_hdr.msg_iov = &mut self.iovecs[i];
self.msgs[i].msg_hdr.msg_iovlen = 1;
self.msgs[i].msg_len = 0;
}
let mut sent_so_far: usize = 0;
while sent_so_far < total {
let remaining = total - sent_so_far;
let sent = unsafe {
libc::sendmmsg(
self.socket_fd,
self.msgs.as_mut_ptr().add(sent_so_far),
remaining as u32,
0,
)
};
if sent < 0 {
let err = io::Error::last_os_error();
if err.kind() == io::ErrorKind::Interrupted {
continue;
}
if sent_so_far > 0 {
return Ok(sent_so_far);
}
return Err(err);
}
let sent = sent as usize;
if sent == 0 {
break;
}
sent_so_far += sent;
}
Ok(sent_so_far)
}
pub fn recv_batch(&mut self, max_count: usize) -> io::Result<Vec<(Bytes, SocketAddr)>> {
let count = max_count.min(MAX_BATCH_SIZE);
if count == 0 {
return Ok(Vec::new());
}
if self.recv_buffers.is_empty() {
return Err(io::Error::new(
io::ErrorKind::Unsupported,
"BatchedTransport constructed via `new_send_only` cannot \
receive packets — use `new` if recv is needed",
));
}
for i in 0..count {
self.recv_buffers[i].clear();
self.recv_buffers[i].reserve(MAX_PACKET_SIZE);
unsafe {
self.recv_buffers[i].set_len(MAX_PACKET_SIZE);
}
self.iovecs[i] = libc::iovec {
iov_base: self.recv_buffers[i].as_mut_ptr() as *mut _,
iov_len: MAX_PACKET_SIZE,
};
self.addrs[i] = unsafe { std::mem::zeroed() };
self.msgs[i].msg_hdr = unsafe { std::mem::zeroed() };
self.msgs[i].msg_hdr.msg_name = &mut self.addrs[i] as *mut _ as *mut _;
self.msgs[i].msg_hdr.msg_namelen = std::mem::size_of::<libc::sockaddr_in>() as u32;
self.msgs[i].msg_hdr.msg_iov = &mut self.iovecs[i];
self.msgs[i].msg_hdr.msg_iovlen = 1;
self.msgs[i].msg_len = 0;
}
let received = unsafe {
libc::recvmmsg(
self.socket_fd,
self.msgs.as_mut_ptr(),
count as u32,
libc::MSG_DONTWAIT as _,
std::ptr::null_mut(),
)
};
if received < 0 {
let err = io::Error::last_os_error();
if err.kind() == io::ErrorKind::WouldBlock {
return Ok(Vec::new());
}
return Err(err);
}
let mut results = Vec::with_capacity(received as usize);
for i in 0..(received as usize) {
let len = self.msgs[i].msg_len as usize;
let mut buffer = std::mem::replace(
&mut self.recv_buffers[i],
BytesMut::with_capacity(MAX_PACKET_SIZE),
);
buffer.truncate(len);
let addr = sockaddr_to_socket_addr(&self.addrs[i])?;
results.push((buffer.freeze(), addr));
}
Ok(results)
}
#[allow(dead_code)]
pub fn recv_batch_blocking(
&mut self,
max_count: usize,
) -> io::Result<Vec<(Bytes, SocketAddr)>> {
let count = max_count.min(MAX_BATCH_SIZE);
if count == 0 {
return Ok(Vec::new());
}
if self.recv_buffers.is_empty() {
return Err(io::Error::new(
io::ErrorKind::Unsupported,
"BatchedTransport constructed via `new_send_only` cannot \
receive packets — use `new` if recv is needed",
));
}
for i in 0..count {
self.recv_buffers[i].clear();
self.recv_buffers[i].reserve(MAX_PACKET_SIZE);
unsafe {
self.recv_buffers[i].set_len(MAX_PACKET_SIZE);
}
self.iovecs[i] = libc::iovec {
iov_base: self.recv_buffers[i].as_mut_ptr() as *mut _,
iov_len: MAX_PACKET_SIZE,
};
self.addrs[i] = unsafe { std::mem::zeroed() };
self.msgs[i].msg_hdr = unsafe { std::mem::zeroed() };
self.msgs[i].msg_hdr.msg_name = &mut self.addrs[i] as *mut _ as *mut _;
self.msgs[i].msg_hdr.msg_namelen = std::mem::size_of::<libc::sockaddr_in>() as u32;
self.msgs[i].msg_hdr.msg_iov = &mut self.iovecs[i];
self.msgs[i].msg_hdr.msg_iovlen = 1;
self.msgs[i].msg_len = 0;
}
let received = unsafe {
libc::recvmmsg(
self.socket_fd,
self.msgs.as_mut_ptr(),
count as u32,
0_i32 as _,
std::ptr::null_mut(),
)
};
if received < 0 {
return Err(io::Error::last_os_error());
}
let mut results = Vec::with_capacity(received as usize);
for i in 0..(received as usize) {
let len = self.msgs[i].msg_len as usize;
let mut buffer = std::mem::replace(
&mut self.recv_buffers[i],
BytesMut::with_capacity(MAX_PACKET_SIZE),
);
buffer.truncate(len);
let addr = sockaddr_to_socket_addr(&self.addrs[i])?;
results.push((buffer.freeze(), addr));
}
Ok(results)
}
}
impl std::fmt::Debug for BatchedTransport {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BatchedTransport")
.field("socket_fd", &self.socket_fd)
.field("max_batch_size", &MAX_BATCH_SIZE)
.finish()
}
}
fn sockaddr_to_socket_addr(addr: &libc::sockaddr_in) -> io::Result<SocketAddr> {
let ip = std::net::Ipv4Addr::from(u32::from_be(addr.sin_addr.s_addr));
let port = u16::from_be(addr.sin_port);
Ok(SocketAddr::new(ip.into(), port))
}
pub fn configure_socket_for_throughput(fd: RawFd) -> io::Result<()> {
unsafe {
let recv_buf: i32 = 64 * 1024 * 1024; let send_buf: i32 = 64 * 1024 * 1024;
if libc::setsockopt(
fd,
libc::SOL_SOCKET,
libc::SO_RCVBUF,
&recv_buf as *const _ as *const libc::c_void,
std::mem::size_of::<i32>() as u32,
) < 0
{
return Err(io::Error::last_os_error());
}
if libc::setsockopt(
fd,
libc::SOL_SOCKET,
libc::SO_SNDBUF,
&send_buf as *const _ as *const libc::c_void,
std::mem::size_of::<i32>() as u32,
) < 0
{
return Err(io::Error::last_os_error());
}
let busy_poll: i32 = 50; let _ = libc::setsockopt(
fd,
libc::SOL_SOCKET,
libc::SO_BUSY_POLL,
&busy_poll as *const _ as *const libc::c_void,
std::mem::size_of::<i32>() as u32,
);
let pmtu: i32 = libc::IP_PMTUDISC_DO;
let _ = libc::setsockopt(
fd,
libc::IPPROTO_IP,
libc::IP_MTU_DISCOVER,
&pmtu as *const _ as *const libc::c_void,
std::mem::size_of::<i32>() as u32,
);
}
Ok(())
}
#[allow(dead_code)]
pub fn enable_timestamps(fd: RawFd) -> io::Result<()> {
unsafe {
let enable: i32 = 1;
if libc::setsockopt(
fd,
libc::SOL_SOCKET,
libc::SO_TIMESTAMPNS,
&enable as *const _ as *const libc::c_void,
std::mem::size_of::<i32>() as u32,
) < 0
{
return Err(io::Error::last_os_error());
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::UdpSocket;
use std::os::unix::io::AsRawFd;
#[test]
fn batched_recv_must_use_set_len_not_resize_zero() {
let src = include_str!("linux.rs");
let src_no_comments: String = src
.lines()
.filter(|l| !l.trim_start().starts_with("//"))
.collect::<Vec<_>>()
.join("\n");
let bad_needle = format!("resize({}, 0)", "MAX_PACKET_SIZE");
assert!(
!src_no_comments.contains(&bad_needle),
"regression: recvmmsg batched recv must NOT pre-zero \
slot buffers per crypto-session perf #131A; pre-fix \
this memset ~512 KiB per batch call only for the \
kernel to overwrite the bytes immediately."
);
let good_needle = format!("{}({})", "set_len", "MAX_PACKET_SIZE");
assert!(
src_no_comments.contains(&good_needle),
"regression: batched recv setup must claim slot \
capacity via set_len so recvmmsg writes the kernel-\
supplied bytes without a pre-zero pass."
);
}
#[test]
fn test_batched_transport_creation() {
let socket = UdpSocket::bind("127.0.0.1:0").unwrap();
let fd = socket.as_raw_fd();
let transport = BatchedTransport::new(fd);
assert!(transport.iovecs.len() == MAX_BATCH_SIZE);
assert!(transport.msgs.len() == MAX_BATCH_SIZE);
}
#[test]
fn test_send_recv_batch() {
let socket1 = UdpSocket::bind("127.0.0.1:0").unwrap();
let socket2 = UdpSocket::bind("127.0.0.1:0").unwrap();
socket1.set_nonblocking(true).unwrap();
socket2.set_nonblocking(true).unwrap();
let addr1 = socket1.local_addr().unwrap();
let addr2 = socket2.local_addr().unwrap();
let mut transport1 = BatchedTransport::new(socket1.as_raw_fd());
let mut transport2 = BatchedTransport::new(socket2.as_raw_fd());
let packets = vec![
Bytes::from_static(b"packet1"),
Bytes::from_static(b"packet2"),
Bytes::from_static(b"packet3"),
];
let sent = transport2.send_batch(&packets, addr1).unwrap();
assert_eq!(sent, 3);
std::thread::sleep(std::time::Duration::from_millis(10));
let received = transport1.recv_batch(10).unwrap();
assert_eq!(received.len(), 3);
assert_eq!(&received[0].0[..], b"packet1");
assert_eq!(&received[1].0[..], b"packet2");
assert_eq!(&received[2].0[..], b"packet3");
for (_, source) in &received {
assert_eq!(*source, addr2);
}
}
#[test]
fn test_configure_socket() {
let socket = UdpSocket::bind("127.0.0.1:0").unwrap();
let fd = socket.as_raw_fd();
configure_socket_for_throughput(fd).unwrap();
}
#[test]
fn recv_batch_returns_unsupported_for_send_only_transport() {
let socket = UdpSocket::bind("127.0.0.1:0").unwrap();
let fd = socket.as_raw_fd();
let mut transport = BatchedTransport::new_send_only(fd);
let err = transport
.recv_batch(8)
.expect_err("send-only recv must surface Unsupported, not panic");
assert_eq!(err.kind(), io::ErrorKind::Unsupported);
let err_blocking = transport
.recv_batch_blocking(8)
.expect_err("send-only recv_batch_blocking must also surface Unsupported");
assert_eq!(err_blocking.kind(), io::ErrorKind::Unsupported);
let mut recv_transport = BatchedTransport::new(fd);
let zero = recv_transport.recv_batch(0).unwrap();
assert!(zero.is_empty());
}
#[test]
fn send_batch_with_empty_input_returns_ok_zero() {
let socket = UdpSocket::bind("127.0.0.1:0").unwrap();
let mut transport = BatchedTransport::new_send_only(socket.as_raw_fd());
let target: SocketAddr = "127.0.0.1:9999".parse().unwrap();
let sent = transport.send_batch(&[], target).unwrap();
assert_eq!(sent, 0, "empty input must short-circuit to Ok(0)");
}
#[test]
fn send_batch_rejects_ipv6_target_with_unsupported_kind() {
let socket = UdpSocket::bind("127.0.0.1:0").unwrap();
let mut transport = BatchedTransport::new_send_only(socket.as_raw_fd());
let target: SocketAddr = "[::1]:9999".parse().unwrap();
let packets = vec![Bytes::from_static(b"x")];
let err = transport
.send_batch(&packets, target)
.expect_err("IPv6 target must surface Unsupported");
assert_eq!(err.kind(), io::ErrorKind::Unsupported);
}
#[test]
fn send_batch_chunks_inputs_larger_than_max_batch_size() {
let socket1 = UdpSocket::bind("127.0.0.1:0").unwrap();
let socket2 = UdpSocket::bind("127.0.0.1:0").unwrap();
socket1.set_nonblocking(true).unwrap();
socket2.set_nonblocking(true).unwrap();
let addr1 = socket1.local_addr().unwrap();
let mut transport2 = BatchedTransport::new_send_only(socket2.as_raw_fd());
let total = MAX_BATCH_SIZE + 1;
let packets: Vec<Bytes> = (0..total)
.map(|i| Bytes::copy_from_slice(format!("chunk-{i:03}").as_bytes()))
.collect();
let sent = transport2
.send_batch(&packets, addr1)
.expect("chunked send_batch");
assert!(
sent > MAX_BATCH_SIZE,
"send_batch with {total} packets reported only {sent}; \
chunking past MAX_BATCH_SIZE = {MAX_BATCH_SIZE} did not run"
);
}
#[test]
fn send_batch_reuses_one_transport_across_targets() {
let recv_a = UdpSocket::bind("127.0.0.1:0").unwrap();
let recv_b = UdpSocket::bind("127.0.0.1:0").unwrap();
let send_sock = UdpSocket::bind("127.0.0.1:0").unwrap();
let timeout = std::time::Duration::from_secs(2);
recv_a.set_read_timeout(Some(timeout)).unwrap();
recv_b.set_read_timeout(Some(timeout)).unwrap();
let addr_a = recv_a.local_addr().unwrap();
let addr_b = recv_b.local_addr().unwrap();
let mut transport = BatchedTransport::new_send_only(send_sock.as_raw_fd());
assert_eq!(
transport
.send_batch(
&[Bytes::from_static(b"a1"), Bytes::from_static(b"a2")],
addr_a
)
.unwrap(),
2
);
assert_eq!(
transport
.send_batch(&[Bytes::from_static(b"b1")], addr_b)
.unwrap(),
1
);
assert_eq!(
transport
.send_batch(&[Bytes::from_static(b"a3")], addr_a)
.unwrap(),
1
);
let mut buf = [0u8; 16];
let recv = |s: &UdpSocket, b: &mut [u8]| {
let n = s.recv(b).expect("recv within timeout");
b[..n].to_vec()
};
assert_eq!(&recv(&recv_a, &mut buf)[..], &b"a1"[..]);
assert_eq!(&recv(&recv_a, &mut buf)[..], &b"a2"[..]);
assert_eq!(&recv(&recv_a, &mut buf)[..], &b"a3"[..]);
assert_eq!(&recv(&recv_b, &mut buf)[..], &b"b1"[..]);
}
#[test]
fn recv_batch_blocking_delivers_loopback_packets() {
let recv_sock = UdpSocket::bind("127.0.0.1:0").unwrap();
let send_sock = UdpSocket::bind("127.0.0.1:0").unwrap();
let recv_addr = recv_sock.local_addr().unwrap();
recv_sock
.set_read_timeout(Some(std::time::Duration::from_secs(2)))
.unwrap();
let mut transport = BatchedTransport::new(recv_sock.as_raw_fd());
for i in 0u8..3 {
send_sock
.send_to(&[0xCC, i, 0xDD], recv_addr)
.expect("send loopback");
}
let received = transport
.recv_batch_blocking(8)
.expect("recv_batch_blocking");
assert!(
!received.is_empty(),
"recv_batch_blocking returned 0 packets after 3 loopback sends"
);
}
#[test]
fn enable_timestamps_succeeds_on_fresh_socket() {
let socket = UdpSocket::bind("127.0.0.1:0").unwrap();
enable_timestamps(socket.as_raw_fd())
.expect("SO_TIMESTAMPNS must accept on a fresh DGRAM socket");
}
}