1use crate::connection::ConnectionManager;
13use crate::types::{ConnectionStatus, NetworkError, PeerId};
14use dashmap::DashMap;
15use libp2p::core::Multiaddr;
16use parking_lot::RwLock;
17use rand::{thread_rng, Rng};
18use serde::{Deserialize, Serialize};
19use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
20use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
21use std::sync::Arc;
22use std::time::{Duration, Instant};
23use thiserror::Error;
24use tokio::net::UdpSocket;
25use tokio::sync::{mpsc, Mutex, Semaphore};
26use tokio::time::{interval, sleep, timeout};
27use tracing::{debug, error, info, warn};
28
29type TransactionId = [u8; 12];
31
32#[derive(Debug, Clone)]
34pub struct Message {
35 pub msg_type: MessageType,
37 pub transaction_id: TransactionId,
39 pub attributes: Vec<Attribute>,
41}
42
43#[derive(Debug, Clone, Copy, PartialEq)]
45pub enum MessageType {
46 BindingRequest,
48 BindingResponse,
50 BindingErrorResponse,
52 AllocateRequest,
54 AllocateResponse,
56}
57
58#[derive(Debug, Clone)]
60pub enum Attribute {
61 MappedAddress(SocketAddr),
63 XorMappedAddress(SocketAddr),
65 ChangedAddress(SocketAddr),
67 Username(String),
69 MessageIntegrity(Vec<u8>),
71 ErrorCode(u16, String),
73 UnknownAttributes(Vec<u16>),
75 Realm(String),
77 Nonce(Vec<u8>),
79}
80
81#[derive(Debug, Error)]
85pub enum NatTraversalError {
86 #[error("STUN error: {0}")]
88 StunError(String),
89
90 #[error("TURN error: {0}")]
92 TurnError(String),
93
94 #[error("UPnP error: {0}")]
96 UpnpError(String),
97
98 #[error("NAT-PMP error: {0}")]
100 NatPmpError(String),
101
102 #[error("Hole punching failed: {0}")]
104 HolePunchError(String),
105
106 #[error("Relay error: {0}")]
108 RelayError(String),
109
110 #[error("NAT detection failed: {0}")]
112 DetectionError(String),
113
114 #[error("Connection upgrade failed: {0}")]
116 UpgradeError(String),
117
118 #[error("Network error: {0}")]
120 NetworkError(#[from] NetworkError),
121
122 #[error("IO error: {0}")]
124 IoError(#[from] std::io::Error),
125
126 #[error("Operation timed out")]
128 Timeout,
129
130 #[error("Connection error: {0}")]
132 ConnectionError(NetworkError),
133}
134
135#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
137pub enum NatType {
138 None,
140 FullCone,
142 RestrictedCone,
144 PortRestrictedCone,
146 Symmetric,
148 Unknown,
150}
151
152#[derive(Debug, Clone)]
154pub struct NatInfo {
155 pub nat_type: NatType,
157 pub public_ip: Option<IpAddr>,
159 pub public_port: Option<u16>,
161 pub local_ip: IpAddr,
163 pub local_port: u16,
165 pub hairpinning: bool,
167 pub detected_at: Instant,
169 pub confidence: f64,
171}
172
173#[derive(Debug, Clone)]
175pub struct StunServer {
176 pub address: SocketAddr,
178 pub priority: u32,
180 pub is_active: bool,
182 pub last_success: Option<Instant>,
184 pub avg_response_ms: u64,
186}
187
188impl StunServer {
189 pub fn new(address: SocketAddr, priority: u32) -> Self {
191 Self {
192 address,
193 priority,
194 is_active: true,
195 last_success: None,
196 avg_response_ms: 0,
197 }
198 }
199}
200
201#[derive(Debug, Clone)]
203pub struct TurnServer {
204 pub address: SocketAddr,
206 pub username: String,
208 pub password: String,
210 pub realm: Option<String>,
212 pub priority: u32,
214 pub is_active: bool,
216 pub relay_address: Option<SocketAddr>,
218}
219
220#[derive(Debug, Clone)]
222pub struct NatTraversalConfig {
223 pub enable_stun: bool,
225 pub enable_turn: bool,
227 pub enable_upnp: bool,
229 pub enable_nat_pmp: bool,
231 pub enable_hole_punching: bool,
233 pub enable_relay: bool,
235 pub enable_ipv6: bool,
237 pub stun_servers: Vec<StunServer>,
239 pub turn_servers: Vec<TurnServer>,
241 pub max_relay_connections: usize,
243 pub hole_punch_timeout: Duration,
245 pub detection_interval: Duration,
247 pub upgrade_interval: Duration,
249 pub port_mapping_lifetime: Duration,
251}
252
253impl Default for NatTraversalConfig {
254 fn default() -> Self {
255 Self {
256 enable_stun: true,
257 enable_turn: true,
258 enable_upnp: true,
259 enable_nat_pmp: true,
260 enable_hole_punching: true,
261 enable_relay: true,
262 enable_ipv6: true,
263 stun_servers: vec![
264 StunServer::new("stun1.l.google.com:19302".parse().unwrap(), 1),
265 StunServer::new("stun2.l.google.com:19302".parse().unwrap(), 2),
266 StunServer::new("stun3.l.google.com:19302".parse().unwrap(), 3),
267 StunServer::new("stun4.l.google.com:19302".parse().unwrap(), 4),
268 ],
269 turn_servers: vec![],
270 max_relay_connections: 50,
271 hole_punch_timeout: Duration::from_secs(30),
272 detection_interval: Duration::from_secs(300), upgrade_interval: Duration::from_secs(60), port_mapping_lifetime: Duration::from_secs(3600), }
276 }
277}
278
279pub struct NatTraversalManager {
281 config: NatTraversalConfig,
283 nat_info: Arc<RwLock<Option<NatInfo>>>,
285 connection_manager: Arc<ConnectionManager>,
287 stun_client: Arc<StunClient>,
289 turn_client: Arc<TurnClient>,
291 upnp_manager: Arc<UpnpManager>,
293 nat_pmp_client: Arc<NatPmpClient>,
295 hole_punch_coordinator: Arc<HolePunchCoordinator>,
297 relay_manager: Arc<RelayManager>,
299 upgrade_manager: Arc<ConnectionUpgradeManager>,
301 port_mappings: Arc<DashMap<u16, PortMapping>>,
303 detection_handle: Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>,
305 stats: Arc<NatTraversalStats>,
307}
308
309#[derive(Debug, Clone)]
311pub struct PortMapping {
312 pub local_port: u16,
314 pub external_port: u16,
316 pub protocol: PortMappingProtocol,
318 pub method: PortMappingMethod,
320 pub created_at: Instant,
322 pub expires_at: Instant,
324}
325
326#[derive(Debug, Clone, Copy, PartialEq, Eq)]
328pub enum PortMappingProtocol {
329 TCP,
331 UDP,
333}
334
335#[derive(Debug, Clone, Copy, PartialEq, Eq)]
337pub enum PortMappingMethod {
338 Upnp,
340 NatPmp,
342 Manual,
344}
345
346#[derive(Debug)]
348pub struct NatTraversalStats {
349 pub total_attempts: AtomicU64,
351 pub successful_traversals: AtomicU64,
353 pub failed_traversals: AtomicU64,
355 pub stun_success: AtomicU64,
357 pub stun_failures: AtomicU64,
359 pub hole_punch_success: AtomicU64,
361 pub hole_punch_failures: AtomicU64,
363 pub relay_connections: AtomicU32,
365 pub upgraded_connections: AtomicU64,
367 pub port_mappings_created: AtomicU64,
369 pub port_mappings_failed: AtomicU64,
371 pub avg_traversal_time_ms: AtomicU64,
373}
374
375impl Default for NatTraversalStats {
376 fn default() -> Self {
377 Self {
378 total_attempts: AtomicU64::new(0),
379 successful_traversals: AtomicU64::new(0),
380 failed_traversals: AtomicU64::new(0),
381 stun_success: AtomicU64::new(0),
382 stun_failures: AtomicU64::new(0),
383 hole_punch_success: AtomicU64::new(0),
384 hole_punch_failures: AtomicU64::new(0),
385 relay_connections: AtomicU32::new(0),
386 upgraded_connections: AtomicU64::new(0),
387 port_mappings_created: AtomicU64::new(0),
388 port_mappings_failed: AtomicU64::new(0),
389 avg_traversal_time_ms: AtomicU64::new(0),
390 }
391 }
392}
393
394pub struct StunClient {
396 servers: Arc<RwLock<Vec<StunServer>>>,
398 socket: Arc<Mutex<Option<UdpSocket>>>,
400 #[allow(dead_code)]
402 transactions: Arc<DashMap<TransactionId, StunTransaction>>,
403}
404
405#[derive(Debug)]
407#[allow(dead_code)]
408struct StunTransaction {
409 server: SocketAddr,
411 sent_at: Instant,
413 callback: Arc<Mutex<Option<mpsc::Sender<Result<Message, NatTraversalError>>>>>,
415}
416
417impl StunClient {
418 pub fn new(servers: Vec<StunServer>) -> Self {
420 Self {
421 servers: Arc::new(RwLock::new(servers)),
422 socket: Arc::new(Mutex::new(None)),
423 transactions: Arc::new(DashMap::new()),
424 }
425 }
426
427 pub async fn detect_nat(&self) -> Result<NatInfo, NatTraversalError> {
429 let local_addr = if false {
431 SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0)
433 } else {
434 SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)
435 };
436
437 let socket = UdpSocket::bind(local_addr).await?;
438 let local_addr = socket.local_addr()?;
439
440 *self.socket.lock().await = Some(socket);
442
443 let mut results = Vec::new();
445 let servers = self.servers.read().clone();
446
447 for server in servers.iter().filter(|s| s.is_active) {
448 match self.query_stun_server(&server.address).await {
449 Ok(mapped_addr) => {
450 results.push((server.clone(), mapped_addr));
451 if results.len() >= 3 {
452 break; }
454 }
455 Err(e) => {
456 warn!("STUN query to {} failed: {}", server.address, e);
457 }
458 }
459 }
460
461 if results.is_empty() {
462 return Err(NatTraversalError::DetectionError(
463 "No STUN servers responded".to_string(),
464 ));
465 }
466
467 let nat_type = self.analyze_nat_type(&results, local_addr).await?;
469 let (public_ip, public_port) = if let Some((_, addr)) = results.first() {
470 (Some(addr.ip()), Some(addr.port()))
471 } else {
472 (None, None)
473 };
474
475 Ok(NatInfo {
476 nat_type,
477 public_ip,
478 public_port,
479 local_ip: local_addr.ip(),
480 local_port: local_addr.port(),
481 hairpinning: false, detected_at: Instant::now(),
483 confidence: self.calculate_confidence(&results),
484 })
485 }
486
487 async fn query_stun_server(
489 &self,
490 server: &SocketAddr,
491 ) -> Result<SocketAddr, NatTraversalError> {
492 let socket_guard = self.socket.lock().await;
494 let socket = socket_guard
495 .as_ref()
496 .ok_or_else(|| NatTraversalError::StunError("Socket not initialized".to_string()))?;
497
498 let request_data = b"STUN_REQUEST";
500
501 socket
503 .send_to(request_data, server)
504 .await
505 .map_err(|e| NatTraversalError::StunError(e.to_string()))?;
506
507 let mut response_buf = vec![0u8; 1024];
509 let (_len, from) = timeout(Duration::from_secs(5), socket.recv_from(&mut response_buf))
510 .await
511 .map_err(|_| NatTraversalError::Timeout)??;
512
513 if from != *server {
514 return Err(NatTraversalError::StunError(
515 "Response from wrong server".to_string(),
516 ));
517 }
518
519 let local_addr = socket.local_addr()?;
522
523 Ok(SocketAddr::new(server.ip(), local_addr.port()))
525 }
526
527 async fn analyze_nat_type(
529 &self,
530 results: &[(StunServer, SocketAddr)],
531 local_addr: SocketAddr,
532 ) -> Result<NatType, NatTraversalError> {
533 if let Some((_, public_addr)) = results.first() {
535 if public_addr.ip() == local_addr.ip() {
536 return Ok(NatType::None);
537 }
538 }
539
540 let all_same = results.windows(2).all(|w| w[0].1 == w[1].1);
542
543 if all_same {
544 Ok(NatType::RestrictedCone)
547 } else {
548 Ok(NatType::Symmetric)
550 }
551 }
552
553 fn calculate_confidence(&self, results: &[(StunServer, SocketAddr)]) -> f64 {
555 let base_confidence = results.len() as f64 / 3.0; base_confidence.min(1.0)
557 }
558}
559
560pub struct TurnClient {
562 servers: Arc<RwLock<Vec<TurnServer>>>,
564 allocations: Arc<DashMap<SocketAddr, TurnAllocation>>,
566 allocation_limit: Arc<Semaphore>,
568}
569
570#[derive(Debug, Clone)]
572pub struct TurnAllocation {
573 pub server: SocketAddr,
575 pub relay_address: SocketAddr,
577 pub lifetime: Duration,
579 pub created_at: Instant,
581 pub refresh_handle: Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>,
583}
584
585impl TurnClient {
586 pub fn new(servers: Vec<TurnServer>, max_allocations: usize) -> Self {
588 Self {
589 servers: Arc::new(RwLock::new(servers)),
590 allocations: Arc::new(DashMap::new()),
591 allocation_limit: Arc::new(Semaphore::new(max_allocations)),
592 }
593 }
594
595 pub async fn allocate_relay(&self) -> Result<TurnAllocation, NatTraversalError> {
597 let _permit =
599 self.allocation_limit.acquire().await.map_err(|_| {
600 NatTraversalError::TurnError("Allocation limit reached".to_string())
601 })?;
602
603 let servers = self.servers.read().clone();
605 for server in servers.iter().filter(|s| s.is_active) {
606 match self.allocate_from_server(server).await {
607 Ok(allocation) => {
608 self.allocations.insert(server.address, allocation.clone());
609 return Ok(allocation);
610 }
611 Err(e) => {
612 warn!("TURN allocation from {} failed: {}", server.address, e);
613 }
614 }
615 }
616
617 Err(NatTraversalError::TurnError(
618 "No TURN servers available".to_string(),
619 ))
620 }
621
622 async fn allocate_from_server(
624 &self,
625 server: &TurnServer,
626 ) -> Result<TurnAllocation, NatTraversalError> {
627 Ok(TurnAllocation {
630 server: server.address,
631 relay_address: server.address, lifetime: Duration::from_secs(600),
633 created_at: Instant::now(),
634 refresh_handle: Arc::new(Mutex::new(None)),
635 })
636 }
637}
638
639#[derive(Debug, Clone)]
641pub struct SimpleGateway {
642 pub address: SocketAddr,
644 pub name: String,
646}
647
648pub struct UpnpManager {
650 gateway: Arc<Mutex<Option<SimpleGateway>>>,
652 mappings: Arc<DashMap<u16, UpnpMapping>>,
654 #[allow(dead_code)]
656 refresh_interval: Duration,
657}
658
659#[derive(Debug, Clone)]
661pub struct UpnpMapping {
662 pub local_port: u16,
664 pub external_port: u16,
666 pub protocol: PortMappingProtocol,
668 pub description: String,
670 pub lease_duration: Duration,
672 pub created_at: Instant,
674}
675
676impl UpnpManager {
677 pub fn new(refresh_interval: Duration) -> Self {
679 Self {
680 gateway: Arc::new(Mutex::new(None)),
681 mappings: Arc::new(DashMap::new()),
682 refresh_interval,
683 }
684 }
685
686 pub async fn discover_gateway(&self) -> Result<(), NatTraversalError> {
688 let potential_gateways = vec!["192.168.1.1:1900", "192.168.0.1:1900", "10.0.0.1:1900"];
691
692 for gateway_addr in potential_gateways {
693 if let Ok(addr) = gateway_addr.parse::<SocketAddr>() {
694 if let Ok(socket) = UdpSocket::bind("0.0.0.0:0").await {
696 if socket.send_to(b"M-SEARCH", addr).await.is_ok() {
697 info!("Discovered UPnP gateway at: {}", addr);
698 let gateway = SimpleGateway {
699 address: addr,
700 name: "UPnP Gateway".to_string(),
701 };
702 *self.gateway.lock().await = Some(gateway);
703 return Ok(());
704 }
705 }
706 }
707 }
708
709 Err(NatTraversalError::UpnpError(
710 "No UPnP gateway found".to_string(),
711 ))
712 }
713
714 pub async fn create_mapping(
716 &self,
717 local_port: u16,
718 external_port: u16,
719 protocol: PortMappingProtocol,
720 description: &str,
721 lease_duration: Duration,
722 ) -> Result<UpnpMapping, NatTraversalError> {
723 info!(
726 "Creating UPnP port mapping: {}:{} -> {} ({})",
727 local_port, external_port, protocol as u8, description
728 );
729
730 let mapping = UpnpMapping {
731 local_port,
732 external_port,
733 protocol,
734 description: description.to_string(),
735 lease_duration,
736 created_at: Instant::now(),
737 };
738
739 self.mappings.insert(local_port, mapping.clone());
740 Ok(mapping)
741 }
742
743 #[allow(dead_code)]
745 async fn get_local_ip(&self) -> Result<IpAddr, NatTraversalError> {
746 let socket = UdpSocket::bind("0.0.0.0:0").await?;
748 socket.connect("8.8.8.8:80").await?;
749 let local_addr = socket.local_addr()?;
750 Ok(local_addr.ip())
751 }
752}
753
754pub struct NatPmpClient {
756 gateway: Arc<Mutex<Option<IpAddr>>>,
758 mappings: Arc<DashMap<u16, NatPmpMapping>>,
760}
761
762#[derive(Debug, Clone)]
764pub struct NatPmpMapping {
765 pub local_port: u16,
767 pub external_port: u16,
769 pub is_tcp: bool,
771 pub lifetime: Duration,
773 pub created_at: Instant,
775}
776
777impl NatPmpClient {
778 pub fn new() -> Self {
780 Self {
781 gateway: Arc::new(Mutex::new(None)),
782 mappings: Arc::new(DashMap::new()),
783 }
784 }
785
786 pub async fn discover_gateway(&self) -> Result<(), NatTraversalError> {
788 let common_gateways = vec!["192.168.1.1", "192.168.0.1", "10.0.0.1"];
791
792 for gateway_str in common_gateways {
793 if let Ok(gateway) = gateway_str.parse::<IpAddr>() {
794 if self.test_gateway(&gateway).await {
796 *self.gateway.lock().await = Some(gateway);
797 info!("Discovered NAT-PMP gateway: {}", gateway);
798 return Ok(());
799 }
800 }
801 }
802
803 Err(NatTraversalError::NatPmpError(
804 "No NAT-PMP gateway found".to_string(),
805 ))
806 }
807
808 async fn test_gateway(&self, _gateway: &IpAddr) -> bool {
810 false
813 }
814
815 pub async fn create_mapping(
817 &self,
818 local_port: u16,
819 external_port: u16,
820 is_tcp: bool,
821 lifetime: Duration,
822 ) -> Result<NatPmpMapping, NatTraversalError> {
823 let gateway = self.gateway.lock().await;
824 let _gateway_addr = gateway
825 .as_ref()
826 .ok_or_else(|| NatTraversalError::NatPmpError("No gateway discovered".to_string()))?;
827
828 let mapping = NatPmpMapping {
831 local_port,
832 external_port,
833 is_tcp,
834 lifetime,
835 created_at: Instant::now(),
836 };
837
838 self.mappings.insert(local_port, mapping.clone());
839 Ok(mapping)
840 }
841}
842
843pub struct HolePunchCoordinator {
845 attempts: Arc<DashMap<PeerId, HolePunchAttempt>>,
847 success_handlers: Arc<DashMap<PeerId, mpsc::Sender<SocketAddr>>>,
849 timeout: Duration,
851}
852
853#[derive(Debug)]
855pub struct HolePunchAttempt {
856 pub peer_id: PeerId,
858 pub local_candidates: Vec<SocketAddr>,
860 pub remote_candidates: Vec<SocketAddr>,
862 pub started_at: Instant,
864 pub phase: HolePunchPhase,
866 pub succeeded: Arc<AtomicBool>,
868}
869
870#[derive(Debug, Clone, Copy, PartialEq, Eq)]
872pub enum HolePunchPhase {
873 GatheringCandidates,
875 ExchangingCandidates,
877 Probing,
879 Connected,
881 Failed,
883}
884
885impl HolePunchCoordinator {
886 pub fn new(timeout: Duration) -> Self {
888 Self {
889 attempts: Arc::new(DashMap::new()),
890 success_handlers: Arc::new(DashMap::new()),
891 timeout,
892 }
893 }
894
895 pub async fn start_hole_punch(
897 &self,
898 peer_id: PeerId,
899 local_candidates: Vec<SocketAddr>,
900 remote_candidates: Vec<SocketAddr>,
901 ) -> Result<SocketAddr, NatTraversalError> {
902 info!("Starting hole punch to peer {:?}", peer_id);
903
904 let attempt = HolePunchAttempt {
905 peer_id,
906 local_candidates: local_candidates.clone(),
907 remote_candidates: remote_candidates.clone(),
908 started_at: Instant::now(),
909 phase: HolePunchPhase::Probing,
910 succeeded: Arc::new(AtomicBool::new(false)),
911 };
912
913 self.attempts.insert(peer_id, attempt);
914
915 let (tx, mut rx) = mpsc::channel(1);
917 self.success_handlers.insert(peer_id, tx);
918
919 let _probe_tasks: Vec<_> = local_candidates
921 .iter()
922 .flat_map(|local| {
923 remote_candidates
924 .iter()
925 .map(move |remote| self.probe_candidate_pair(*local, *remote, peer_id))
926 })
927 .collect();
928
929 tokio::select! {
931 result = rx.recv() => {
932 match result {
933 Some(addr) => {
934 self.mark_success(peer_id);
935 Ok(addr)
936 }
937 None => Err(NatTraversalError::HolePunchError("Channel closed".to_string()))
938 }
939 }
940 _ = sleep(self.timeout) => {
941 self.mark_failure(peer_id);
942 Err(NatTraversalError::HolePunchError("Timeout".to_string()))
943 }
944 }
945 }
946
947 async fn probe_candidate_pair(
949 &self,
950 local: SocketAddr,
951 remote: SocketAddr,
952 peer_id: PeerId,
953 ) -> Result<(), NatTraversalError> {
954 debug!("Probing candidate pair: {} -> {}", local, remote);
955
956 let socket = UdpSocket::bind(local).await?;
957
958 for i in 0..5 {
960 let probe_data = format!("HOLE_PUNCH_PROBE_{}", i).into_bytes();
961 socket.send_to(&probe_data, remote).await?;
962
963 let mut buf = vec![0u8; 1024];
965 match timeout(Duration::from_millis(200), socket.recv_from(&mut buf)).await {
966 Ok(Ok((len, from))) => {
967 if from == remote && len > 0 {
968 if let Some(handler) = self.success_handlers.get(&peer_id) {
970 let _ = handler.send(local).await;
971 }
972 return Ok(());
973 }
974 }
975 _ => continue, }
977
978 sleep(Duration::from_millis(100)).await;
979 }
980
981 Err(NatTraversalError::HolePunchError(
982 "No response from remote".to_string(),
983 ))
984 }
985
986 fn mark_success(&self, peer_id: PeerId) {
988 if let Some(mut attempt) = self.attempts.get_mut(&peer_id) {
989 attempt.phase = HolePunchPhase::Connected;
990 attempt.succeeded.store(true, Ordering::Relaxed);
991 }
992 }
993
994 fn mark_failure(&self, peer_id: PeerId) {
996 if let Some(mut attempt) = self.attempts.get_mut(&peer_id) {
997 attempt.phase = HolePunchPhase::Failed;
998 }
999 }
1000}
1001
1002pub struct RelayManager {
1004 relay_servers: Arc<RwLock<Vec<RelayServer>>>,
1006 relay_connections: Arc<DashMap<PeerId, RelayConnection>>,
1008 connection_limit: Arc<Semaphore>,
1010 stats: Arc<RelayStats>,
1012}
1013
1014#[derive(Debug, Clone)]
1016pub struct RelayServer {
1017 pub id: PeerId,
1019 pub address: Multiaddr,
1021 pub capacity: u32,
1023 pub load: Arc<AtomicU32>,
1025 pub is_available: bool,
1027 pub last_health_check: Option<Instant>,
1029}
1030
1031#[derive(Debug, Clone)]
1033pub struct RelayConnection {
1034 pub relay_server: PeerId,
1036 pub target_peer: PeerId,
1038 pub connection_id: u64,
1040 pub established_at: Instant,
1042 pub bytes_relayed: Arc<AtomicU64>,
1044 pub is_active: Arc<AtomicBool>,
1046}
1047
1048#[derive(Debug)]
1050pub struct RelayStats {
1051 pub total_connections: AtomicU64,
1053 pub active_connections: AtomicU32,
1055 pub bytes_relayed: AtomicU64,
1057 pub failed_attempts: AtomicU64,
1059}
1060
1061impl RelayManager {
1062 pub fn new(max_connections: usize) -> Self {
1064 Self {
1065 relay_servers: Arc::new(RwLock::new(Vec::new())),
1066 relay_connections: Arc::new(DashMap::new()),
1067 connection_limit: Arc::new(Semaphore::new(max_connections)),
1068 stats: Arc::new(RelayStats {
1069 total_connections: AtomicU64::new(0),
1070 active_connections: AtomicU32::new(0),
1071 bytes_relayed: AtomicU64::new(0),
1072 failed_attempts: AtomicU64::new(0),
1073 }),
1074 }
1075 }
1076
1077 pub async fn add_relay_server(&self, server: RelayServer) {
1079 self.relay_servers.write().push(server);
1080 }
1081
1082 pub async fn establish_relay(
1084 &self,
1085 target_peer: PeerId,
1086 ) -> Result<RelayConnection, NatTraversalError> {
1087 let _permit =
1089 self.connection_limit.acquire().await.map_err(|_| {
1090 NatTraversalError::RelayError("Connection limit reached".to_string())
1091 })?;
1092
1093 let relay_server = self.select_relay_server().await?;
1095
1096 let connection = RelayConnection {
1098 relay_server: relay_server.id,
1099 target_peer,
1100 connection_id: thread_rng().gen(),
1101 established_at: Instant::now(),
1102 bytes_relayed: Arc::new(AtomicU64::new(0)),
1103 is_active: Arc::new(AtomicBool::new(true)),
1104 };
1105
1106 self.stats.total_connections.fetch_add(1, Ordering::Relaxed);
1108 self.stats
1109 .active_connections
1110 .fetch_add(1, Ordering::Relaxed);
1111 relay_server.load.fetch_add(1, Ordering::Relaxed);
1112
1113 self.relay_connections
1114 .insert(target_peer, connection.clone());
1115
1116 info!(
1117 "Established relay connection to {:?} via {:?}",
1118 target_peer, relay_server.id
1119 );
1120 Ok(connection)
1121 }
1122
1123 async fn select_relay_server(&self) -> Result<RelayServer, NatTraversalError> {
1125 let servers = self.relay_servers.read();
1126
1127 servers
1128 .iter()
1129 .filter(|s| s.is_available)
1130 .min_by_key(|s| s.load.load(Ordering::Relaxed))
1131 .cloned()
1132 .ok_or_else(|| NatTraversalError::RelayError("No relay servers available".to_string()))
1133 }
1134
1135 pub async fn close_relay(&self, peer_id: &PeerId) {
1137 if let Some((_, connection)) = self.relay_connections.remove(peer_id) {
1138 connection.is_active.store(false, Ordering::Relaxed);
1139 self.stats
1140 .active_connections
1141 .fetch_sub(1, Ordering::Relaxed);
1142
1143 let servers = self.relay_servers.read();
1145 if let Some(server) = servers.iter().find(|s| s.id == connection.relay_server) {
1146 server.load.fetch_sub(1, Ordering::Relaxed);
1147 }
1148 }
1149 }
1150}
1151
1152pub struct ConnectionUpgradeManager {
1154 upgrade_attempts: Arc<DashMap<PeerId, UpgradeAttempt>>,
1156 upgrade_interval: Duration,
1158 nat_manager: Option<Arc<NatTraversalManager>>,
1160}
1161
1162#[derive(Debug)]
1164pub struct UpgradeAttempt {
1165 pub peer_id: PeerId,
1167 pub current_type: ConnectionType,
1169 pub attempt_count: u32,
1171 pub last_attempt: Instant,
1173 pub succeeded: bool,
1175}
1176
1177#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1179pub enum ConnectionType {
1180 Direct,
1182 Relay,
1184 Turn,
1186}
1187
1188impl ConnectionUpgradeManager {
1189 pub fn new(upgrade_interval: Duration) -> Self {
1191 Self {
1192 upgrade_attempts: Arc::new(DashMap::new()),
1193 upgrade_interval,
1194 nat_manager: None,
1195 }
1196 }
1197
1198 pub fn set_nat_manager(&mut self, nat_manager: Arc<NatTraversalManager>) {
1200 self.nat_manager = Some(nat_manager);
1201 }
1202
1203 pub async fn try_upgrade(
1205 &self,
1206 peer_id: PeerId,
1207 current_type: ConnectionType,
1208 ) -> Result<ConnectionType, NatTraversalError> {
1209 if current_type == ConnectionType::Direct {
1210 return Ok(ConnectionType::Direct); }
1212
1213 let mut attempt = self
1214 .upgrade_attempts
1215 .entry(peer_id)
1216 .or_insert(UpgradeAttempt {
1217 peer_id,
1218 current_type,
1219 attempt_count: 0,
1220 last_attempt: Instant::now(),
1221 succeeded: false,
1222 });
1223
1224 if attempt.last_attempt.elapsed() < self.upgrade_interval {
1226 return Err(NatTraversalError::UpgradeError(
1227 "Too soon to retry".to_string(),
1228 ));
1229 }
1230
1231 attempt.attempt_count += 1;
1232 attempt.last_attempt = Instant::now();
1233
1234 if let Some(nat_manager) = &self.nat_manager {
1236 match nat_manager.establish_direct_connection(peer_id).await {
1237 Ok(_) => {
1238 attempt.succeeded = true;
1239 info!(
1240 "Successfully upgraded connection to {:?} from {:?} to Direct",
1241 peer_id, current_type
1242 );
1243 Ok(ConnectionType::Direct)
1244 }
1245 Err(e) => {
1246 warn!("Failed to upgrade connection to {:?}: {}", peer_id, e);
1247 Err(e)
1248 }
1249 }
1250 } else {
1251 Err(NatTraversalError::UpgradeError(
1252 "NAT manager not available".to_string(),
1253 ))
1254 }
1255 }
1256}
1257
1258impl NatTraversalManager {
1259 pub fn new(config: NatTraversalConfig, connection_manager: Arc<ConnectionManager>) -> Self {
1261 let stats = Arc::new(NatTraversalStats::default());
1262
1263 Self {
1264 config: config.clone(),
1265 nat_info: Arc::new(RwLock::new(None)),
1266 connection_manager,
1267 stun_client: Arc::new(StunClient::new(config.stun_servers.clone())),
1268 turn_client: Arc::new(TurnClient::new(
1269 config.turn_servers.clone(),
1270 config.max_relay_connections,
1271 )),
1272 upnp_manager: Arc::new(UpnpManager::new(config.port_mapping_lifetime)),
1273 nat_pmp_client: Arc::new(NatPmpClient::new()),
1274 hole_punch_coordinator: Arc::new(HolePunchCoordinator::new(config.hole_punch_timeout)),
1275 relay_manager: Arc::new(RelayManager::new(config.max_relay_connections)),
1276 upgrade_manager: Arc::new(ConnectionUpgradeManager::new(config.upgrade_interval)),
1277 port_mappings: Arc::new(DashMap::new()),
1278 detection_handle: Arc::new(Mutex::new(None)),
1279 stats,
1280 }
1281 }
1282
1283 pub async fn initialize(&self) -> Result<(), NatTraversalError> {
1285 info!("Initializing NAT traversal manager");
1286
1287 if self.config.enable_stun {
1289 self.start_nat_detection().await?;
1290 }
1291
1292 if self.config.enable_upnp {
1294 if let Err(e) = self.upnp_manager.discover_gateway().await {
1295 warn!("UPnP gateway discovery failed: {}", e);
1296 }
1297 }
1298
1299 if self.config.enable_nat_pmp {
1300 if let Err(e) = self.nat_pmp_client.discover_gateway().await {
1301 warn!("NAT-PMP gateway discovery failed: {}", e);
1302 }
1303 }
1304
1305 self.start_periodic_tasks().await;
1307
1308 Ok(())
1309 }
1310
1311 async fn start_nat_detection(&self) -> Result<(), NatTraversalError> {
1313 match self.stun_client.detect_nat().await {
1314 Ok(nat_info) => {
1315 info!("NAT detected: {:?}", nat_info.nat_type);
1316 *self.nat_info.write() = Some(nat_info);
1317 self.stats.stun_success.fetch_add(1, Ordering::Relaxed);
1318 Ok(())
1319 }
1320 Err(e) => {
1321 error!("NAT detection failed: {}", e);
1322 self.stats.stun_failures.fetch_add(1, Ordering::Relaxed);
1323 Err(e)
1324 }
1325 }
1326 }
1327
1328 async fn start_periodic_tasks(&self) {
1330 let nat_info = Arc::clone(&self.nat_info);
1331 let stun_client = Arc::clone(&self.stun_client);
1332 let stats = Arc::clone(&self.stats);
1333 let detection_interval = self.config.detection_interval;
1334
1335 let detection_task = tokio::spawn(async move {
1337 let mut interval = interval(detection_interval);
1338 loop {
1339 interval.tick().await;
1340
1341 match stun_client.detect_nat().await {
1342 Ok(new_info) => {
1343 *nat_info.write() = Some(new_info);
1344 stats.stun_success.fetch_add(1, Ordering::Relaxed);
1345 }
1346 Err(e) => {
1347 warn!("Periodic NAT detection failed: {}", e);
1348 stats.stun_failures.fetch_add(1, Ordering::Relaxed);
1349 }
1350 }
1351 }
1352 });
1353
1354 *self.detection_handle.lock().await = Some(detection_task);
1355 }
1356
1357 pub fn get_nat_info(&self) -> Option<NatInfo> {
1359 self.nat_info.read().clone()
1360 }
1361
1362 pub async fn create_port_mapping(
1364 &self,
1365 local_port: u16,
1366 external_port: u16,
1367 protocol: PortMappingProtocol,
1368 ) -> Result<PortMapping, NatTraversalError> {
1369 if self.config.enable_upnp {
1371 match self
1372 .upnp_manager
1373 .create_mapping(
1374 local_port,
1375 external_port,
1376 protocol,
1377 "QuDAG P2P",
1378 self.config.port_mapping_lifetime,
1379 )
1380 .await
1381 {
1382 Ok(mapping) => {
1383 let port_mapping = PortMapping {
1384 local_port,
1385 external_port: mapping.external_port,
1386 protocol,
1387 method: PortMappingMethod::Upnp,
1388 created_at: Instant::now(),
1389 expires_at: Instant::now() + mapping.lease_duration,
1390 };
1391
1392 self.port_mappings.insert(local_port, port_mapping.clone());
1393 self.stats
1394 .port_mappings_created
1395 .fetch_add(1, Ordering::Relaxed);
1396 return Ok(port_mapping);
1397 }
1398 Err(e) => {
1399 warn!("UPnP port mapping failed: {}", e);
1400 }
1401 }
1402 }
1403
1404 if self.config.enable_nat_pmp {
1406 let is_tcp = matches!(protocol, PortMappingProtocol::TCP);
1407 match self
1408 .nat_pmp_client
1409 .create_mapping(
1410 local_port,
1411 external_port,
1412 is_tcp,
1413 self.config.port_mapping_lifetime,
1414 )
1415 .await
1416 {
1417 Ok(mapping) => {
1418 let port_mapping = PortMapping {
1419 local_port,
1420 external_port: mapping.external_port,
1421 protocol,
1422 method: PortMappingMethod::NatPmp,
1423 created_at: Instant::now(),
1424 expires_at: Instant::now() + mapping.lifetime,
1425 };
1426
1427 self.port_mappings.insert(local_port, port_mapping.clone());
1428 self.stats
1429 .port_mappings_created
1430 .fetch_add(1, Ordering::Relaxed);
1431 return Ok(port_mapping);
1432 }
1433 Err(e) => {
1434 warn!("NAT-PMP port mapping failed: {}", e);
1435 }
1436 }
1437 }
1438
1439 self.stats
1440 .port_mappings_failed
1441 .fetch_add(1, Ordering::Relaxed);
1442 Err(NatTraversalError::UpnpError(
1443 "All port mapping methods failed".to_string(),
1444 ))
1445 }
1446
1447 pub async fn connect_peer(&self, peer_id: PeerId) -> Result<(), NatTraversalError> {
1449 match self.connection_manager.connect(peer_id).await {
1451 Ok(()) => return Ok(()),
1452 Err(e) => {
1453 debug!("Direct connection failed: {}, trying NAT traversal", e);
1454 }
1455 }
1456
1457 if self.config.enable_hole_punching {
1459 match self.try_hole_punch(peer_id).await {
1460 Ok(()) => return Ok(()),
1461 Err(e) => {
1462 debug!("Hole punching failed: {}", e);
1463 self.stats
1464 .hole_punch_failures
1465 .fetch_add(1, Ordering::Relaxed);
1466 }
1467 }
1468 }
1469
1470 if self.config.enable_relay {
1472 match self.establish_relay_connection(peer_id).await {
1473 Ok(()) => return Ok(()),
1474 Err(e) => {
1475 error!("Relay connection failed: {}", e);
1476 }
1477 }
1478 }
1479
1480 Err(NatTraversalError::ConnectionError(
1481 NetworkError::ConnectionError("All connection methods failed".to_string()),
1482 ))
1483 }
1484
1485 async fn try_hole_punch(&self, peer_id: PeerId) -> Result<(), NatTraversalError> {
1487 let local_candidates = self.gather_local_candidates().await?;
1489
1490 let remote_candidates = self.exchange_candidates(peer_id, &local_candidates).await?;
1492
1493 match self
1495 .hole_punch_coordinator
1496 .start_hole_punch(peer_id, local_candidates, remote_candidates)
1497 .await
1498 {
1499 Ok(addr) => {
1500 info!("Hole punch successful, connected via {}", addr);
1501 self.stats
1502 .hole_punch_success
1503 .fetch_add(1, Ordering::Relaxed);
1504
1505 self.connection_manager
1507 .update_status(peer_id, ConnectionStatus::Connected);
1508 Ok(())
1509 }
1510 Err(e) => Err(e),
1511 }
1512 }
1513
1514 async fn gather_local_candidates(&self) -> Result<Vec<SocketAddr>, NatTraversalError> {
1516 let mut candidates = Vec::new();
1517
1518 if let Some(nat_info) = self.nat_info.read().as_ref() {
1520 if let (Some(ip), Some(port)) = (nat_info.public_ip, nat_info.public_port) {
1521 candidates.push(SocketAddr::new(ip, port));
1522 }
1523 }
1524
1525 for mapping in self.port_mappings.iter() {
1530 if let Some(public_ip) = self.get_public_ip() {
1531 candidates.push(SocketAddr::new(public_ip, mapping.external_port));
1532 }
1533 }
1534
1535 Ok(candidates)
1536 }
1537
1538 async fn exchange_candidates(
1540 &self,
1541 _peer_id: PeerId,
1542 _local_candidates: &[SocketAddr],
1543 ) -> Result<Vec<SocketAddr>, NatTraversalError> {
1544 Ok(Vec::new())
1547 }
1548
1549 async fn establish_relay_connection(&self, peer_id: PeerId) -> Result<(), NatTraversalError> {
1551 if self.config.enable_turn {
1553 match self.turn_client.allocate_relay().await {
1554 Ok(allocation) => {
1555 info!("TURN relay allocated: {}", allocation.relay_address);
1556 return Ok(());
1558 }
1559 Err(e) => {
1560 warn!("TURN allocation failed: {}", e);
1561 }
1562 }
1563 }
1564
1565 match self.relay_manager.establish_relay(peer_id).await {
1567 Ok(connection) => {
1568 info!(
1569 "Relay connection established via {:?}",
1570 connection.relay_server
1571 );
1572 self.stats.relay_connections.fetch_add(1, Ordering::Relaxed);
1573
1574 self.connection_manager
1576 .update_status(peer_id, ConnectionStatus::Connected);
1577
1578 self.schedule_connection_upgrade(peer_id, ConnectionType::Relay);
1580
1581 Ok(())
1582 }
1583 Err(e) => Err(e),
1584 }
1585 }
1586
1587 fn schedule_connection_upgrade(&self, peer_id: PeerId, current_type: ConnectionType) {
1589 let upgrade_manager = Arc::clone(&self.upgrade_manager);
1590 let stats = Arc::clone(&self.stats);
1591
1592 tokio::spawn(async move {
1593 sleep(Duration::from_secs(30)).await;
1595
1596 match upgrade_manager.try_upgrade(peer_id, current_type).await {
1597 Ok(ConnectionType::Direct) => {
1598 stats.upgraded_connections.fetch_add(1, Ordering::Relaxed);
1599 stats.relay_connections.fetch_sub(1, Ordering::Relaxed);
1600 }
1601 Ok(_) => {}
1602 Err(e) => {
1603 debug!("Connection upgrade failed: {}", e);
1604 }
1605 }
1606 });
1607 }
1608
1609 async fn establish_direct_connection(&self, peer_id: PeerId) -> Result<(), NatTraversalError> {
1611 self.try_hole_punch(peer_id).await
1613 }
1614
1615 fn get_public_ip(&self) -> Option<IpAddr> {
1617 self.nat_info.read().as_ref()?.public_ip
1618 }
1619
1620 pub fn get_stats(&self) -> NatTraversalStats {
1622 NatTraversalStats {
1623 total_attempts: AtomicU64::new(self.stats.total_attempts.load(Ordering::Relaxed)),
1624 successful_traversals: AtomicU64::new(
1625 self.stats.successful_traversals.load(Ordering::Relaxed),
1626 ),
1627 failed_traversals: AtomicU64::new(self.stats.failed_traversals.load(Ordering::Relaxed)),
1628 stun_success: AtomicU64::new(self.stats.stun_success.load(Ordering::Relaxed)),
1629 stun_failures: AtomicU64::new(self.stats.stun_failures.load(Ordering::Relaxed)),
1630 hole_punch_success: AtomicU64::new(
1631 self.stats.hole_punch_success.load(Ordering::Relaxed),
1632 ),
1633 hole_punch_failures: AtomicU64::new(
1634 self.stats.hole_punch_failures.load(Ordering::Relaxed),
1635 ),
1636 relay_connections: AtomicU32::new(self.stats.relay_connections.load(Ordering::Relaxed)),
1637 upgraded_connections: AtomicU64::new(
1638 self.stats.upgraded_connections.load(Ordering::Relaxed),
1639 ),
1640 port_mappings_created: AtomicU64::new(
1641 self.stats.port_mappings_created.load(Ordering::Relaxed),
1642 ),
1643 port_mappings_failed: AtomicU64::new(
1644 self.stats.port_mappings_failed.load(Ordering::Relaxed),
1645 ),
1646 avg_traversal_time_ms: AtomicU64::new(
1647 self.stats.avg_traversal_time_ms.load(Ordering::Relaxed),
1648 ),
1649 }
1650 }
1651
1652 pub async fn shutdown(&self) -> Result<(), NatTraversalError> {
1654 info!("Shutting down NAT traversal manager");
1655
1656 if let Some(handle) = self.detection_handle.lock().await.take() {
1658 handle.abort();
1659 }
1660
1661 let relay_peers: Vec<_> = self
1663 .relay_manager
1664 .relay_connections
1665 .iter()
1666 .map(|entry| *entry.key())
1667 .collect();
1668
1669 for peer_id in relay_peers {
1670 self.relay_manager.close_relay(&peer_id).await;
1671 }
1672
1673 Ok(())
1677 }
1678}
1679
1680#[cfg(test)]
1681mod tests {
1682 use super::*;
1683
1684 #[tokio::test]
1685 async fn test_nat_detection() {
1686 let servers = vec![StunServer::new("8.8.8.8:3478".parse().unwrap(), 1)];
1687
1688 let client = StunClient::new(servers);
1689
1690 match client.detect_nat().await {
1693 Ok(nat_info) => {
1694 println!("NAT type: {:?}", nat_info.nat_type);
1695 println!("Public IP: {:?}", nat_info.public_ip);
1696 }
1697 Err(e) => {
1698 println!("NAT detection failed: {}", e);
1699 }
1700 }
1701 }
1702
1703 #[test]
1704 fn test_nat_type_properties() {
1705 assert_eq!(NatType::None, NatType::None);
1706 assert_ne!(NatType::FullCone, NatType::Symmetric);
1707 }
1708}