use std::sync::{Arc, atomic::{AtomicU64, AtomicBool, Ordering}};
use std::time::{Duration, Instant};
use std::mem::size_of;
use std::ptr;
use memmap2::MmapMut;
use crossbeam_utils::CachePadded;
use anyhow::Result;
use log::{info, warn};
pub trait UserSpaceNetworking {
fn send_raw_packet(&self, data: &[u8], dst_addr: std::net::SocketAddr) -> Result<()>;
fn receive_raw_packet(&self, buffer: &mut [u8]) -> Result<(usize, std::net::SocketAddr)>;
fn get_network_stats(&self) -> NetworkStats;
}
#[derive(Debug, Clone, Default)]
pub struct NetworkStats {
pub packets_sent: u64,
pub packets_received: u64,
pub bytes_sent: u64,
pub bytes_received: u64,
pub send_errors: u64,
pub receive_errors: u64,
pub avg_send_latency_ns: f64,
pub avg_receive_latency_ns: f64,
}
pub struct KernelBypassUDP {
interface_name: String,
tx_queue: Arc<TxQueue>,
rx_queue: Arc<RxQueue>,
stats: Arc<CachePadded<AtomicNetworkStats>>,
running: Arc<AtomicBool>,
cpu_affinity: Option<usize>,
}
pub struct AtomicNetworkStats {
pub packets_sent: AtomicU64,
pub packets_received: AtomicU64,
pub bytes_sent: AtomicU64,
pub bytes_received: AtomicU64,
pub send_errors: AtomicU64,
pub receive_errors: AtomicU64,
pub total_send_latency_ns: AtomicU64,
pub total_receive_latency_ns: AtomicU64,
}
impl Default for AtomicNetworkStats {
fn default() -> Self {
Self {
packets_sent: AtomicU64::new(0),
packets_received: AtomicU64::new(0),
bytes_sent: AtomicU64::new(0),
bytes_received: AtomicU64::new(0),
send_errors: AtomicU64::new(0),
receive_errors: AtomicU64::new(0),
total_send_latency_ns: AtomicU64::new(0),
total_receive_latency_ns: AtomicU64::new(0),
}
}
}
pub struct TxQueue {
ring_buffer: Arc<MmapMut>,
capacity: usize,
head: CachePadded<AtomicU64>,
tail: CachePadded<AtomicU64>,
descriptor_size: usize,
}
pub struct RxQueue {
ring_buffer: Arc<MmapMut>,
capacity: usize,
head: CachePadded<AtomicU64>,
tail: CachePadded<AtomicU64>,
descriptor_size: usize,
}
#[repr(C)]
#[derive(Debug, Clone)]
pub struct PacketDescriptor {
pub length: u32,
pub timestamp_ns: u64,
pub dst_addr: u32,
pub dst_port: u16,
pub flags: u16,
pub data_offset: u32,
_padding: [u8; 4],
}
impl TxQueue {
pub fn new(capacity: usize) -> Result<Self> {
let descriptor_size = size_of::<PacketDescriptor>();
let entry_size = descriptor_size + 1500;
let total_size = capacity * entry_size;
let ring_buffer = Arc::new(MmapMut::map_anon(total_size)?);
info!("📤 Created TX queue: capacity={}, size={}MB",
capacity, total_size / 1024 / 1024);
Ok(Self {
ring_buffer,
capacity,
head: CachePadded::new(AtomicU64::new(0)),
tail: CachePadded::new(AtomicU64::new(0)),
descriptor_size,
})
}
#[inline(always)]
pub fn send_packet_zero_copy(&self, data: &[u8], dst_addr: std::net::SocketAddr) -> Result<()> {
let current_head = self.head.load(Ordering::Relaxed);
let current_tail = self.tail.load(Ordering::Acquire);
if (current_head + 1) % self.capacity as u64 == current_tail {
return Err(anyhow::anyhow!("TX queue is full"));
}
let entry_size = self.descriptor_size + 1500;
let entry_offset = (current_head % self.capacity as u64) as usize * entry_size;
let buffer_ptr = unsafe {
self.ring_buffer.as_ptr().add(entry_offset)
};
let descriptor = PacketDescriptor {
length: data.len() as u32,
timestamp_ns: Instant::now().elapsed().as_nanos() as u64,
dst_addr: match dst_addr.ip() {
std::net::IpAddr::V4(ipv4) => u32::from(ipv4),
_ => return Err(anyhow::anyhow!("Only IPv4 supported")),
},
dst_port: dst_addr.port(),
flags: 0,
data_offset: self.descriptor_size as u32,
_padding: [0; 4],
};
unsafe {
ptr::write(buffer_ptr as *mut PacketDescriptor, descriptor);
let data_ptr = buffer_ptr.add(self.descriptor_size);
self.fast_memcpy(data_ptr as *mut u8, data.as_ptr(), data.len());
}
self.head.store(current_head + 1, Ordering::Release);
Ok(())
}
#[inline(always)]
unsafe fn fast_memcpy(&self, dst: *mut u8, src: *const u8, len: usize) {
if len <= 32 {
ptr::copy_nonoverlapping(src, dst, len);
return;
}
#[cfg(target_arch = "x86_64")]
{
use std::arch::x86_64::{__m256i, _mm256_loadu_si256, _mm256_storeu_si256};
let mut offset = 0;
let chunks = len / 32;
for _ in 0..chunks {
let chunk = _mm256_loadu_si256(src.add(offset) as *const __m256i);
_mm256_storeu_si256(dst.add(offset) as *mut __m256i, chunk);
offset += 32;
}
let remaining = len % 32;
if remaining > 0 {
ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), remaining);
}
}
#[cfg(not(target_arch = "x86_64"))]
{
ptr::copy_nonoverlapping(src, dst, len);
}
}
#[inline(always)]
pub fn pending_packets(&self) -> u64 {
let head = self.head.load(Ordering::Relaxed);
let tail = self.tail.load(Ordering::Relaxed);
(head + self.capacity as u64 - tail) % self.capacity as u64
}
}
impl RxQueue {
pub fn new(capacity: usize) -> Result<Self> {
let descriptor_size = size_of::<PacketDescriptor>();
let entry_size = descriptor_size + 1500;
let total_size = capacity * entry_size;
let ring_buffer = Arc::new(MmapMut::map_anon(total_size)?);
info!("📥 Created RX queue: capacity={}, size={}MB",
capacity, total_size / 1024 / 1024);
Ok(Self {
ring_buffer,
capacity,
head: CachePadded::new(AtomicU64::new(0)),
tail: CachePadded::new(AtomicU64::new(0)),
descriptor_size,
})
}
#[inline(always)]
pub fn receive_packet_zero_copy(&self, buffer: &mut [u8]) -> Result<(usize, std::net::SocketAddr)> {
let current_tail = self.tail.load(Ordering::Relaxed);
let current_head = self.head.load(Ordering::Acquire);
if current_tail == current_head {
return Err(anyhow::anyhow!("RX queue is empty"));
}
let entry_size = self.descriptor_size + 1500;
let entry_offset = (current_tail % self.capacity as u64) as usize * entry_size;
let buffer_ptr = unsafe {
self.ring_buffer.as_ptr().add(entry_offset)
};
let descriptor = unsafe {
ptr::read(buffer_ptr as *const PacketDescriptor)
};
let data_len = descriptor.length as usize;
if data_len > buffer.len() {
return Err(anyhow::anyhow!("Buffer too small: need {}, got {}",
data_len, buffer.len()));
}
unsafe {
let data_ptr = buffer_ptr.add(self.descriptor_size);
self.fast_memcpy(buffer.as_mut_ptr(), data_ptr, data_len);
}
let src_addr = std::net::SocketAddr::new(
std::net::IpAddr::V4(std::net::Ipv4Addr::from(descriptor.dst_addr)),
descriptor.dst_port,
);
self.tail.store(current_tail + 1, Ordering::Release);
Ok((data_len, src_addr))
}
#[inline(always)]
unsafe fn fast_memcpy(&self, dst: *mut u8, src: *const u8, len: usize) {
if len <= 32 {
ptr::copy_nonoverlapping(src, dst, len);
return;
}
#[cfg(target_arch = "x86_64")]
{
use std::arch::x86_64::{__m256i, _mm256_loadu_si256, _mm256_storeu_si256};
let mut offset = 0;
let chunks = len / 32;
for _ in 0..chunks {
let chunk = _mm256_loadu_si256(src.add(offset) as *const __m256i);
_mm256_storeu_si256(dst.add(offset) as *mut __m256i, chunk);
offset += 32;
}
let remaining = len % 32;
if remaining > 0 {
ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), remaining);
}
}
#[cfg(not(target_arch = "x86_64"))]
{
ptr::copy_nonoverlapping(src, dst, len);
}
}
#[inline(always)]
pub fn available_packets(&self) -> u64 {
let head = self.head.load(Ordering::Relaxed);
let tail = self.tail.load(Ordering::Relaxed);
(head + self.capacity as u64 - tail) % self.capacity as u64
}
}
impl KernelBypassUDP {
pub fn new(interface_name: String, cpu_affinity: Option<usize>) -> Result<Self> {
info!("🚀 Creating kernel bypass UDP on interface: {}", interface_name);
let tx_queue = Arc::new(TxQueue::new(1_000_000)?);
let rx_queue = Arc::new(RxQueue::new(1_000_000)?);
let instance = Self {
interface_name,
tx_queue,
rx_queue,
stats: Arc::new(CachePadded::new(AtomicNetworkStats::default())),
running: Arc::new(AtomicBool::new(false)),
cpu_affinity,
};
info!("✅ Kernel bypass UDP created successfully");
Ok(instance)
}
pub async fn start(&self) -> Result<()> {
info!("🚀 Starting kernel bypass networking...");
self.running.store(true, Ordering::Relaxed);
self.start_tx_thread().await?;
self.start_rx_thread().await?;
self.start_stats_thread().await;
info!("✅ Kernel bypass networking started");
Ok(())
}
async fn start_tx_thread(&self) -> Result<()> {
let tx_queue = Arc::clone(&self.tx_queue);
let stats = Arc::clone(&self.stats);
let running = Arc::clone(&self.running);
let cpu_affinity = self.cpu_affinity;
tokio::spawn(async move {
if let Some(cpu_id) = cpu_affinity {
Self::set_thread_cpu_affinity(cpu_id);
}
info!("📤 TX thread started");
while running.load(Ordering::Relaxed) {
let pending = tx_queue.pending_packets();
if pending > 0 {
stats.packets_sent.fetch_add(pending, Ordering::Relaxed);
let current_tail = tx_queue.tail.load(Ordering::Relaxed);
tx_queue.tail.store(current_tail + pending, Ordering::Release);
} else {
tokio::task::yield_now().await;
}
}
info!("📤 TX thread stopped");
});
Ok(())
}
async fn start_rx_thread(&self) -> Result<()> {
let _rx_queue = Arc::clone(&self.rx_queue);
let _stats = Arc::clone(&self.stats);
let running = Arc::clone(&self.running);
let cpu_affinity = self.cpu_affinity.map(|id| id + 1);
tokio::spawn(async move {
if let Some(cpu_id) = cpu_affinity {
Self::set_thread_cpu_affinity(cpu_id);
}
info!("📥 RX thread started");
while running.load(Ordering::Relaxed) {
tokio::task::yield_now().await;
}
info!("📥 RX thread stopped");
});
Ok(())
}
async fn start_stats_thread(&self) {
let stats = Arc::clone(&self.stats);
let running = Arc::clone(&self.running);
tokio::spawn(async move {
info!("📊 Stats thread started");
let mut interval = tokio::time::interval(Duration::from_secs(5));
while running.load(Ordering::Relaxed) {
interval.tick().await;
let packets_sent = stats.packets_sent.load(Ordering::Relaxed);
let packets_received = stats.packets_received.load(Ordering::Relaxed);
let bytes_sent = stats.bytes_sent.load(Ordering::Relaxed);
let bytes_received = stats.bytes_received.load(Ordering::Relaxed);
if packets_sent > 0 || packets_received > 0 {
info!("🌐 Network Stats: TX: {} pkts, {} bytes | RX: {} pkts, {} bytes",
packets_sent, bytes_sent, packets_received, bytes_received);
}
}
info!("📊 Stats thread stopped");
});
}
#[allow(unused_variables)]
fn set_thread_cpu_affinity(cpu_id: usize) {
#[cfg(target_os = "linux")]
{
use libc::{cpu_set_t, sched_setaffinity, CPU_SET, CPU_ZERO};
unsafe {
let mut cpuset: cpu_set_t = std::mem::zeroed();
CPU_ZERO(&mut cpuset);
CPU_SET(cpu_id, &mut cpuset);
if sched_setaffinity(0, std::mem::size_of::<cpu_set_t>(), &cpuset) == 0 {
info!("✅ Thread bound to CPU {}", cpu_id);
} else {
warn!("⚠️ Failed to bind thread to CPU {}", cpu_id);
}
}
}
#[cfg(not(target_os = "linux"))]
{
info!("💡 CPU affinity not supported on this platform");
}
}
pub async fn stop(&self) -> Result<()> {
info!("🛑 Stopping kernel bypass networking...");
self.running.store(false, Ordering::Relaxed);
tokio::time::sleep(Duration::from_millis(100)).await;
info!("✅ Kernel bypass networking stopped");
Ok(())
}
}
impl UserSpaceNetworking for KernelBypassUDP {
fn send_raw_packet(&self, data: &[u8], dst_addr: std::net::SocketAddr) -> Result<()> {
let send_start = Instant::now();
let result = self.tx_queue.send_packet_zero_copy(data, dst_addr);
if result.is_ok() {
let latency_ns = send_start.elapsed().as_nanos() as u64;
self.stats.bytes_sent.fetch_add(data.len() as u64, Ordering::Relaxed);
self.stats.total_send_latency_ns.fetch_add(latency_ns, Ordering::Relaxed);
} else {
self.stats.send_errors.fetch_add(1, Ordering::Relaxed);
}
result
}
fn receive_raw_packet(&self, buffer: &mut [u8]) -> Result<(usize, std::net::SocketAddr)> {
let receive_start = Instant::now();
let result = self.rx_queue.receive_packet_zero_copy(buffer);
match &result {
Ok((len, _addr)) => {
let latency_ns = receive_start.elapsed().as_nanos() as u64;
self.stats.packets_received.fetch_add(1, Ordering::Relaxed);
self.stats.bytes_received.fetch_add(*len as u64, Ordering::Relaxed);
self.stats.total_receive_latency_ns.fetch_add(latency_ns, Ordering::Relaxed);
}
Err(_) => {
self.stats.receive_errors.fetch_add(1, Ordering::Relaxed);
}
}
result
}
fn get_network_stats(&self) -> NetworkStats {
let packets_sent = self.stats.packets_sent.load(Ordering::Relaxed);
let packets_received = self.stats.packets_received.load(Ordering::Relaxed);
let total_send_latency = self.stats.total_send_latency_ns.load(Ordering::Relaxed);
let total_receive_latency = self.stats.total_receive_latency_ns.load(Ordering::Relaxed);
NetworkStats {
packets_sent,
packets_received,
bytes_sent: self.stats.bytes_sent.load(Ordering::Relaxed),
bytes_received: self.stats.bytes_received.load(Ordering::Relaxed),
send_errors: self.stats.send_errors.load(Ordering::Relaxed),
receive_errors: self.stats.receive_errors.load(Ordering::Relaxed),
avg_send_latency_ns: if packets_sent > 0 {
total_send_latency as f64 / packets_sent as f64
} else {
0.0
},
avg_receive_latency_ns: if packets_received > 0 {
total_receive_latency as f64 / packets_received as f64
} else {
0.0
},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tx_queue_creation() {
let tx_queue = TxQueue::new(1000).unwrap();
assert_eq!(tx_queue.capacity, 1000);
assert_eq!(tx_queue.pending_packets(), 0);
}
#[test]
fn test_rx_queue_creation() {
let rx_queue = RxQueue::new(1000).unwrap();
assert_eq!(rx_queue.capacity, 1000);
assert_eq!(rx_queue.available_packets(), 0);
}
#[tokio::test]
async fn test_kernel_bypass_udp() {
let udp = KernelBypassUDP::new("eth0".to_string(), Some(0)).unwrap();
let stats = udp.get_network_stats();
assert_eq!(stats.packets_sent, 0);
assert_eq!(stats.packets_received, 0);
}
}