#![cfg(feature = "tokio-async")]
use crate::error::Error as NetmapError;
use crate::ffi;
use crate::netmap::Netmap;
use std::io;
use std::os::unix::io::AsRawFd;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::io::unix::AsyncFd;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
#[derive(Debug)]
pub struct TokioNetmap {
async_fd_netmap: Arc<AsyncFd<Netmap>>,
}
impl TokioNetmap {
pub fn new(netmap: Netmap) -> io::Result<Self> {
Ok(Self {
async_fd_netmap: Arc::new(AsyncFd::new(netmap)?),
})
}
pub fn rx_ring(&self, ring_idx: usize) -> Result<AsyncNetmapRxRing, NetmapError> {
let netmap_instance = self.async_fd_netmap.get_ref();
if ring_idx >= netmap_instance.num_rx_rings() {
return Err(NetmapError::InvalidRingIndex(ring_idx));
}
let ring_ptr = unsafe { ffi::NETMAP_RXRING((*netmap_instance.desc).nifp, ring_idx as u32) };
Ok(AsyncNetmapRxRing {
shared_fd_netmap: Arc::clone(&self.async_fd_netmap),
ring_ptr,
})
}
pub fn tx_ring(&self, ring_idx: usize) -> Result<AsyncNetmapTxRing, NetmapError> {
let netmap_instance = self.async_fd_netmap.get_ref();
if ring_idx >= netmap_instance.num_tx_rings() {
return Err(NetmapError::InvalidRingIndex(ring_idx));
}
let ring_ptr = unsafe { ffi::NETMAP_TXRING((*netmap_instance.desc).nifp, ring_idx as u32) };
Ok(AsyncNetmapTxRing {
shared_fd_netmap: Arc::clone(&self.async_fd_netmap),
ring_ptr,
})
}
}
#[derive(Debug)]
pub struct AsyncNetmapRxRing {
shared_fd_netmap: Arc<AsyncFd<Netmap>>,
ring_ptr: *mut ffi::netmap_ring, }
unsafe impl Send for AsyncNetmapRxRing {}
impl AsyncRead for AsyncNetmapRxRing {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let self_mut = self.get_mut();
loop {
unsafe {
let fd = self_mut.shared_fd_netmap.get_ref().as_raw_fd();
let ret = libc::ioctl(fd, ffi::NIOCRXSYNC as libc::c_ulong, 0 as *mut ffi::nmreq);
if ret == -1 {
return Poll::Ready(Err(io::Error::last_os_error()));
}
}
let ring = unsafe { &*self_mut.ring_ptr };
let mut head = ring.head;
let mut tail = ring.tail;
let num_slots = ring.num_slots;
if head == tail {
match self_mut.shared_fd_netmap.poll_read_ready_mut(cx) {
Poll::Ready(Ok(mut ready_guard)) => {
ready_guard.clear_ready();
let updated_ring = unsafe { &*self_mut.ring_ptr };
head = updated_ring.head;
if head == tail { return Poll::Pending; }
}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
}
}
let current_slot_idx = tail % num_slots;
let slot = unsafe { &*ring.slot.add(current_slot_idx as usize) };
let packet_len = slot.len as usize;
if packet_len == 0 || buf.remaining() == 0 {
unsafe {
let mutable_ring = &mut *self_mut.ring_ptr;
let new_tail = (tail + 1) % num_slots;
mutable_ring.cur = new_tail;
mutable_ring.tail = new_tail;
}
return Poll::Ready(Ok(()));
}
let len_to_copy = std::cmp::min(packet_len, buf.remaining());
let packet_data = unsafe { std::slice::from_raw_parts(slot.buf as *const u8, len_to_copy) };
buf.put_slice(packet_data);
unsafe {
let mutable_ring = &mut *self_mut.ring_ptr;
let new_tail = (tail + 1) % num_slots;
mutable_ring.cur = new_tail;
mutable_ring.tail = new_tail;
}
return Poll::Ready(Ok(()));
}
}
}
#[derive(Debug)]
pub struct AsyncNetmapTxRing {
shared_fd_netmap: Arc<AsyncFd<Netmap>>,
ring_ptr: *mut ffi::netmap_ring,
}
unsafe impl Send for AsyncNetmapTxRing {}
impl AsyncWrite for AsyncNetmapTxRing {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let self_mut = self.get_mut();
loop {
let ring = unsafe { &*self_mut.ring_ptr };
let head = ring.head;
let tail = ring.tail;
let num_slots = ring.num_slots;
let max_payload = ring.nr_buf_size as usize;
let is_full = (head + 1) % num_slots == tail;
if is_full {
match self_mut.shared_fd_netmap.poll_write_ready_mut(cx) {
Poll::Ready(Ok(mut ready_guard)) => {
ready_guard.clear_ready();
}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), Poll::Pending => return Poll::Pending, }
} else { if buf.is_empty() {
return Poll::Ready(Ok(0)); }
if buf.len() > max_payload {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
NetmapError::PacketTooLarge(buf.len()),
)));
}
let current_slot_idx = head % num_slots;
let slot = unsafe { &mut *ring.slot.add(current_slot_idx as usize) };
let slot_buf_slice = unsafe { std::slice::from_raw_parts_mut(slot.buf as *mut u8, buf.len()) };
slot_buf_slice.copy_from_slice(buf);
slot.len = buf.len() as u16;
slot.flags = 0;
unsafe {
let mutable_ring = &mut *self_mut.ring_ptr;
let new_head = (head + 1) % num_slots;
mutable_ring.head = new_head;
mutable_ring.cur = new_head; }
return Poll::Ready(Ok(buf.len())); }
}
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
unsafe {
let self_mut = self.get_mut(); let fd = self_mut.shared_fd_netmap.get_ref().as_raw_fd();
let ret = libc::ioctl(fd, ffi::NIOCTXSYNC as libc::c_ulong, 0 as *mut ffi::nmreq);
if ret == -1 {
return Poll::Ready(Err(io::Error::last_os_error()));
}
}
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.poll_flush(cx) {
Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
}
}