use std::collections::HashMap;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::atomic::{AtomicU16, Ordering};
use std::time::{Duration, Instant};
use dashmap::DashMap;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use super::packet::{FlowTuple, IpPacket, IpPacketMut, TransportProtocol};
use crate::error::{Error, Result};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NatConfig {
#[serde(default = "default_tunnel_ipv4")]
pub tunnel_ipv4: Ipv4Addr,
#[serde(default)]
pub tunnel_ipv6: Option<Ipv6Addr>,
#[serde(default = "default_server_ipv4")]
pub server_ipv4: Ipv4Addr,
#[serde(default)]
pub server_ipv6: Option<Ipv6Addr>,
#[serde(default = "default_udp_timeout", with = "humantime_serde")]
pub udp_timeout: Duration,
#[serde(default = "default_tcp_timeout", with = "humantime_serde")]
pub tcp_timeout: Duration,
#[serde(default = "default_icmp_timeout", with = "humantime_serde")]
pub icmp_timeout: Duration,
#[serde(default = "default_port_start")]
pub port_range_start: u16,
#[serde(default = "default_port_end")]
pub port_range_end: u16,
#[serde(default)]
pub hairpin_nat: bool,
}
fn default_tunnel_ipv4() -> Ipv4Addr {
"10.0.85.1".parse().unwrap()
}
fn default_server_ipv4() -> Ipv4Addr {
"10.0.85.2".parse().unwrap()
}
fn default_udp_timeout() -> Duration {
Duration::from_secs(180)
}
fn default_tcp_timeout() -> Duration {
Duration::from_secs(7200)
}
fn default_icmp_timeout() -> Duration {
Duration::from_secs(60)
}
fn default_port_start() -> u16 {
32768
}
fn default_port_end() -> u16 {
61000
}
impl Default for NatConfig {
fn default() -> Self {
Self {
tunnel_ipv4: default_tunnel_ipv4(),
tunnel_ipv6: None,
server_ipv4: default_server_ipv4(),
server_ipv6: None,
udp_timeout: default_udp_timeout(),
tcp_timeout: default_tcp_timeout(),
icmp_timeout: default_icmp_timeout(),
port_range_start: default_port_start(),
port_range_end: default_port_end(),
hairpin_nat: false,
}
}
}
#[derive(Debug, Clone)]
pub struct NatEntry {
pub original_src: IpAddr,
pub original_src_port: u16,
pub translated_src: IpAddr,
pub translated_src_port: u16,
pub dst_addr: IpAddr,
pub dst_port: u16,
pub protocol: TransportProtocol,
pub created_at: Instant,
pub last_used: Instant,
pub packet_count: u64,
pub byte_count: u64,
}
impl NatEntry {
fn new(
original_src: IpAddr,
original_src_port: u16,
translated_src: IpAddr,
translated_src_port: u16,
dst_addr: IpAddr,
dst_port: u16,
protocol: TransportProtocol,
) -> Self {
let now = Instant::now();
Self {
original_src,
original_src_port,
translated_src,
translated_src_port,
dst_addr,
dst_port,
protocol,
created_at: now,
last_used: now,
packet_count: 0,
byte_count: 0,
}
}
fn touch(&mut self, bytes: usize) {
self.last_used = Instant::now();
self.packet_count += 1;
self.byte_count += bytes as u64;
}
fn is_expired(&self, timeout: Duration) -> bool {
self.last_used.elapsed() > timeout
}
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
struct OutboundKey {
src_addr: IpAddr,
src_port: u16,
dst_addr: IpAddr,
dst_port: u16,
protocol: u8,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
struct InboundKey {
translated_port: u16,
src_addr: IpAddr,
src_port: u16,
protocol: u8,
}
pub struct NatTable {
config: NatConfig,
outbound: DashMap<OutboundKey, NatEntry>,
inbound: DashMap<InboundKey, OutboundKey>,
next_port: AtomicU16,
stats: RwLock<NatStats>,
}
#[derive(Debug, Clone, Default)]
pub struct NatStats {
pub active_entries: usize,
pub total_translations: u64,
pub packets_outbound: u64,
pub packets_inbound: u64,
pub bytes_outbound: u64,
pub bytes_inbound: u64,
}
impl NatTable {
pub fn new(config: NatConfig) -> Self {
Self {
next_port: AtomicU16::new(config.port_range_start),
config,
outbound: DashMap::new(),
inbound: DashMap::new(),
stats: RwLock::new(NatStats::default()),
}
}
pub fn config(&self) -> &NatConfig {
&self.config
}
pub fn stats(&self) -> NatStats {
let mut stats = self.stats.read().clone();
stats.active_entries = self.outbound.len();
stats
}
pub fn translate_outbound(&self, packet: &mut [u8]) -> Result<Option<NatEntry>> {
let (flow, header_len, protocol) = {
let parsed = IpPacket::parse(packet)?;
(parsed.flow_tuple(), parsed.header_len, parsed.protocol)
};
if !self.is_local_address(flow.src_addr) {
return Ok(None);
}
let key = OutboundKey {
src_addr: flow.src_addr,
src_port: flow.src_port,
dst_addr: flow.dst_addr,
dst_port: flow.dst_port,
protocol: flow.protocol.protocol_number(),
};
let entry = if let Some(mut existing) = self.outbound.get_mut(&key) {
existing.touch(packet.len());
existing.clone()
} else {
let translated_port = self.allocate_port()?;
let translated_addr = self.get_tunnel_address(flow.src_addr);
let entry = NatEntry::new(
flow.src_addr,
flow.src_port,
translated_addr,
translated_port,
flow.dst_addr,
flow.dst_port,
flow.protocol,
);
let inbound_key = InboundKey {
translated_port,
src_addr: flow.dst_addr,
src_port: flow.dst_port,
protocol: flow.protocol.protocol_number(),
};
self.outbound.insert(key.clone(), entry.clone());
self.inbound.insert(inbound_key, key);
tracing::debug!(
original = %flow.src_addr,
original_port = flow.src_port,
translated = %entry.translated_src,
translated_port = entry.translated_src_port,
dst = %flow.dst_addr,
dst_port = flow.dst_port,
"Created NAT entry"
);
entry
};
let mut pkt = IpPacketMut::new(packet)?;
pkt.update_transport_checksum(entry.original_src, entry.translated_src)?;
pkt.set_src_addr(entry.translated_src)?;
self.set_src_port_direct(pkt.data_mut(), header_len, protocol, entry.translated_src_port)?;
{
let mut stats = self.stats.write();
stats.packets_outbound += 1;
stats.bytes_outbound += packet.len() as u64;
stats.total_translations += 1;
}
Ok(Some(entry))
}
pub fn translate_inbound(&self, packet: &mut [u8]) -> Result<Option<(IpAddr, u16)>> {
let (flow, header_len, protocol) = {
let parsed = IpPacket::parse(packet)?;
(parsed.flow_tuple(), parsed.header_len, parsed.protocol)
};
if !self.is_tunnel_address(flow.dst_addr) {
return Ok(None);
}
let inbound_key = InboundKey {
translated_port: flow.dst_port,
src_addr: flow.src_addr,
src_port: flow.src_port,
protocol: flow.protocol.protocol_number(),
};
let outbound_key = match self.inbound.get(&inbound_key) {
Some(key) => key.clone(),
None => {
tracing::trace!(
dst_port = flow.dst_port,
src = %flow.src_addr,
src_port = flow.src_port,
"No NAT entry for inbound packet"
);
return Ok(None);
}
};
let entry = match self.outbound.get_mut(&outbound_key) {
Some(mut e) => {
e.touch(packet.len());
e.clone()
}
None => return Ok(None),
};
let mut pkt = IpPacketMut::new(packet)?;
pkt.update_transport_checksum(entry.translated_src, entry.original_src)?;
pkt.set_dst_addr(entry.original_src)?;
self.set_dst_port_direct(pkt.data_mut(), header_len, protocol, entry.original_src_port)?;
{
let mut stats = self.stats.write();
stats.packets_inbound += 1;
stats.bytes_inbound += packet.len() as u64;
}
Ok(Some((entry.original_src, entry.original_src_port)))
}
pub fn cleanup_expired(&self) {
let mut to_remove = Vec::new();
for entry in self.outbound.iter() {
let timeout = match entry.protocol {
TransportProtocol::Tcp => self.config.tcp_timeout,
TransportProtocol::Udp => self.config.udp_timeout,
TransportProtocol::Icmp | TransportProtocol::Icmpv6 => self.config.icmp_timeout,
_ => self.config.udp_timeout,
};
if entry.is_expired(timeout) {
to_remove.push(entry.key().clone());
}
}
for key in to_remove {
if let Some((_, entry)) = self.outbound.remove(&key) {
let inbound_key = InboundKey {
translated_port: entry.translated_src_port,
src_addr: entry.dst_addr,
src_port: entry.dst_port,
protocol: entry.protocol.protocol_number(),
};
self.inbound.remove(&inbound_key);
tracing::trace!(
original = %entry.original_src,
original_port = entry.original_src_port,
"Removed expired NAT entry"
);
}
}
}
pub fn entry_count(&self) -> usize {
self.outbound.len()
}
pub fn clear(&self) {
self.outbound.clear();
self.inbound.clear();
}
fn is_local_address(&self, addr: IpAddr) -> bool {
match addr {
IpAddr::V4(ipv4) => {
ipv4.is_private() || ipv4.is_loopback() || ipv4.is_link_local()
}
IpAddr::V6(ipv6) => {
ipv6.is_loopback()
|| (ipv6.segments()[0] & 0xfe00) == 0xfc00 || (ipv6.segments()[0] & 0xffc0) == 0xfe80 }
}
}
fn is_tunnel_address(&self, addr: IpAddr) -> bool {
match addr {
IpAddr::V4(ipv4) => ipv4 == self.config.tunnel_ipv4,
IpAddr::V6(ipv6) => self.config.tunnel_ipv6.map_or(false, |t| t == ipv6),
}
}
fn get_tunnel_address(&self, original: IpAddr) -> IpAddr {
match original {
IpAddr::V4(_) => IpAddr::V4(self.config.tunnel_ipv4),
IpAddr::V6(_) => {
self.config.tunnel_ipv6
.map(IpAddr::V6)
.unwrap_or_else(|| IpAddr::V4(self.config.tunnel_ipv4))
}
}
}
fn allocate_port(&self) -> Result<u16> {
let range_size = self.config.port_range_end - self.config.port_range_start;
for _ in 0..range_size {
let port = self.next_port.fetch_add(1, Ordering::SeqCst);
if port >= self.config.port_range_end {
self.next_port.store(self.config.port_range_start, Ordering::SeqCst);
}
let in_use = self.outbound.iter().any(|e| e.translated_src_port == port);
if !in_use {
return Ok(port);
}
}
Err(Error::Config("NAT port exhaustion".into()))
}
fn set_src_port_direct(&self, packet: &mut [u8], header_len: usize, protocol: TransportProtocol, port: u16) -> Result<()> {
if packet.len() < header_len + 2 {
return Ok(());
}
match protocol {
TransportProtocol::Tcp | TransportProtocol::Udp => {
packet[header_len..header_len + 2].copy_from_slice(&port.to_be_bytes());
}
_ => {}
}
Ok(())
}
fn set_dst_port_direct(&self, packet: &mut [u8], header_len: usize, protocol: TransportProtocol, port: u16) -> Result<()> {
if packet.len() < header_len + 4 {
return Ok(());
}
match protocol {
TransportProtocol::Tcp | TransportProtocol::Udp => {
packet[header_len + 2..header_len + 4].copy_from_slice(&port.to_be_bytes());
}
_ => {}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_nat_config_default() {
let config = NatConfig::default();
assert_eq!(config.tunnel_ipv4, Ipv4Addr::new(10, 0, 85, 1));
assert_eq!(config.port_range_start, 32768);
}
#[test]
fn test_nat_entry_timeout() {
let entry = NatEntry::new(
IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
12345,
IpAddr::V4(Ipv4Addr::new(10, 0, 85, 1)),
32768,
IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)),
80,
TransportProtocol::Tcp,
);
assert!(!entry.is_expired(Duration::from_secs(60)));
}
#[test]
fn test_port_allocation() {
let config = NatConfig {
port_range_start: 1000,
port_range_end: 1010,
..Default::default()
};
let nat = NatTable::new(config);
let port1 = nat.allocate_port().unwrap();
let port2 = nat.allocate_port().unwrap();
assert!(port1 >= 1000 && port1 < 1010);
assert!(port2 >= 1000 && port2 < 1010);
assert_ne!(port1, port2);
}
}