use crate::{
libc::{self, InternalXdpFlags, socket, xdp},
rings,
};
use std::{fmt, io::Error, os::fd::AsRawFd as _};
#[derive(Debug)]
pub enum SocketError {
SocketCreation(Error),
SetSockOpt {
inner: Error,
option: OptName,
},
GetSockOpt {
inner: Error,
option: OptName,
},
RingMap {
inner: Error,
ring: rings::Ring,
},
Bind(Error),
}
impl std::error::Error for SocketError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(match self {
Self::SocketCreation(e) | Self::Bind(e) => e,
Self::SetSockOpt { inner, .. }
| Self::GetSockOpt { inner, .. }
| Self::RingMap { inner, .. } => inner,
})
}
}
impl fmt::Display for SocketError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{self:?}")
}
}
pub struct XdpSocketBuilder {
sock: std::os::fd::OwnedFd,
}
#[derive(Copy, Clone, Debug)]
#[repr(i32)]
pub enum OptName {
UmemRegion = libc::xdp::SockOpts::XDP_UMEM_REG,
UmemFillRing = libc::xdp::SockOpts::XDP_UMEM_FILL_RING,
UmemCompletionRing = libc::xdp::SockOpts::XDP_UMEM_COMPLETION_RING,
RxRing = libc::xdp::SockOpts::XDP_RX_RING,
TxRing = libc::xdp::SockOpts::XDP_TX_RING,
XdpMmapOffsets = libc::xdp::SockOpts::XDP_MMAP_OFFSETS,
}
#[derive(Copy, Clone)]
pub struct BindFlags(xdp::BindFlags::Enum);
impl BindFlags {
fn new() -> Self {
Self(0)
}
#[inline]
pub fn force_zerocopy(&mut self) {
self.0 |= xdp::BindFlags::XDP_ZEROCOPY;
self.0 &= !xdp::BindFlags::XDP_COPY;
}
#[inline]
pub fn force_copy(&mut self) {
self.0 |= xdp::BindFlags::XDP_COPY;
self.0 &= !xdp::BindFlags::XDP_ZEROCOPY;
}
#[inline]
fn needs_wakeup(&mut self) {
self.0 |= xdp::BindFlags::XDP_USE_NEED_WAKEUP;
}
}
impl XdpSocketBuilder {
pub fn new() -> Result<Self, SocketError> {
use std::os::fd::FromRawFd;
let socket = socket::socket(
socket::AddressFamily::AF_XDP,
socket::Kind::SOCK_RAW | socket::Kind::SOCK_CLOEXEC,
socket::Protocol::NONE,
);
if socket < 0 {
return Err(SocketError::SocketCreation(Error::last_os_error()));
}
Ok(Self {
sock: unsafe { std::os::fd::OwnedFd::from_raw_fd(socket) },
})
}
pub fn build_rings(
&mut self,
umem: &crate::Umem,
cfg: rings::RingConfig,
) -> Result<(rings::Rings, BindFlags), SocketError> {
let offsets = self.build_rings_inner(umem, &cfg)?;
let socket = self.sock.as_raw_fd();
let fill_ring = rings::FillRing::new(socket, &cfg, &offsets)?;
let rx_ring = if cfg.rx_count > 0 {
Some(rings::RxRing::new(socket, &cfg, &offsets)?)
} else {
None
};
let completion_ring = rings::CompletionRing::new(socket, &cfg, &offsets)?;
let tx_ring = if cfg.tx_count > 0 {
Some(rings::TxRing::new(socket, &cfg, &offsets)?)
} else {
None
};
Ok((
rings::Rings {
fill_ring,
rx_ring,
completion_ring,
tx_ring,
},
BindFlags::new(),
))
}
pub fn build_wakable_rings(
&mut self,
umem: &crate::Umem,
cfg: rings::RingConfig,
) -> Result<(rings::WakableRings, BindFlags), SocketError> {
let offsets = self.build_rings_inner(umem, &cfg)?;
let socket = self.sock.as_raw_fd();
let fill_ring = rings::WakableFillRing::new(socket, &cfg, &offsets)?;
let rx_ring = if cfg.rx_count > 0 {
Some(rings::RxRing::new(socket, &cfg, &offsets)?)
} else {
None
};
let completion_ring = rings::CompletionRing::new(socket, &cfg, &offsets)?;
let tx_ring = if cfg.tx_count > 0 {
Some(rings::WakableTxRing::new(socket, &cfg, &offsets)?)
} else {
None
};
let mut bflags = BindFlags::new();
bflags.needs_wakeup();
Ok((
rings::WakableRings {
fill_ring,
rx_ring,
completion_ring,
tx_ring,
},
bflags,
))
}
fn build_rings_inner(
&mut self,
umem: &crate::Umem,
cfg: &rings::RingConfig,
) -> Result<libc::rings::xdp_mmap_offsets, SocketError> {
let mut flags = 0;
let chunk_size = umem.frame_size as u32 + xdp::XDP_PACKET_HEADROOM as u32;
if !chunk_size.is_power_of_two() {
flags |= xdp::UmemFlags::XDP_UMEM_UNALIGNED_CHUNK_FLAG;
}
if umem.options != 0 {
flags |= xdp::UmemFlags::XDP_UMEM_TX_METADATA_LEN;
if umem.options & InternalXdpFlags::USE_SOFTWARE_OFFLOAD != 0 {
flags |= xdp::UmemFlags::XDP_UMEM_TX_SW_CSUM;
}
}
let umem_reg = xdp::XdpUmemReg {
addr: umem.mmap.ptr as _,
len: umem.mmap.len() as _,
chunk_size,
headroom: umem.head_room as _,
flags,
tx_metadata_len: if umem.options != 0 {
std::mem::size_of::<libc::xdp::xsk_tx_metadata>() as _
} else {
0
},
};
self.set_sockopt(OptName::UmemRegion, &umem_reg)?;
self.set_sockopt(OptName::UmemFillRing, &cfg.fill_count)?;
self.set_sockopt(OptName::UmemCompletionRing, &cfg.completion_count)?;
if cfg.rx_count > 0 {
self.set_sockopt(OptName::RxRing, &cfg.rx_count)?;
}
if cfg.tx_count > 0 {
self.set_sockopt(OptName::TxRing, &cfg.tx_count)?;
}
let mut offsets = unsafe { std::mem::zeroed::<libc::rings::xdp_mmap_offsets>() };
let expected_size = std::mem::size_of_val(&offsets) as u32;
let mut size = expected_size;
let socket = self.sock.as_raw_fd();
if unsafe {
libc::socket::getsockopt(
socket,
libc::socket::Level::SOL_XDP,
OptName::XdpMmapOffsets as _,
(&mut offsets as *mut libc::rings::xdp_mmap_offsets).cast(),
&mut size,
)
} != 0
{
return Err(SocketError::GetSockOpt {
inner: std::io::Error::last_os_error(),
option: OptName::XdpMmapOffsets,
});
}
if size != expected_size {
return Err(SocketError::GetSockOpt {
inner: std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("expected size {expected_size} but size returned was {size}"),
),
option: OptName::XdpMmapOffsets,
});
}
Ok(offsets)
}
pub fn bind(
self,
interface_index: crate::nic::NicIndex,
queue_id: u32,
bind_flags: BindFlags,
) -> Result<XdpSocket, SocketError> {
let xdp_sockaddr = xdp::sockaddr_xdp {
sxdp_family: socket::AddressFamily::AF_XDP as _,
sxdp_flags: bind_flags.0,
sxdp_ifindex: interface_index.0,
sxdp_queue_id: queue_id,
sxdp_shared_umem_fd: 0,
};
if unsafe {
socket::bind(
self.sock.as_raw_fd(),
(&xdp_sockaddr as *const xdp::sockaddr_xdp).cast(),
std::mem::size_of_val(&xdp_sockaddr) as _,
)
} != 0
{
return Err(SocketError::Bind(std::io::Error::last_os_error()));
}
Ok(XdpSocket { sock: self.sock })
}
#[inline]
fn set_sockopt<T>(&mut self, name: OptName, val: &T) -> Result<(), SocketError> {
if unsafe {
libc::socket::setsockopt(
self.sock.as_raw_fd(),
socket::Level::SOL_XDP,
name as i32,
(val as *const T).cast(),
std::mem::size_of_val(val) as _,
)
} != 0
{
return Err(SocketError::SetSockOpt {
inner: std::io::Error::last_os_error(),
option: name,
});
}
Ok(())
}
}
impl std::os::fd::AsRawFd for XdpSocketBuilder {
fn as_raw_fd(&self) -> std::os::fd::RawFd {
self.sock.as_raw_fd()
}
}
pub struct XdpSocket {
sock: std::os::fd::OwnedFd,
}
#[derive(Copy, Clone)]
pub struct PollTimeout(i32);
impl PollTimeout {
pub const fn new(duration: Option<std::time::Duration>) -> Self {
let ms = if let Some(dur) = duration {
let ms = dur.as_millis();
if ms > i32::MAX as _ {
panic!("timeout cannot exceed i32::MAX milliseconds");
}
ms as i32
} else {
-1
};
Self(ms)
}
}
impl XdpSocket {
#[inline]
pub fn poll(&self, timeout: PollTimeout) -> std::io::Result<bool> {
self.poll_inner(
socket::PollEvents::POLLIN | socket::PollEvents::POLLOUT,
timeout,
)
}
#[inline]
pub fn poll_read(&self, timeout: PollTimeout) -> std::io::Result<bool> {
self.poll_inner(socket::PollEvents::POLLIN, timeout)
}
#[inline]
pub fn poll_write(&self, timeout: PollTimeout) -> std::io::Result<bool> {
self.poll_inner(socket::PollEvents::POLLOUT, timeout)
}
#[inline]
fn poll_inner(&self, events: i16, timeout: PollTimeout) -> std::io::Result<bool> {
let ret = unsafe {
socket::poll(
&mut socket::pollfd {
fd: self.sock.as_raw_fd(),
events,
revents: 0,
},
1,
timeout.0,
)
};
if ret < 0 {
let err = std::io::Error::last_os_error();
if err.kind() == std::io::ErrorKind::Interrupted {
Ok(false)
} else {
Err(err)
}
} else {
Ok(ret != 0)
}
}
#[inline]
pub fn raw_fd(&self) -> std::os::fd::RawFd {
self.sock.as_raw_fd()
}
}
impl std::os::fd::AsRawFd for XdpSocket {
fn as_raw_fd(&self) -> std::os::fd::RawFd {
self.sock.as_raw_fd()
}
}