use std::cell::Cell;
use std::marker::PhantomData;
use std::mem::ManuallyDrop;
use std::os::fd::{AsFd, AsRawFd, BorrowedFd, OwnedFd};
use std::time::{Duration, Instant};
use crate::afpacket::ring::MmapRing;
use crate::afpacket::{fanout, ffi, filter, ring, socket};
use crate::config::{BpfFilter, BpfInsn, FanoutFlags, FanoutMode, RingProfile, TimestampSource};
use crate::error::Error;
use crate::packet::{BatchIter, Packet, PacketBatch};
use crate::stats::CaptureStats;
use crate::traits::PacketSource;
pub struct Capture {
ring: MmapRing,
fd: OwnedFd,
current_block: usize,
expected_seq: u64,
poll_timeout: Duration,
cumulative: Cell<CaptureStats>,
}
impl Capture {
pub fn open(interface: &str) -> Result<Self, Error> {
Self::builder().interface(interface).build()
}
pub fn builder() -> CaptureBuilder {
CaptureBuilder::default()
}
pub fn stats(&self) -> Result<CaptureStats, Error> {
let raw = socket::get_packet_stats(self.fd.as_fd())?;
Ok(CaptureStats::from(raw))
}
pub fn cumulative_stats(&self) -> Result<CaptureStats, Error> {
let delta = socket::get_packet_stats(self.fd.as_fd())?;
let total = self.cumulative.get();
let new_total = CaptureStats {
packets: total.packets.saturating_add(delta.tp_packets),
drops: total.drops.saturating_add(delta.tp_drops),
freeze_count: total.freeze_count.saturating_add(delta.tp_freeze_q_cnt),
};
self.cumulative.set(new_total);
Ok(new_total)
}
pub fn packets(&mut self) -> Packets<'_> {
Packets {
cap: self as *mut Capture,
timeout: self.poll_timeout,
deadline: None,
batch: None,
iter: None,
last_error: None,
_marker: PhantomData,
}
}
pub fn packets_until(&mut self, deadline: Instant) -> Packets<'_> {
Packets {
cap: self as *mut Capture,
timeout: self.poll_timeout,
deadline: Some(deadline),
batch: None,
iter: None,
last_error: None,
_marker: PhantomData,
}
}
pub fn packets_for(&mut self, total: Duration) -> Packets<'_> {
self.packets_until(Instant::now() + total)
}
pub fn attach_ebpf_filter<F: AsFd>(&self, prog: F) -> Result<(), Error> {
filter::attach_ebpf_socket_filter(self.fd.as_fd(), prog.as_fd())
}
pub fn attach_fanout_ebpf<F: AsFd>(&self, prog: F) -> Result<(), Error> {
fanout::attach_fanout_ebpf(self.fd.as_fd(), prog.as_fd())
}
pub fn detach_filter(&self) -> Result<(), Error> {
filter::detach_bpf_filter(self.fd.as_fd())
}
pub unsafe fn ring_ptr(&self) -> *const u8 {
self.ring.base().as_ptr()
}
pub fn ring_len(&self) -> usize {
self.ring.size()
}
pub fn next_batch(&mut self) -> Option<PacketBatch<'_>> {
let bd = self.ring.block_ptr(self.current_block);
let status = unsafe { ring::read_block_status(bd) };
if status & ffi::TP_STATUS_USER == 0 {
return None;
}
let seq = unsafe { (*bd.as_ptr()).hdr.bh1.seq_num };
if seq != self.expected_seq && self.expected_seq != 0 {
tracing::warn!(
expected = self.expected_seq,
actual = seq,
dropped = seq.saturating_sub(self.expected_seq),
"block sequence gap"
);
}
self.expected_seq = seq + 1;
let batch = unsafe { PacketBatch::new(bd) };
self.current_block = (self.current_block + 1) % self.ring.block_count();
Some(batch)
}
pub fn next_batch_blocking(
&mut self,
timeout: Duration,
) -> Result<Option<PacketBatch<'_>>, Error> {
{
let bd = self.ring.block_ptr(self.current_block);
let status = unsafe { ring::read_block_status(bd) };
if status & ffi::TP_STATUS_USER != 0 {
return Ok(self.next_batch());
}
}
let mut pfds = [nix::poll::PollFd::new(
self.fd.as_fd(),
nix::poll::PollFlags::POLLIN,
)];
crate::syscall::poll_eintr_safe(&mut pfds, timeout).map_err(Error::Io)?;
Ok(self.next_batch())
}
}
impl PacketSource for Capture {
fn next_batch(&mut self) -> Option<PacketBatch<'_>> {
Capture::next_batch(self)
}
fn next_batch_blocking(&mut self, timeout: Duration) -> Result<Option<PacketBatch<'_>>, Error> {
Capture::next_batch_blocking(self, timeout)
}
fn stats(&self) -> Result<CaptureStats, Error> {
Capture::stats(self)
}
fn cumulative_stats(&self) -> Result<CaptureStats, Error> {
Capture::cumulative_stats(self)
}
}
impl std::fmt::Debug for Capture {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Capture")
.field("ring_size", &self.ring.size())
.field("block_count", &self.ring.block_count())
.field("current_block", &self.current_block)
.field("poll_timeout", &self.poll_timeout)
.finish()
}
}
impl AsFd for Capture {
fn as_fd(&self) -> BorrowedFd<'_> {
self.fd.as_fd()
}
}
impl AsRawFd for Capture {
fn as_raw_fd(&self) -> std::os::fd::RawFd {
self.fd.as_raw_fd()
}
}
unsafe impl Send for Capture {}
pub struct Packets<'cap> {
cap: *mut Capture,
timeout: Duration,
deadline: Option<Instant>,
batch: Option<ManuallyDrop<PacketBatch<'static>>>,
iter: Option<BatchIter<'static>>,
last_error: Option<Error>,
_marker: PhantomData<&'cap mut Capture>,
}
impl<'cap> Packets<'cap> {
pub fn take_error(&mut self) -> Option<Error> {
self.last_error.take()
}
fn drop_batch(&mut self) {
self.iter = None;
if let Some(batch) = self.batch.take() {
let _ = ManuallyDrop::into_inner(batch);
}
}
}
impl<'cap> Iterator for Packets<'cap> {
type Item = Packet<'cap>;
fn next(&mut self) -> Option<Packet<'cap>> {
loop {
if let Some(it) = self.iter.as_mut() {
if let Some(pkt) = it.next() {
let pkt: Packet<'cap> = unsafe { std::mem::transmute(pkt) };
return Some(pkt);
}
self.drop_batch();
}
let cap = unsafe { &mut *self.cap };
let effective_timeout = match self.deadline {
Some(d) => match d.checked_duration_since(Instant::now()) {
Some(remaining) => remaining.min(self.timeout),
None => return None,
},
None => self.timeout,
};
match cap.next_batch_blocking(effective_timeout) {
Ok(Some(batch)) => {
if batch.is_empty() {
drop(batch);
continue;
}
let erased_batch: PacketBatch<'static> = unsafe { std::mem::transmute(batch) };
self.batch = Some(ManuallyDrop::new(erased_batch));
let iter: BatchIter<'_> = self.batch.as_ref().unwrap().iter();
let iter_erased: BatchIter<'static> = unsafe { std::mem::transmute(iter) };
self.iter = Some(iter_erased);
}
Ok(None) => continue,
Err(e) => {
self.last_error = Some(e);
return None;
}
}
}
}
}
impl Drop for Packets<'_> {
fn drop(&mut self) {
self.drop_batch();
}
}
#[must_use]
#[derive(Clone)]
pub struct CaptureBuilder {
interface: Option<String>,
block_size: usize,
block_count: usize,
frame_size: usize,
block_timeout_ms: u32,
fill_rxhash: bool,
promiscuous: bool,
ignore_outgoing: bool,
busy_poll_us: Option<u32>,
prefer_busy_poll: Option<bool>,
busy_poll_budget: Option<u16>,
reuseport: bool,
rcvbuf: Option<usize>,
rcvbuf_force: bool,
timestamp_source: TimestampSource,
poll_timeout: Duration,
fanout: Option<(FanoutMode, u16)>,
fanout_flags: FanoutFlags,
bpf_filter: Option<Vec<BpfInsn>>,
}
impl Default for CaptureBuilder {
fn default() -> Self {
Self {
interface: None,
block_size: 1 << 22, block_count: 64,
frame_size: 2048,
block_timeout_ms: 60,
fill_rxhash: true,
promiscuous: false,
ignore_outgoing: false,
busy_poll_us: None,
prefer_busy_poll: None,
busy_poll_budget: None,
reuseport: false,
rcvbuf: None,
rcvbuf_force: false,
timestamp_source: TimestampSource::default(),
poll_timeout: Duration::from_millis(100),
fanout: None,
fanout_flags: FanoutFlags::empty(),
bpf_filter: None,
}
}
}
impl CaptureBuilder {
pub fn interface(mut self, name: &str) -> Self {
self.interface = Some(name.to_string());
self
}
pub fn profile(mut self, profile: RingProfile) -> Self {
let (bs, bc, fs, timeout) = profile.params();
self.block_size = bs;
self.block_count = bc;
self.frame_size = fs;
self.block_timeout_ms = timeout;
self
}
pub fn snap_len(mut self, len: u32) -> Self {
let frame = ffi::tpacket_align(ffi::TPACKET3_HDRLEN + len as usize);
self.frame_size = frame;
self
}
pub fn block_size(mut self, bytes: usize) -> Self {
self.block_size = bytes;
self
}
pub fn block_count(mut self, n: usize) -> Self {
self.block_count = n;
self
}
pub fn frame_size(mut self, bytes: usize) -> Self {
self.frame_size = bytes;
self
}
pub fn block_timeout_ms(mut self, ms: u32) -> Self {
self.block_timeout_ms = ms;
self
}
pub fn fill_rxhash(mut self, enable: bool) -> Self {
self.fill_rxhash = enable;
self
}
pub fn promiscuous(mut self, enable: bool) -> Self {
self.promiscuous = enable;
self
}
pub fn ignore_outgoing(mut self, enable: bool) -> Self {
self.ignore_outgoing = enable;
self
}
pub fn busy_poll_us(mut self, us: u32) -> Self {
self.busy_poll_us = Some(us);
self
}
pub fn prefer_busy_poll(mut self, enable: bool) -> Self {
self.prefer_busy_poll = Some(enable);
self
}
pub fn busy_poll_budget(mut self, budget: u16) -> Self {
self.busy_poll_budget = Some(budget);
self
}
pub fn reuseport(mut self, enable: bool) -> Self {
self.reuseport = enable;
self
}
pub fn rcvbuf(mut self, bytes: usize) -> Self {
self.rcvbuf = Some(bytes);
self
}
pub fn rcvbuf_force(mut self, enable: bool) -> Self {
self.rcvbuf_force = enable;
self
}
pub fn timestamp_source(mut self, source: TimestampSource) -> Self {
self.timestamp_source = source;
self
}
pub fn poll_timeout(mut self, timeout: Duration) -> Self {
self.poll_timeout = timeout;
self
}
pub fn fanout(mut self, mode: FanoutMode, group_id: u16) -> Self {
self.fanout = Some((mode, group_id));
self
}
pub fn fanout_flags(mut self, flags: FanoutFlags) -> Self {
self.fanout_flags = flags;
self
}
pub fn bpf_filter(mut self, insns: Vec<BpfInsn>) -> Self {
self.bpf_filter = Some(insns);
self
}
pub fn build(self) -> Result<Capture, Error> {
let mut current_count = self.block_count;
let min_count = (self.block_count / 4).max(1);
loop {
match build_inner(&self, current_count) {
Ok(cap) => return Ok(cap),
Err(Error::Mmap(ref e)) if is_enomem(e) && current_count > min_count => {
current_count = (current_count * 3 / 4).max(min_count);
tracing::warn!(
"ENOMEM: retrying with {current_count} blocks (was {})",
self.block_count
);
}
Err(Error::SockOpt { ref source, .. })
if is_enomem(source) && current_count > min_count =>
{
current_count = (current_count * 3 / 4).max(min_count);
tracing::warn!(
"ENOMEM: retrying with {current_count} blocks (was {})",
self.block_count
);
}
Err(e) => return Err(e),
}
}
}
}
fn is_enomem(e: &std::io::Error) -> bool {
e.raw_os_error() == Some(libc::ENOMEM)
}
fn build_inner(b: &CaptureBuilder, block_count: usize) -> Result<Capture, Error> {
let interface = b
.interface
.as_deref()
.ok_or_else(|| Error::Config("interface is required".into()))?;
if !b.block_size.is_power_of_two() {
return Err(Error::Config(format!(
"block_size {} is not a power of 2",
b.block_size
)));
}
let page_size = 4096usize;
if b.block_size % page_size != 0 {
return Err(Error::Config(format!(
"block_size {} is not a multiple of PAGE_SIZE ({})",
b.block_size, page_size
)));
}
crate::afpacket::validate_frame_size(b.frame_size)?;
if b.frame_size > b.block_size {
return Err(Error::Config(format!(
"frame_size {} exceeds block_size {}",
b.frame_size, b.block_size
)));
}
if block_count == 0 {
return Err(Error::Config("block_count must be > 0".into()));
}
let frame_nr = (b.block_size / b.frame_size) * block_count;
let fd = socket::create_packet_socket()?;
socket::set_packet_version(fd.as_fd())?;
let mut req: ffi::tpacket_req3 = unsafe { std::mem::zeroed() };
req.tp_block_size = b.block_size as u32;
req.tp_block_nr = block_count as u32;
req.tp_frame_size = b.frame_size as u32;
req.tp_frame_nr = frame_nr as u32;
req.tp_retire_blk_tov = b.block_timeout_ms;
req.tp_sizeof_priv = 0;
req.tp_feature_req_word = if b.fill_rxhash {
ffi::TP_FT_REQ_FILL_RXHASH
} else {
0
};
socket::set_rx_ring(fd.as_fd(), &req)?;
let ring_size = b.block_size * block_count;
let ring = MmapRing::new(fd.as_fd(), ring_size, b.block_size, block_count)?;
let ifindex = socket::resolve_interface(interface)?;
socket::bind_to_interface(fd.as_fd(), ifindex)?;
if b.promiscuous {
socket::set_promiscuous(fd.as_fd(), ifindex)?;
}
if b.ignore_outgoing {
socket::set_ignore_outgoing(fd.as_fd())?;
}
if let Some(us) = b.busy_poll_us {
socket::set_busy_poll(fd.as_fd(), us)?;
}
if let Some(prefer) = b.prefer_busy_poll {
socket::set_prefer_busy_poll(fd.as_fd(), prefer)?;
}
if let Some(budget) = b.busy_poll_budget {
socket::set_busy_poll_budget(fd.as_fd(), budget)?;
}
if b.reuseport {
socket::set_reuseport(fd.as_fd(), true)?;
}
if let Some(bytes) = b.rcvbuf {
if b.rcvbuf_force {
socket::set_rcvbuf_force(fd.as_fd(), bytes)?;
} else {
socket::set_rcvbuf(fd.as_fd(), bytes)?;
}
}
if b.timestamp_source != TimestampSource::Software {
socket::set_timestamp_source(fd.as_fd(), b.timestamp_source)?;
}
if let Some((mode, group_id)) = b.fanout {
fanout::join_fanout(fd.as_fd(), group_id, mode, b.fanout_flags)?;
}
if let Some(insns) = &b.bpf_filter {
let filt = BpfFilter::new(insns.clone());
filter::attach_bpf_filter(fd.as_fd(), &filt)?;
}
Ok(Capture {
ring,
fd,
current_block: 0,
expected_seq: 0,
poll_timeout: b.poll_timeout,
cumulative: Cell::new(CaptureStats::default()),
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn builder_rejects_missing_interface() {
let err = CaptureBuilder::default().build().unwrap_err();
assert!(matches!(err, Error::Config(_)));
}
#[test]
fn builder_rejects_bad_block_size() {
let err = CaptureBuilder::default()
.interface("lo")
.block_size(3000) .build()
.unwrap_err();
assert!(matches!(err, Error::Config(_)));
}
#[test]
fn builder_rejects_bad_frame_size() {
let err = CaptureBuilder::default()
.interface("lo")
.frame_size(100) .build()
.unwrap_err();
assert!(matches!(err, Error::Config(_)));
}
#[test]
fn builder_rejects_small_frame_size() {
let err = CaptureBuilder::default()
.interface("lo")
.frame_size(32) .build()
.unwrap_err();
assert!(matches!(err, Error::Config(_)));
}
#[test]
fn builder_rejects_zero_block_count() {
let err = CaptureBuilder::default()
.interface("lo")
.block_count(0)
.build()
.unwrap_err();
assert!(matches!(err, Error::Config(_)));
}
#[test]
fn builder_defaults() {
let b = CaptureBuilder::default();
assert_eq!(b.block_size, 1 << 22);
assert_eq!(b.block_count, 64);
assert_eq!(b.frame_size, 2048);
assert_eq!(b.block_timeout_ms, 60);
assert!(b.fill_rxhash);
assert!(!b.promiscuous);
assert!(!b.ignore_outgoing);
assert_eq!(b.poll_timeout, Duration::from_millis(100));
}
#[test]
fn builder_fill_rxhash_setter() {
let b = CaptureBuilder::default().fill_rxhash(false);
assert!(!b.fill_rxhash);
let b = CaptureBuilder::default().fill_rxhash(true);
assert!(b.fill_rxhash);
}
#[test]
fn builder_poll_timeout_setter() {
let b = CaptureBuilder::default().poll_timeout(Duration::from_millis(25));
assert_eq!(b.poll_timeout, Duration::from_millis(25));
}
#[test]
fn builder_busy_poll_trio_chain() {
let b = CaptureBuilder::default()
.busy_poll_us(50)
.prefer_busy_poll(true)
.busy_poll_budget(64);
assert_eq!(b.busy_poll_us, Some(50));
assert_eq!(b.prefer_busy_poll, Some(true));
assert_eq!(b.busy_poll_budget, Some(64));
}
#[test]
fn builder_busy_poll_default_unset() {
let b = CaptureBuilder::default();
assert_eq!(b.busy_poll_us, None);
assert_eq!(b.prefer_busy_poll, None);
assert_eq!(b.busy_poll_budget, None);
}
}