use crate::sys::{self, Fd};
use std::io;
use std::net::SocketAddr;
#[cfg(unix)]
use std::os::fd::AsRawFd;
#[cfg(all(unix, not(feature = "tokio")))]
use std::os::fd::IntoRawFd;
#[cfg(windows)]
use std::os::windows::io::AsRawSocket;
#[cfg(all(windows, not(feature = "tokio")))]
use std::os::windows::io::IntoRawSocket;
use crate::batch::{RecvBatchRaw, SendBatchRaw};
#[cfg(feature = "metrics")]
use crate::metrics::BingerMetrics;
use crate::sockaddr;
#[derive(Debug, Clone)]
pub struct Config {
pub(crate) batch_size: usize,
pub(crate) send_buf_size: Option<usize>,
pub(crate) recv_os_buf_size: Option<usize>,
pub(crate) adaptive_batching: bool,
#[cfg(feature = "metrics")]
pub(crate) metrics_enabled: bool,
}
impl Default for Config {
fn default() -> Self {
Self {
batch_size: 32,
send_buf_size: None,
recv_os_buf_size: None,
adaptive_batching: false,
#[cfg(feature = "metrics")]
metrics_enabled: false,
}
}
}
impl Config {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_batch_size(mut self, n: usize) -> Self {
self.batch_size = n;
self
}
#[must_use]
pub fn with_send_buf_size(mut self, n: usize) -> Self {
self.send_buf_size = Some(n);
self
}
#[must_use]
pub fn with_recv_os_buf_size(mut self, n: usize) -> Self {
self.recv_os_buf_size = Some(n);
self
}
#[must_use]
pub fn with_adaptive_batching(mut self, enabled: bool) -> Self {
self.adaptive_batching = enabled;
self
}
#[cfg(feature = "metrics")]
#[must_use]
pub fn with_metrics(mut self, enabled: bool) -> Self {
self.metrics_enabled = enabled;
self
}
}
#[allow(clippy::struct_excessive_bools)]
pub struct PlatformCaps {
pub supports_sendmmsg: bool,
pub supports_recvmmsg: bool,
#[cfg(target_os = "macos")]
pub supports_sendmsg_x: bool,
#[cfg(target_os = "macos")]
pub supports_recvmsg_x: bool,
#[cfg(target_os = "windows")]
pub supports_wsa_send_msg: bool,
#[cfg(target_os = "windows")]
pub supports_wsa_recv_msg: bool,
pub supports_gso: bool,
pub supports_gro: bool,
pub supports_busy_poll: bool,
pub supports_pacing: bool,
#[cfg(feature = "timestamping")]
pub supports_timestamping: bool,
#[cfg(feature = "pktinfo")]
pub supports_pktinfo: bool,
pub max_batch_size: usize,
pub backend_name: &'static str,
}
#[must_use]
pub fn platform_capabilities() -> PlatformCaps {
PlatformCaps {
supports_sendmmsg: cfg!(target_os = "linux"),
supports_recvmmsg: cfg!(target_os = "linux"),
#[cfg(target_os = "macos")]
supports_sendmsg_x: true,
#[cfg(target_os = "macos")]
supports_recvmsg_x: true,
#[cfg(target_os = "windows")]
supports_wsa_send_msg: true,
#[cfg(target_os = "windows")]
supports_wsa_recv_msg: true,
supports_gso: cfg!(all(target_os = "linux", feature = "gso")),
supports_gro: cfg!(all(target_os = "linux", feature = "gro")),
supports_busy_poll: cfg!(all(target_os = "linux", feature = "busy-poll")),
supports_pacing: cfg!(all(target_os = "linux", feature = "pacing")),
#[cfg(feature = "timestamping")]
supports_timestamping: cfg!(all(target_os = "linux", feature = "timestamping")),
#[cfg(feature = "pktinfo")]
supports_pktinfo: cfg!(all(target_os = "linux", feature = "pktinfo")),
max_batch_size: if cfg!(target_os = "linux") { 1024 } else { 32 },
backend_name: backends(),
}
}
const fn backends() -> &'static str {
#[cfg(target_os = "linux")]
{
"sendmmsg/recvmmsg (Linux)"
}
#[cfg(target_os = "macos")]
{
"sendmsg_x/recvmsg_x (macOS)"
}
#[cfg(target_os = "windows")]
{
"WSASendMsg/WSARecvMsg (Windows)"
}
#[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
{
"fallback (loop sendto/recvfrom)"
}
}
struct AdaptiveState {
target_size: usize,
would_block_count: u64,
total_send_count: u64,
last_adjustment: std::time::Instant,
}
impl AdaptiveState {
const MIN_BATCH: usize = 1;
const MAX_BATCH: usize = 1024;
const ADJUSTMENT_INTERVAL: std::time::Duration = std::time::Duration::from_millis(100);
fn new(initial: usize) -> Self {
Self {
target_size: initial.clamp(Self::MIN_BATCH, Self::MAX_BATCH),
would_block_count: 0,
total_send_count: 0,
last_adjustment: std::time::Instant::now(),
}
}
fn record_would_block(&mut self) {
self.would_block_count += 1;
self.total_send_count += 1;
}
fn record_event(&mut self) {
self.total_send_count += 1;
}
#[allow(clippy::cast_precision_loss)]
fn maybe_adjust(&mut self) {
if self.last_adjustment.elapsed() < Self::ADJUSTMENT_INTERVAL {
return;
}
if self.total_send_count == 0 {
return;
}
let wb_rate = self.would_block_count as f64 / self.total_send_count as f64;
if wb_rate > 0.3 {
self.target_size = (self.target_size / 2).max(Self::MIN_BATCH);
} else if wb_rate < 0.1 && self.target_size < Self::MAX_BATCH {
self.target_size = (self.target_size * 3 / 2).min(Self::MAX_BATCH);
}
self.would_block_count = 0;
self.total_send_count = 0;
self.last_adjustment = std::time::Instant::now();
}
fn recommended_size(&self) -> usize {
self.target_size
}
}
pub struct BingerUdp {
fd: Fd,
#[cfg(feature = "tokio")]
tokio_sock: tokio::net::UdpSocket,
#[cfg(feature = "metrics")]
metrics: Option<BingerMetrics>,
adaptive_send: Option<std::sync::Mutex<AdaptiveState>>,
adaptive_recv: Option<std::sync::Mutex<AdaptiveState>>,
}
impl BingerUdp {
fn raw_fd_std(socket: &std::net::UdpSocket) -> Fd {
#[cfg(unix)]
{
socket.as_raw_fd()
}
#[cfg(windows)]
{
socket.as_raw_socket() as Fd
}
}
#[cfg(not(feature = "tokio"))]
fn raw_fd_std_owned(socket: std::net::UdpSocket) -> Fd {
#[cfg(unix)]
{
socket.into_raw_fd()
}
#[cfg(windows)]
{
socket.into_raw_socket() as Fd
}
}
#[allow(clippy::needless_pass_by_value)]
pub fn from_std(socket: std::net::UdpSocket, config: Config) -> io::Result<Self> {
socket.set_nonblocking(true)?;
if let Some(size) = config.send_buf_size {
sockaddr::raw_setsockopt(
Self::raw_fd_std(&socket),
sys::SOL_SOCKET,
sys::SO_SNDBUF,
size as libc::c_int,
)?;
}
if let Some(size) = config.recv_os_buf_size {
sockaddr::raw_setsockopt(
Self::raw_fd_std(&socket),
sys::SOL_SOCKET,
sys::SO_RCVBUF,
size as libc::c_int,
)?;
}
#[cfg(feature = "tokio")]
let fd = Self::raw_fd_std(&socket);
#[cfg(feature = "tokio")]
let tokio_sock = tokio::net::UdpSocket::from_std(socket)?;
#[cfg(not(feature = "tokio"))]
let fd = Self::raw_fd_std_owned(socket);
Ok(Self {
fd,
#[cfg(feature = "tokio")]
tokio_sock,
#[cfg(feature = "metrics")]
metrics: if config.metrics_enabled {
Some(BingerMetrics::default())
} else {
None
},
adaptive_send: if config.adaptive_batching {
Some(std::sync::Mutex::new(AdaptiveState::new(config.batch_size)))
} else {
None
},
adaptive_recv: if config.adaptive_batching {
Some(std::sync::Mutex::new(AdaptiveState::new(config.batch_size)))
} else {
None
},
})
}
pub async fn send_batch(&self, batch: &mut SendBatchRaw) -> io::Result<usize> {
loop {
match self.try_send_batch(batch) {
Ok(n) => return Ok(n),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
if let Some(ref state) = self.adaptive_send {
if let Ok(mut s) = state.lock() {
s.record_would_block();
s.maybe_adjust();
}
}
self.wait_writable().await?;
}
Err(e) => return Err(e),
}
}
}
pub async fn recv_batch(&self, batch: &mut RecvBatchRaw) -> io::Result<usize> {
loop {
match self.try_recv_batch(batch) {
Ok(0) => {
if let Some(ref state) = self.adaptive_recv {
if let Ok(mut s) = state.lock() {
s.record_would_block();
s.maybe_adjust();
}
}
self.wait_readable().await?;
}
Ok(n) => return Ok(n),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
if let Some(ref state) = self.adaptive_recv {
if let Ok(mut s) = state.lock() {
s.record_would_block();
s.maybe_adjust();
}
}
self.wait_readable().await?;
}
Err(e) => return Err(e),
}
}
}
pub fn try_send_batch(&self, batch: &mut SendBatchRaw) -> io::Result<usize> {
#[cfg(not(feature = "metrics"))]
let sent = crate::platform::try_send_batch(self.fd, batch)?;
#[cfg(feature = "metrics")]
let sent = crate::platform::try_send_batch(self.fd, batch).map_err(|e| {
if let Some(ref m) = self.metrics {
m.inc_send_errors();
if e.kind() == io::ErrorKind::WouldBlock {
m.inc_send_would_block();
}
}
e
})?;
if let Some(ref state) = self.adaptive_send {
if let Ok(mut s) = state.lock() {
s.record_event();
}
}
#[cfg(feature = "metrics")]
if let Some(ref m) = self.metrics {
m.inc_packets_sent(sent as u64);
m.inc_batches_sent();
m.inc_send_syscalls();
}
Ok(sent)
}
pub fn try_recv_batch(&self, batch: &mut RecvBatchRaw) -> io::Result<usize> {
#[cfg(not(feature = "metrics"))]
let received = crate::platform::try_recv_batch(self.fd, batch)?;
#[cfg(feature = "metrics")]
let received = crate::platform::try_recv_batch(self.fd, batch).map_err(|e| {
if let Some(ref m) = self.metrics {
m.inc_recv_errors();
if e.kind() == io::ErrorKind::WouldBlock {
m.inc_recv_would_block();
}
}
e
})?;
if let Some(ref state) = self.adaptive_recv {
if let Ok(mut s) = state.lock() {
s.record_event();
}
}
#[cfg(feature = "metrics")]
if let Some(ref m) = self.metrics {
m.inc_packets_received(received as u64);
m.inc_batches_received();
m.inc_recv_syscalls();
}
Ok(received)
}
pub async fn send_to(&self, buf: &[u8], addr: SocketAddr) -> io::Result<usize> {
let mut batch = SendBatchRaw::with_capacity(1);
batch.push(buf, Some(addr)).expect("capacity 1");
self.send_batch(&mut batch).await?;
Ok(buf.len())
}
pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
let mut batch = RecvBatchRaw::with_capacity(1, buf.len());
self.recv_batch(&mut batch).await?;
let data = batch.data(0);
let addr = batch.addr(0);
let len = data.len().min(buf.len());
buf[..len].copy_from_slice(&data[..len]);
Ok((len, addr))
}
pub fn try_send_to(&self, buf: &[u8], addr: SocketAddr) -> io::Result<usize> {
let mut batch = SendBatchRaw::with_capacity(1);
batch.push(buf, Some(addr)).expect("capacity 1");
self.try_send_batch(&mut batch)?;
Ok(buf.len())
}
pub fn connect(&self, addr: SocketAddr) -> io::Result<()> {
sockaddr::raw_connect(self.fd, addr)
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
sockaddr::raw_getsockname(self.fd)
}
pub fn ttl(&self) -> io::Result<u32> {
sockaddr::raw_getsockopt(self.fd, sys::IPPROTO_IP, sys::IP_TTL).map(|v| v as u32)
}
pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
sockaddr::raw_setsockopt(self.fd, sys::IPPROTO_IP, sys::IP_TTL, ttl as libc::c_int)
}
#[must_use]
pub fn as_raw_fd(&self) -> Fd {
self.fd
}
#[cfg(all(target_os = "linux", feature = "gso"))]
pub fn set_gso(&self, enabled: bool) -> io::Result<()> {
sockaddr::raw_setsockopt(
self.fd,
libc::IPPROTO_UDP,
libc::UDP_SEGMENT,
i32::from(enabled),
)
}
#[cfg(all(target_os = "linux", feature = "gro"))]
pub fn set_gro(&self, enabled: bool) -> io::Result<()> {
sockaddr::raw_setsockopt(
self.fd,
libc::IPPROTO_UDP,
libc::UDP_GRO,
i32::from(enabled),
)
}
#[cfg(all(target_os = "linux", feature = "gso", not(feature = "miri-safe")))]
pub fn try_send_gso(&self, data: &[u8], segment_size: u16) -> io::Result<usize> {
crate::platform::try_send_gso(self.fd, data, segment_size)
}
#[cfg(all(target_os = "linux", feature = "gso", not(feature = "miri-safe")))]
pub async fn send_gso(&self, data: &[u8], segment_size: u16) -> io::Result<usize> {
loop {
match self.try_send_gso(data, segment_size) {
Ok(n) => return Ok(n),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
self.wait_writable().await?;
}
Err(e) => return Err(e),
}
}
}
#[cfg(all(target_os = "linux", feature = "pacing"))]
pub fn set_pacing_rate(&self, bytes_per_sec: u32) -> io::Result<()> {
sockaddr::raw_setsockopt_u32(
self.fd,
libc::SOL_SOCKET,
crate::sys::SO_MAX_PACING_RATE,
bytes_per_sec,
)
}
#[cfg(all(target_os = "linux", feature = "busy-poll"))]
pub fn set_busy_poll(&self, usecs: u32) -> io::Result<()> {
sockaddr::raw_setsockopt(
self.fd,
libc::SOL_SOCKET,
crate::sys::SO_BUSY_POLL,
usecs as libc::c_int,
)
}
#[cfg(all(target_os = "linux", feature = "timestamping"))]
pub fn enable_timestamping(&self, enabled: bool) -> io::Result<()> {
sockaddr::raw_setsockopt(
self.fd,
libc::SOL_SOCKET,
crate::sys::SO_TIMESTAMPNS,
i32::from(enabled),
)
}
#[cfg(all(target_os = "linux", feature = "pktinfo"))]
pub fn enable_pktinfo(&self, enabled: bool) -> io::Result<()> {
sockaddr::raw_setsockopt(
self.fd,
libc::IPPROTO_IP,
libc::IP_PKTINFO,
i32::from(enabled),
)?;
sockaddr::raw_setsockopt(
self.fd,
libc::IPPROTO_IPV6,
libc::IPV6_RECVPKTINFO,
i32::from(enabled),
)
}
#[must_use]
pub fn capabilities(&self) -> PlatformCaps {
platform_capabilities()
}
#[must_use]
pub fn recommended_batch_size(&self) -> usize {
self.adaptive_send
.as_ref()
.and_then(|s| s.lock().ok())
.map_or(32, |g| g.recommended_size())
}
#[cfg(feature = "metrics")]
#[must_use]
pub fn metrics(&self) -> Option<&BingerMetrics> {
self.metrics.as_ref()
}
#[cfg(feature = "tokio")]
pub async fn readable(&self) -> io::Result<()> {
self.tokio_sock.readable().await
}
#[cfg(feature = "tokio")]
pub async fn writable(&self) -> io::Result<()> {
self.tokio_sock.writable().await
}
#[cfg(feature = "tokio")]
async fn wait_writable(&self) -> io::Result<()> {
self.tokio_sock.writable().await?;
Ok(())
}
#[cfg(feature = "tokio")]
async fn wait_readable(&self) -> io::Result<()> {
self.tokio_sock.readable().await?;
Ok(())
}
#[cfg(not(feature = "tokio"))]
async fn wait_writable(&self) -> io::Result<()> {
Err(io::Error::new(
io::ErrorKind::Other,
"tokio feature disabled",
))
}
#[cfg(not(feature = "tokio"))]
async fn wait_readable(&self) -> io::Result<()> {
Err(io::Error::new(
io::ErrorKind::Other,
"tokio feature disabled",
))
}
}
#[cfg(not(feature = "tokio"))]
impl Drop for BingerUdp {
fn drop(&mut self) {
sys::close_fd(self.fd);
}
}
unsafe impl Send for BingerUdp {}
unsafe impl Sync for BingerUdp {}