use std::collections::HashMap;
use std::net::SocketAddr;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub struct NatId(pub u16);
impl NatId {
pub fn from_checksums(sent_checksum: u16, received_checksum: u16) -> Self {
Self(received_checksum.wrapping_sub(sent_checksum))
}
pub fn is_natted(&self) -> bool {
self.0 != 0
}
pub fn value(&self) -> u16 {
self.0
}
}
#[derive(Debug, Clone, Copy)]
pub struct IpIdMarker {
pub ip_id: u16,
pub udp_checksum: u16,
pub ttl: u8,
pub flow_id: u16,
}
impl IpIdMarker {
pub fn new(udp_checksum: u16, ttl: u8, flow_id: u16) -> Self {
Self {
ip_id: udp_checksum,
udp_checksum,
ttl,
flow_id,
}
}
pub fn from_probe(ttl: u8, flow_id: u16, _use_src_port: bool) -> Self {
let id = u16::from(ttl).wrapping_add(flow_id);
Self {
ip_id: id,
udp_checksum: id, ttl,
flow_id,
}
}
pub fn matches(&self, received_ip_id: u16) -> bool {
self.ip_id == received_ip_id
}
pub fn matches_checksum(&self, received_checksum: u16) -> bool {
self.udp_checksum == received_checksum
}
}
#[derive(Debug, Clone)]
pub struct NatProbe {
pub marker: IpIdMarker,
pub dst_addr: SocketAddr,
pub src_addr: SocketAddr,
pub sent_at: Instant,
pub payload: Vec<u8>,
}
impl NatProbe {
pub fn new(
marker: IpIdMarker,
src_addr: SocketAddr,
dst_addr: SocketAddr,
payload: Vec<u8>,
) -> Self {
Self {
marker,
dst_addr,
src_addr,
sent_at: Instant::now(),
payload,
}
}
pub fn create_payload(base_payload: &[u8], ttl: u8, flow_id: u16) -> Vec<u8> {
let mut payload = base_payload.to_vec();
let id = u16::from(ttl).wrapping_add(flow_id);
payload.push((id >> 8) as u8);
payload.push((id & 0xff) as u8);
payload
}
}
#[derive(Debug, Clone)]
pub struct NatProbeResponse {
pub inner_ip_id: u16,
pub inner_udp_checksum: u16,
pub responder_addr: SocketAddr,
pub response_ip_id: u16,
pub received_at: Instant,
pub icmp_type: u8,
pub icmp_code: u8,
}
impl NatProbeResponse {
pub fn nat_id(&self, original_checksum: u16) -> NatId {
NatId::from_checksums(original_checksum, self.inner_udp_checksum)
}
pub fn rtt(&self, sent_at: Instant) -> Duration {
self.received_at.saturating_duration_since(sent_at)
}
pub fn is_ttl_exceeded(&self) -> bool {
self.icmp_type == 11 && self.icmp_code == 0
}
pub fn is_destination(&self) -> bool {
self.icmp_type == 3 && self.icmp_code == 3
}
}
#[derive(Debug, Clone)]
pub struct NatDetectionState {
current_nat_id: NatId,
hop_nat_ids: Vec<Option<NatId>>,
nat_locations: Vec<usize>,
last_update: Instant,
}
impl Default for NatDetectionState {
fn default() -> Self {
Self::new()
}
}
impl NatDetectionState {
pub fn new() -> Self {
Self {
current_nat_id: NatId::default(),
hop_nat_ids: Vec::new(),
nat_locations: Vec::new(),
last_update: Instant::now(),
}
}
pub fn update_hop(&mut self, hop: usize, nat_id: NatId) -> bool {
while self.hop_nat_ids.len() <= hop {
self.hop_nat_ids.push(None);
}
let nat_detected = hop > 0
&& self
.hop_nat_ids
.get(hop - 1)
.is_some_and(|prev| prev.is_some_and(|prev_id| prev_id != nat_id));
if nat_detected && !self.nat_locations.contains(&hop) {
self.nat_locations.push(hop);
}
self.hop_nat_ids[hop] = Some(nat_id);
self.current_nat_id = nat_id;
self.last_update = Instant::now();
nat_detected
}
pub fn has_nat(&self) -> bool {
!self.nat_locations.is_empty() || self.current_nat_id.is_natted()
}
pub fn nat_count(&self) -> usize {
self.nat_locations.len()
}
pub fn nat_locations(&self) -> &[usize] {
&self.nat_locations
}
pub fn current_nat_id(&self) -> NatId {
self.current_nat_id
}
pub fn reset(&mut self) {
self.current_nat_id = NatId::default();
self.hop_nat_ids.clear();
self.nat_locations.clear();
self.last_update = Instant::now();
}
}
#[derive(Debug)]
pub struct ProbeMatcher {
probes_by_ip_id: HashMap<u16, NatProbe>,
probes_by_checksum: HashMap<u16, u16>, timeout: Duration,
}
impl ProbeMatcher {
pub fn new(timeout: Duration) -> Self {
Self {
probes_by_ip_id: HashMap::new(),
probes_by_checksum: HashMap::new(),
timeout,
}
}
pub fn register_probe(&mut self, probe: NatProbe) {
let ip_id = probe.marker.ip_id;
let checksum = probe.marker.udp_checksum;
self.probes_by_checksum.insert(checksum, ip_id);
self.probes_by_ip_id.insert(ip_id, probe);
}
pub fn match_response(&mut self, response: &NatProbeResponse) -> Option<NatProbe> {
if let Some(probe) = self.probes_by_ip_id.remove(&response.inner_ip_id) {
self.probes_by_checksum.remove(&probe.marker.udp_checksum);
return Some(probe);
}
if let Some(&ip_id) = self.probes_by_checksum.get(&response.inner_udp_checksum) {
if let Some(probe) = self.probes_by_ip_id.remove(&ip_id) {
self.probes_by_checksum.remove(&probe.marker.udp_checksum);
return Some(probe);
}
}
None
}
pub fn cleanup(&mut self) {
let now = Instant::now();
let timeout = self.timeout;
self.probes_by_ip_id
.retain(|_, probe| now.duration_since(probe.sent_at) < timeout);
self.probes_by_checksum.clear();
for probe in self.probes_by_ip_id.values() {
self.probes_by_checksum
.insert(probe.marker.udp_checksum, probe.marker.ip_id);
}
}
pub fn pending_count(&self) -> usize {
self.probes_by_ip_id.len()
}
}
#[derive(Debug, Clone)]
pub struct UplinkNatState {
detection: NatDetectionState,
is_natted: bool,
external_addr: Option<SocketAddr>,
nat_type: NatType,
}
impl Default for UplinkNatState {
fn default() -> Self {
Self::new()
}
}
impl UplinkNatState {
pub fn new() -> Self {
Self {
detection: NatDetectionState::new(),
is_natted: false,
external_addr: None,
nat_type: NatType::Unknown,
}
}
pub fn is_natted(&self) -> bool {
self.is_natted
}
pub fn external_addr(&self) -> Option<SocketAddr> {
self.external_addr
}
pub fn nat_type(&self) -> NatType {
self.nat_type
}
pub fn update(&mut self, probe: &NatProbe, response: &NatProbeResponse) {
let nat_id = response.nat_id(probe.marker.udp_checksum);
self.is_natted = nat_id.is_natted();
if self.is_natted {
self.detection.update_hop(probe.marker.ttl as usize, nat_id);
}
}
pub fn set_external_addr(&mut self, addr: SocketAddr) {
self.external_addr = Some(addr);
}
pub fn set_nat_type(&mut self, nat_type: NatType) {
self.nat_type = nat_type;
self.is_natted = nat_type != NatType::None && nat_type != NatType::Unknown;
}
pub fn detection_state(&self) -> &NatDetectionState {
&self.detection
}
pub fn reset(&mut self) {
self.detection.reset();
self.is_natted = false;
self.external_addr = None;
self.nat_type = NatType::Unknown;
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum NatType {
#[default]
Unknown,
None,
FullCone,
RestrictedCone,
PortRestrictedCone,
Symmetric,
}
impl NatType {
pub fn allows_direct_connect(&self) -> bool {
matches!(self, Self::None | Self::FullCone)
}
pub fn requires_hole_punch(&self) -> bool {
matches!(self, Self::RestrictedCone | Self::PortRestrictedCone)
}
pub fn requires_relay(&self) -> bool {
matches!(self, Self::Symmetric)
}
}
pub fn compute_udp_checksum(
src_addr: &[u8],
dst_addr: &[u8],
src_port: u16,
dst_port: u16,
payload: &[u8],
) -> u16 {
let udp_len = 8 + payload.len();
let mut sum: u32 = 0;
for chunk in src_addr.chunks(2) {
sum = sum.wrapping_add(u32::from(u16::from_be_bytes([
chunk[0],
*chunk.get(1).unwrap_or(&0),
])));
}
for chunk in dst_addr.chunks(2) {
sum = sum.wrapping_add(u32::from(u16::from_be_bytes([
chunk[0],
*chunk.get(1).unwrap_or(&0),
])));
}
sum = sum.wrapping_add(17); sum = sum.wrapping_add(udp_len as u32);
sum = sum.wrapping_add(u32::from(src_port));
sum = sum.wrapping_add(u32::from(dst_port));
sum = sum.wrapping_add(udp_len as u32);
for chunk in payload.chunks(2) {
let word = if chunk.len() == 2 {
u16::from_be_bytes([chunk[0], chunk[1]])
} else {
u16::from_be_bytes([chunk[0], 0])
};
sum = sum.wrapping_add(u32::from(word));
}
while sum > 0xffff {
sum = (sum & 0xffff) + (sum >> 16);
}
let checksum = !sum as u16;
if checksum == 0 {
0xffff
} else {
checksum
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_nat_id_from_checksums() {
let nat_id = NatId::from_checksums(0x1234, 0x1234);
assert!(!nat_id.is_natted());
assert_eq!(nat_id.value(), 0);
let nat_id = NatId::from_checksums(0x1234, 0x1235);
assert!(nat_id.is_natted());
assert_eq!(nat_id.value(), 1);
}
#[test]
fn test_ip_id_marker() {
let marker = IpIdMarker::new(0xABCD, 64, 12345);
assert_eq!(marker.ip_id, 0xABCD);
assert!(marker.matches(0xABCD));
assert!(!marker.matches(0x1234));
}
#[test]
fn test_nat_detection_state() {
let mut state = NatDetectionState::new();
let nat_detected = state.update_hop(0, NatId(0));
assert!(!nat_detected);
assert!(!state.has_nat());
let nat_detected = state.update_hop(1, NatId(0));
assert!(!nat_detected);
let nat_detected = state.update_hop(2, NatId(100));
assert!(nat_detected);
assert!(state.has_nat());
assert_eq!(state.nat_count(), 1);
assert_eq!(state.nat_locations(), &[2]);
}
#[test]
fn test_probe_matcher() {
let mut matcher = ProbeMatcher::new(Duration::from_secs(5));
let marker = IpIdMarker::new(0xABCD, 64, 12345);
let probe = NatProbe::new(
marker,
SocketAddr::from(([192, 168, 1, 1], 12345)),
SocketAddr::from(([8, 8, 8, 8], 53)),
vec![0; 8],
);
matcher.register_probe(probe);
let response = NatProbeResponse {
inner_ip_id: 0xABCD,
inner_udp_checksum: 0xABCD,
responder_addr: SocketAddr::from(([1, 2, 3, 4], 0)),
response_ip_id: 0,
received_at: Instant::now(),
icmp_type: 11,
icmp_code: 0,
};
let matched = matcher.match_response(&response);
assert!(matched.is_some());
assert_eq!(matcher.pending_count(), 0);
}
#[test]
fn test_uplink_nat_state() {
let mut state = UplinkNatState::new();
assert!(!state.is_natted());
assert_eq!(state.nat_type(), NatType::Unknown);
state.set_nat_type(NatType::FullCone);
assert!(state.is_natted());
assert_eq!(state.nat_type(), NatType::FullCone);
state.set_external_addr(SocketAddr::from(([1, 2, 3, 4], 12345)));
assert!(state.external_addr().is_some());
state.reset();
assert!(!state.is_natted());
assert!(state.external_addr().is_none());
}
#[test]
fn test_nat_type() {
assert!(NatType::None.allows_direct_connect());
assert!(NatType::FullCone.allows_direct_connect());
assert!(!NatType::Symmetric.allows_direct_connect());
assert!(NatType::RestrictedCone.requires_hole_punch());
assert!(NatType::Symmetric.requires_relay());
}
#[test]
fn test_udp_checksum() {
let src = [192, 168, 1, 1];
let dst = [8, 8, 8, 8];
let checksum = compute_udp_checksum(&src, &dst, 12345, 53, b"test");
assert_ne!(checksum, 0);
}
#[test]
fn test_probe_payload_creation() {
let payload = NatProbe::create_payload(b"HELLO", 64, 12345);
assert_eq!(payload.len(), 7); }
}