use std::os::fd::{AsFd, AsRawFd};
use tokio::io::unix::AsyncFd;
use crate::error::Error;
use crate::packet::{OwnedPacket, PacketBatch};
use crate::traits::PacketSource;
pub struct AsyncCapture<S: PacketSource + AsRawFd> {
inner: AsyncFd<S>,
}
impl<S: PacketSource + AsRawFd> AsyncCapture<S> {
pub fn new(source: S) -> Result<Self, Error> {
let fd = AsyncFd::new(source).map_err(Error::Io)?;
Ok(Self { inner: fd })
}
pub async fn readable(&mut self) -> Result<ReadableGuard<'_, S>, Error> {
let guard = self.inner.readable_mut().await.map_err(Error::Io)?;
Ok(ReadableGuard { guard })
}
pub async fn try_recv_batch(&mut self) -> Result<PacketBatch<'_>, Error> {
loop {
let self_ptr: *mut Self = self;
let guard = unsafe { (*self_ptr).inner.readable_mut() }
.await
.map_err(Error::Io)?;
let mut guard = guard;
if let Some(batch) = guard.get_inner_mut().next_batch() {
let batch: PacketBatch<'_> = unsafe { std::mem::transmute(batch) };
return Ok(batch);
}
guard.clear_ready();
}
}
pub async fn recv(&mut self) -> Result<Vec<OwnedPacket>, Error> {
loop {
{
let mut guard = self.inner.readable_mut().await.map_err(Error::Io)?;
if let Some(batch) = guard.get_inner_mut().next_batch() {
let packets: Vec<OwnedPacket> = batch.iter().map(|p| p.to_owned()).collect();
return Ok(packets);
}
guard.clear_ready();
}
}
}
pub fn get_ref(&self) -> &S {
self.inner.get_ref()
}
pub fn get_mut(&mut self) -> &mut S {
self.inner.get_mut()
}
pub fn into_inner(self) -> S {
self.inner.into_inner()
}
pub fn into_stream(self) -> PacketStream<S> {
PacketStream::new(self)
}
pub fn stats(&self) -> Result<crate::stats::CaptureStats, Error> {
self.inner.get_ref().stats()
}
pub fn cumulative_stats(&self) -> Result<crate::stats::CaptureStats, Error> {
self.inner.get_ref().cumulative_stats()
}
}
impl<S: PacketSource + AsRawFd> AsFd for AsyncCapture<S> {
fn as_fd(&self) -> std::os::fd::BorrowedFd<'_> {
self.inner.get_ref().as_fd()
}
}
impl<S: PacketSource + AsRawFd + Send> crate::traits::AsyncPacketSource for AsyncCapture<S> {
fn next_batch(
&mut self,
) -> impl std::future::Future<Output = Result<crate::packet::PacketBatch<'_>, Error>> + Send
{
self.try_recv_batch()
}
}
pub struct PacketStream<S: PacketSource + AsRawFd> {
cap: AsyncCapture<S>,
}
impl<S: PacketSource + AsRawFd> PacketStream<S> {
pub fn new(cap: AsyncCapture<S>) -> Self {
Self { cap }
}
pub fn into_inner(self) -> AsyncCapture<S> {
self.cap
}
}
impl<S: PacketSource + AsRawFd + Unpin> futures_core::Stream for PacketStream<S> {
type Item = Result<Vec<OwnedPacket>, Error>;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let this = self.get_mut();
loop {
let mut ready = match this.cap.inner.poll_read_ready_mut(cx) {
std::task::Poll::Ready(Ok(g)) => g,
std::task::Poll::Ready(Err(e)) => {
return std::task::Poll::Ready(Some(Err(Error::Io(e))));
}
std::task::Poll::Pending => return std::task::Poll::Pending,
};
if let Some(batch) = ready.get_inner_mut().next_batch() {
let pkts: Vec<OwnedPacket> = batch.iter().map(|p| p.to_owned()).collect();
return std::task::Poll::Ready(Some(Ok(pkts)));
}
ready.clear_ready();
}
}
}
pub struct ReadableGuard<'a, S: PacketSource + AsRawFd> {
guard: tokio::io::unix::AsyncFdReadyMutGuard<'a, S>,
}
impl<'a, S: PacketSource + AsRawFd> ReadableGuard<'a, S> {
pub fn next_batch(&mut self) -> Option<PacketBatch<'_>> {
let guard_ptr: *mut tokio::io::unix::AsyncFdReadyMutGuard<'a, S> = &raw mut self.guard;
let batch = unsafe { (*guard_ptr).get_inner_mut().next_batch() };
if batch.is_none() {
unsafe { (*guard_ptr).clear_ready() };
}
batch
}
pub fn get_inner_mut(&mut self) -> &mut S {
self.guard.get_inner_mut()
}
}