use crate::connection::ConnectionManager;
use crate::types::{ConnectionStatus, NetworkError, PeerId};
use dashmap::DashMap;
use libp2p::core::Multiaddr;
use parking_lot::RwLock;
use rand::{thread_rng, Rng};
use serde::{Deserialize, Serialize};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use thiserror::Error;
use tokio::net::UdpSocket;
use tokio::sync::{mpsc, Mutex, Semaphore};
use tokio::time::{interval, sleep, timeout};
use tracing::{debug, error, info, warn};
type TransactionId = [u8; 12];
#[derive(Debug, Clone)]
pub struct Message {
pub msg_type: MessageType,
pub transaction_id: TransactionId,
pub attributes: Vec<Attribute>,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum MessageType {
BindingRequest,
BindingResponse,
BindingErrorResponse,
AllocateRequest,
AllocateResponse,
}
#[derive(Debug, Clone)]
pub enum Attribute {
MappedAddress(SocketAddr),
XorMappedAddress(SocketAddr),
ChangedAddress(SocketAddr),
Username(String),
MessageIntegrity(Vec<u8>),
ErrorCode(u16, String),
UnknownAttributes(Vec<u16>),
Realm(String),
Nonce(Vec<u8>),
}
#[derive(Debug, Error)]
pub enum NatTraversalError {
#[error("STUN error: {0}")]
StunError(String),
#[error("TURN error: {0}")]
TurnError(String),
#[error("UPnP error: {0}")]
UpnpError(String),
#[error("NAT-PMP error: {0}")]
NatPmpError(String),
#[error("Hole punching failed: {0}")]
HolePunchError(String),
#[error("Relay error: {0}")]
RelayError(String),
#[error("NAT detection failed: {0}")]
DetectionError(String),
#[error("Connection upgrade failed: {0}")]
UpgradeError(String),
#[error("Network error: {0}")]
NetworkError(#[from] NetworkError),
#[error("IO error: {0}")]
IoError(#[from] std::io::Error),
#[error("Operation timed out")]
Timeout,
#[error("Connection error: {0}")]
ConnectionError(NetworkError),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum NatType {
None,
FullCone,
RestrictedCone,
PortRestrictedCone,
Symmetric,
Unknown,
}
#[derive(Debug, Clone)]
pub struct NatInfo {
pub nat_type: NatType,
pub public_ip: Option<IpAddr>,
pub public_port: Option<u16>,
pub local_ip: IpAddr,
pub local_port: u16,
pub hairpinning: bool,
pub detected_at: Instant,
pub confidence: f64,
}
#[derive(Debug, Clone)]
pub struct StunServer {
pub address: SocketAddr,
pub priority: u32,
pub is_active: bool,
pub last_success: Option<Instant>,
pub avg_response_ms: u64,
}
impl StunServer {
pub fn new(address: SocketAddr, priority: u32) -> Self {
Self {
address,
priority,
is_active: true,
last_success: None,
avg_response_ms: 0,
}
}
}
#[derive(Debug, Clone)]
pub struct TurnServer {
pub address: SocketAddr,
pub username: String,
pub password: String,
pub realm: Option<String>,
pub priority: u32,
pub is_active: bool,
pub relay_address: Option<SocketAddr>,
}
#[derive(Debug, Clone)]
pub struct NatTraversalConfig {
pub enable_stun: bool,
pub enable_turn: bool,
pub enable_upnp: bool,
pub enable_nat_pmp: bool,
pub enable_hole_punching: bool,
pub enable_relay: bool,
pub enable_ipv6: bool,
pub stun_servers: Vec<StunServer>,
pub turn_servers: Vec<TurnServer>,
pub max_relay_connections: usize,
pub hole_punch_timeout: Duration,
pub detection_interval: Duration,
pub upgrade_interval: Duration,
pub port_mapping_lifetime: Duration,
}
impl Default for NatTraversalConfig {
fn default() -> Self {
Self {
enable_stun: true,
enable_turn: true,
enable_upnp: true,
enable_nat_pmp: true,
enable_hole_punching: true,
enable_relay: true,
enable_ipv6: true,
stun_servers: vec![
StunServer::new("stun1.l.google.com:19302".parse().unwrap(), 1),
StunServer::new("stun2.l.google.com:19302".parse().unwrap(), 2),
StunServer::new("stun3.l.google.com:19302".parse().unwrap(), 3),
StunServer::new("stun4.l.google.com:19302".parse().unwrap(), 4),
],
turn_servers: vec![],
max_relay_connections: 50,
hole_punch_timeout: Duration::from_secs(30),
detection_interval: Duration::from_secs(300), upgrade_interval: Duration::from_secs(60), port_mapping_lifetime: Duration::from_secs(3600), }
}
}
pub struct NatTraversalManager {
config: NatTraversalConfig,
nat_info: Arc<RwLock<Option<NatInfo>>>,
connection_manager: Arc<ConnectionManager>,
stun_client: Arc<StunClient>,
turn_client: Arc<TurnClient>,
upnp_manager: Arc<UpnpManager>,
nat_pmp_client: Arc<NatPmpClient>,
hole_punch_coordinator: Arc<HolePunchCoordinator>,
relay_manager: Arc<RelayManager>,
upgrade_manager: Arc<ConnectionUpgradeManager>,
port_mappings: Arc<DashMap<u16, PortMapping>>,
detection_handle: Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>,
stats: Arc<NatTraversalStats>,
}
#[derive(Debug, Clone)]
pub struct PortMapping {
pub local_port: u16,
pub external_port: u16,
pub protocol: PortMappingProtocol,
pub method: PortMappingMethod,
pub created_at: Instant,
pub expires_at: Instant,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PortMappingProtocol {
TCP,
UDP,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PortMappingMethod {
Upnp,
NatPmp,
Manual,
}
#[derive(Debug)]
pub struct NatTraversalStats {
pub total_attempts: AtomicU64,
pub successful_traversals: AtomicU64,
pub failed_traversals: AtomicU64,
pub stun_success: AtomicU64,
pub stun_failures: AtomicU64,
pub hole_punch_success: AtomicU64,
pub hole_punch_failures: AtomicU64,
pub relay_connections: AtomicU32,
pub upgraded_connections: AtomicU64,
pub port_mappings_created: AtomicU64,
pub port_mappings_failed: AtomicU64,
pub avg_traversal_time_ms: AtomicU64,
}
impl Default for NatTraversalStats {
fn default() -> Self {
Self {
total_attempts: AtomicU64::new(0),
successful_traversals: AtomicU64::new(0),
failed_traversals: AtomicU64::new(0),
stun_success: AtomicU64::new(0),
stun_failures: AtomicU64::new(0),
hole_punch_success: AtomicU64::new(0),
hole_punch_failures: AtomicU64::new(0),
relay_connections: AtomicU32::new(0),
upgraded_connections: AtomicU64::new(0),
port_mappings_created: AtomicU64::new(0),
port_mappings_failed: AtomicU64::new(0),
avg_traversal_time_ms: AtomicU64::new(0),
}
}
}
pub struct StunClient {
servers: Arc<RwLock<Vec<StunServer>>>,
socket: Arc<Mutex<Option<UdpSocket>>>,
#[allow(dead_code)]
transactions: Arc<DashMap<TransactionId, StunTransaction>>,
}
#[derive(Debug)]
#[allow(dead_code)]
struct StunTransaction {
server: SocketAddr,
sent_at: Instant,
callback: Arc<Mutex<Option<mpsc::Sender<Result<Message, NatTraversalError>>>>>,
}
impl StunClient {
pub fn new(servers: Vec<StunServer>) -> Self {
Self {
servers: Arc::new(RwLock::new(servers)),
socket: Arc::new(Mutex::new(None)),
transactions: Arc::new(DashMap::new()),
}
}
pub async fn detect_nat(&self) -> Result<NatInfo, NatTraversalError> {
let local_addr = if false {
SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0)
} else {
SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)
};
let socket = UdpSocket::bind(local_addr).await?;
let local_addr = socket.local_addr()?;
*self.socket.lock().await = Some(socket);
let mut results = Vec::new();
let servers = self.servers.read().clone();
for server in servers.iter().filter(|s| s.is_active) {
match self.query_stun_server(&server.address).await {
Ok(mapped_addr) => {
results.push((server.clone(), mapped_addr));
if results.len() >= 3 {
break; }
}
Err(e) => {
warn!("STUN query to {} failed: {}", server.address, e);
}
}
}
if results.is_empty() {
return Err(NatTraversalError::DetectionError(
"No STUN servers responded".to_string(),
));
}
let nat_type = self.analyze_nat_type(&results, local_addr).await?;
let (public_ip, public_port) = if let Some((_, addr)) = results.first() {
(Some(addr.ip()), Some(addr.port()))
} else {
(None, None)
};
Ok(NatInfo {
nat_type,
public_ip,
public_port,
local_ip: local_addr.ip(),
local_port: local_addr.port(),
hairpinning: false, detected_at: Instant::now(),
confidence: self.calculate_confidence(&results),
})
}
async fn query_stun_server(
&self,
server: &SocketAddr,
) -> Result<SocketAddr, NatTraversalError> {
let socket_guard = self.socket.lock().await;
let socket = socket_guard
.as_ref()
.ok_or_else(|| NatTraversalError::StunError("Socket not initialized".to_string()))?;
let request_data = b"STUN_REQUEST";
socket
.send_to(request_data, server)
.await
.map_err(|e| NatTraversalError::StunError(e.to_string()))?;
let mut response_buf = vec![0u8; 1024];
let (_len, from) = timeout(Duration::from_secs(5), socket.recv_from(&mut response_buf))
.await
.map_err(|_| NatTraversalError::Timeout)??;
if from != *server {
return Err(NatTraversalError::StunError(
"Response from wrong server".to_string(),
));
}
let local_addr = socket.local_addr()?;
Ok(SocketAddr::new(server.ip(), local_addr.port()))
}
async fn analyze_nat_type(
&self,
results: &[(StunServer, SocketAddr)],
local_addr: SocketAddr,
) -> Result<NatType, NatTraversalError> {
if let Some((_, public_addr)) = results.first() {
if public_addr.ip() == local_addr.ip() {
return Ok(NatType::None);
}
}
let all_same = results.windows(2).all(|w| w[0].1 == w[1].1);
if all_same {
Ok(NatType::RestrictedCone)
} else {
Ok(NatType::Symmetric)
}
}
fn calculate_confidence(&self, results: &[(StunServer, SocketAddr)]) -> f64 {
let base_confidence = results.len() as f64 / 3.0; base_confidence.min(1.0)
}
}
pub struct TurnClient {
servers: Arc<RwLock<Vec<TurnServer>>>,
allocations: Arc<DashMap<SocketAddr, TurnAllocation>>,
allocation_limit: Arc<Semaphore>,
}
#[derive(Debug, Clone)]
pub struct TurnAllocation {
pub server: SocketAddr,
pub relay_address: SocketAddr,
pub lifetime: Duration,
pub created_at: Instant,
pub refresh_handle: Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>,
}
impl TurnClient {
pub fn new(servers: Vec<TurnServer>, max_allocations: usize) -> Self {
Self {
servers: Arc::new(RwLock::new(servers)),
allocations: Arc::new(DashMap::new()),
allocation_limit: Arc::new(Semaphore::new(max_allocations)),
}
}
pub async fn allocate_relay(&self) -> Result<TurnAllocation, NatTraversalError> {
let _permit =
self.allocation_limit.acquire().await.map_err(|_| {
NatTraversalError::TurnError("Allocation limit reached".to_string())
})?;
let servers = self.servers.read().clone();
for server in servers.iter().filter(|s| s.is_active) {
match self.allocate_from_server(server).await {
Ok(allocation) => {
self.allocations.insert(server.address, allocation.clone());
return Ok(allocation);
}
Err(e) => {
warn!("TURN allocation from {} failed: {}", server.address, e);
}
}
}
Err(NatTraversalError::TurnError(
"No TURN servers available".to_string(),
))
}
async fn allocate_from_server(
&self,
server: &TurnServer,
) -> Result<TurnAllocation, NatTraversalError> {
Ok(TurnAllocation {
server: server.address,
relay_address: server.address, lifetime: Duration::from_secs(600),
created_at: Instant::now(),
refresh_handle: Arc::new(Mutex::new(None)),
})
}
}
#[derive(Debug, Clone)]
pub struct SimpleGateway {
pub address: SocketAddr,
pub name: String,
}
pub struct UpnpManager {
gateway: Arc<Mutex<Option<SimpleGateway>>>,
mappings: Arc<DashMap<u16, UpnpMapping>>,
#[allow(dead_code)]
refresh_interval: Duration,
}
#[derive(Debug, Clone)]
pub struct UpnpMapping {
pub local_port: u16,
pub external_port: u16,
pub protocol: PortMappingProtocol,
pub description: String,
pub lease_duration: Duration,
pub created_at: Instant,
}
impl UpnpManager {
pub fn new(refresh_interval: Duration) -> Self {
Self {
gateway: Arc::new(Mutex::new(None)),
mappings: Arc::new(DashMap::new()),
refresh_interval,
}
}
pub async fn discover_gateway(&self) -> Result<(), NatTraversalError> {
let potential_gateways = vec!["192.168.1.1:1900", "192.168.0.1:1900", "10.0.0.1:1900"];
for gateway_addr in potential_gateways {
if let Ok(addr) = gateway_addr.parse::<SocketAddr>() {
if let Ok(socket) = UdpSocket::bind("0.0.0.0:0").await {
if socket.send_to(b"M-SEARCH", addr).await.is_ok() {
info!("Discovered UPnP gateway at: {}", addr);
let gateway = SimpleGateway {
address: addr,
name: "UPnP Gateway".to_string(),
};
*self.gateway.lock().await = Some(gateway);
return Ok(());
}
}
}
}
Err(NatTraversalError::UpnpError(
"No UPnP gateway found".to_string(),
))
}
pub async fn create_mapping(
&self,
local_port: u16,
external_port: u16,
protocol: PortMappingProtocol,
description: &str,
lease_duration: Duration,
) -> Result<UpnpMapping, NatTraversalError> {
info!(
"Creating UPnP port mapping: {}:{} -> {} ({})",
local_port, external_port, protocol as u8, description
);
let mapping = UpnpMapping {
local_port,
external_port,
protocol,
description: description.to_string(),
lease_duration,
created_at: Instant::now(),
};
self.mappings.insert(local_port, mapping.clone());
Ok(mapping)
}
#[allow(dead_code)]
async fn get_local_ip(&self) -> Result<IpAddr, NatTraversalError> {
let socket = UdpSocket::bind("0.0.0.0:0").await?;
socket.connect("8.8.8.8:80").await?;
let local_addr = socket.local_addr()?;
Ok(local_addr.ip())
}
}
pub struct NatPmpClient {
gateway: Arc<Mutex<Option<IpAddr>>>,
mappings: Arc<DashMap<u16, NatPmpMapping>>,
}
#[derive(Debug, Clone)]
pub struct NatPmpMapping {
pub local_port: u16,
pub external_port: u16,
pub is_tcp: bool,
pub lifetime: Duration,
pub created_at: Instant,
}
impl NatPmpClient {
pub fn new() -> Self {
Self {
gateway: Arc::new(Mutex::new(None)),
mappings: Arc::new(DashMap::new()),
}
}
pub async fn discover_gateway(&self) -> Result<(), NatTraversalError> {
let common_gateways = vec!["192.168.1.1", "192.168.0.1", "10.0.0.1"];
for gateway_str in common_gateways {
if let Ok(gateway) = gateway_str.parse::<IpAddr>() {
if self.test_gateway(&gateway).await {
*self.gateway.lock().await = Some(gateway);
info!("Discovered NAT-PMP gateway: {}", gateway);
return Ok(());
}
}
}
Err(NatTraversalError::NatPmpError(
"No NAT-PMP gateway found".to_string(),
))
}
async fn test_gateway(&self, _gateway: &IpAddr) -> bool {
false
}
pub async fn create_mapping(
&self,
local_port: u16,
external_port: u16,
is_tcp: bool,
lifetime: Duration,
) -> Result<NatPmpMapping, NatTraversalError> {
let gateway = self.gateway.lock().await;
let _gateway_addr = gateway
.as_ref()
.ok_or_else(|| NatTraversalError::NatPmpError("No gateway discovered".to_string()))?;
let mapping = NatPmpMapping {
local_port,
external_port,
is_tcp,
lifetime,
created_at: Instant::now(),
};
self.mappings.insert(local_port, mapping.clone());
Ok(mapping)
}
}
pub struct HolePunchCoordinator {
attempts: Arc<DashMap<PeerId, HolePunchAttempt>>,
success_handlers: Arc<DashMap<PeerId, mpsc::Sender<SocketAddr>>>,
timeout: Duration,
}
#[derive(Debug)]
pub struct HolePunchAttempt {
pub peer_id: PeerId,
pub local_candidates: Vec<SocketAddr>,
pub remote_candidates: Vec<SocketAddr>,
pub started_at: Instant,
pub phase: HolePunchPhase,
pub succeeded: Arc<AtomicBool>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HolePunchPhase {
GatheringCandidates,
ExchangingCandidates,
Probing,
Connected,
Failed,
}
impl HolePunchCoordinator {
pub fn new(timeout: Duration) -> Self {
Self {
attempts: Arc::new(DashMap::new()),
success_handlers: Arc::new(DashMap::new()),
timeout,
}
}
pub async fn start_hole_punch(
&self,
peer_id: PeerId,
local_candidates: Vec<SocketAddr>,
remote_candidates: Vec<SocketAddr>,
) -> Result<SocketAddr, NatTraversalError> {
info!("Starting hole punch to peer {:?}", peer_id);
let attempt = HolePunchAttempt {
peer_id,
local_candidates: local_candidates.clone(),
remote_candidates: remote_candidates.clone(),
started_at: Instant::now(),
phase: HolePunchPhase::Probing,
succeeded: Arc::new(AtomicBool::new(false)),
};
self.attempts.insert(peer_id, attempt);
let (tx, mut rx) = mpsc::channel(1);
self.success_handlers.insert(peer_id, tx);
let _probe_tasks: Vec<_> = local_candidates
.iter()
.flat_map(|local| {
remote_candidates
.iter()
.map(move |remote| self.probe_candidate_pair(*local, *remote, peer_id))
})
.collect();
tokio::select! {
result = rx.recv() => {
match result {
Some(addr) => {
self.mark_success(peer_id);
Ok(addr)
}
None => Err(NatTraversalError::HolePunchError("Channel closed".to_string()))
}
}
_ = sleep(self.timeout) => {
self.mark_failure(peer_id);
Err(NatTraversalError::HolePunchError("Timeout".to_string()))
}
}
}
async fn probe_candidate_pair(
&self,
local: SocketAddr,
remote: SocketAddr,
peer_id: PeerId,
) -> Result<(), NatTraversalError> {
debug!("Probing candidate pair: {} -> {}", local, remote);
let socket = UdpSocket::bind(local).await?;
for i in 0..5 {
let probe_data = format!("HOLE_PUNCH_PROBE_{}", i).into_bytes();
socket.send_to(&probe_data, remote).await?;
let mut buf = vec![0u8; 1024];
match timeout(Duration::from_millis(200), socket.recv_from(&mut buf)).await {
Ok(Ok((len, from))) => {
if from == remote && len > 0 {
if let Some(handler) = self.success_handlers.get(&peer_id) {
let _ = handler.send(local).await;
}
return Ok(());
}
}
_ => continue, }
sleep(Duration::from_millis(100)).await;
}
Err(NatTraversalError::HolePunchError(
"No response from remote".to_string(),
))
}
fn mark_success(&self, peer_id: PeerId) {
if let Some(mut attempt) = self.attempts.get_mut(&peer_id) {
attempt.phase = HolePunchPhase::Connected;
attempt.succeeded.store(true, Ordering::Relaxed);
}
}
fn mark_failure(&self, peer_id: PeerId) {
if let Some(mut attempt) = self.attempts.get_mut(&peer_id) {
attempt.phase = HolePunchPhase::Failed;
}
}
}
pub struct RelayManager {
relay_servers: Arc<RwLock<Vec<RelayServer>>>,
relay_connections: Arc<DashMap<PeerId, RelayConnection>>,
connection_limit: Arc<Semaphore>,
stats: Arc<RelayStats>,
}
#[derive(Debug, Clone)]
pub struct RelayServer {
pub id: PeerId,
pub address: Multiaddr,
pub capacity: u32,
pub load: Arc<AtomicU32>,
pub is_available: bool,
pub last_health_check: Option<Instant>,
}
#[derive(Debug, Clone)]
pub struct RelayConnection {
pub relay_server: PeerId,
pub target_peer: PeerId,
pub connection_id: u64,
pub established_at: Instant,
pub bytes_relayed: Arc<AtomicU64>,
pub is_active: Arc<AtomicBool>,
}
#[derive(Debug)]
pub struct RelayStats {
pub total_connections: AtomicU64,
pub active_connections: AtomicU32,
pub bytes_relayed: AtomicU64,
pub failed_attempts: AtomicU64,
}
impl RelayManager {
pub fn new(max_connections: usize) -> Self {
Self {
relay_servers: Arc::new(RwLock::new(Vec::new())),
relay_connections: Arc::new(DashMap::new()),
connection_limit: Arc::new(Semaphore::new(max_connections)),
stats: Arc::new(RelayStats {
total_connections: AtomicU64::new(0),
active_connections: AtomicU32::new(0),
bytes_relayed: AtomicU64::new(0),
failed_attempts: AtomicU64::new(0),
}),
}
}
pub async fn add_relay_server(&self, server: RelayServer) {
self.relay_servers.write().push(server);
}
pub async fn establish_relay(
&self,
target_peer: PeerId,
) -> Result<RelayConnection, NatTraversalError> {
let _permit =
self.connection_limit.acquire().await.map_err(|_| {
NatTraversalError::RelayError("Connection limit reached".to_string())
})?;
let relay_server = self.select_relay_server().await?;
let connection = RelayConnection {
relay_server: relay_server.id,
target_peer,
connection_id: thread_rng().gen(),
established_at: Instant::now(),
bytes_relayed: Arc::new(AtomicU64::new(0)),
is_active: Arc::new(AtomicBool::new(true)),
};
self.stats.total_connections.fetch_add(1, Ordering::Relaxed);
self.stats
.active_connections
.fetch_add(1, Ordering::Relaxed);
relay_server.load.fetch_add(1, Ordering::Relaxed);
self.relay_connections
.insert(target_peer, connection.clone());
info!(
"Established relay connection to {:?} via {:?}",
target_peer, relay_server.id
);
Ok(connection)
}
async fn select_relay_server(&self) -> Result<RelayServer, NatTraversalError> {
let servers = self.relay_servers.read();
servers
.iter()
.filter(|s| s.is_available)
.min_by_key(|s| s.load.load(Ordering::Relaxed))
.cloned()
.ok_or_else(|| NatTraversalError::RelayError("No relay servers available".to_string()))
}
pub async fn close_relay(&self, peer_id: &PeerId) {
if let Some((_, connection)) = self.relay_connections.remove(peer_id) {
connection.is_active.store(false, Ordering::Relaxed);
self.stats
.active_connections
.fetch_sub(1, Ordering::Relaxed);
let servers = self.relay_servers.read();
if let Some(server) = servers.iter().find(|s| s.id == connection.relay_server) {
server.load.fetch_sub(1, Ordering::Relaxed);
}
}
}
}
pub struct ConnectionUpgradeManager {
upgrade_attempts: Arc<DashMap<PeerId, UpgradeAttempt>>,
upgrade_interval: Duration,
nat_manager: Option<Arc<NatTraversalManager>>,
}
#[derive(Debug)]
pub struct UpgradeAttempt {
pub peer_id: PeerId,
pub current_type: ConnectionType,
pub attempt_count: u32,
pub last_attempt: Instant,
pub succeeded: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConnectionType {
Direct,
Relay,
Turn,
}
impl ConnectionUpgradeManager {
pub fn new(upgrade_interval: Duration) -> Self {
Self {
upgrade_attempts: Arc::new(DashMap::new()),
upgrade_interval,
nat_manager: None,
}
}
pub fn set_nat_manager(&mut self, nat_manager: Arc<NatTraversalManager>) {
self.nat_manager = Some(nat_manager);
}
pub async fn try_upgrade(
&self,
peer_id: PeerId,
current_type: ConnectionType,
) -> Result<ConnectionType, NatTraversalError> {
if current_type == ConnectionType::Direct {
return Ok(ConnectionType::Direct); }
let mut attempt = self
.upgrade_attempts
.entry(peer_id)
.or_insert(UpgradeAttempt {
peer_id,
current_type,
attempt_count: 0,
last_attempt: Instant::now(),
succeeded: false,
});
if attempt.last_attempt.elapsed() < self.upgrade_interval {
return Err(NatTraversalError::UpgradeError(
"Too soon to retry".to_string(),
));
}
attempt.attempt_count += 1;
attempt.last_attempt = Instant::now();
if let Some(nat_manager) = &self.nat_manager {
match nat_manager.establish_direct_connection(peer_id).await {
Ok(_) => {
attempt.succeeded = true;
info!(
"Successfully upgraded connection to {:?} from {:?} to Direct",
peer_id, current_type
);
Ok(ConnectionType::Direct)
}
Err(e) => {
warn!("Failed to upgrade connection to {:?}: {}", peer_id, e);
Err(e)
}
}
} else {
Err(NatTraversalError::UpgradeError(
"NAT manager not available".to_string(),
))
}
}
}
impl NatTraversalManager {
pub fn new(config: NatTraversalConfig, connection_manager: Arc<ConnectionManager>) -> Self {
let stats = Arc::new(NatTraversalStats::default());
Self {
config: config.clone(),
nat_info: Arc::new(RwLock::new(None)),
connection_manager,
stun_client: Arc::new(StunClient::new(config.stun_servers.clone())),
turn_client: Arc::new(TurnClient::new(
config.turn_servers.clone(),
config.max_relay_connections,
)),
upnp_manager: Arc::new(UpnpManager::new(config.port_mapping_lifetime)),
nat_pmp_client: Arc::new(NatPmpClient::new()),
hole_punch_coordinator: Arc::new(HolePunchCoordinator::new(config.hole_punch_timeout)),
relay_manager: Arc::new(RelayManager::new(config.max_relay_connections)),
upgrade_manager: Arc::new(ConnectionUpgradeManager::new(config.upgrade_interval)),
port_mappings: Arc::new(DashMap::new()),
detection_handle: Arc::new(Mutex::new(None)),
stats,
}
}
pub async fn initialize(&self) -> Result<(), NatTraversalError> {
info!("Initializing NAT traversal manager");
if self.config.enable_stun {
self.start_nat_detection().await?;
}
if self.config.enable_upnp {
if let Err(e) = self.upnp_manager.discover_gateway().await {
warn!("UPnP gateway discovery failed: {}", e);
}
}
if self.config.enable_nat_pmp {
if let Err(e) = self.nat_pmp_client.discover_gateway().await {
warn!("NAT-PMP gateway discovery failed: {}", e);
}
}
self.start_periodic_tasks().await;
Ok(())
}
async fn start_nat_detection(&self) -> Result<(), NatTraversalError> {
match self.stun_client.detect_nat().await {
Ok(nat_info) => {
info!("NAT detected: {:?}", nat_info.nat_type);
*self.nat_info.write() = Some(nat_info);
self.stats.stun_success.fetch_add(1, Ordering::Relaxed);
Ok(())
}
Err(e) => {
error!("NAT detection failed: {}", e);
self.stats.stun_failures.fetch_add(1, Ordering::Relaxed);
Err(e)
}
}
}
async fn start_periodic_tasks(&self) {
let nat_info = Arc::clone(&self.nat_info);
let stun_client = Arc::clone(&self.stun_client);
let stats = Arc::clone(&self.stats);
let detection_interval = self.config.detection_interval;
let detection_task = tokio::spawn(async move {
let mut interval = interval(detection_interval);
loop {
interval.tick().await;
match stun_client.detect_nat().await {
Ok(new_info) => {
*nat_info.write() = Some(new_info);
stats.stun_success.fetch_add(1, Ordering::Relaxed);
}
Err(e) => {
warn!("Periodic NAT detection failed: {}", e);
stats.stun_failures.fetch_add(1, Ordering::Relaxed);
}
}
}
});
*self.detection_handle.lock().await = Some(detection_task);
}
pub fn get_nat_info(&self) -> Option<NatInfo> {
self.nat_info.read().clone()
}
pub async fn create_port_mapping(
&self,
local_port: u16,
external_port: u16,
protocol: PortMappingProtocol,
) -> Result<PortMapping, NatTraversalError> {
if self.config.enable_upnp {
match self
.upnp_manager
.create_mapping(
local_port,
external_port,
protocol,
"QuDAG P2P",
self.config.port_mapping_lifetime,
)
.await
{
Ok(mapping) => {
let port_mapping = PortMapping {
local_port,
external_port: mapping.external_port,
protocol,
method: PortMappingMethod::Upnp,
created_at: Instant::now(),
expires_at: Instant::now() + mapping.lease_duration,
};
self.port_mappings.insert(local_port, port_mapping.clone());
self.stats
.port_mappings_created
.fetch_add(1, Ordering::Relaxed);
return Ok(port_mapping);
}
Err(e) => {
warn!("UPnP port mapping failed: {}", e);
}
}
}
if self.config.enable_nat_pmp {
let is_tcp = matches!(protocol, PortMappingProtocol::TCP);
match self
.nat_pmp_client
.create_mapping(
local_port,
external_port,
is_tcp,
self.config.port_mapping_lifetime,
)
.await
{
Ok(mapping) => {
let port_mapping = PortMapping {
local_port,
external_port: mapping.external_port,
protocol,
method: PortMappingMethod::NatPmp,
created_at: Instant::now(),
expires_at: Instant::now() + mapping.lifetime,
};
self.port_mappings.insert(local_port, port_mapping.clone());
self.stats
.port_mappings_created
.fetch_add(1, Ordering::Relaxed);
return Ok(port_mapping);
}
Err(e) => {
warn!("NAT-PMP port mapping failed: {}", e);
}
}
}
self.stats
.port_mappings_failed
.fetch_add(1, Ordering::Relaxed);
Err(NatTraversalError::UpnpError(
"All port mapping methods failed".to_string(),
))
}
pub async fn connect_peer(&self, peer_id: PeerId) -> Result<(), NatTraversalError> {
match self.connection_manager.connect(peer_id).await {
Ok(()) => return Ok(()),
Err(e) => {
debug!("Direct connection failed: {}, trying NAT traversal", e);
}
}
if self.config.enable_hole_punching {
match self.try_hole_punch(peer_id).await {
Ok(()) => return Ok(()),
Err(e) => {
debug!("Hole punching failed: {}", e);
self.stats
.hole_punch_failures
.fetch_add(1, Ordering::Relaxed);
}
}
}
if self.config.enable_relay {
match self.establish_relay_connection(peer_id).await {
Ok(()) => return Ok(()),
Err(e) => {
error!("Relay connection failed: {}", e);
}
}
}
Err(NatTraversalError::ConnectionError(
NetworkError::ConnectionError("All connection methods failed".to_string()),
))
}
async fn try_hole_punch(&self, peer_id: PeerId) -> Result<(), NatTraversalError> {
let local_candidates = self.gather_local_candidates().await?;
let remote_candidates = self.exchange_candidates(peer_id, &local_candidates).await?;
match self
.hole_punch_coordinator
.start_hole_punch(peer_id, local_candidates, remote_candidates)
.await
{
Ok(addr) => {
info!("Hole punch successful, connected via {}", addr);
self.stats
.hole_punch_success
.fetch_add(1, Ordering::Relaxed);
self.connection_manager
.update_status(peer_id, ConnectionStatus::Connected);
Ok(())
}
Err(e) => Err(e),
}
}
async fn gather_local_candidates(&self) -> Result<Vec<SocketAddr>, NatTraversalError> {
let mut candidates = Vec::new();
if let Some(nat_info) = self.nat_info.read().as_ref() {
if let (Some(ip), Some(port)) = (nat_info.public_ip, nat_info.public_port) {
candidates.push(SocketAddr::new(ip, port));
}
}
for mapping in self.port_mappings.iter() {
if let Some(public_ip) = self.get_public_ip() {
candidates.push(SocketAddr::new(public_ip, mapping.external_port));
}
}
Ok(candidates)
}
async fn exchange_candidates(
&self,
_peer_id: PeerId,
_local_candidates: &[SocketAddr],
) -> Result<Vec<SocketAddr>, NatTraversalError> {
Ok(Vec::new())
}
async fn establish_relay_connection(&self, peer_id: PeerId) -> Result<(), NatTraversalError> {
if self.config.enable_turn {
match self.turn_client.allocate_relay().await {
Ok(allocation) => {
info!("TURN relay allocated: {}", allocation.relay_address);
return Ok(());
}
Err(e) => {
warn!("TURN allocation failed: {}", e);
}
}
}
match self.relay_manager.establish_relay(peer_id).await {
Ok(connection) => {
info!(
"Relay connection established via {:?}",
connection.relay_server
);
self.stats.relay_connections.fetch_add(1, Ordering::Relaxed);
self.connection_manager
.update_status(peer_id, ConnectionStatus::Connected);
self.schedule_connection_upgrade(peer_id, ConnectionType::Relay);
Ok(())
}
Err(e) => Err(e),
}
}
fn schedule_connection_upgrade(&self, peer_id: PeerId, current_type: ConnectionType) {
let upgrade_manager = Arc::clone(&self.upgrade_manager);
let stats = Arc::clone(&self.stats);
tokio::spawn(async move {
sleep(Duration::from_secs(30)).await;
match upgrade_manager.try_upgrade(peer_id, current_type).await {
Ok(ConnectionType::Direct) => {
stats.upgraded_connections.fetch_add(1, Ordering::Relaxed);
stats.relay_connections.fetch_sub(1, Ordering::Relaxed);
}
Ok(_) => {}
Err(e) => {
debug!("Connection upgrade failed: {}", e);
}
}
});
}
async fn establish_direct_connection(&self, peer_id: PeerId) -> Result<(), NatTraversalError> {
self.try_hole_punch(peer_id).await
}
fn get_public_ip(&self) -> Option<IpAddr> {
self.nat_info.read().as_ref()?.public_ip
}
pub fn get_stats(&self) -> NatTraversalStats {
NatTraversalStats {
total_attempts: AtomicU64::new(self.stats.total_attempts.load(Ordering::Relaxed)),
successful_traversals: AtomicU64::new(
self.stats.successful_traversals.load(Ordering::Relaxed),
),
failed_traversals: AtomicU64::new(self.stats.failed_traversals.load(Ordering::Relaxed)),
stun_success: AtomicU64::new(self.stats.stun_success.load(Ordering::Relaxed)),
stun_failures: AtomicU64::new(self.stats.stun_failures.load(Ordering::Relaxed)),
hole_punch_success: AtomicU64::new(
self.stats.hole_punch_success.load(Ordering::Relaxed),
),
hole_punch_failures: AtomicU64::new(
self.stats.hole_punch_failures.load(Ordering::Relaxed),
),
relay_connections: AtomicU32::new(self.stats.relay_connections.load(Ordering::Relaxed)),
upgraded_connections: AtomicU64::new(
self.stats.upgraded_connections.load(Ordering::Relaxed),
),
port_mappings_created: AtomicU64::new(
self.stats.port_mappings_created.load(Ordering::Relaxed),
),
port_mappings_failed: AtomicU64::new(
self.stats.port_mappings_failed.load(Ordering::Relaxed),
),
avg_traversal_time_ms: AtomicU64::new(
self.stats.avg_traversal_time_ms.load(Ordering::Relaxed),
),
}
}
pub async fn shutdown(&self) -> Result<(), NatTraversalError> {
info!("Shutting down NAT traversal manager");
if let Some(handle) = self.detection_handle.lock().await.take() {
handle.abort();
}
let relay_peers: Vec<_> = self
.relay_manager
.relay_connections
.iter()
.map(|entry| *entry.key())
.collect();
for peer_id in relay_peers {
self.relay_manager.close_relay(&peer_id).await;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_nat_detection() {
let servers = vec![StunServer::new("8.8.8.8:3478".parse().unwrap(), 1)];
let client = StunClient::new(servers);
match client.detect_nat().await {
Ok(nat_info) => {
println!("NAT type: {:?}", nat_info.nat_type);
println!("Public IP: {:?}", nat_info.public_ip);
}
Err(e) => {
println!("NAT detection failed: {}", e);
}
}
}
#[test]
fn test_nat_type_properties() {
assert_eq!(NatType::None, NatType::None);
assert_ne!(NatType::FullCone, NatType::Symmetric);
}
}