use bytes::{Bytes, BytesMut};
use dashmap::DashMap;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Instant;
use tokio::net::UdpSocket;
use super::route::{RoutingHeader, ROUTING_HEADER_SIZE};
#[derive(Debug, Clone)]
pub struct ProxyConfig {
pub local_id: u64,
pub bind_addr: SocketAddr,
pub max_packet_size: usize,
pub track_latency: bool,
}
impl Default for ProxyConfig {
fn default() -> Self {
Self {
local_id: 0,
bind_addr: SocketAddr::from(([0, 0, 0, 0], 0)),
max_packet_size: 65535,
track_latency: true,
}
}
}
impl ProxyConfig {
pub fn new(local_id: u64, bind_addr: SocketAddr) -> Self {
Self {
local_id,
bind_addr,
..Default::default()
}
}
}
#[derive(Debug, Default)]
pub struct HopStats {
pub packets_forwarded: AtomicU64,
pub packets_dropped: AtomicU64,
pub bytes_forwarded: AtomicU64,
total_latency_ns: AtomicU64,
latency_samples: AtomicU64,
}
impl HopStats {
pub fn new() -> Self {
Self::default()
}
#[inline]
pub fn record_forward(&self, bytes: u64, latency_ns: u64) {
self.packets_forwarded.fetch_add(1, Ordering::Relaxed);
self.bytes_forwarded.fetch_add(bytes, Ordering::Relaxed);
if latency_ns > 0 {
self.total_latency_ns
.fetch_add(latency_ns, Ordering::Relaxed);
self.latency_samples.fetch_add(1, Ordering::Relaxed);
}
}
#[inline]
pub fn record_drop(&self) {
self.packets_dropped.fetch_add(1, Ordering::Relaxed);
}
pub fn avg_latency_ns(&self) -> u64 {
let samples = self.latency_samples.load(Ordering::Relaxed);
if samples == 0 {
return 0;
}
self.total_latency_ns.load(Ordering::Relaxed) / samples
}
pub fn forwarded(&self) -> u64 {
self.packets_forwarded.load(Ordering::Relaxed)
}
pub fn dropped(&self) -> u64 {
self.packets_dropped.load(Ordering::Relaxed)
}
}
#[derive(Debug, Clone, Default)]
pub struct ProxyStats {
pub packets_received: u64,
pub packets_forwarded: u64,
pub packets_dropped: u64,
pub bytes_forwarded: u64,
pub avg_latency_ns: u64,
pub routes: usize,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ProxyError {
PacketTooSmall,
InvalidHeader,
TtlExpired,
NoRoute,
SendFailed(String),
}
impl std::fmt::Display for ProxyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::PacketTooSmall => write!(f, "packet too small"),
Self::InvalidHeader => write!(f, "invalid routing header"),
Self::TtlExpired => write!(f, "TTL expired"),
Self::NoRoute => write!(f, "no route to destination"),
Self::SendFailed(e) => write!(f, "send failed: {}", e),
}
}
}
impl std::error::Error for ProxyError {}
#[derive(Debug)]
pub enum ForwardResult {
Forwarded {
next_hop: SocketAddr,
packet: Bytes,
latency_ns: u64,
},
Local(Bytes),
Dropped(ProxyError),
}
pub struct NetProxy {
#[allow(dead_code)]
config: ProxyConfig,
socket: Arc<UdpSocket>,
next_hop: DashMap<u64, SocketAddr>,
local_id: u64,
hop_stats: DashMap<u64, HopStats>,
packets_received: AtomicU64,
packets_forwarded: AtomicU64,
packets_dropped: AtomicU64,
bytes_forwarded: AtomicU64,
total_latency_ns: AtomicU64,
latency_samples: AtomicU64,
}
impl NetProxy {
pub async fn new(config: ProxyConfig) -> std::io::Result<Self> {
let socket = UdpSocket::bind(config.bind_addr).await?;
let local_id = config.local_id;
Ok(Self {
config,
socket: Arc::new(socket),
next_hop: DashMap::new(),
local_id,
hop_stats: DashMap::new(),
packets_received: AtomicU64::new(0),
packets_forwarded: AtomicU64::new(0),
packets_dropped: AtomicU64::new(0),
bytes_forwarded: AtomicU64::new(0),
total_latency_ns: AtomicU64::new(0),
latency_samples: AtomicU64::new(0),
})
}
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
self.socket.local_addr()
}
pub fn add_route(&self, dest_id: u64, next_hop: SocketAddr) {
self.next_hop.insert(dest_id, next_hop);
}
pub fn remove_route(&self, dest_id: u64) {
self.next_hop.remove(&dest_id);
self.hop_stats.remove(&dest_id);
}
pub fn lookup(&self, dest_id: u64) -> Option<SocketAddr> {
self.next_hop.get(&dest_id).map(|r| *r)
}
pub fn forward(&self, data: Bytes) -> ForwardResult {
let start = Instant::now();
let len = data.len() as u64;
self.packets_received.fetch_add(1, Ordering::Relaxed);
if data.len() < ROUTING_HEADER_SIZE {
self.packets_dropped.fetch_add(1, Ordering::Relaxed);
return ForwardResult::Dropped(ProxyError::PacketTooSmall);
}
let header = match RoutingHeader::from_bytes(&data[..ROUTING_HEADER_SIZE]) {
Some(h) => h,
None => {
self.packets_dropped.fetch_add(1, Ordering::Relaxed);
return ForwardResult::Dropped(ProxyError::InvalidHeader);
}
};
if header.dest_id == self.local_id {
return ForwardResult::Local(data.slice(ROUTING_HEADER_SIZE..));
}
if header.is_expired() {
self.packets_dropped.fetch_add(1, Ordering::Relaxed);
self.record_hop_drop(header.dest_id);
return ForwardResult::Dropped(ProxyError::TtlExpired);
}
let next_hop = match self.lookup(header.dest_id) {
Some(addr) => addr,
None => {
self.packets_dropped.fetch_add(1, Ordering::Relaxed);
self.record_hop_drop(header.dest_id);
return ForwardResult::Dropped(ProxyError::NoRoute);
}
};
let mut new_header = header;
new_header.forward();
if new_header.is_expired() {
self.packets_dropped.fetch_add(1, Ordering::Relaxed);
self.record_hop_drop(header.dest_id);
return ForwardResult::Dropped(ProxyError::TtlExpired);
}
let mut fwd_data = BytesMut::with_capacity(data.len());
new_header.write_to(&mut fwd_data);
fwd_data.extend_from_slice(&data[ROUTING_HEADER_SIZE..]);
let latency_ns = start.elapsed().as_nanos() as u64;
self.packets_forwarded.fetch_add(1, Ordering::Relaxed);
self.bytes_forwarded.fetch_add(len, Ordering::Relaxed);
self.total_latency_ns
.fetch_add(latency_ns, Ordering::Relaxed);
self.latency_samples.fetch_add(1, Ordering::Relaxed);
self.record_hop_forward(header.dest_id, len, latency_ns);
ForwardResult::Forwarded {
next_hop,
packet: fwd_data.freeze(),
latency_ns,
}
}
pub async fn forward_and_send(&self, data: Bytes) -> Result<ForwardResult, ProxyError> {
match self.forward(data) {
ForwardResult::Forwarded {
next_hop,
ref packet,
latency_ns,
} => {
let packet_len = packet.len() as u64;
match self.socket.send_to(packet, next_hop).await {
Ok(_) => Ok(ForwardResult::Forwarded {
next_hop,
packet: packet.clone(),
latency_ns,
}),
Err(e) => {
self.packets_forwarded
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| {
Some(v.saturating_sub(1))
})
.ok();
self.bytes_forwarded
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| {
Some(v.saturating_sub(packet_len))
})
.ok();
self.total_latency_ns
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| {
Some(v.saturating_sub(latency_ns))
})
.ok();
self.latency_samples
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| {
Some(v.saturating_sub(1))
})
.ok();
Err(ProxyError::SendFailed(e.to_string()))
}
}
}
ForwardResult::Local(payload) => Ok(ForwardResult::Local(payload)),
ForwardResult::Dropped(e) => Err(e),
}
}
pub async fn send_to(&self, data: &[u8], dest: SocketAddr) -> std::io::Result<usize> {
self.socket.send_to(data, dest).await
}
pub async fn recv_from(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> {
self.socket.recv_from(buf).await
}
pub fn stats(&self) -> ProxyStats {
let samples = self.latency_samples.load(Ordering::Relaxed);
let avg_latency = self
.total_latency_ns
.load(Ordering::Relaxed)
.checked_div(samples)
.unwrap_or(0);
ProxyStats {
packets_received: self.packets_received.load(Ordering::Relaxed),
packets_forwarded: self.packets_forwarded.load(Ordering::Relaxed),
packets_dropped: self.packets_dropped.load(Ordering::Relaxed),
bytes_forwarded: self.bytes_forwarded.load(Ordering::Relaxed),
avg_latency_ns: avg_latency,
routes: self.next_hop.len(),
}
}
pub fn reset_stats(&self) {
self.packets_received.store(0, Ordering::Relaxed);
self.packets_forwarded.store(0, Ordering::Relaxed);
self.packets_dropped.store(0, Ordering::Relaxed);
self.bytes_forwarded.store(0, Ordering::Relaxed);
self.total_latency_ns.store(0, Ordering::Relaxed);
self.latency_samples.store(0, Ordering::Relaxed);
}
pub fn hop_stats(&self, dest_id: u64) -> Option<(u64, u64, u64)> {
self.hop_stats
.get(&dest_id)
.map(|s| (s.forwarded(), s.dropped(), s.avg_latency_ns()))
}
fn record_hop_forward(&self, dest_id: u64, bytes: u64, latency_ns: u64) {
self.hop_stats
.entry(dest_id)
.or_default()
.record_forward(bytes, latency_ns);
}
fn record_hop_drop(&self, dest_id: u64) {
self.hop_stats.entry(dest_id).or_default().record_drop();
}
pub fn route_count(&self) -> usize {
self.next_hop.len()
}
}
impl std::fmt::Debug for NetProxy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NetProxy")
.field("local_id", &format!("{:016x}", self.local_id))
.field("routes", &self.next_hop.len())
.field(
"packets_forwarded",
&self.packets_forwarded.load(Ordering::Relaxed),
)
.finish()
}
}
pub struct MultiHopPacketBuilder {
src_id: u32,
}
impl MultiHopPacketBuilder {
pub fn new(src_id: u32) -> Self {
Self { src_id }
}
pub fn build(&self, dest_id: u64, ttl: u8, payload: &[u8]) -> Bytes {
let mut buf = BytesMut::with_capacity(ROUTING_HEADER_SIZE + payload.len());
let header = RoutingHeader::new(dest_id, self.src_id, ttl);
header.write_to(&mut buf);
buf.extend_from_slice(payload);
buf.freeze()
}
pub fn build_priority(&self, dest_id: u64, ttl: u8, payload: &[u8]) -> Bytes {
let mut buf = BytesMut::with_capacity(ROUTING_HEADER_SIZE + payload.len());
let header = RoutingHeader::priority(dest_id, self.src_id, ttl);
header.write_to(&mut buf);
buf.extend_from_slice(payload);
buf.freeze()
}
pub fn build_control(&self, dest_id: u64, ttl: u8, payload: &[u8]) -> Bytes {
let mut buf = BytesMut::with_capacity(ROUTING_HEADER_SIZE + payload.len());
let header = RoutingHeader::control(dest_id, self.src_id, ttl);
header.write_to(&mut buf);
buf.extend_from_slice(payload);
buf.freeze()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_forward_result() {
let builder = MultiHopPacketBuilder::new(0xABCD);
let packet = builder.build(0x1234, 8, b"hello world");
assert_eq!(packet.len(), ROUTING_HEADER_SIZE + 11);
let header = RoutingHeader::from_bytes(&packet[..ROUTING_HEADER_SIZE]).unwrap();
assert_eq!(header.dest_id, 0x1234);
assert_eq!(header.src_id, 0xABCD);
assert_eq!(header.ttl, 8);
assert_eq!(header.hop_count, 0);
}
#[test]
fn test_priority_packet() {
let builder = MultiHopPacketBuilder::new(0x1111);
let packet = builder.build_priority(0x2222, 4, b"urgent");
let header = RoutingHeader::from_bytes(&packet[..ROUTING_HEADER_SIZE]).unwrap();
assert!(header.flags.is_priority());
}
#[test]
fn test_control_packet() {
let builder = MultiHopPacketBuilder::new(0x1111);
let packet = builder.build_control(0x2222, 2, b"ping");
let header = RoutingHeader::from_bytes(&packet[..ROUTING_HEADER_SIZE]).unwrap();
assert!(header.flags.is_control());
}
#[tokio::test]
async fn test_proxy_creation() {
let config = ProxyConfig::new(0x1234, "127.0.0.1:0".parse().unwrap());
let proxy = NetProxy::new(config).await.unwrap();
assert_eq!(proxy.route_count(), 0);
assert_eq!(proxy.stats().packets_received, 0);
}
#[tokio::test]
async fn test_proxy_routing() {
let config = ProxyConfig::new(0x1234, "127.0.0.1:0".parse().unwrap());
let proxy = NetProxy::new(config).await.unwrap();
let dest: SocketAddr = "127.0.0.1:9001".parse().unwrap();
proxy.add_route(0x5678, dest);
assert_eq!(proxy.lookup(0x5678), Some(dest));
assert_eq!(proxy.lookup(0x9999), None);
}
#[tokio::test]
async fn test_proxy_forward() {
let config = ProxyConfig::new(0x1234, "127.0.0.1:0".parse().unwrap());
let proxy = NetProxy::new(config).await.unwrap();
let next_hop: SocketAddr = "127.0.0.1:9001".parse().unwrap();
proxy.add_route(0x5678, next_hop);
let builder = MultiHopPacketBuilder::new(0xABCD);
let packet = builder.build(0x5678, 8, b"test payload");
match proxy.forward(packet) {
ForwardResult::Forwarded { next_hop: addr, .. } => {
assert_eq!(addr, next_hop);
}
_ => panic!("expected forwarded"),
}
let stats = proxy.stats();
assert_eq!(stats.packets_received, 1);
assert_eq!(stats.packets_forwarded, 1);
assert_eq!(stats.packets_dropped, 0);
}
#[tokio::test]
async fn test_proxy_local_delivery() {
let config = ProxyConfig::new(0x1234, "127.0.0.1:0".parse().unwrap());
let proxy = NetProxy::new(config).await.unwrap();
let builder = MultiHopPacketBuilder::new(0xABCD);
let packet = builder.build(0x1234, 8, b"local payload");
match proxy.forward(packet) {
ForwardResult::Local(payload) => {
assert_eq!(&payload[..], b"local payload");
}
_ => panic!("expected local delivery"),
}
}
#[tokio::test]
async fn test_proxy_ttl_expired() {
let config = ProxyConfig::new(0x1234, "127.0.0.1:0".parse().unwrap());
let proxy = NetProxy::new(config).await.unwrap();
proxy.add_route(0x5678, "127.0.0.1:9001".parse().unwrap());
let builder = MultiHopPacketBuilder::new(0xABCD);
let packet = builder.build(0x5678, 0, b"expired");
match proxy.forward(packet) {
ForwardResult::Dropped(ProxyError::TtlExpired) => {}
_ => panic!("expected TTL expired"),
}
assert_eq!(proxy.stats().packets_dropped, 1);
}
#[tokio::test]
async fn test_proxy_no_route() {
let config = ProxyConfig::new(0x1234, "127.0.0.1:0".parse().unwrap());
let proxy = NetProxy::new(config).await.unwrap();
let builder = MultiHopPacketBuilder::new(0xABCD);
let packet = builder.build(0x9999, 8, b"no route");
match proxy.forward(packet) {
ForwardResult::Dropped(ProxyError::NoRoute) => {}
_ => panic!("expected no route"),
}
}
#[tokio::test]
async fn remove_route_also_drops_hop_stats() {
let config = ProxyConfig::new(0x1234, "127.0.0.1:0".parse().unwrap());
let proxy = NetProxy::new(config).await.unwrap();
let next_hop: SocketAddr = "127.0.0.1:9001".parse().unwrap();
proxy.add_route(0x5678, next_hop);
proxy.record_hop_forward(0x5678, 100, 1000);
proxy.record_hop_drop(0x5678);
assert!(
proxy.hop_stats(0x5678).is_some(),
"hop_stats must be present after recording activity"
);
proxy.remove_route(0x5678);
assert!(
proxy.hop_stats(0x5678).is_none(),
"hop_stats entry must be dropped along with the route — \
pre-fix this leaked memory linearly with churned destinations"
);
}
#[test]
fn test_hop_stats() {
let stats = HopStats::new();
stats.record_forward(100, 1000);
stats.record_forward(200, 2000);
stats.record_drop();
assert_eq!(stats.forwarded(), 2);
assert_eq!(stats.dropped(), 1);
assert_eq!(stats.avg_latency_ns(), 1500);
}
#[tokio::test]
async fn test_forward_returns_packet_data() {
let config = ProxyConfig::new(0x1234, "127.0.0.1:0".parse().unwrap());
let proxy = NetProxy::new(config).await.unwrap();
let next_hop: SocketAddr = "127.0.0.1:9001".parse().unwrap();
proxy.add_route(0x5678, next_hop);
let builder = MultiHopPacketBuilder::new(0xABCD);
let packet = builder.build(0x5678, 4, b"payload");
match proxy.forward(packet) {
ForwardResult::Forwarded {
next_hop: addr,
packet: fwd_packet,
..
} => {
assert_eq!(addr, next_hop);
let header = RoutingHeader::from_bytes(&fwd_packet[..ROUTING_HEADER_SIZE]).unwrap();
assert_eq!(header.ttl, 3, "TTL should be decremented");
assert_eq!(header.hop_count, 1, "hop_count should be incremented");
assert_eq!(&fwd_packet[ROUTING_HEADER_SIZE..], b"payload");
}
_ => panic!("expected forwarded"),
}
}
}