use std::os::fd::AsFd;
use std::time::Duration;
use crate::afpacket::rx::{Capture, CaptureBuilder};
use crate::afpacket::tx::{Injector, InjectorBuilder};
use crate::config::RingProfile;
use crate::error::Error;
use crate::packet::Packet;
use crate::stats::CaptureStats;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum BridgeAction {
Forward,
Drop,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum BridgeDirection {
AtoB,
BtoA,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct BridgeStats {
pub a_to_b: CaptureStats,
pub b_to_a: CaptureStats,
pub a_to_b_dropped_too_large: u64,
pub a_to_b_dropped_ring_full: u64,
pub b_to_a_dropped_too_large: u64,
pub b_to_a_dropped_ring_full: u64,
}
impl std::fmt::Display for BridgeStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"A→B: {} dropped(too_large={}, ring_full={}) | B→A: {} dropped(too_large={}, ring_full={})",
self.a_to_b,
self.a_to_b_dropped_too_large,
self.a_to_b_dropped_ring_full,
self.b_to_a,
self.b_to_a_dropped_too_large,
self.b_to_a_dropped_ring_full,
)
}
}
#[must_use]
pub struct Bridge {
rx_a: Capture,
tx_b: Injector,
rx_b: Capture,
tx_a: Injector,
poll_timeout: Duration,
drops: DropCounters,
}
#[derive(Debug, Default, Clone, Copy)]
struct DropCounters {
a_to_b_too_large: u64,
a_to_b_ring_full: u64,
b_to_a_too_large: u64,
b_to_a_ring_full: u64,
}
impl Bridge {
pub fn builder() -> BridgeBuilder {
BridgeBuilder::default()
}
pub fn run<F>(&mut self, mut filter: F) -> Result<(), Error>
where
F: FnMut(&Packet<'_>, BridgeDirection) -> BridgeAction,
{
loop {
let [a_ready, b_ready] = self.poll_both(self.poll_timeout)?;
if a_ready {
self.drain_direction(&mut filter, BridgeDirection::AtoB)?;
}
if b_ready {
self.drain_direction(&mut filter, BridgeDirection::BtoA)?;
}
}
}
pub fn run_iterations<F>(&mut self, iterations: usize, mut filter: F) -> Result<(), Error>
where
F: FnMut(&Packet<'_>, BridgeDirection) -> BridgeAction,
{
for _ in 0..iterations {
let [a_ready, b_ready] = self.poll_both(self.poll_timeout)?;
if a_ready {
self.drain_direction(&mut filter, BridgeDirection::AtoB)?;
}
if b_ready {
self.drain_direction(&mut filter, BridgeDirection::BtoA)?;
}
}
Ok(())
}
fn poll_both(&self, timeout: Duration) -> Result<[bool; 2], Error> {
use nix::poll::{PollFd, PollFlags};
let mut pfds = [
PollFd::new(self.rx_a.as_fd(), PollFlags::POLLIN),
PollFd::new(self.rx_b.as_fd(), PollFlags::POLLIN),
];
crate::syscall::poll_eintr_safe(&mut pfds, timeout).map_err(Error::Io)?;
Ok([
pfds[0]
.revents()
.is_some_and(|r| r.contains(PollFlags::POLLIN)),
pfds[1]
.revents()
.is_some_and(|r| r.contains(PollFlags::POLLIN)),
])
}
#[cfg(feature = "tokio")]
pub async fn run_async<F>(&mut self, mut filter: F) -> Result<(), Error>
where
F: FnMut(&Packet<'_>, BridgeDirection) -> BridgeAction,
{
use std::os::fd::{AsRawFd, RawFd};
use tokio::io::Interest;
use tokio::io::unix::AsyncFd;
struct FdHolder(RawFd);
impl AsRawFd for FdHolder {
fn as_raw_fd(&self) -> RawFd {
self.0
}
}
let async_a =
AsyncFd::with_interest(FdHolder(self.rx_a.as_fd().as_raw_fd()), Interest::READABLE)
.map_err(Error::Io)?;
let async_b =
AsyncFd::with_interest(FdHolder(self.rx_b.as_fd().as_raw_fd()), Interest::READABLE)
.map_err(Error::Io)?;
loop {
tokio::select! {
result = async_a.readable() => {
let mut guard = result.map_err(Error::Io)?;
self.drain_direction(&mut filter, BridgeDirection::AtoB)?;
guard.clear_ready();
}
result = async_b.readable() => {
let mut guard = result.map_err(Error::Io)?;
self.drain_direction(&mut filter, BridgeDirection::BtoA)?;
guard.clear_ready();
}
}
}
}
#[cfg(feature = "tokio")]
pub async fn run_iterations_async<F>(
&mut self,
iterations: usize,
mut filter: F,
) -> Result<(), Error>
where
F: FnMut(&Packet<'_>, BridgeDirection) -> BridgeAction,
{
use std::os::fd::{AsRawFd, RawFd};
use tokio::io::Interest;
use tokio::io::unix::AsyncFd;
struct FdHolder(RawFd);
impl AsRawFd for FdHolder {
fn as_raw_fd(&self) -> RawFd {
self.0
}
}
let async_a =
AsyncFd::with_interest(FdHolder(self.rx_a.as_fd().as_raw_fd()), Interest::READABLE)
.map_err(Error::Io)?;
let async_b =
AsyncFd::with_interest(FdHolder(self.rx_b.as_fd().as_raw_fd()), Interest::READABLE)
.map_err(Error::Io)?;
for _ in 0..iterations {
tokio::select! {
result = async_a.readable() => {
let mut guard = result.map_err(Error::Io)?;
self.drain_direction(&mut filter, BridgeDirection::AtoB)?;
guard.clear_ready();
}
result = async_b.readable() => {
let mut guard = result.map_err(Error::Io)?;
self.drain_direction(&mut filter, BridgeDirection::BtoA)?;
guard.clear_ready();
}
_ = tokio::time::sleep(self.poll_timeout) => {
}
}
}
Ok(())
}
fn drain_direction<F>(
&mut self,
filter: &mut F,
direction: BridgeDirection,
) -> Result<(), Error>
where
F: FnMut(&Packet<'_>, BridgeDirection) -> BridgeAction,
{
let (rx, tx, too_large, ring_full) = match direction {
BridgeDirection::AtoB => (
&mut self.rx_a,
&mut self.tx_b,
&mut self.drops.a_to_b_too_large,
&mut self.drops.a_to_b_ring_full,
),
BridgeDirection::BtoA => (
&mut self.rx_b,
&mut self.tx_a,
&mut self.drops.b_to_a_too_large,
&mut self.drops.b_to_a_ring_full,
),
};
let tx_capacity = tx.frame_capacity();
while let Some(batch) = rx.next_batch() {
for pkt in &batch {
if filter(&pkt, direction) != BridgeAction::Forward {
continue;
}
if pkt.len() > tx_capacity {
*too_large += 1;
tracing::warn!(
pkt_len = pkt.len(),
tx_capacity,
"Bridge: dropping packet — exceeds TX frame capacity"
);
continue;
}
match tx.allocate(pkt.len()) {
Some(mut slot) => {
slot.data_mut()[..pkt.len()].copy_from_slice(pkt.data());
slot.set_len(pkt.len());
slot.send();
}
None => {
*ring_full += 1;
tracing::debug!(
pkt_len = pkt.len(),
"Bridge: dropping packet — TX ring full"
);
}
}
}
tx.flush()?;
}
Ok(())
}
pub fn stats(&self) -> Result<BridgeStats, Error> {
Ok(BridgeStats {
a_to_b: self.rx_a.stats()?,
b_to_a: self.rx_b.stats()?,
a_to_b_dropped_too_large: self.drops.a_to_b_too_large,
a_to_b_dropped_ring_full: self.drops.a_to_b_ring_full,
b_to_a_dropped_too_large: self.drops.b_to_a_too_large,
b_to_a_dropped_ring_full: self.drops.b_to_a_ring_full,
})
}
pub fn cumulative_stats(&self) -> Result<BridgeStats, Error> {
Ok(BridgeStats {
a_to_b: self.rx_a.cumulative_stats()?,
b_to_a: self.rx_b.cumulative_stats()?,
a_to_b_dropped_too_large: self.drops.a_to_b_too_large,
a_to_b_dropped_ring_full: self.drops.a_to_b_ring_full,
b_to_a_dropped_too_large: self.drops.b_to_a_too_large,
b_to_a_dropped_ring_full: self.drops.b_to_a_ring_full,
})
}
pub fn into_inner(self) -> BridgeHandles {
BridgeHandles {
rx_a: self.rx_a,
tx_b: self.tx_b,
rx_b: self.rx_b,
tx_a: self.tx_a,
}
}
pub fn handles(&self) -> (&Capture, &Injector, &Capture, &Injector) {
(&self.rx_a, &self.tx_b, &self.rx_b, &self.tx_a)
}
}
#[derive(Debug)]
pub struct BridgeHandles {
pub rx_a: Capture,
pub tx_b: Injector,
pub rx_b: Capture,
pub tx_a: Injector,
}
impl std::fmt::Debug for Bridge {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Bridge")
.field("rx_a", &self.rx_a)
.field("rx_b", &self.rx_b)
.finish()
}
}
#[must_use]
pub struct BridgeBuilder {
interface_a: Option<String>,
interface_b: Option<String>,
profile: RingProfile,
promiscuous: bool,
qdisc_bypass: bool,
poll_timeout: Duration,
a_block_size: Option<usize>,
a_block_count: Option<usize>,
a_frame_size: Option<usize>,
a_block_timeout_ms: Option<u32>,
b_block_size: Option<usize>,
b_block_count: Option<usize>,
b_frame_size: Option<usize>,
b_block_timeout_ms: Option<u32>,
tx_a_frame_size: Option<usize>,
tx_b_frame_size: Option<usize>,
tx_a_frame_count: Option<usize>,
tx_b_frame_count: Option<usize>,
}
impl Default for BridgeBuilder {
fn default() -> Self {
Self {
interface_a: None,
interface_b: None,
profile: RingProfile::Default,
promiscuous: true,
qdisc_bypass: true,
poll_timeout: Duration::from_millis(100),
a_block_size: None,
a_block_count: None,
a_frame_size: None,
a_block_timeout_ms: None,
b_block_size: None,
b_block_count: None,
b_frame_size: None,
b_block_timeout_ms: None,
tx_a_frame_size: None,
tx_b_frame_size: None,
tx_a_frame_count: None,
tx_b_frame_count: None,
}
}
}
impl BridgeBuilder {
pub fn interface_a(mut self, name: &str) -> Self {
self.interface_a = Some(name.to_string());
self
}
pub fn interface_b(mut self, name: &str) -> Self {
self.interface_b = Some(name.to_string());
self
}
pub fn profile(mut self, profile: RingProfile) -> Self {
self.profile = profile;
self
}
pub fn promiscuous(mut self, enable: bool) -> Self {
self.promiscuous = enable;
self
}
pub fn qdisc_bypass(mut self, enable: bool) -> Self {
self.qdisc_bypass = enable;
self
}
pub fn poll_timeout(mut self, timeout: Duration) -> Self {
self.poll_timeout = timeout;
self
}
pub fn a_block_size(mut self, bytes: usize) -> Self {
self.a_block_size = Some(bytes);
self
}
pub fn a_block_count(mut self, n: usize) -> Self {
self.a_block_count = Some(n);
self
}
pub fn a_frame_size(mut self, bytes: usize) -> Self {
self.a_frame_size = Some(bytes);
self
}
pub fn a_block_timeout_ms(mut self, ms: u32) -> Self {
self.a_block_timeout_ms = Some(ms);
self
}
pub fn b_block_size(mut self, bytes: usize) -> Self {
self.b_block_size = Some(bytes);
self
}
pub fn b_block_count(mut self, n: usize) -> Self {
self.b_block_count = Some(n);
self
}
pub fn b_frame_size(mut self, bytes: usize) -> Self {
self.b_frame_size = Some(bytes);
self
}
pub fn b_block_timeout_ms(mut self, ms: u32) -> Self {
self.b_block_timeout_ms = Some(ms);
self
}
pub fn tx_a_frame_size(mut self, bytes: usize) -> Self {
self.tx_a_frame_size = Some(bytes);
self
}
pub fn tx_b_frame_size(mut self, bytes: usize) -> Self {
self.tx_b_frame_size = Some(bytes);
self
}
pub fn tx_a_frame_count(mut self, n: usize) -> Self {
self.tx_a_frame_count = Some(n);
self
}
pub fn tx_b_frame_count(mut self, n: usize) -> Self {
self.tx_b_frame_count = Some(n);
self
}
pub fn build(self) -> Result<Bridge, Error> {
let iface_a = self
.interface_a
.ok_or_else(|| Error::Config("interface_a is required".into()))?;
let iface_b = self
.interface_b
.ok_or_else(|| Error::Config("interface_b is required".into()))?;
let (bs, bc, fs, timeout) = self.profile.params();
let a_bs = self.a_block_size.unwrap_or(bs);
let a_bc = self.a_block_count.unwrap_or(bc);
let a_fs = self.a_frame_size.unwrap_or(fs);
let a_to = self.a_block_timeout_ms.unwrap_or(timeout);
let b_bs = self.b_block_size.unwrap_or(bs);
let b_bc = self.b_block_count.unwrap_or(bc);
let b_fs = self.b_frame_size.unwrap_or(fs);
let b_to = self.b_block_timeout_ms.unwrap_or(timeout);
let tx_b_fs = self.tx_b_frame_size.unwrap_or(b_fs);
let tx_a_fs = self.tx_a_frame_size.unwrap_or(a_fs);
let rx_a = CaptureBuilder::default()
.interface(&iface_a)
.block_size(a_bs)
.block_count(a_bc)
.frame_size(a_fs)
.block_timeout_ms(a_to)
.promiscuous(self.promiscuous)
.build()?;
let mut tx_b_builder = InjectorBuilder::default()
.interface(&iface_b)
.frame_size(tx_b_fs)
.qdisc_bypass(self.qdisc_bypass);
if let Some(n) = self.tx_b_frame_count {
tx_b_builder = tx_b_builder.frame_count(n);
}
let tx_b = tx_b_builder.build()?;
let rx_b = CaptureBuilder::default()
.interface(&iface_b)
.block_size(b_bs)
.block_count(b_bc)
.frame_size(b_fs)
.block_timeout_ms(b_to)
.promiscuous(self.promiscuous)
.build()?;
let mut tx_a_builder = InjectorBuilder::default()
.interface(&iface_a)
.frame_size(tx_a_fs)
.qdisc_bypass(self.qdisc_bypass);
if let Some(n) = self.tx_a_frame_count {
tx_a_builder = tx_a_builder.frame_count(n);
}
let tx_a = tx_a_builder.build()?;
Ok(Bridge {
rx_a,
tx_b,
rx_b,
tx_a,
poll_timeout: self.poll_timeout,
drops: DropCounters::default(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn builder_rejects_missing_a() {
let err = BridgeBuilder::default()
.interface_b("lo")
.build()
.unwrap_err();
assert!(matches!(err, Error::Config(_)));
}
#[test]
fn builder_rejects_missing_b() {
let err = BridgeBuilder::default()
.interface_a("lo")
.build()
.unwrap_err();
assert!(matches!(err, Error::Config(_)));
}
#[test]
fn builder_defaults() {
let b = BridgeBuilder::default();
assert!(b.promiscuous);
assert!(b.qdisc_bypass);
assert_eq!(b.profile, RingProfile::Default);
assert_eq!(b.poll_timeout, Duration::from_millis(100));
}
#[test]
fn builder_poll_timeout_setter() {
let b = BridgeBuilder::default().poll_timeout(Duration::from_millis(25));
assert_eq!(b.poll_timeout, Duration::from_millis(25));
}
#[test]
fn builder_per_direction_overrides_stored() {
let b = BridgeBuilder::default()
.a_block_size(1 << 20)
.a_block_count(8)
.a_frame_size(4096)
.a_block_timeout_ms(20)
.b_block_size(1 << 21)
.b_frame_size(8192)
.tx_a_frame_size(2048)
.tx_b_frame_size(65536)
.tx_b_frame_count(512);
assert_eq!(b.a_block_size, Some(1 << 20));
assert_eq!(b.a_block_count, Some(8));
assert_eq!(b.a_frame_size, Some(4096));
assert_eq!(b.a_block_timeout_ms, Some(20));
assert_eq!(b.b_block_size, Some(1 << 21));
assert_eq!(b.b_frame_size, Some(8192));
assert_eq!(b.tx_a_frame_size, Some(2048));
assert_eq!(b.tx_b_frame_size, Some(65536));
assert_eq!(b.tx_b_frame_count, Some(512));
}
#[test]
fn bridge_action_eq() {
assert_eq!(BridgeAction::Forward, BridgeAction::Forward);
assert_ne!(BridgeAction::Forward, BridgeAction::Drop);
}
#[test]
fn bridge_direction_eq() {
assert_ne!(BridgeDirection::AtoB, BridgeDirection::BtoA);
}
#[test]
fn bridge_stats_display() {
let stats = BridgeStats::default();
let s = stats.to_string();
assert!(s.contains("A→B"));
assert!(s.contains("B→A"));
assert!(s.contains("too_large=0"));
assert!(s.contains("ring_full=0"));
}
#[test]
fn bridge_stats_drop_counters_display() {
let stats = BridgeStats {
a_to_b_dropped_too_large: 7,
b_to_a_dropped_ring_full: 13,
..BridgeStats::default()
};
let s = stats.to_string();
assert!(s.contains("too_large=7"));
assert!(s.contains("ring_full=13"));
}
}