use errno::errno;
use libbpf_sys::{xsk_ring_cons, xsk_ring_prod, xsk_umem, xsk_umem_config, XDP_PACKET_HEADROOM};
use std::sync::Arc;
use std::{convert::TryInto, error::Error, fmt, io, marker::PhantomData, mem::MaybeUninit, ptr};
use crate::socket::{self, Fd};
use super::{config::UmemConfig, mmap::MmapArea};
#[derive(Debug, Clone, PartialEq)]
pub enum FrameStatus {
Free,
OnTxQueue,
OnRxQueue,
}
impl FrameStatus {
pub fn is_free(&self) -> bool {
match self {
FrameStatus::Free => true,
_ => false,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct Frame<'umem> {
addr: usize,
len: usize,
options: u32,
mtu: usize,
_marker: PhantomData<&'umem ()>,
mmap_area: Arc<MmapArea>,
pub status: FrameStatus,
}
impl Frame<'_> {
#[inline]
pub fn addr(&self) -> usize {
self.addr
}
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn options(&self) -> u32 {
self.options
}
#[inline]
pub fn set_addr(&mut self, addr: usize) {
self.addr = addr
}
#[inline]
pub fn set_len(&mut self, len: usize) {
self.len = len
}
#[inline]
pub fn set_options(&mut self, options: u32) {
self.options = options
}
#[inline]
pub fn is_access_valid(&self, len: usize) -> Result<(), AccessError> {
if len > self.len {
return Err(AccessError::CrossesFrameBoundary {
addr: self.addr,
len,
});
} else {
Ok(())
}
}
#[inline]
pub fn is_data_valid(&self, data: &[u8]) -> Result<(), DataError> {
if data.len() > self.mtu {
return Err(DataError::SizeExceedsMtu {
data_len: data.len(),
mtu: self.mtu,
});
}
Ok(())
}
#[inline]
pub unsafe fn read_from_umem(&self, len: usize) -> &[u8] {
self.mmap_area.mem_range(self.addr, len)
}
#[inline]
pub unsafe fn read_from_umem_checked(&self, len: usize) -> Result<&[u8], AccessError> {
self.is_access_valid(len)?;
Ok(self.mmap_area.mem_range(self.addr, len))
}
#[inline]
pub unsafe fn write_to_umem(&mut self, data: &[u8]) {
let data_len = data.len();
if data_len > 0 {
let umem_region = self.mmap_area.mem_range_mut(&self.addr(), &data_len);
umem_region[..data_len].copy_from_slice(data);
}
self.set_len(data_len);
}
#[inline]
pub unsafe fn write_to_umem_checked(&mut self, data: &[u8]) -> Result<(), WriteError> {
let data_len = data.len();
if data_len > 0 {
self.is_data_valid(data).map_err(|e| WriteError::Data(e))?;
let umem_region = self.mmap_area.mem_range_mut(&self.addr(), &data_len);
umem_region[..data_len].copy_from_slice(data);
}
self.set_len(data_len);
Ok(())
}
#[inline]
pub unsafe fn umem_region_mut(&mut self, len: &usize) -> &mut [u8] {
self.mmap_area.mem_range_mut(&self.addr, &len)
}
#[inline]
pub unsafe fn umem_region_mut_checked(&mut self, len: usize) -> Result<&mut [u8], AccessError> {
Ok(self.mmap_area.mem_range_mut(&self.addr, &len))
}
}
pub struct UmemBuilder {
config: UmemConfig,
}
pub struct UmemBuilderWithMmap {
config: UmemConfig,
mmap_area: MmapArea,
}
struct XskUmem(*mut xsk_umem);
unsafe impl Send for XskUmem {}
impl Drop for XskUmem {
fn drop(&mut self) {
log::debug!("deleting umem");
let err = unsafe { libbpf_sys::xsk_umem__delete(self.0) };
if err != 0 {
log::error!("xsk_umem__delete() failed: {}", errno());
}
}
}
pub struct Umem<'a> {
config: UmemConfig,
frame_size: usize,
umem_len: usize,
mtu: usize,
inner: Box<XskUmem>,
_marker: PhantomData<&'a ()>,
}
impl UmemBuilder {
pub fn create_mmap(self) -> io::Result<UmemBuilderWithMmap> {
let mmap_area = MmapArea::new(self.config.umem_len(), self.config.use_huge_pages())?;
Ok(UmemBuilderWithMmap {
config: self.config,
mmap_area,
})
}
}
impl<'a> UmemBuilderWithMmap {
pub fn create_umem(
mut self,
) -> io::Result<(Umem<'a>, FillQueue<'a>, CompQueue<'a>, Vec<Frame<'a>>)> {
let umem_create_config = xsk_umem_config {
fill_size: self.config.fill_queue_size(),
comp_size: self.config.comp_queue_size(),
frame_size: self.config.frame_size(),
frame_headroom: self.config.frame_headroom(),
flags: 0,
};
let mut umem_ptr: *mut xsk_umem = ptr::null_mut();
let mut fq_ptr: MaybeUninit<xsk_ring_prod> = MaybeUninit::zeroed();
let mut cq_ptr: MaybeUninit<xsk_ring_cons> = MaybeUninit::zeroed();
let err = unsafe {
libbpf_sys::xsk_umem__create(
&mut umem_ptr,
self.mmap_area.as_mut_ptr(),
self.mmap_area.len() as u64,
fq_ptr.as_mut_ptr(),
cq_ptr.as_mut_ptr(),
&umem_create_config,
)
};
if err != 0 {
let e = errno::errno();
return Err(io::Error::from_raw_os_error(e.0));
}
let frame_size = self.config.frame_size() as usize;
let frame_count = self.config.frame_count() as usize;
let frame_headroom = self.config.frame_headroom() as usize;
let xdp_packet_headroom = XDP_PACKET_HEADROOM as usize;
let mtu = frame_size - (xdp_packet_headroom + frame_headroom);
let mut frames: Vec<Frame> = Vec::with_capacity(frame_count);
let mmap = Arc::new(self.mmap_area);
for i in 0..frame_count {
let addr = i * frame_size;
let len = 0;
let options = 0;
let frame = Frame {
addr,
len,
options,
mtu,
_marker: PhantomData,
mmap_area: mmap.clone(),
status: FrameStatus::Free,
};
frames.push(frame);
}
let fill_queue = FillQueue {
size: self.config.fill_queue_size(),
inner: unsafe { Box::new(fq_ptr.assume_init()) },
_marker: PhantomData,
};
let comp_queue = CompQueue {
size: self.config.comp_queue_size(),
inner: unsafe { Box::new(cq_ptr.assume_init()) },
_marker: PhantomData,
};
let umem = Umem {
config: self.config,
frame_size,
umem_len: frame_count * frame_size,
mtu,
inner: Box::new(XskUmem(umem_ptr)),
_marker: PhantomData,
};
Ok((umem, fill_queue, comp_queue, frames))
}
}
impl Umem<'_> {
pub fn builder(config: UmemConfig) -> UmemBuilder {
UmemBuilder { config }
}
pub fn config(&self) -> &UmemConfig {
&self.config
}
#[inline]
pub fn mtu(&self) -> usize {
self.mtu
}
pub(crate) fn as_mut_ptr(&mut self) -> *mut xsk_umem {
unsafe { self.inner.0.as_mut().expect("failed to get mut umem ptr") }
}
}
#[derive(Debug)]
pub struct FillQueue<'umem> {
size: u32,
inner: Box<xsk_ring_prod>,
_marker: PhantomData<&'umem ()>,
}
impl FillQueue<'_> {
#[inline]
pub unsafe fn produce(&mut self, descs: &mut [Frame]) -> usize {
let nb = descs.len() as u64;
if nb == 0 {
return 0;
}
let mut idx: u32 = 0;
let cnt = libbpf_sys::_xsk_ring_prod__reserve(self.inner.as_mut(), nb, &mut idx);
if cnt > 0 {
for desc in descs.iter().take(cnt.try_into().unwrap()) {
*libbpf_sys::_xsk_ring_prod__fill_addr(self.inner.as_mut(), idx) = desc.addr as u64;
idx += 1;
}
libbpf_sys::_xsk_ring_prod__submit(self.inner.as_mut(), cnt);
}
cnt.try_into().unwrap()
}
#[inline]
pub unsafe fn produce_and_wakeup(
&mut self,
descs: &mut [Frame],
socket_fd: &mut Fd,
poll_timeout: i32,
) -> io::Result<usize> {
let cnt = self.produce(descs);
if cnt > 0 && self.needs_wakeup() {
self.wakeup(socket_fd, poll_timeout)?;
}
Ok(cnt)
}
#[inline]
pub fn wakeup(&self, fd: &mut Fd, poll_timeout: i32) -> io::Result<()> {
socket::poll_read(fd, poll_timeout)?;
Ok(())
}
#[inline]
pub fn needs_wakeup(&self) -> bool {
unsafe { libbpf_sys::_xsk_ring_prod__needs_wakeup(self.inner.as_ref()) != 0 }
}
}
unsafe impl Send for FillQueue<'_> {}
#[derive(Debug)]
pub struct CompQueue<'umem> {
size: u32,
inner: Box<xsk_ring_cons>,
_marker: PhantomData<&'umem ()>,
}
impl CompQueue<'_> {
#[inline]
pub fn consume(&mut self, n_frames: u64) -> Vec<u64> {
let mut idx: u32 = 0;
let cnt =
unsafe { libbpf_sys::_xsk_ring_cons__peek(self.inner.as_mut(), n_frames, &mut idx) };
let mut free_frames = vec![];
for i in 0..cnt {
let addr: u64 = unsafe {
*libbpf_sys::_xsk_ring_cons__comp_addr(self.inner.as_mut(), idx + i as u32)
};
free_frames.push(addr);
}
unsafe { libbpf_sys::_xsk_ring_cons__release(self.inner.as_mut(), cnt) };
free_frames
}
}
unsafe impl Send for CompQueue<'_> {}
#[derive(Debug)]
pub enum AccessError {
NullRegion,
RegionOutOfBounds {
addr: usize,
len: usize,
umem_len: usize,
},
CrossesFrameBoundary { addr: usize, len: usize },
}
impl fmt::Display for AccessError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use AccessError::*;
match self {
NullRegion => write!(f, "region has zero length"),
RegionOutOfBounds {
addr,
len,
umem_len,
} => write!(
f,
"UMEM region [{}, {}] is out of bounds (UMEM length is {})",
addr,
addr + (len - 1),
umem_len
),
CrossesFrameBoundary { addr, len } => write!(
f,
"UMEM region [{}, {}] intersects with more then one frame",
addr,
addr + (len - 1),
),
}
}
}
impl Error for AccessError {}
#[derive(Debug)]
pub enum DataError {
SizeExceedsMtu { data_len: usize, mtu: usize },
}
impl fmt::Display for DataError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
DataError::SizeExceedsMtu { data_len, mtu } => write!(
f,
"data length ({} bytes) cannot be greater than the MTU ({} bytes)",
data_len, mtu
),
}
}
}
impl Error for DataError {}
#[derive(Debug)]
pub enum WriteError {
Access(AccessError),
Data(DataError),
}
impl fmt::Display for WriteError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use WriteError::*;
match self {
Access(access_err) => write!(f, "{}", access_err),
Data(data_err) => write!(f, "{}", data_err),
}
}
}
impl Error for WriteError {}
#[cfg(test)]
mod tests {
use rand;
use std::num::NonZeroU32;
use super::*;
use crate::umem::UmemConfig;
const FRAME_COUNT: u32 = 8;
const FRAME_SIZE: u32 = 2048;
fn generate_random_bytes(len: u32) -> Vec<u8> {
(0..len).map(|_| rand::random::<u8>()).collect()
}
fn umem_config() -> UmemConfig {
UmemConfig::new(FRAME_COUNT, FRAME_SIZE, 4, 4, 0, false).unwrap()
}
fn umem<'a>() -> (Umem<'a>, FillQueue<'a>, CompQueue<'a>, Vec<Frame<'a>>) {
let config = umem_config();
Umem::builder(config)
.create_mmap()
.expect("Failed to create mmap region")
.create_umem()
.expect("Failed to create UMEM")
}
#[test]
fn umem_create_succeeds_when_frame_count_is_one() {
let config = UmemConfig::new(1, 4096, 4, 4, 0, false).unwrap();
Umem::builder(config)
.create_mmap()
.expect("Failed to create mmap region")
.create_umem()
.expect("Failed to create UMEM");
}
#[test]
fn umem_create_succeeds_when_fill_size_is_one() {
let config = UmemConfig::new(16, 4096, 1, 4, 0, false).unwrap();
Umem::builder(config)
.create_mmap()
.expect("Failed to create mmap region")
.create_umem()
.expect("Failed to create UMEM");
}
#[test]
fn umem_create_succeeds_when_comp_size_is_one() {
let config = UmemConfig::new(16, 4096, 4, 1, 0, false).unwrap();
Umem::builder(config)
.create_mmap()
.expect("Failed to create mmap region")
.create_umem()
.expect("Failed to create UMEM");
}
#[test]
#[should_panic]
fn umem_create_fails_when_frame_size_is_lt_2048() {
let config = UmemConfig::new(1, 2047, 4, 4, 0, false).unwrap();
Umem::builder(config)
.create_mmap()
.expect("Failed to create mmap region")
.create_umem()
.expect("Failed to create UMEM");
}
#[test]
fn mtu_is_correct() {
let config = UmemConfig::new(1, 2048, 4, 4, 512, false).unwrap();
let (umem, _fq, _cq, _frame_descs) = Umem::builder(config)
.create_mmap()
.expect("Failed to create mmap region")
.create_umem()
.expect("Failed to create UMEM");
assert_eq!(umem.mtu(), (2048 - XDP_PACKET_HEADROOM - 512) as usize);
}
#[test]
fn umem_access_checks_ok() {
}
#[test]
fn data_checks_ok() {
let (_umem, _fq, _cq, frames) = umem();
let empty_data: Vec<u8> = Vec::new();
assert!(frames[0].is_data_valid(&empty_data).is_ok());
let mtu = FRAME_SIZE - XDP_PACKET_HEADROOM;
let data = generate_random_bytes(mtu - 1);
assert!(frames[0].is_data_valid(&data).is_ok());
let data = generate_random_bytes(mtu);
assert!(frames[0].is_data_valid(&data).is_ok());
let data = generate_random_bytes(mtu + 1);
assert!(matches!(
frames[0].is_data_valid(&data),
Err(DataError::SizeExceedsMtu { .. })
));
}
#[test]
fn write_no_data_to_umem() {
let (mut _umem, _fq, _cq, mut frames) = umem();
let data = [];
unsafe {
frames[0].write_to_umem_checked(&data[..]).unwrap();
}
assert_eq!(frames[0].len(), 0);
}
#[test]
fn write_to_umem_frame_then_read_small_byte_array() {
let (mut _umem, _fq, _cq, mut frames) = umem();
let data = [b'H', b'e', b'l', b'l', b'o'];
unsafe {
frames[0].write_to_umem_checked(&data[..]).unwrap();
}
assert_eq!(frames[0].len(), 5);
let umem_region = unsafe { frames[0].read_from_umem_checked(frames[0].len).unwrap() };
assert_eq!(data, umem_region[..data.len()]);
}
#[test]
fn write_max_bytes_to_neighbouring_umem_frames() {
let (mut _umem, _fq, _cq, mut frames) = umem();
let data_len = FRAME_SIZE;
let fst_data = generate_random_bytes(data_len);
let snd_data = generate_random_bytes(data_len);
unsafe {
let umem_region = frames[0]
.umem_region_mut_checked(data_len as usize)
.unwrap();
umem_region.copy_from_slice(&fst_data[..]);
frames[0].set_len(data_len as usize);
let umem_region = frames[1]
.umem_region_mut_checked(data_len as usize)
.unwrap();
umem_region.copy_from_slice(&snd_data[..]);
frames[1].set_len(data_len as usize);
}
let fst_frame_ref = unsafe { frames[0].read_from_umem(frames[0].len()) };
let snd_frame_ref = unsafe { frames[1].read_from_umem(frames[1].len()) };
assert_eq!(fst_data[..], fst_frame_ref[..fst_data.len()]);
assert_eq!(snd_data[..], snd_frame_ref[..snd_data.len()]);
let mem_len = (FRAME_SIZE * 2) as usize;
let mem_range = unsafe { frames[0].mmap_area.mem_range(0, mem_len) };
let mut data_vec = Vec::with_capacity(mem_len);
data_vec.extend_from_slice(&fst_data);
data_vec.extend_from_slice(&snd_data);
assert_eq!(&data_vec[..], mem_range);
}
}