mod fd;
pub use fd::{Fd, XdpStatistics};
mod rx_queue;
pub use rx_queue::RxQueue;
mod tx_queue;
pub use tx_queue::TxQueue;
use libbpf_sys::xsk_socket;
use std::{
borrow::Borrow,
error::Error,
fmt, io,
ptr::{self, NonNull},
sync::{Arc, Mutex},
};
use crate::{
config::{Interface, SocketConfig},
ring::{XskRingCons, XskRingProd},
umem::{CompQueue, FillQueue, Umem},
};
#[derive(Debug)]
struct XskSocket(NonNull<xsk_socket>);
impl XskSocket {
unsafe fn new(ptr: NonNull<xsk_socket>) -> Self {
Self(ptr)
}
}
impl Drop for XskSocket {
fn drop(&mut self) {
unsafe {
libbpf_sys::xsk_socket__delete(self.0.as_mut());
}
}
}
unsafe impl Send for XskSocket {}
#[derive(Debug)]
struct SocketInner {
_ptr: XskSocket,
_umem: Umem,
}
impl SocketInner {
fn new(ptr: XskSocket, umem: Umem) -> Self {
Self {
_ptr: ptr,
_umem: umem,
}
}
}
#[derive(Debug)]
pub struct Socket {
fd: Fd,
_inner: Arc<Mutex<SocketInner>>,
}
impl Socket {
#[allow(clippy::new_ret_no_self)]
#[allow(clippy::type_complexity)]
pub fn new(
config: SocketConfig,
umem: &Umem,
if_name: &Interface,
queue_id: u32,
) -> Result<(TxQueue, RxQueue, Option<(FillQueue, CompQueue)>), SocketCreateError> {
let mut socket_ptr = ptr::null_mut();
let mut tx_q = XskRingProd::default();
let mut rx_q = XskRingCons::default();
let (err, fq, cq) = unsafe {
umem.with_ptr_and_saved_queues(|xsk_umem, saved_fq_and_cq| {
let (mut fq, mut cq) = saved_fq_and_cq
.take()
.unwrap_or_else(|| (Box::default(), Box::default()));
let err = libbpf_sys::xsk_socket__create_shared(
&mut socket_ptr,
if_name.as_cstr().as_ptr(),
queue_id,
xsk_umem,
rx_q.as_mut(),
tx_q.as_mut(),
fq.as_mut().as_mut(), cq.as_mut().as_mut(),
&config.into(),
);
(err, fq, cq)
})
};
if err != 0 {
return Err(SocketCreateError {
reason: "non-zero error code returned when creating AF_XDP socket",
err: io::Error::from_raw_os_error(-err),
});
}
let socket_ptr = match NonNull::new(socket_ptr) {
Some(init_xsk) => {
unsafe { XskSocket::new(init_xsk) }
}
None => {
return Err(SocketCreateError {
reason: "returned socket pointer was null",
err: io::Error::from_raw_os_error(-err),
});
}
};
let fd = unsafe { libbpf_sys::xsk_socket__fd(socket_ptr.0.as_ref()) };
if fd < 0 {
return Err(SocketCreateError {
reason: "failed to retrieve AF_XDP socket file descriptor",
err: io::Error::from_raw_os_error(-fd),
});
}
let socket = Socket {
fd: Fd::new(fd),
_inner: Arc::new(Mutex::new(SocketInner::new(socket_ptr, umem.clone()))),
};
let tx_q = if tx_q.is_ring_null() {
return Err(SocketCreateError {
reason: "returned tx queue ring is null",
err: io::Error::from_raw_os_error(-err),
});
} else {
TxQueue::new(tx_q, socket.clone())
};
let rx_q = if rx_q.is_ring_null() {
return Err(SocketCreateError {
reason: "returned rx queue ring is null",
err: io::Error::from_raw_os_error(-err),
});
} else {
RxQueue::new(rx_q, socket)
};
let fq_and_cq = match (fq.is_ring_null(), cq.is_ring_null()) {
(true, true) => None,
(false, false) => {
let fq = FillQueue::new(*fq, umem.clone());
let cq = CompQueue::new(*cq, umem.clone());
Some((fq, cq))
}
_ => {
return Err(SocketCreateError {
reason: "fill queue xor comp queue ring is null, either both or neither should be non-null",
err: io::Error::from_raw_os_error(-err),
});
}
};
Ok((tx_q, rx_q, fq_and_cq))
}
}
impl Clone for Socket {
fn clone(&self) -> Self {
Self {
fd: self.fd.clone(),
_inner: self._inner.clone(),
}
}
}
#[derive(Debug)]
pub struct SocketCreateError {
reason: &'static str,
err: io::Error,
}
impl fmt::Display for SocketCreateError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.reason)
}
}
impl Error for SocketCreateError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
Some(self.err.borrow())
}
}