use std::net::SocketAddr;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u8)]
pub enum NatClass {
Unknown = 0,
Open = 1,
Cone = 2,
Symmetric = 3,
}
impl NatClass {
pub fn tag(&self) -> &'static str {
match self {
NatClass::Open => "nat:open",
NatClass::Cone => "nat:cone",
NatClass::Symmetric => "nat:symmetric",
NatClass::Unknown => "nat:unknown",
}
}
pub fn from_tag(tag: &str) -> Option<Self> {
match tag {
"nat:open" => Some(NatClass::Open),
"nat:cone" => Some(NatClass::Cone),
"nat:symmetric" => Some(NatClass::Symmetric),
"nat:unknown" => Some(NatClass::Unknown),
_ => None,
}
}
#[inline]
pub fn as_u8(self) -> u8 {
self as u8
}
#[inline]
pub fn from_u8(raw: u8) -> Self {
match raw {
1 => NatClass::Open,
2 => NatClass::Cone,
3 => NatClass::Symmetric,
_ => NatClass::Unknown,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum PairAction {
Direct,
SinglePunch,
SkipPunch,
}
pub fn pair_action(local: NatClass, remote: NatClass) -> PairAction {
use NatClass::*;
match (local, remote) {
(Open, Open) => PairAction::Direct,
(Open, Cone) => PairAction::Direct,
(Open, Symmetric) => PairAction::SinglePunch,
(Open, Unknown) => PairAction::Direct,
(Cone, Open) => PairAction::Direct,
(Cone, Cone) => PairAction::SinglePunch,
(Cone, Symmetric) => PairAction::SinglePunch,
(Cone, Unknown) => PairAction::SinglePunch,
(Symmetric, Open) => PairAction::SinglePunch,
(Symmetric, Cone) => PairAction::SinglePunch,
(Symmetric, Symmetric) => PairAction::SkipPunch,
(Symmetric, Unknown) => PairAction::SkipPunch,
(Unknown, Open) => PairAction::Direct,
(Unknown, Cone) => PairAction::SinglePunch,
(Unknown, Symmetric) => PairAction::SkipPunch,
(Unknown, Unknown) => PairAction::Direct,
}
}
#[derive(Debug, Clone, Default)]
pub struct ClassifyFsm {
probes: Vec<(u64, SocketAddr)>,
}
impl ClassifyFsm {
pub fn new() -> Self {
Self::default()
}
pub fn observe(&mut self, peer: u64, reflex: SocketAddr) {
if let Some(slot) = self.probes.iter_mut().find(|(p, _)| *p == peer) {
slot.1 = reflex;
} else {
self.probes.push((peer, reflex));
}
}
pub fn observation_count(&self) -> usize {
self.probes.len()
}
pub fn clear(&mut self) {
self.probes.clear();
}
pub fn classify(&self, bind_addr: SocketAddr) -> NatClass {
if self.probes.len() < 2 {
return NatClass::Unknown;
}
let bind_ip_is_wildcard = bind_addr.ip().is_unspecified();
if self.probes.iter().any(|(_, reflex)| {
reflex.port() == bind_addr.port()
&& (bind_ip_is_wildcard || reflex.ip() == bind_addr.ip())
}) {
return NatClass::Open;
}
let first_port = self.probes[0].1.port();
let port_stable = self
.probes
.iter()
.all(|(_, reflex)| reflex.port() == first_port);
if port_stable {
NatClass::Cone
} else {
NatClass::Symmetric
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn sa(addr: &str) -> SocketAddr {
addr.parse().unwrap()
}
#[test]
fn empty_classifies_as_unknown() {
let fsm = ClassifyFsm::new();
assert_eq!(fsm.classify(sa("10.0.0.1:9001")), NatClass::Unknown);
}
#[test]
fn one_probe_classifies_as_unknown() {
let mut fsm = ClassifyFsm::new();
fsm.observe(1, sa("10.0.0.1:9001"));
assert_eq!(fsm.classify(sa("10.0.0.1:9001")), NatClass::Unknown);
}
#[test]
fn reflex_matching_bind_is_open() {
let bind = sa("192.0.2.1:9001");
let mut fsm = ClassifyFsm::new();
fsm.observe(1, bind);
fsm.observe(2, bind);
assert_eq!(fsm.classify(bind), NatClass::Open);
}
#[test]
fn stable_port_across_peers_is_cone() {
let mut fsm = ClassifyFsm::new();
fsm.observe(1, sa("198.51.100.5:54321"));
fsm.observe(2, sa("198.51.100.5:54321"));
assert_eq!(fsm.classify(sa("192.0.2.1:9001")), NatClass::Cone);
}
#[test]
fn varying_port_is_symmetric() {
let mut fsm = ClassifyFsm::new();
fsm.observe(1, sa("198.51.100.5:54321"));
fsm.observe(2, sa("198.51.100.5:54322"));
assert_eq!(fsm.classify(sa("192.0.2.1:9001")), NatClass::Symmetric);
}
#[test]
fn later_observation_from_same_peer_replaces_earlier() {
let mut fsm = ClassifyFsm::new();
fsm.observe(1, sa("198.51.100.5:54321"));
fsm.observe(2, sa("198.51.100.5:54321"));
fsm.observe(1, sa("198.51.100.5:54322"));
assert_eq!(fsm.observation_count(), 2);
assert_eq!(fsm.classify(sa("192.0.2.1:9001")), NatClass::Symmetric);
}
#[test]
fn clear_resets_to_unknown() {
let mut fsm = ClassifyFsm::new();
fsm.observe(1, sa("198.51.100.5:54321"));
fsm.observe(2, sa("198.51.100.5:54321"));
fsm.clear();
assert_eq!(fsm.observation_count(), 0);
assert_eq!(fsm.classify(sa("192.0.2.1:9001")), NatClass::Unknown);
}
#[test]
fn open_beats_cone_when_bind_equals_one_reflex() {
let bind = sa("192.0.2.1:9001");
let mut fsm = ClassifyFsm::new();
fsm.observe(1, bind);
fsm.observe(2, sa("198.51.100.5:54321"));
assert_eq!(fsm.classify(bind), NatClass::Open);
}
#[test]
fn wildcard_bind_v4_recognizes_open() {
let mut fsm = ClassifyFsm::new();
fsm.observe(1, sa("192.0.2.1:9001"));
fsm.observe(2, sa("203.0.113.7:9001"));
let bind = sa("0.0.0.0:9001");
assert_eq!(
fsm.classify(bind),
NatClass::Open,
"wildcard bind must classify port-matching reflex as Open"
);
}
#[test]
fn wildcard_bind_v6_recognizes_open() {
let mut fsm = ClassifyFsm::new();
fsm.observe(1, sa("[2001:db8::1]:9001"));
fsm.observe(2, sa("[2001:db8::2]:9001"));
let bind = sa("[::]:9001");
assert_eq!(
fsm.classify(bind),
NatClass::Open,
"wildcard v6 bind must classify port-matching reflex as Open"
);
}
#[test]
fn wildcard_bind_with_varying_ports_is_symmetric() {
let mut fsm = ClassifyFsm::new();
fsm.observe(1, sa("192.0.2.1:54321"));
fsm.observe(2, sa("203.0.113.7:54322"));
let bind = sa("0.0.0.0:9001");
assert_eq!(fsm.classify(bind), NatClass::Symmetric);
}
#[test]
fn classification_is_stable_under_observation_permutation() {
let bind = sa("192.0.2.1:9001");
let obs = vec![
(1u64, sa("198.51.100.5:54321")),
(2u64, sa("198.51.100.5:54321")),
(3u64, sa("198.51.100.6:54321")),
(4u64, sa("198.51.100.7:54321")),
];
let mut fsm_a = ClassifyFsm::new();
for (p, r) in &obs {
fsm_a.observe(*p, *r);
}
let class_a = fsm_a.classify(bind);
let mut fsm_b = ClassifyFsm::new();
for (p, r) in obs.iter().rev() {
fsm_b.observe(*p, *r);
}
let class_b = fsm_b.classify(bind);
let mut fsm_c = ClassifyFsm::new();
for i in [0usize, 2, 1, 3] {
let (p, r) = obs[i];
fsm_c.observe(p, r);
}
let class_c = fsm_c.classify(bind);
assert_eq!(class_a, class_b, "ordering A vs reverse must agree");
assert_eq!(class_a, class_c, "ordering A vs interleaved must agree");
assert_eq!(fsm_a.observation_count(), fsm_b.observation_count());
assert_eq!(fsm_a.observation_count(), fsm_c.observation_count());
}
#[test]
fn classify_is_idempotent_across_many_calls() {
let bind = sa("192.0.2.1:9001");
let mut fsm = ClassifyFsm::new();
fsm.observe(1, sa("198.51.100.5:54321"));
fsm.observe(2, sa("198.51.100.5:54321"));
let first = fsm.classify(bind);
for _ in 0..1_000 {
assert_eq!(fsm.classify(bind), first);
}
assert_eq!(fsm.observation_count(), 2);
}
#[test]
fn fsm_accepts_many_observations_and_reflects_latest_in_class() {
let bind = sa("192.0.2.1:9001");
let mut fsm = ClassifyFsm::new();
for i in 1..=8 {
fsm.observe(i, sa(&format!("198.51.100.{i}:54321")));
}
assert_eq!(fsm.observation_count(), 8);
assert_eq!(fsm.classify(bind), NatClass::Cone);
fsm.observe(9, sa("198.51.100.9:54322"));
assert_eq!(fsm.observation_count(), 9);
assert_eq!(fsm.classify(bind), NatClass::Symmetric);
}
#[test]
fn concurrent_classify_reads_are_consistent() {
use std::sync::Arc;
use std::thread;
let bind = sa("192.0.2.1:9001");
let mut fsm = ClassifyFsm::new();
fsm.observe(1, sa("198.51.100.5:54321"));
fsm.observe(2, sa("198.51.100.5:54321"));
let fsm = Arc::new(fsm);
let expected = fsm.classify(bind);
let mut handles = Vec::new();
for _ in 0..8 {
let fsm = fsm.clone();
handles.push(thread::spawn(move || {
let mut seen = Vec::with_capacity(200);
for _ in 0..200 {
seen.push(fsm.classify(bind));
}
seen
}));
}
for h in handles {
let results = h.join().expect("thread panicked");
assert!(
results.iter().all(|c| *c == expected),
"some thread saw an inconsistent classification — \
got {results:?}, expected all = {expected:?}",
);
}
}
#[test]
fn tag_roundtrip() {
for variant in [
NatClass::Open,
NatClass::Cone,
NatClass::Symmetric,
NatClass::Unknown,
] {
let tag = variant.tag();
assert_eq!(NatClass::from_tag(tag), Some(variant));
}
}
#[test]
fn unknown_tag_rejects() {
assert_eq!(NatClass::from_tag("gpu"), None);
assert_eq!(NatClass::from_tag("nat:"), None);
assert_eq!(NatClass::from_tag("nat:weird"), None);
assert_eq!(NatClass::from_tag(""), None);
}
#[test]
fn u8_roundtrip() {
assert_eq!(NatClass::Unknown.as_u8(), 0);
for variant in [
NatClass::Unknown,
NatClass::Open,
NatClass::Cone,
NatClass::Symmetric,
] {
assert_eq!(NatClass::from_u8(variant.as_u8()), variant);
}
}
#[test]
fn from_u8_unknown_collapses_to_unknown() {
assert_eq!(NatClass::from_u8(4), NatClass::Unknown);
assert_eq!(NatClass::from_u8(255), NatClass::Unknown);
}
#[test]
fn pair_action_open_with_non_symmetric_is_direct() {
for peer in [NatClass::Open, NatClass::Cone, NatClass::Unknown] {
assert_eq!(
pair_action(NatClass::Open, peer),
PairAction::Direct,
"Open × {peer:?} should be Direct",
);
assert_eq!(
pair_action(peer, NatClass::Open),
PairAction::Direct,
"{peer:?} × Open should be Direct",
);
}
}
#[test]
fn pair_action_open_with_symmetric_is_single_punch() {
assert_eq!(
pair_action(NatClass::Open, NatClass::Symmetric),
PairAction::SinglePunch,
"Open × Symmetric needs coordinator-driven reverse connect",
);
assert_eq!(
pair_action(NatClass::Symmetric, NatClass::Open),
PairAction::SinglePunch,
"Symmetric × Open needs the same coordinator-driven flow",
);
}
#[test]
fn pair_action_symmetric_symmetric_skips_punch() {
assert_eq!(
pair_action(NatClass::Symmetric, NatClass::Symmetric),
PairAction::SkipPunch,
);
}
#[test]
fn pair_action_cone_cone_single_punch() {
assert_eq!(
pair_action(NatClass::Cone, NatClass::Cone),
PairAction::SinglePunch,
);
}
#[test]
fn pair_action_symmetric_cone_attempts_one() {
assert_eq!(
pair_action(NatClass::Symmetric, NatClass::Cone),
PairAction::SinglePunch,
);
assert_eq!(
pair_action(NatClass::Cone, NatClass::Symmetric),
PairAction::SinglePunch,
);
}
#[test]
fn pair_action_unknown_unknown_is_direct() {
assert_eq!(
pair_action(NatClass::Unknown, NatClass::Unknown),
PairAction::Direct,
);
}
#[test]
fn pair_action_symmetric_unknown_skips_punch() {
assert_eq!(
pair_action(NatClass::Symmetric, NatClass::Unknown),
PairAction::SkipPunch,
);
assert_eq!(
pair_action(NatClass::Unknown, NatClass::Symmetric),
PairAction::SkipPunch,
);
}
#[test]
fn pair_action_cone_unknown_attempts_one() {
assert_eq!(
pair_action(NatClass::Cone, NatClass::Unknown),
PairAction::SinglePunch,
);
assert_eq!(
pair_action(NatClass::Unknown, NatClass::Cone),
PairAction::SinglePunch,
);
}
#[test]
fn pair_action_matches_plan_matrix() {
use NatClass::*;
use PairAction::*;
let cases: &[(NatClass, NatClass, PairAction)] = &[
(Open, Open, Direct),
(Open, Cone, Direct),
(Open, Symmetric, SinglePunch),
(Open, Unknown, Direct),
(Cone, Open, Direct),
(Cone, Cone, SinglePunch),
(Cone, Symmetric, SinglePunch),
(Cone, Unknown, SinglePunch),
(Symmetric, Open, SinglePunch),
(Symmetric, Cone, SinglePunch),
(Symmetric, Symmetric, SkipPunch),
(Symmetric, Unknown, SkipPunch),
(Unknown, Open, Direct),
(Unknown, Cone, SinglePunch),
(Unknown, Symmetric, SkipPunch),
(Unknown, Unknown, Direct),
];
assert_eq!(cases.len(), 16, "matrix has 16 cells (4 × 4)");
for &(local, remote, expected) in cases {
assert_eq!(
pair_action(local, remote),
expected,
"pair_action({local:?}, {remote:?})",
);
}
}
struct Lcg(u64);
impl Lcg {
fn new(seed: u64) -> Self {
Self(seed.wrapping_add(0x9E37_79B9_7F4A_7C15))
}
fn next_u32(&mut self) -> u32 {
self.0 = self
.0
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
(self.0 >> 32) as u32
}
fn pick_class(&mut self) -> NatClass {
match self.next_u32() % 4 {
0 => NatClass::Unknown,
1 => NatClass::Open,
2 => NatClass::Cone,
_ => NatClass::Symmetric,
}
}
fn pick_port(&mut self) -> u16 {
if self.next_u32() & 1 == 0 {
54_321
} else {
40_000 + (self.next_u32() % 10_000) as u16
}
}
fn pick_ip_last_octet(&mut self) -> u8 {
(self.next_u32() % 250 + 5) as u8
}
}
#[test]
fn pair_action_is_total_and_yields_one_of_three_actions() {
const N: usize = 4_000;
let mut rng = Lcg::new(0x00C0_FFEE_F00D);
let valid = |a: PairAction| {
matches!(
a,
PairAction::Direct | PairAction::SinglePunch | PairAction::SkipPunch,
)
};
for _ in 0..N {
let local = rng.pick_class();
let remote = rng.pick_class();
let action = pair_action(local, remote);
assert!(
valid(action),
"pair_action({local:?}, {remote:?}) returned {action:?} — not a valid variant",
);
}
for &local in &[
NatClass::Open,
NatClass::Cone,
NatClass::Symmetric,
NatClass::Unknown,
] {
for &remote in &[
NatClass::Open,
NatClass::Cone,
NatClass::Symmetric,
NatClass::Unknown,
] {
let _ = pair_action(local, remote);
}
}
}
#[test]
fn unknown_pair_resolves_to_direct() {
assert_eq!(
pair_action(NatClass::Unknown, NatClass::Unknown),
PairAction::Direct,
"the 'attempt direct, fall back on failure' contract for \
Unknown × Unknown must not regress",
);
}
#[test]
fn fsm_classify_never_panics_under_random_observation_storms() {
const N: usize = 500;
let mut rng = Lcg::new(0xDEAD_BEEF_CAFE);
for iter in 0..N {
let mut fsm = ClassifyFsm::new();
let bind_port = rng.pick_port();
let bind: SocketAddr = format!("10.0.0.1:{bind_port}").parse().unwrap();
let obs_count = (rng.next_u32() % 13) as usize;
let mut unique_peers = std::collections::HashSet::new();
for _ in 0..obs_count {
let peer = (rng.next_u32() % 4) as u64;
unique_peers.insert(peer);
let ip_octet = rng.pick_ip_last_octet();
let port = rng.pick_port();
let reflex: SocketAddr = format!("198.51.100.{ip_octet}:{port}").parse().unwrap();
fsm.observe(peer, reflex);
}
assert_eq!(
fsm.observation_count(),
unique_peers.len(),
"iter {iter}: observation_count drifted from distinct-peer count",
);
let class = fsm.classify(bind);
assert!(
matches!(
class,
NatClass::Unknown | NatClass::Open | NatClass::Cone | NatClass::Symmetric,
),
"iter {iter}: classify returned {class:?} — invalid variant",
);
let remote = rng.pick_class();
let action = pair_action(class, remote);
let _ = action; }
}
#[test]
fn unknown_classification_always_recovers_on_enough_observations() {
const N: usize = 200;
let mut rng = Lcg::new(0xABCD_1234_5678);
for iter in 0..N {
let mut fsm = ClassifyFsm::new();
let bind: SocketAddr = "10.0.0.1:9001".parse().unwrap();
assert_eq!(fsm.classify(bind), NatClass::Unknown);
let p1 = rng.pick_port();
fsm.observe(1, format!("198.51.100.5:{p1}").parse().unwrap());
assert_eq!(
fsm.classify(bind),
NatClass::Unknown,
"iter {iter}: one observation must stay Unknown",
);
let p2 = rng.pick_port();
fsm.observe(2, format!("198.51.100.6:{p2}").parse().unwrap());
let class = fsm.classify(bind);
assert!(
matches!(class, NatClass::Open | NatClass::Cone | NatClass::Symmetric),
"iter {iter}: after 2 observations, class must be non-Unknown (was {class:?})",
);
}
}
#[test]
fn reclassification_is_order_independent_over_random_samples() {
const N: usize = 200;
let mut rng = Lcg::new(0x7A7A_B0B0);
let bind: SocketAddr = "192.0.2.1:9001".parse().unwrap();
for iter in 0..N {
let count = 2 + (rng.next_u32() % 5) as u64;
let obs: Vec<(u64, SocketAddr)> = (0..count)
.map(|i| {
let port = rng.pick_port();
let ip_octet = rng.pick_ip_last_octet();
let sa: SocketAddr = format!("198.51.100.{ip_octet}:{port}").parse().unwrap();
(i, sa)
})
.collect();
let mut fwd = ClassifyFsm::new();
for (p, r) in &obs {
fwd.observe(*p, *r);
}
let mut rev = ClassifyFsm::new();
for (p, r) in obs.iter().rev() {
rev.observe(*p, *r);
}
assert_eq!(
fwd.classify(bind),
rev.classify(bind),
"iter {iter}: classification must not depend on observe order",
);
}
}
}