use nix::sys::socket::{MsgFlags, MultiHeaders, SockaddrStorage};
#[cfg(target_os = "linux")]
use nix::sys::socket::{setsockopt, sockopt};
use std::{
io::{self, IoSlice},
net::SocketAddr,
os::fd::AsRawFd,
};
use tokio::io::Interest;
use crate::{
packet::Packet,
udp::{UdpSend, check_send_max_number_of_packets, socket::UdpSocket},
};
const MAX_PACKET_COUNT: usize = 100;
#[derive(Default)]
pub struct SendmmsgBuf {
targets: Vec<Option<SockaddrStorage>>,
}
impl UdpSend for super::UdpSocket {
type SendManyBuf = SendmmsgBuf;
async fn send_to(&self, packet: Packet, target: SocketAddr) -> io::Result<()> {
tokio::net::UdpSocket::send_to(&self.inner, &packet, target).await?;
Ok(())
}
async fn send_many_to(
&self,
buf: &mut SendmmsgBuf,
packets: &mut Vec<(Packet, SocketAddr)>,
) -> io::Result<()> {
check_send_max_number_of_packets(MAX_PACKET_COUNT, packets)?;
let fd = self.inner.as_raw_fd();
buf.targets.clear();
let mut packets_buf = [[IoSlice::new(&[])]; MAX_PACKET_COUNT];
for ((packet, target), packets_buf) in packets.iter().zip(&mut packets_buf) {
buf.targets.push(Some(SockaddrStorage::from(*target)));
*packets_buf = [IoSlice::new(&packet[..])];
}
let len = buf.targets.len();
let pkts = &packets_buf[..len];
let mut packet_buf_start = 0;
while packet_buf_start < len {
let result = self
.inner
.async_io(Interest::WRITABLE, || {
let mut multiheaders =
MultiHeaders::preallocate(pkts[packet_buf_start..].len(), None);
let multiresult = nix::sys::socket::sendmmsg(
fd,
&mut multiheaders,
&pkts[packet_buf_start..],
&buf.targets[packet_buf_start..],
[],
MsgFlags::MSG_DONTWAIT,
)?;
let n = multiresult.count();
Ok(n)
})
.await;
let n = result?;
packet_buf_start += n;
}
debug_assert!(packet_buf_start == len, "all packets should be sent");
packets.clear();
Ok(())
}
fn max_number_of_packets_to_send(&self) -> usize {
MAX_PACKET_COUNT
}
fn local_addr(&self) -> io::Result<Option<SocketAddr>> {
UdpSocket::local_addr(self).map(Some)
}
#[cfg(target_os = "linux")]
fn set_fwmark(&self, mark: u32) -> io::Result<()> {
setsockopt(&self.inner, sockopt::Mark, &mark)?;
Ok(())
}
}
#[cfg(target_os = "linux")]
mod gro {
const MAX_SEGMENTS: usize = 64;
const MAX_GRO_SIZE: usize = MAX_SEGMENTS * 4096;
use super::MAX_PACKET_COUNT;
use crate::packet::{Packet, PacketBufPool};
use crate::udp::{UdpRecv, socket::UdpSocket};
use bytes::BytesMut;
use nix::cmsg_space;
use nix::sys::socket::{ControlMessageOwned, MsgFlags, MultiHeaders, SockaddrStorage};
use std::io::{self, IoSliceMut};
use std::net::SocketAddr;
use std::os::fd::AsRawFd;
use tokio::io::Interest;
pub struct RecvManyBuf {
pub(crate) gro_bufs: Box<[BytesMut; MAX_PACKET_COUNT]>,
}
unsafe impl Send for RecvManyBuf {}
impl Default for RecvManyBuf {
fn default() -> Self {
let mut gro_buf = BytesMut::zeroed(MAX_PACKET_COUNT * MAX_GRO_SIZE);
let gro_bufs = [(); MAX_PACKET_COUNT];
let gro_bufs = gro_bufs.map(|()| gro_buf.split_to(MAX_GRO_SIZE));
let gro_bufs = Box::new(gro_bufs);
Self { gro_bufs }
}
}
impl UdpRecv for UdpSocket {
type RecvManyBuf = RecvManyBuf;
async fn recv_from(
&mut self,
pool: &mut PacketBufPool,
) -> io::Result<(Packet, SocketAddr)> {
let mut buf = pool.get();
let (n, src) = self.inner.recv_from(&mut buf).await?;
buf.truncate(n);
Ok((buf, src))
}
async fn recv_many_from(
&mut self,
recv_many_bufs: &mut Self::RecvManyBuf,
pool: &mut PacketBufPool,
packets: &mut Vec<(Packet, SocketAddr)>,
) -> io::Result<()> {
let fd = self.inner.as_raw_fd();
self.inner
.async_io(Interest::READABLE, move || {
let headers = &mut MultiHeaders::<SockaddrStorage>::preallocate(
MAX_PACKET_COUNT,
Some(cmsg_space!(i32)),
);
let mut io_slices: [[IoSliceMut; 1]; MAX_PACKET_COUNT] =
std::array::from_fn(|_| [IoSliceMut::new(&mut [])]);
for (i, buf) in recv_many_bufs.gro_bufs.iter_mut().enumerate() {
io_slices[i] = [IoSliceMut::new(&mut buf[..])];
}
let results = nix::sys::socket::recvmmsg(
fd,
headers,
&mut io_slices[..MAX_PACKET_COUNT],
MsgFlags::MSG_DONTWAIT,
None,
)?;
for result in results {
let iov = result.iovs().next().expect("we create exactly one IoSlice");
let Some(source_addr) = result.address.as_ref().and_then(|addr| {
addr.as_sockaddr_in()
.map(|addr| (*addr).into())
.or_else(|| addr.as_sockaddr_in6().map(|addr| (*addr).into()))
}) else {
if cfg!(debug_assertions) {
log::debug!("recvmmsg returned packet without source");
}
continue;
};
let mut cmsgs = result.cmsgs().expect("we have allocated enough memory");
let gro_size = cmsgs
.find_map(|cmsg| match cmsg {
ControlMessageOwned::UdpGroSegments(gro_size) => Some(gro_size),
_ => None,
})
.and_then(|gro_size| usize::try_from(gro_size).ok())
.filter(|&gro_size| gro_size > 0);
if let Some(gro_size) = gro_size {
for gro_segment in iov.chunks(gro_size) {
let mut buf = pool.get();
buf[..gro_segment.len()].copy_from_slice(gro_segment);
buf.truncate(gro_segment.len());
packets.push((buf, source_addr));
}
} else {
let size = result.bytes;
let mut buf = pool.get();
buf[..size].copy_from_slice(&iov[..size]);
buf.truncate(size);
packets.push((buf, source_addr));
}
}
Ok(())
})
.await?;
Ok(())
}
fn enable_udp_gro(&self) -> io::Result<()> {
use std::os::fd::AsFd;
nix::sys::socket::setsockopt(
&self.inner.as_fd(),
nix::sys::socket::sockopt::UdpGroSegment,
&true,
)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::UdpSocket;
use crate::packet::PacketBufPool;
use crate::udp::{UdpRecv, UdpSend};
use std::net::Ipv6Addr;
use std::time::Duration;
#[tokio::test]
async fn recv_many_from_preserves_ipv6_source_addr() {
let mut receiver = UdpSocket::bind((Ipv6Addr::LOCALHOST, 0).into()).unwrap();
receiver.enable_udp_gro().ok();
let sender = UdpSocket::bind((Ipv6Addr::LOCALHOST, 0).into()).unwrap();
let recv_addr = receiver.local_addr().unwrap();
let sender_addr = sender.local_addr().unwrap();
let send_pool = PacketBufPool::<4096>::new(1);
let mut packet = send_pool.get();
packet.truncate(5);
packet.copy_from_slice(b"hello");
sender.send_to(packet, recv_addr).await.unwrap();
let mut recv_pool = PacketBufPool::<4096>::new(1);
let mut recv_many_buf = <UdpSocket as UdpRecv>::RecvManyBuf::default();
let mut packets = vec![];
tokio::time::timeout(
Duration::from_secs(1),
receiver.recv_many_from(&mut recv_many_buf, &mut recv_pool, &mut packets),
)
.await
.unwrap()
.unwrap();
assert!(!packets.is_empty(), "expected at least one IPv6 packet");
assert_eq!(packets[0].1, sender_addr);
}
}
}
#[cfg(target_os = "android")]
mod android {
use crate::packet::{Packet, PacketBufPool};
use crate::udp::UdpRecv;
use std::io;
use std::net::SocketAddr;
impl UdpRecv for super::UdpSocket {
type RecvManyBuf = ();
async fn recv_from(
&mut self,
pool: &mut PacketBufPool,
) -> io::Result<(Packet, SocketAddr)> {
let mut buf = pool.get();
let (n, src) = self.inner.recv_from(&mut buf).await?;
buf.truncate(n);
Ok((buf, src))
}
}
}