use std::net::{Ipv4Addr, SocketAddrV4};
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::time::Instant;
use hashbrown::HashMap;
use crate::datapath::CachePadded;
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u8)]
pub enum ConnState {
#[default]
New = 0,
Established = 1,
FinWait = 2,
Closing = 3,
TimedOut = 4,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ConnTrackKey {
pub src_ip: Ipv4Addr,
pub dst_ip: Ipv4Addr,
pub src_port: u16,
pub dst_port: u16,
pub protocol: u8,
}
impl ConnTrackKey {
#[inline]
#[must_use]
pub const fn new(
src_ip: Ipv4Addr,
dst_ip: Ipv4Addr,
src_port: u16,
dst_port: u16,
protocol: u8,
) -> Self {
Self {
src_ip,
dst_ip,
src_port,
dst_port,
protocol,
}
}
#[inline]
#[must_use]
pub const fn reverse(&self) -> Self {
Self {
src_ip: self.dst_ip,
dst_ip: self.src_ip,
src_port: self.dst_port,
dst_port: self.src_port,
protocol: self.protocol,
}
}
#[inline]
#[must_use]
pub fn fast_hash(&self) -> u64 {
let mut hash: u64 = 0xcbf2_9ce4_8422_2325;
for byte in self.src_ip.octets() {
hash ^= byte as u64;
hash = hash.wrapping_mul(0x0100_0000_01b3);
}
for byte in self.dst_ip.octets() {
hash ^= byte as u64;
hash = hash.wrapping_mul(0x0100_0000_01b3);
}
hash ^= self.src_port as u64;
hash = hash.wrapping_mul(0x0100_0000_01b3);
hash ^= self.dst_port as u64;
hash = hash.wrapping_mul(0x0100_0000_01b3);
hash ^= self.protocol as u64;
hash = hash.wrapping_mul(0x0100_0000_01b3);
hash
}
}
#[repr(C, align(64))]
pub struct ConnTrackEntry {
pub orig_src: SocketAddrV4,
pub orig_dst: SocketAddrV4,
pub nat_src: SocketAddrV4,
pub nat_dst: SocketAddrV4,
pub state: ConnState,
pub protocol: u8,
pub flags: u8,
pub last_seen: AtomicU32,
pub packets: AtomicU64,
pub bytes: AtomicU64,
pub created_at: u32,
}
impl ConnTrackEntry {
pub const FLAG_SNAT: u8 = 1 << 0;
pub const FLAG_DNAT: u8 = 1 << 1;
pub const FLAG_REPLY_SEEN: u8 = 1 << 2;
#[must_use]
pub fn new_snat(
orig_src: SocketAddrV4,
orig_dst: SocketAddrV4,
nat_src: SocketAddrV4,
protocol: u8,
) -> Self {
let now = Instant::now().elapsed().as_secs() as u32;
Self {
orig_src,
orig_dst,
nat_src,
nat_dst: orig_dst, state: ConnState::New,
protocol,
flags: Self::FLAG_SNAT,
last_seen: AtomicU32::new(now),
packets: AtomicU64::new(0),
bytes: AtomicU64::new(0),
created_at: now,
}
}
#[inline]
pub fn touch(&self) {
let now = Instant::now().elapsed().as_secs() as u32;
self.last_seen.store(now, Ordering::Relaxed);
}
#[inline]
pub fn record_packet(&self, bytes: u64) {
self.packets.fetch_add(1, Ordering::Relaxed);
self.bytes.fetch_add(bytes, Ordering::Relaxed);
self.touch();
}
#[inline]
#[must_use]
pub fn is_expired(&self, timeout_secs: u32) -> bool {
let now = Instant::now().elapsed().as_secs() as u32;
let last = self.last_seen.load(Ordering::Relaxed);
now.saturating_sub(last) > timeout_secs
}
#[inline]
#[must_use]
pub const fn has_snat(&self) -> bool {
self.flags & Self::FLAG_SNAT != 0
}
#[inline]
#[must_use]
pub const fn has_dnat(&self) -> bool {
self.flags & Self::FLAG_DNAT != 0
}
#[inline]
#[must_use]
pub fn reply_seen(&self) -> bool {
self.flags & Self::FLAG_REPLY_SEEN != 0
}
#[inline]
pub fn mark_reply_seen(&mut self) {
self.flags |= Self::FLAG_REPLY_SEEN;
self.state = ConnState::Established;
}
}
impl std::fmt::Debug for ConnTrackEntry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConnTrackEntry")
.field("orig_src", &self.orig_src)
.field("orig_dst", &self.orig_dst)
.field("nat_src", &self.nat_src)
.field("nat_dst", &self.nat_dst)
.field("state", &self.state)
.field("protocol", &self.protocol)
.field("packets", &self.packets.load(Ordering::Relaxed))
.field("bytes", &self.bytes.load(Ordering::Relaxed))
.finish()
}
}
#[derive(Debug)]
pub struct FastCacheEntry {
pub key_hash: u64,
pub key: ConnTrackKey,
pub entry_ptr: *const ConnTrackEntry,
pub hits: u32,
}
unsafe impl Send for FastCacheEntry {}
unsafe impl Sync for FastCacheEntry {}
impl FastCacheEntry {
#[must_use]
pub fn new(key: ConnTrackKey, entry_ptr: *const ConnTrackEntry) -> Self {
Self {
key_hash: key.fast_hash(),
key,
entry_ptr,
hits: 0,
}
}
}
pub struct PortAllocator {
current: AtomicU32,
start: u16,
end: u16,
}
impl PortAllocator {
#[must_use]
pub fn new(start: u16, end: u16) -> Self {
Self {
current: AtomicU32::new(start as u32),
start,
end,
}
}
#[inline]
pub fn allocate(&self) -> u16 {
let range = (self.end - self.start + 1) as u32;
let port = self.current.fetch_add(1, Ordering::Relaxed);
self.start + ((port - self.start as u32) % range) as u16
}
}
pub struct ConnTrackTable {
entries: HashMap<ConnTrackKey, Box<ConnTrackEntry>>,
reverse: HashMap<ConnTrackKey, ConnTrackKey>,
fast_cache: Vec<Option<FastCacheEntry>>,
fast_cache_mask: usize,
port_alloc: PortAllocator,
external_ip: Ipv4Addr,
timeout_secs: u32,
stats: ConnTrackStats,
}
#[derive(Debug, Default)]
pub struct ConnTrackStats {
pub lookups: CachePadded<AtomicU64>,
pub fast_hits: CachePadded<AtomicU64>,
pub slow_lookups: CachePadded<AtomicU64>,
pub created: CachePadded<AtomicU64>,
pub expired: CachePadded<AtomicU64>,
}
impl ConnTrackTable {
#[must_use]
pub fn new(
external_ip: Ipv4Addr,
port_start: u16,
port_end: u16,
fast_cache_size: usize,
timeout_secs: u32,
) -> Self {
let fast_cache_size = fast_cache_size.next_power_of_two();
let fast_cache = (0..fast_cache_size).map(|_| None).collect();
Self {
entries: HashMap::new(),
reverse: HashMap::new(),
fast_cache,
fast_cache_mask: fast_cache_size - 1,
port_alloc: PortAllocator::new(port_start, port_end),
external_ip,
timeout_secs,
stats: ConnTrackStats::default(),
}
}
pub fn lookup(&mut self, key: &ConnTrackKey) -> Option<&ConnTrackEntry> {
self.stats.lookups.0.fetch_add(1, Ordering::Relaxed);
let hash = key.fast_hash();
let cache_idx = (hash as usize) & self.fast_cache_mask;
if let Some(ref cache_entry) = self.fast_cache[cache_idx] {
if cache_entry.key_hash == hash && cache_entry.key == *key {
self.stats.fast_hits.0.fetch_add(1, Ordering::Relaxed);
return Some(unsafe { &*cache_entry.entry_ptr });
}
}
self.stats.slow_lookups.0.fetch_add(1, Ordering::Relaxed);
if let Some(entry) = self.entries.get(key) {
let entry_ptr = entry.as_ref() as *const ConnTrackEntry;
self.fast_cache[cache_idx] = Some(FastCacheEntry::new(*key, entry_ptr));
return Some(entry);
}
None
}
pub fn lookup_reverse(&mut self, nat_key: &ConnTrackKey) -> Option<&ConnTrackEntry> {
let orig_key = self.reverse.get(nat_key).copied()?;
self.lookup(&orig_key)
}
pub fn get_or_create(&mut self, key: ConnTrackKey) -> &ConnTrackEntry {
if self.entries.contains_key(&key) {
return self.lookup(&key).unwrap();
}
let nat_port = self.port_alloc.allocate();
let nat_src = SocketAddrV4::new(self.external_ip, nat_port);
let entry = Box::new(ConnTrackEntry::new_snat(
SocketAddrV4::new(key.src_ip, key.src_port),
SocketAddrV4::new(key.dst_ip, key.dst_port),
nat_src,
key.protocol,
));
let reverse_key = ConnTrackKey::new(
key.dst_ip,
self.external_ip,
key.dst_port,
nat_port,
key.protocol,
);
self.reverse.insert(reverse_key, key);
self.entries.insert(key, entry);
self.stats.created.0.fetch_add(1, Ordering::Relaxed);
let entry_ref = self.entries.get(&key).unwrap();
let entry_ptr = entry_ref.as_ref() as *const ConnTrackEntry;
let hash = key.fast_hash();
let cache_idx = (hash as usize) & self.fast_cache_mask;
self.fast_cache[cache_idx] = Some(FastCacheEntry::new(key, entry_ptr));
entry_ref
}
pub fn expire_old(&mut self) -> usize {
let timeout = self.timeout_secs;
let expired_keys: Vec<ConnTrackKey> = self
.entries
.iter()
.filter(|(_, entry)| entry.is_expired(timeout))
.map(|(key, _)| *key)
.collect();
let count = expired_keys.len();
for key in expired_keys {
self.remove(&key);
}
self.stats
.expired
.0
.fetch_add(count as u64, Ordering::Relaxed);
count
}
pub fn remove(&mut self, key: &ConnTrackKey) {
if let Some(entry) = self.entries.remove(key) {
let reverse_key = ConnTrackKey::new(
key.dst_ip,
self.external_ip,
key.dst_port,
entry.nat_src.port(),
key.protocol,
);
self.reverse.remove(&reverse_key);
let hash = key.fast_hash();
let cache_idx = (hash as usize) & self.fast_cache_mask;
if let Some(ref cache_entry) = self.fast_cache[cache_idx] {
if cache_entry.key == *key {
self.fast_cache[cache_idx] = None;
}
}
}
}
#[must_use]
pub fn len(&self) -> usize {
self.entries.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
#[must_use]
pub fn stats(&self) -> &ConnTrackStats {
&self.stats
}
pub fn clear(&mut self) {
self.entries.clear();
self.reverse.clear();
for entry in &mut self.fast_cache {
*entry = None;
}
}
}
impl std::fmt::Debug for ConnTrackTable {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConnTrackTable")
.field("entries", &self.entries.len())
.field("external_ip", &self.external_ip)
.field("timeout_secs", &self.timeout_secs)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_conntrack_key() {
let key = ConnTrackKey::new(
Ipv4Addr::new(192, 168, 1, 100),
Ipv4Addr::new(8, 8, 8, 8),
12345,
80,
6,
);
let reverse = key.reverse();
assert_eq!(reverse.src_ip, Ipv4Addr::new(8, 8, 8, 8));
assert_eq!(reverse.dst_ip, Ipv4Addr::new(192, 168, 1, 100));
assert_eq!(reverse.src_port, 80);
assert_eq!(reverse.dst_port, 12345);
}
#[test]
fn test_conntrack_entry() {
let entry = ConnTrackEntry::new_snat(
SocketAddrV4::new(Ipv4Addr::new(192, 168, 1, 100), 12345),
SocketAddrV4::new(Ipv4Addr::new(8, 8, 8, 8), 80),
SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 54321),
6,
);
assert!(entry.has_snat());
assert!(!entry.has_dnat());
assert_eq!(entry.state, ConnState::New);
}
#[test]
fn test_conntrack_table_create() {
let mut table = ConnTrackTable::new(Ipv4Addr::new(10, 0, 0, 1), 49152, 65535, 256, 300);
let key = ConnTrackKey::new(
Ipv4Addr::new(192, 168, 1, 100),
Ipv4Addr::new(8, 8, 8, 8),
12345,
80,
6,
);
let entry = table.get_or_create(key);
assert_eq!(entry.nat_src.ip(), &Ipv4Addr::new(10, 0, 0, 1));
assert_eq!(table.len(), 1);
}
#[test]
fn test_conntrack_table_lookup() {
let mut table = ConnTrackTable::new(Ipv4Addr::new(10, 0, 0, 1), 49152, 65535, 256, 300);
let key = ConnTrackKey::new(
Ipv4Addr::new(192, 168, 1, 100),
Ipv4Addr::new(8, 8, 8, 8),
12345,
80,
6,
);
let _ = table.get_or_create(key);
assert!(table.lookup(&key).is_some());
let other_key = ConnTrackKey::new(
Ipv4Addr::new(192, 168, 1, 200),
Ipv4Addr::new(8, 8, 8, 8),
12346,
80,
6,
);
assert!(table.lookup(&other_key).is_none());
}
#[test]
fn test_port_allocator() {
let alloc = PortAllocator::new(1000, 1010);
let ports: Vec<u16> = (0..20).map(|_| alloc.allocate()).collect();
for (i, port) in ports.iter().enumerate().take(11) {
assert_eq!(*port, 1000 + (i as u16));
}
assert_eq!(ports[11], 1000);
}
}