1#![allow(missing_docs)]
8
9use std::{collections::HashMap, fmt, net::SocketAddr, sync::Arc, time::Duration};
16
17#[allow(clippy::panic)]
36fn create_random_port_bind_addr() -> SocketAddr {
37 "0.0.0.0:0"
38 .parse()
39 .unwrap_or_else(|_| panic!("Random port bind address format is always valid"))
40}
41
42use tracing::{debug, error, info, warn};
43
44use std::sync::atomic::{AtomicBool, Ordering};
45
46use tokio::{
47 net::UdpSocket,
48 sync::{mpsc, mpsc::error::TryRecvError},
49 time::{sleep, timeout},
50};
51
52use crate::high_level::default_runtime;
53
54use crate::{
55 VarInt,
56 candidate_discovery::{CandidateDiscoveryManager, DiscoveryConfig, DiscoveryEvent},
57 connection::nat_traversal::{CandidateSource, CandidateState, NatTraversalRole},
58};
59
60use crate::{
61 ClientConfig, ConnectionError, EndpointConfig, ServerConfig, TransportConfig,
62 high_level::{Connection as QuinnConnection, Endpoint as QuinnEndpoint},
63};
64
65#[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))]
66use crate::{crypto::rustls::QuicClientConfig, crypto::rustls::QuicServerConfig};
67
68use crate::config::validation::{ConfigValidator, ValidationResult};
69
70#[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))]
71use crate::crypto::raw_public_keys::RawPublicKeyConfigBuilder;
72
73pub struct NatTraversalEndpoint {
75 quinn_endpoint: Option<QuinnEndpoint>,
77 config: NatTraversalConfig,
81 bootstrap_nodes: Arc<std::sync::RwLock<Vec<BootstrapNode>>>,
83 active_sessions: Arc<std::sync::RwLock<HashMap<PeerId, NatTraversalSession>>>,
85 discovery_manager: Arc<std::sync::Mutex<CandidateDiscoveryManager>>,
87 event_callback: Option<Box<dyn Fn(NatTraversalEvent) + Send + Sync>>,
89 shutdown: Arc<AtomicBool>,
91 event_tx: Option<mpsc::UnboundedSender<NatTraversalEvent>>,
93 event_rx: std::sync::Mutex<mpsc::UnboundedReceiver<NatTraversalEvent>>,
95 connections: Arc<std::sync::RwLock<HashMap<PeerId, QuinnConnection>>>,
97 local_peer_id: PeerId,
99 timeout_config: crate::config::nat_timeouts::TimeoutConfig,
101 emitted_established_events: Arc<std::sync::RwLock<std::collections::HashSet<PeerId>>>,
104}
105
106#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
138pub struct NatTraversalConfig {
139 pub role: EndpointRole,
141 pub bootstrap_nodes: Vec<SocketAddr>,
143 pub max_candidates: usize,
145 pub coordination_timeout: Duration,
147 pub enable_symmetric_nat: bool,
149 pub enable_relay_fallback: bool,
151 pub max_concurrent_attempts: usize,
153 pub bind_addr: Option<SocketAddr>,
170 pub prefer_rfc_nat_traversal: bool,
173 pub timeouts: crate::config::nat_timeouts::TimeoutConfig,
175}
176
177#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
179pub enum EndpointRole {
180 Client,
182 Server {
184 can_coordinate: bool,
186 },
187 Bootstrap,
189}
190
191impl EndpointRole {
192 pub fn name(&self) -> &'static str {
194 match self {
195 Self::Client => "client",
196 Self::Server { .. } => "server",
197 Self::Bootstrap => "bootstrap",
198 }
199 }
200}
201
202#[derive(
204 Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, serde::Serialize, serde::Deserialize,
205)]
206pub struct PeerId(pub [u8; 32]);
207
208#[derive(Debug, Clone)]
210pub struct BootstrapNode {
211 pub address: SocketAddr,
213 pub last_seen: std::time::Instant,
215 pub can_coordinate: bool,
217 pub rtt: Option<Duration>,
219 pub coordination_count: u32,
221}
222
223impl BootstrapNode {
224 pub fn new(address: SocketAddr) -> Self {
226 Self {
227 address,
228 last_seen: std::time::Instant::now(),
229 can_coordinate: true,
230 rtt: None,
231 coordination_count: 0,
232 }
233 }
234}
235
236#[derive(Debug, Clone)]
238pub struct CandidatePair {
239 pub local_candidate: CandidateAddress,
241 pub remote_candidate: CandidateAddress,
243 pub priority: u64,
245 pub state: CandidatePairState,
247}
248
249#[derive(Debug, Clone, Copy, PartialEq, Eq)]
251pub enum CandidatePairState {
252 Waiting,
254 InProgress,
256 Succeeded,
258 Failed,
260 Cancelled,
262}
263
264#[derive(Debug)]
266struct NatTraversalSession {
267 peer_id: PeerId,
269 #[allow(dead_code)]
271 coordinator: SocketAddr,
272 attempt: u32,
274 started_at: std::time::Instant,
276 phase: TraversalPhase,
278 candidates: Vec<CandidateAddress>,
280 session_state: SessionState,
282}
283
284#[derive(Debug, Clone)]
286pub struct SessionState {
287 pub state: ConnectionState,
289 pub last_transition: std::time::Instant,
291 pub connection: Option<QuinnConnection>,
293 pub active_attempts: Vec<(SocketAddr, std::time::Instant)>,
295 pub metrics: ConnectionMetrics,
297}
298
299#[derive(Debug, Clone, Copy, PartialEq, Eq)]
301pub enum ConnectionState {
302 Idle,
304 Connecting,
306 Connected,
308 Migrating,
310 Closed,
312}
313
314#[derive(Debug, Clone, Default)]
316pub struct ConnectionMetrics {
317 pub rtt: Option<Duration>,
319 pub loss_rate: f64,
321 pub bytes_sent: u64,
323 pub bytes_received: u64,
325 pub last_activity: Option<std::time::Instant>,
327}
328
329#[derive(Debug, Clone)]
331pub struct SessionStateUpdate {
332 pub peer_id: PeerId,
334 pub old_state: ConnectionState,
336 pub new_state: ConnectionState,
338 pub reason: StateChangeReason,
340}
341
342#[derive(Debug, Clone, Copy, PartialEq, Eq)]
344pub enum StateChangeReason {
345 Timeout,
347 ConnectionEstablished,
349 ConnectionClosed,
351 MigrationComplete,
353 MigrationFailed,
355 NetworkError,
357 UserClosed,
359}
360
361#[derive(Debug, Clone, Copy, PartialEq, Eq)]
363pub enum TraversalPhase {
364 Discovery,
366 Coordination,
368 Synchronization,
370 Punching,
372 Validation,
374 Connected,
376 Failed,
378}
379
380#[derive(Debug, Clone, Copy)]
382enum SessionUpdate {
383 Timeout,
385 Disconnected,
387 UpdateMetrics,
389 InvalidState,
391 Retry,
393 MigrationTimeout,
395 Remove,
397}
398
399#[derive(Debug, Clone)]
401pub struct CandidateAddress {
402 pub address: SocketAddr,
404 pub priority: u32,
406 pub source: CandidateSource,
408 pub state: CandidateState,
410}
411
412impl CandidateAddress {
413 pub fn new(
415 address: SocketAddr,
416 priority: u32,
417 source: CandidateSource,
418 ) -> Result<Self, CandidateValidationError> {
419 Self::validate_address(&address)?;
420 Ok(Self {
421 address,
422 priority,
423 source,
424 state: CandidateState::New,
425 })
426 }
427
428 pub fn validate_address(addr: &SocketAddr) -> Result<(), CandidateValidationError> {
430 if addr.port() == 0 {
432 return Err(CandidateValidationError::InvalidPort(0));
433 }
434
435 #[cfg(not(test))]
437 if addr.port() < 1024 {
438 return Err(CandidateValidationError::PrivilegedPort(addr.port()));
439 }
440
441 match addr.ip() {
442 std::net::IpAddr::V4(ipv4) => {
443 if ipv4.is_unspecified() {
445 return Err(CandidateValidationError::UnspecifiedAddress);
446 }
447 if ipv4.is_broadcast() {
448 return Err(CandidateValidationError::BroadcastAddress);
449 }
450 if ipv4.is_multicast() {
451 return Err(CandidateValidationError::MulticastAddress);
452 }
453 if ipv4.octets()[0] == 0 {
455 return Err(CandidateValidationError::ReservedAddress);
456 }
457 if ipv4.octets()[0] >= 240 {
459 return Err(CandidateValidationError::ReservedAddress);
460 }
461 }
462 std::net::IpAddr::V6(ipv6) => {
463 if ipv6.is_unspecified() {
465 return Err(CandidateValidationError::UnspecifiedAddress);
466 }
467 if ipv6.is_multicast() {
468 return Err(CandidateValidationError::MulticastAddress);
469 }
470 let segments = ipv6.segments();
472 if segments[0] == 0x2001 && segments[1] == 0x0db8 {
473 return Err(CandidateValidationError::DocumentationAddress);
474 }
475 if ipv6.to_ipv4_mapped().is_some() {
477 return Err(CandidateValidationError::IPv4MappedAddress);
478 }
479 }
480 }
481
482 Ok(())
483 }
484
485 pub fn is_suitable_for_nat_traversal(&self) -> bool {
487 match self.address.ip() {
488 std::net::IpAddr::V4(ipv4) => {
489 #[cfg(test)]
494 if ipv4.is_loopback() {
495 return true;
496 }
497 !ipv4.is_loopback()
498 && !ipv4.is_link_local()
499 && !ipv4.is_multicast()
500 && !ipv4.is_broadcast()
501 }
502 std::net::IpAddr::V6(ipv6) => {
503 #[cfg(test)]
509 if ipv6.is_loopback() {
510 return true;
511 }
512 let segments = ipv6.segments();
513 let is_link_local = (segments[0] & 0xffc0) == 0xfe80;
514 let is_unique_local = (segments[0] & 0xfe00) == 0xfc00;
515
516 !ipv6.is_loopback() && !is_link_local && !is_unique_local && !ipv6.is_multicast()
517 }
518 }
519 }
520
521 pub fn effective_priority(&self) -> u32 {
523 match self.state {
524 CandidateState::Valid => self.priority,
525 CandidateState::New => self.priority.saturating_sub(10),
526 CandidateState::Validating => self.priority.saturating_sub(5),
527 CandidateState::Failed => 0,
528 CandidateState::Removed => 0,
529 }
530 }
531}
532
533#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
535pub enum CandidateValidationError {
536 #[error("invalid port number: {0}")]
538 InvalidPort(u16),
539 #[error("privileged port not allowed: {0}")]
541 PrivilegedPort(u16),
542 #[error("unspecified address not allowed")]
544 UnspecifiedAddress,
545 #[error("broadcast address not allowed")]
547 BroadcastAddress,
548 #[error("multicast address not allowed")]
550 MulticastAddress,
551 #[error("reserved address not allowed")]
553 ReservedAddress,
554 #[error("documentation address not allowed")]
556 DocumentationAddress,
557 #[error("IPv4-mapped IPv6 address not allowed")]
559 IPv4MappedAddress,
560}
561
562#[derive(Debug, Clone)]
564pub enum NatTraversalEvent {
565 CandidateDiscovered {
567 peer_id: PeerId,
569 candidate: CandidateAddress,
571 },
572 CoordinationRequested {
574 peer_id: PeerId,
576 coordinator: SocketAddr,
578 },
579 CoordinationSynchronized {
581 peer_id: PeerId,
583 round_id: VarInt,
585 },
586 HolePunchingStarted {
588 peer_id: PeerId,
590 targets: Vec<SocketAddr>,
592 },
593 PathValidated {
595 peer_id: PeerId,
597 address: SocketAddr,
599 rtt: Duration,
601 },
602 CandidateValidated {
604 peer_id: PeerId,
606 candidate_address: SocketAddr,
608 },
609 TraversalSucceeded {
611 peer_id: PeerId,
613 final_address: SocketAddr,
615 total_time: Duration,
617 },
618 ConnectionEstablished {
620 peer_id: PeerId,
621 remote_address: SocketAddr,
623 },
624 TraversalFailed {
626 peer_id: PeerId,
628 error: NatTraversalError,
630 fallback_available: bool,
632 },
633 ConnectionLost {
635 peer_id: PeerId,
637 reason: String,
639 },
640 PhaseTransition {
642 peer_id: PeerId,
644 from_phase: TraversalPhase,
646 to_phase: TraversalPhase,
648 },
649 SessionStateChanged {
651 peer_id: PeerId,
653 new_state: ConnectionState,
655 },
656}
657
658#[derive(Debug, Clone)]
660pub enum NatTraversalError {
661 NoBootstrapNodes,
663 NoCandidatesFound,
665 CandidateDiscoveryFailed(String),
667 CoordinationFailed(String),
669 HolePunchingFailed,
671 PunchingFailed(String),
673 ValidationFailed(String),
675 ValidationTimeout,
677 NetworkError(String),
679 ConfigError(String),
681 ProtocolError(String),
683 Timeout,
685 ConnectionFailed(String),
687 TraversalFailed(String),
689 PeerNotConnected,
691}
692
693impl Default for NatTraversalConfig {
694 fn default() -> Self {
695 Self {
696 role: EndpointRole::Client,
697 bootstrap_nodes: Vec::new(),
698 max_candidates: 8,
699 coordination_timeout: Duration::from_secs(10),
700 enable_symmetric_nat: true,
701 enable_relay_fallback: true,
702 max_concurrent_attempts: 3,
703 bind_addr: None,
704 prefer_rfc_nat_traversal: true, timeouts: crate::config::nat_timeouts::TimeoutConfig::default(),
706 }
707 }
708}
709
710impl ConfigValidator for NatTraversalConfig {
711 fn validate(&self) -> ValidationResult<()> {
712 use crate::config::validation::*;
713
714 match self.role {
716 EndpointRole::Client => {
717 if self.bootstrap_nodes.is_empty() {
718 return Err(ConfigValidationError::InvalidRole(
719 "Client endpoints require at least one bootstrap node".to_string(),
720 ));
721 }
722 }
723 EndpointRole::Server { can_coordinate } => {
724 if can_coordinate && self.bootstrap_nodes.is_empty() {
725 return Err(ConfigValidationError::InvalidRole(
726 "Server endpoints with coordination capability require bootstrap nodes"
727 .to_string(),
728 ));
729 }
730 }
731 EndpointRole::Bootstrap => {
732 }
734 }
735
736 if !self.bootstrap_nodes.is_empty() {
738 validate_bootstrap_nodes(&self.bootstrap_nodes)?;
739 }
740
741 validate_range(self.max_candidates, 1, 256, "max_candidates")?;
743
744 validate_duration(
746 self.coordination_timeout,
747 Duration::from_millis(100),
748 Duration::from_secs(300),
749 "coordination_timeout",
750 )?;
751
752 validate_range(
754 self.max_concurrent_attempts,
755 1,
756 16,
757 "max_concurrent_attempts",
758 )?;
759
760 if self.max_concurrent_attempts > self.max_candidates {
762 return Err(ConfigValidationError::IncompatibleConfiguration(
763 "max_concurrent_attempts cannot exceed max_candidates".to_string(),
764 ));
765 }
766
767 if self.role == EndpointRole::Bootstrap && self.enable_relay_fallback {
768 return Err(ConfigValidationError::IncompatibleConfiguration(
769 "Bootstrap nodes should not enable relay fallback".to_string(),
770 ));
771 }
772
773 Ok(())
774 }
775}
776
777impl NatTraversalEndpoint {
778 pub async fn new(
780 config: NatTraversalConfig,
781 event_callback: Option<Box<dyn Fn(NatTraversalEvent) + Send + Sync>>,
782 ) -> Result<Self, NatTraversalError> {
783 Self::new_impl(config, event_callback).await
784 }
785
786 async fn new_impl(
788 config: NatTraversalConfig,
789 event_callback: Option<Box<dyn Fn(NatTraversalEvent) + Send + Sync>>,
790 ) -> Result<Self, NatTraversalError> {
791 Self::new_common(config, event_callback).await
792 }
793
794 async fn new_common(
796 config: NatTraversalConfig,
797 event_callback: Option<Box<dyn Fn(NatTraversalEvent) + Send + Sync>>,
798 ) -> Result<Self, NatTraversalError> {
799 Self::new_shared_logic(config, event_callback).await
801 }
802
803 async fn new_shared_logic(
805 config: NatTraversalConfig,
806 event_callback: Option<Box<dyn Fn(NatTraversalEvent) + Send + Sync>>,
807 ) -> Result<Self, NatTraversalError> {
808 {
811 config
812 .validate()
813 .map_err(|e| NatTraversalError::ConfigError(e.to_string()))?;
814 }
815
816 let bootstrap_nodes = Arc::new(std::sync::RwLock::new(
820 config
821 .bootstrap_nodes
822 .iter()
823 .map(|&address| BootstrapNode {
824 address,
825 last_seen: std::time::Instant::now(),
826 can_coordinate: true, rtt: None,
828 coordination_count: 0,
829 })
830 .collect(),
831 ));
832
833 let discovery_config = DiscoveryConfig {
835 total_timeout: config.coordination_timeout,
836 max_candidates: config.max_candidates,
837 enable_symmetric_prediction: config.enable_symmetric_nat,
838 bound_address: config.bind_addr, ..DiscoveryConfig::default()
840 };
841
842 let nat_traversal_role = match config.role {
843 EndpointRole::Client => NatTraversalRole::Client,
844 EndpointRole::Server { can_coordinate } => NatTraversalRole::Server {
845 can_relay: can_coordinate,
846 },
847 EndpointRole::Bootstrap => NatTraversalRole::Bootstrap,
848 };
849
850 let discovery_manager = Arc::new(std::sync::Mutex::new(CandidateDiscoveryManager::new(
851 discovery_config,
852 )));
853
854 let (quinn_endpoint, event_tx, event_rx, local_addr) =
857 Self::create_quinn_endpoint(&config, nat_traversal_role).await?;
858
859 {
861 let mut discovery = discovery_manager.lock().map_err(|_| {
862 NatTraversalError::ProtocolError("Discovery manager lock poisoned".to_string())
863 })?;
864 discovery.set_bound_address(local_addr);
865 info!(
866 "Updated discovery manager with bound address: {}",
867 local_addr
868 );
869 }
870
871 let emitted_established_events = Arc::new(std::sync::RwLock::new(std::collections::HashSet::new()));
872
873 let endpoint = Self {
874 quinn_endpoint: Some(quinn_endpoint.clone()),
875 config: config.clone(),
876 bootstrap_nodes,
877 active_sessions: Arc::new(std::sync::RwLock::new(HashMap::new())),
878 discovery_manager,
879 event_callback,
880 shutdown: Arc::new(AtomicBool::new(false)),
881 event_tx: Some(event_tx.clone()),
882 event_rx: std::sync::Mutex::new(event_rx),
883 connections: Arc::new(std::sync::RwLock::new(HashMap::new())),
884 local_peer_id: Self::generate_local_peer_id(),
885 timeout_config: config.timeouts.clone(),
886 emitted_established_events: emitted_established_events.clone(),
887 };
888
889 if matches!(
891 config.role,
892 EndpointRole::Bootstrap | EndpointRole::Server { .. }
893 ) {
894 let endpoint_clone = quinn_endpoint.clone();
895 let shutdown_clone = endpoint.shutdown.clone();
896 let event_tx_clone = event_tx.clone();
897 let connections_clone = endpoint.connections.clone();
898 let emitted_events_clone = emitted_established_events.clone();
899
900 tokio::spawn(async move {
901 Self::accept_connections(
902 endpoint_clone,
903 shutdown_clone,
904 event_tx_clone,
905 connections_clone,
906 emitted_events_clone,
907 )
908 .await;
909 });
910
911 info!("Started accepting connections for {:?} role", config.role);
912 }
913
914 let discovery_manager_clone = endpoint.discovery_manager.clone();
916 let shutdown_clone = endpoint.shutdown.clone();
917 let event_tx_clone = event_tx;
918
919 tokio::spawn(async move {
920 Self::poll_discovery(discovery_manager_clone, shutdown_clone, event_tx_clone).await;
921 });
922
923 info!("Started discovery polling task");
924
925 {
927 let mut discovery = endpoint.discovery_manager.lock().map_err(|_| {
928 NatTraversalError::ProtocolError("Discovery manager lock poisoned".to_string())
929 })?;
930
931 let local_peer_id = endpoint.local_peer_id;
933 let bootstrap_nodes = {
934 let nodes = endpoint.bootstrap_nodes.read().map_err(|_| {
935 NatTraversalError::ProtocolError("Bootstrap nodes lock poisoned".to_string())
936 })?;
937 nodes.clone()
938 };
939
940 discovery
941 .start_discovery(local_peer_id, bootstrap_nodes)
942 .map_err(|e| NatTraversalError::CandidateDiscoveryFailed(e.to_string()))?;
943
944 info!(
945 "Started local candidate discovery for peer {:?}",
946 local_peer_id
947 );
948 }
949
950 Ok(endpoint)
951 }
952
953 pub fn get_quinn_endpoint(&self) -> Option<&crate::high_level::Endpoint> {
955 self.quinn_endpoint.as_ref()
956 }
957
958 pub fn get_event_callback(&self) -> Option<&Box<dyn Fn(NatTraversalEvent) + Send + Sync>> {
960 self.event_callback.as_ref()
961 }
962
963 pub fn initiate_nat_traversal(
965 &self,
966 peer_id: PeerId,
967 coordinator: SocketAddr,
968 ) -> Result<(), NatTraversalError> {
969 info!(
970 "Starting NAT traversal to peer {:?} via coordinator {}",
971 peer_id, coordinator
972 );
973
974 let session = NatTraversalSession {
976 peer_id,
977 coordinator,
978 attempt: 1,
979 started_at: std::time::Instant::now(),
980 phase: TraversalPhase::Discovery,
981 candidates: Vec::new(),
982 session_state: SessionState {
983 state: ConnectionState::Connecting,
984 last_transition: std::time::Instant::now(),
985
986 connection: None,
987 active_attempts: Vec::new(),
988 metrics: ConnectionMetrics::default(),
989 },
990 };
991
992 {
994 let mut sessions = self
995 .active_sessions
996 .write()
997 .map_err(|_| NatTraversalError::ProtocolError("Lock poisoned".to_string()))?;
998 sessions.insert(peer_id, session);
999 }
1000
1001 let bootstrap_nodes_vec = {
1003 let bootstrap_nodes = self
1004 .bootstrap_nodes
1005 .read()
1006 .map_err(|_| NatTraversalError::ProtocolError("Lock poisoned".to_string()))?;
1007 bootstrap_nodes.clone()
1008 };
1009
1010 {
1011 let mut discovery = self.discovery_manager.lock().map_err(|_| {
1012 NatTraversalError::ProtocolError("Discovery manager lock poisoned".to_string())
1013 })?;
1014
1015 discovery
1016 .start_discovery(peer_id, bootstrap_nodes_vec)
1017 .map_err(|e| NatTraversalError::CandidateDiscoveryFailed(e.to_string()))?;
1018 }
1019
1020 if let Some(ref callback) = self.event_callback {
1022 callback(NatTraversalEvent::CoordinationRequested {
1023 peer_id,
1024 coordinator,
1025 });
1026 }
1027
1028 Ok(())
1030 }
1031
1032 pub fn poll_sessions(&self) -> Result<Vec<SessionStateUpdate>, NatTraversalError> {
1034 let mut updates = Vec::new();
1035 let now = std::time::Instant::now();
1036
1037 let mut sessions = self
1038 .active_sessions
1039 .write()
1040 .map_err(|_| NatTraversalError::ProtocolError("Sessions lock poisoned".to_string()))?;
1041
1042 for (peer_id, session) in sessions.iter_mut() {
1043 let mut state_changed = false;
1044
1045 match session.session_state.state {
1046 ConnectionState::Connecting => {
1047 let elapsed = now.duration_since(session.session_state.last_transition);
1049 if elapsed
1050 > self
1051 .timeout_config
1052 .nat_traversal
1053 .connection_establishment_timeout
1054 {
1055 session.session_state.state = ConnectionState::Closed;
1056 session.session_state.last_transition = now;
1057 state_changed = true;
1058
1059 updates.push(SessionStateUpdate {
1060 peer_id: *peer_id,
1061 old_state: ConnectionState::Connecting,
1062 new_state: ConnectionState::Closed,
1063 reason: StateChangeReason::Timeout,
1064 });
1065 }
1066
1067 if let Some(ref _connection) = session.session_state.connection {
1070 session.session_state.state = ConnectionState::Connected;
1071 session.session_state.last_transition = now;
1072 state_changed = true;
1073
1074 updates.push(SessionStateUpdate {
1075 peer_id: *peer_id,
1076 old_state: ConnectionState::Connecting,
1077 new_state: ConnectionState::Connected,
1078 reason: StateChangeReason::ConnectionEstablished,
1079 });
1080 }
1081 }
1082 ConnectionState::Connected => {
1083 {
1086 }
1089
1090 session.session_state.metrics.last_activity = Some(now);
1092 }
1093 ConnectionState::Migrating => {
1094 let elapsed = now.duration_since(session.session_state.last_transition);
1096 if elapsed > Duration::from_secs(10) {
1097 if session.session_state.connection.is_some() {
1100 session.session_state.state = ConnectionState::Connected;
1101 state_changed = true;
1102
1103 updates.push(SessionStateUpdate {
1104 peer_id: *peer_id,
1105 old_state: ConnectionState::Migrating,
1106 new_state: ConnectionState::Connected,
1107 reason: StateChangeReason::MigrationComplete,
1108 });
1109 } else {
1110 session.session_state.state = ConnectionState::Closed;
1111 state_changed = true;
1112
1113 updates.push(SessionStateUpdate {
1114 peer_id: *peer_id,
1115 old_state: ConnectionState::Migrating,
1116 new_state: ConnectionState::Closed,
1117 reason: StateChangeReason::MigrationFailed,
1118 });
1119 }
1120
1121 session.session_state.last_transition = now;
1122 }
1123 }
1124 _ => {}
1125 }
1126
1127 if state_changed {
1129 if let Some(ref callback) = self.event_callback {
1130 callback(NatTraversalEvent::SessionStateChanged {
1131 peer_id: *peer_id,
1132 new_state: session.session_state.state,
1133 });
1134 }
1135 }
1136 }
1137
1138 Ok(updates)
1139 }
1140
1141 pub fn start_session_polling(&self, interval: Duration) -> tokio::task::JoinHandle<()> {
1143 let sessions = self.active_sessions.clone();
1144 let shutdown = self.shutdown.clone();
1145 let timeout_config = self.timeout_config.clone();
1146
1147 tokio::spawn(async move {
1148 let mut ticker = tokio::time::interval(interval);
1149
1150 loop {
1151 ticker.tick().await;
1152
1153 if shutdown.load(Ordering::Relaxed) {
1154 break;
1155 }
1156
1157 let sessions_to_update = {
1159 match sessions.read() {
1160 Ok(sessions_guard) => {
1161 sessions_guard
1162 .iter()
1163 .filter_map(|(peer_id, session)| {
1164 let now = std::time::Instant::now();
1165 let elapsed =
1166 now.duration_since(session.session_state.last_transition);
1167
1168 match session.session_state.state {
1169 ConnectionState::Connecting => {
1170 if elapsed
1172 > timeout_config
1173 .nat_traversal
1174 .connection_establishment_timeout
1175 {
1176 Some((*peer_id, SessionUpdate::Timeout))
1177 } else {
1178 None
1179 }
1180 }
1181 ConnectionState::Connected => {
1182 if let Some(ref conn) = session.session_state.connection
1184 {
1185 if conn.close_reason().is_some() {
1186 Some((*peer_id, SessionUpdate::Disconnected))
1187 } else {
1188 Some((*peer_id, SessionUpdate::UpdateMetrics))
1190 }
1191 } else {
1192 Some((*peer_id, SessionUpdate::InvalidState))
1193 }
1194 }
1195 ConnectionState::Idle => {
1196 if elapsed
1198 > timeout_config
1199 .discovery
1200 .server_reflexive_cache_ttl
1201 {
1202 Some((*peer_id, SessionUpdate::Retry))
1203 } else {
1204 None
1205 }
1206 }
1207 ConnectionState::Migrating => {
1208 if elapsed > timeout_config.nat_traversal.probe_timeout
1210 {
1211 Some((*peer_id, SessionUpdate::MigrationTimeout))
1212 } else {
1213 None
1214 }
1215 }
1216 ConnectionState::Closed => {
1217 if elapsed
1219 > timeout_config.discovery.interface_cache_ttl
1220 {
1221 Some((*peer_id, SessionUpdate::Remove))
1222 } else {
1223 None
1224 }
1225 }
1226 }
1227 })
1228 .collect::<Vec<_>>()
1229 }
1230 _ => {
1231 vec![]
1232 }
1233 }
1234 };
1235
1236 if !sessions_to_update.is_empty() {
1238 if let Ok(mut sessions_guard) = sessions.write() {
1239 for (peer_id, update) in sessions_to_update {
1240 match update {
1241 SessionUpdate::Timeout => {
1242 if let Some(session) = sessions_guard.get_mut(&peer_id) {
1243 session.session_state.state = ConnectionState::Closed;
1244 session.session_state.last_transition =
1245 std::time::Instant::now();
1246 tracing::warn!("Connection to {:?} timed out", peer_id);
1247 }
1248 }
1249 SessionUpdate::Disconnected => {
1250 if let Some(session) = sessions_guard.get_mut(&peer_id) {
1251 session.session_state.state = ConnectionState::Closed;
1252 session.session_state.last_transition =
1253 std::time::Instant::now();
1254 session.session_state.connection = None;
1255 tracing::info!("Connection to {:?} closed", peer_id);
1256 }
1257 }
1258 SessionUpdate::UpdateMetrics => {
1259 if let Some(session) = sessions_guard.get_mut(&peer_id) {
1260 if let Some(ref conn) = session.session_state.connection {
1261 let stats = conn.stats();
1263 session.session_state.metrics.rtt =
1264 Some(stats.path.rtt);
1265 session.session_state.metrics.loss_rate =
1266 stats.path.lost_packets as f64
1267 / stats.path.sent_packets.max(1) as f64;
1268 }
1269 }
1270 }
1271 SessionUpdate::InvalidState => {
1272 if let Some(session) = sessions_guard.get_mut(&peer_id) {
1273 session.session_state.state = ConnectionState::Closed;
1274 session.session_state.last_transition =
1275 std::time::Instant::now();
1276 tracing::error!("Session {:?} in invalid state", peer_id);
1277 }
1278 }
1279 SessionUpdate::Retry => {
1280 if let Some(session) = sessions_guard.get_mut(&peer_id) {
1281 session.session_state.state = ConnectionState::Connecting;
1282 session.session_state.last_transition =
1283 std::time::Instant::now();
1284 session.attempt += 1;
1285 tracing::info!(
1286 "Retrying connection to {:?} (attempt {})",
1287 peer_id,
1288 session.attempt
1289 );
1290 }
1291 }
1292 SessionUpdate::MigrationTimeout => {
1293 if let Some(session) = sessions_guard.get_mut(&peer_id) {
1294 session.session_state.state = ConnectionState::Closed;
1295 session.session_state.last_transition =
1296 std::time::Instant::now();
1297 tracing::warn!("Migration timeout for {:?}", peer_id);
1298 }
1299 }
1300 SessionUpdate::Remove => {
1301 sessions_guard.remove(&peer_id);
1302 tracing::debug!("Removed old session for {:?}", peer_id);
1303 }
1304 }
1305 }
1306 }
1307 }
1308 }
1309 })
1310 }
1311
1312 pub fn get_statistics(&self) -> Result<NatTraversalStatistics, NatTraversalError> {
1316 let sessions = self
1317 .active_sessions
1318 .read()
1319 .map_err(|_| NatTraversalError::ProtocolError("Lock poisoned".to_string()))?;
1320 let bootstrap_nodes = self
1321 .bootstrap_nodes
1322 .read()
1323 .map_err(|_| NatTraversalError::ProtocolError("Lock poisoned".to_string()))?;
1324
1325 let avg_coordination_time = {
1327 let rtts: Vec<Duration> = bootstrap_nodes.iter().filter_map(|b| b.rtt).collect();
1328
1329 if rtts.is_empty() {
1330 Duration::from_millis(500) } else {
1332 let total_millis: u64 = rtts.iter().map(|d| d.as_millis() as u64).sum();
1333 Duration::from_millis(total_millis / rtts.len() as u64 * 2) }
1335 };
1336
1337 Ok(NatTraversalStatistics {
1338 active_sessions: sessions.len(),
1339 total_bootstrap_nodes: bootstrap_nodes.len(),
1340 successful_coordinations: bootstrap_nodes.iter().map(|b| b.coordination_count).sum(),
1341 average_coordination_time: avg_coordination_time,
1342 total_attempts: 0,
1343 successful_connections: 0,
1344 direct_connections: 0,
1345 relayed_connections: 0,
1346 })
1347 }
1348
1349 pub fn add_bootstrap_node(&self, address: SocketAddr) -> Result<(), NatTraversalError> {
1351 let mut bootstrap_nodes = self
1352 .bootstrap_nodes
1353 .write()
1354 .map_err(|_| NatTraversalError::ProtocolError("Lock poisoned".to_string()))?;
1355
1356 if !bootstrap_nodes.iter().any(|b| b.address == address) {
1358 bootstrap_nodes.push(BootstrapNode {
1359 address,
1360 last_seen: std::time::Instant::now(),
1361 can_coordinate: true,
1362 rtt: None,
1363 coordination_count: 0,
1364 });
1365 info!("Added bootstrap node: {}", address);
1366 }
1367 Ok(())
1368 }
1369
1370 pub fn remove_bootstrap_node(&self, address: SocketAddr) -> Result<(), NatTraversalError> {
1372 let mut bootstrap_nodes = self
1373 .bootstrap_nodes
1374 .write()
1375 .map_err(|_| NatTraversalError::ProtocolError("Lock poisoned".to_string()))?;
1376 bootstrap_nodes.retain(|b| b.address != address);
1377 info!("Removed bootstrap node: {}", address);
1378 Ok(())
1379 }
1380
1381 async fn create_quinn_endpoint(
1385 config: &NatTraversalConfig,
1386 _nat_role: NatTraversalRole,
1387 ) -> Result<
1388 (
1389 QuinnEndpoint,
1390 mpsc::UnboundedSender<NatTraversalEvent>,
1391 mpsc::UnboundedReceiver<NatTraversalEvent>,
1392 SocketAddr,
1393 ),
1394 NatTraversalError,
1395 > {
1396 use std::sync::Arc;
1397
1398 let server_config = match config.role {
1400 EndpointRole::Bootstrap | EndpointRole::Server { .. } => {
1401 info!(
1402 "Creating server config for role: {:?} using Raw Public Keys (RFC 7250)",
1403 config.role
1404 );
1405
1406 let (server_key, _public_key) =
1408 crate::crypto::raw_public_keys::key_utils::generate_ed25519_keypair();
1409
1410 let rpk_config = RawPublicKeyConfigBuilder::new()
1412 .with_server_key(server_key)
1413 .allow_any_key() .build_rfc7250_server_config()
1415 .map_err(|e| {
1416 NatTraversalError::ConfigError(format!("RPK server config failed: {e}"))
1417 })?;
1418
1419 let server_crypto = QuicServerConfig::try_from(rpk_config.inner().as_ref().clone())
1420 .map_err(|e| NatTraversalError::ConfigError(e.to_string()))?;
1421
1422 let mut server_config = ServerConfig::with_crypto(Arc::new(server_crypto));
1423
1424 let mut transport_config = TransportConfig::default();
1426 transport_config
1427 .keep_alive_interval(Some(config.timeouts.nat_traversal.retry_interval));
1428 transport_config.max_idle_timeout(Some(crate::VarInt::from_u32(30000).into()));
1429
1430 let nat_config = match config.role {
1435 EndpointRole::Client => {
1436 crate::transport_parameters::NatTraversalConfig::ClientSupport
1437 }
1438 EndpointRole::Bootstrap | EndpointRole::Server { .. } => {
1439 crate::transport_parameters::NatTraversalConfig::ServerSupport {
1440 concurrency_limit: VarInt::from_u32(
1441 config.max_concurrent_attempts as u32,
1442 ),
1443 }
1444 }
1445 };
1446 transport_config.nat_traversal_config(Some(nat_config));
1447
1448 server_config.transport_config(Arc::new(transport_config));
1449
1450 Some(server_config)
1451 }
1452 _ => None,
1453 };
1454
1455 let client_config = {
1457 info!("Creating client config using Raw Public Keys (RFC 7250)");
1458
1459 let rpk_config = RawPublicKeyConfigBuilder::new()
1461 .allow_any_key() .build_rfc7250_client_config()
1463 .map_err(|e| {
1464 NatTraversalError::ConfigError(format!("RPK client config failed: {e}"))
1465 })?;
1466
1467 let client_crypto = QuicClientConfig::try_from(rpk_config.inner().as_ref().clone())
1468 .map_err(|e| NatTraversalError::ConfigError(e.to_string()))?;
1469
1470 let mut client_config = ClientConfig::new(Arc::new(client_crypto));
1471
1472 let mut transport_config = TransportConfig::default();
1474 transport_config.keep_alive_interval(Some(Duration::from_secs(5)));
1475 transport_config.max_idle_timeout(Some(crate::VarInt::from_u32(30000).into()));
1476
1477 let nat_config = match config.role {
1482 EndpointRole::Client => {
1483 crate::transport_parameters::NatTraversalConfig::ClientSupport
1484 }
1485 EndpointRole::Bootstrap | EndpointRole::Server { .. } => {
1486 crate::transport_parameters::NatTraversalConfig::ServerSupport {
1487 concurrency_limit: VarInt::from_u32(config.max_concurrent_attempts as u32),
1488 }
1489 }
1490 };
1491 transport_config.nat_traversal_config(Some(nat_config));
1492
1493 client_config.transport_config(Arc::new(transport_config));
1494
1495 client_config
1496 };
1497
1498 let bind_addr = config
1500 .bind_addr
1501 .unwrap_or_else(create_random_port_bind_addr);
1502 let socket = UdpSocket::bind(bind_addr).await.map_err(|e| {
1503 NatTraversalError::NetworkError(format!("Failed to bind UDP socket: {e}"))
1504 })?;
1505
1506 info!("Binding endpoint to {}", bind_addr);
1507
1508 let std_socket = socket.into_std().map_err(|e| {
1510 NatTraversalError::NetworkError(format!("Failed to convert socket: {e}"))
1511 })?;
1512
1513 let runtime = default_runtime().ok_or_else(|| {
1515 NatTraversalError::ConfigError("No compatible async runtime found".to_string())
1516 })?;
1517
1518 let mut endpoint = QuinnEndpoint::new(
1519 EndpointConfig::default(),
1520 server_config,
1521 std_socket,
1522 runtime,
1523 )
1524 .map_err(|e| {
1525 NatTraversalError::ConfigError(format!("Failed to create Quinn endpoint: {e}"))
1526 })?;
1527
1528 endpoint.set_default_client_config(client_config);
1530
1531 let local_addr = endpoint.local_addr().map_err(|e| {
1533 NatTraversalError::NetworkError(format!("Failed to get local address: {e}"))
1534 })?;
1535
1536 info!("Endpoint bound to actual address: {}", local_addr);
1537
1538 let (event_tx, event_rx) = mpsc::unbounded_channel();
1540
1541 Ok((endpoint, event_tx, event_rx, local_addr))
1542 }
1543
1544 #[allow(clippy::panic)]
1546 pub async fn start_listening(&self, bind_addr: SocketAddr) -> Result<(), NatTraversalError> {
1547 let endpoint = self.quinn_endpoint.as_ref().ok_or_else(|| {
1548 NatTraversalError::ConfigError("Quinn endpoint not initialized".to_string())
1549 })?;
1550
1551 let _socket = UdpSocket::bind(bind_addr).await.map_err(|e| {
1553 NatTraversalError::NetworkError(format!("Failed to bind to {bind_addr}: {e}"))
1554 })?;
1555
1556 info!("Started listening on {}", bind_addr);
1557
1558 let endpoint_clone = endpoint.clone();
1560 let shutdown_clone = self.shutdown.clone();
1561 let event_tx = self
1562 .event_tx
1563 .as_ref()
1564 .unwrap_or_else(|| panic!("event transmitter should be initialized"))
1565 .clone();
1566 let connections_clone = self.connections.clone();
1567 let emitted_events_clone = self.emitted_established_events.clone();
1568
1569 tokio::spawn(async move {
1570 Self::accept_connections(endpoint_clone, shutdown_clone, event_tx, connections_clone, emitted_events_clone)
1571 .await;
1572 });
1573
1574 Ok(())
1575 }
1576
1577 async fn accept_connections(
1579 endpoint: QuinnEndpoint,
1580 shutdown: Arc<AtomicBool>,
1581 event_tx: mpsc::UnboundedSender<NatTraversalEvent>,
1582 connections: Arc<std::sync::RwLock<HashMap<PeerId, QuinnConnection>>>,
1583 emitted_events: Arc<std::sync::RwLock<std::collections::HashSet<PeerId>>>,
1584 ) {
1585 while !shutdown.load(Ordering::Relaxed) {
1586 match endpoint.accept().await {
1587 Some(connecting) => {
1588 let event_tx = event_tx.clone();
1589 let connections = connections.clone();
1590 let emitted_events = emitted_events.clone();
1591 tokio::spawn(async move {
1592 match connecting.await {
1593 Ok(connection) => {
1594 info!("Accepted connection from {}", connection.remote_address());
1595
1596 let peer_id = Self::generate_peer_id_from_address(
1598 connection.remote_address(),
1599 );
1600
1601 if let Ok(mut conns) = connections.write() {
1603 conns.insert(peer_id, connection.clone());
1604 }
1605
1606 let should_emit = if let Ok(mut emitted) = emitted_events.write() {
1608 emitted.insert(peer_id) } else {
1610 true };
1612
1613 if should_emit {
1614 let _ = event_tx.send(NatTraversalEvent::ConnectionEstablished {
1615 peer_id,
1616 remote_address: connection.remote_address(),
1617 });
1618 }
1619
1620 Self::handle_connection(peer_id, connection, event_tx).await;
1622 }
1623 Err(e) => {
1624 debug!("Connection failed: {}", e);
1625 }
1626 }
1627 });
1628 }
1629 None => {
1630 break;
1632 }
1633 }
1634 }
1635 }
1636
1637 async fn poll_discovery(
1639 discovery_manager: Arc<std::sync::Mutex<CandidateDiscoveryManager>>,
1640 shutdown: Arc<AtomicBool>,
1641 _event_tx: mpsc::UnboundedSender<NatTraversalEvent>,
1642 ) {
1643 use tokio::time::{Duration, interval};
1644
1645 let mut poll_interval = interval(Duration::from_millis(100));
1646
1647 while !shutdown.load(Ordering::Relaxed) {
1648 poll_interval.tick().await;
1649
1650 let events = match discovery_manager.lock() {
1652 Ok(mut discovery) => discovery.poll(std::time::Instant::now()),
1653 Err(e) => {
1654 error!("Failed to lock discovery manager: {}", e);
1655 continue;
1656 }
1657 };
1658
1659 for event in events {
1661 match event {
1662 DiscoveryEvent::DiscoveryStarted {
1663 peer_id,
1664 bootstrap_count,
1665 } => {
1666 debug!(
1667 "Discovery started for peer {:?} with {} bootstrap nodes",
1668 peer_id, bootstrap_count
1669 );
1670 }
1671 DiscoveryEvent::LocalScanningStarted => {
1672 debug!("Local interface scanning started");
1673 }
1674 DiscoveryEvent::LocalCandidateDiscovered { candidate } => {
1675 debug!("Discovered local candidate: {}", candidate.address);
1676 }
1679 DiscoveryEvent::LocalScanningCompleted {
1680 candidate_count,
1681 duration,
1682 } => {
1683 debug!(
1684 "Local interface scanning completed: {} candidates in {:?}",
1685 candidate_count, duration
1686 );
1687 }
1688 DiscoveryEvent::ServerReflexiveDiscoveryStarted { bootstrap_count } => {
1689 debug!(
1690 "Server reflexive discovery started with {} bootstrap nodes",
1691 bootstrap_count
1692 );
1693 }
1694 DiscoveryEvent::ServerReflexiveCandidateDiscovered {
1695 candidate,
1696 bootstrap_node,
1697 } => {
1698 debug!(
1699 "Discovered server-reflexive candidate {} via bootstrap {}",
1700 candidate.address, bootstrap_node
1701 );
1702 }
1704 DiscoveryEvent::BootstrapQueryFailed {
1705 bootstrap_node,
1706 error,
1707 } => {
1708 debug!("Bootstrap query failed for {}: {}", bootstrap_node, error);
1709 }
1710 DiscoveryEvent::PortAllocationDetected {
1712 port,
1713 source_address,
1714 bootstrap_node,
1715 timestamp,
1716 } => {
1717 debug!(
1718 "Port allocation detected: port {} from {} via bootstrap {:?} at {:?}",
1719 port, source_address, bootstrap_node, timestamp
1720 );
1721 }
1722 DiscoveryEvent::DiscoveryCompleted {
1723 candidate_count,
1724 total_duration,
1725 success_rate,
1726 } => {
1727 info!(
1728 "Discovery completed with {} candidates in {:?} (success rate: {:.2}%)",
1729 candidate_count,
1730 total_duration,
1731 success_rate * 100.0
1732 );
1733 }
1736 DiscoveryEvent::DiscoveryFailed {
1737 error,
1738 partial_results,
1739 } => {
1740 warn!(
1741 "Discovery failed: {} (found {} partial candidates)",
1742 error,
1743 partial_results.len()
1744 );
1745
1746 }
1751 DiscoveryEvent::PathValidationRequested {
1752 candidate_id,
1753 candidate_address,
1754 challenge_token,
1755 } => {
1756 debug!(
1757 "PATH_CHALLENGE requested for candidate {} at {} with token {:08x}",
1758 candidate_id.0, candidate_address, challenge_token
1759 );
1760 }
1763 DiscoveryEvent::PathValidationResponse {
1764 candidate_id,
1765 candidate_address,
1766 challenge_token: _,
1767 rtt,
1768 } => {
1769 debug!(
1770 "PATH_RESPONSE received for candidate {} at {} with RTT {:?}",
1771 candidate_id.0, candidate_address, rtt
1772 );
1773 }
1775 }
1776 }
1777 }
1778
1779 info!("Discovery polling task shutting down");
1780 }
1781
1782 async fn handle_connection(
1784 peer_id: PeerId,
1785 connection: QuinnConnection,
1786 event_tx: mpsc::UnboundedSender<NatTraversalEvent>,
1787 ) {
1788 let remote_address = connection.remote_address();
1789 let closed = connection.closed();
1790 tokio::pin!(closed);
1791
1792 debug!(
1793 "Handling connection from peer {:?} at {}",
1794 peer_id, remote_address
1795 );
1796
1797 closed.await;
1801
1802 let reason = connection
1803 .close_reason()
1804 .map(|reason| format!("Connection closed: {reason}"))
1805 .unwrap_or_else(|| "Connection closed".to_string());
1806 let _ = event_tx.send(NatTraversalEvent::ConnectionLost { peer_id, reason });
1807 }
1808
1809 async fn handle_bi_stream(
1811 _send: crate::high_level::SendStream,
1812 _recv: crate::high_level::RecvStream,
1813 ) {
1814 }
1843
1844 async fn handle_uni_stream(mut recv: crate::high_level::RecvStream) {
1846 let mut buffer = vec![0u8; 1024];
1847
1848 loop {
1849 match recv.read(&mut buffer).await {
1850 Ok(Some(size)) => {
1851 debug!("Received {} bytes on unidirectional stream", size);
1852 }
1854 Ok(None) => {
1855 debug!("Unidirectional stream closed by peer");
1856 break;
1857 }
1858 Err(e) => {
1859 debug!("Error reading from unidirectional stream: {}", e);
1860 break;
1861 }
1862 }
1863 }
1864 }
1865
1866 pub async fn connect_to_peer(
1868 &self,
1869 peer_id: PeerId,
1870 server_name: &str,
1871 remote_addr: SocketAddr,
1872 ) -> Result<QuinnConnection, NatTraversalError> {
1873 let endpoint = self.quinn_endpoint.as_ref().ok_or_else(|| {
1874 NatTraversalError::ConfigError("Quinn endpoint not initialized".to_string())
1875 })?;
1876
1877 info!("Connecting to peer {:?} at {}", peer_id, remote_addr);
1878
1879 let connecting = endpoint.connect(remote_addr, server_name).map_err(|e| {
1881 NatTraversalError::ConnectionFailed(format!("Failed to initiate connection: {e}"))
1882 })?;
1883
1884 let connection = timeout(
1885 self.timeout_config
1886 .nat_traversal
1887 .connection_establishment_timeout,
1888 connecting,
1889 )
1890 .await
1891 .map_err(|_| NatTraversalError::Timeout)?
1892 .map_err(|e| NatTraversalError::ConnectionFailed(format!("Connection failed: {e}")))?;
1893
1894 info!(
1895 "Successfully connected to peer {:?} at {}",
1896 peer_id, remote_addr
1897 );
1898
1899 if let Some(ref event_tx) = self.event_tx {
1901 let _ = event_tx.send(NatTraversalEvent::ConnectionEstablished {
1902 peer_id,
1903 remote_address: remote_addr,
1904 });
1905 }
1906
1907 Ok(connection)
1908 }
1909
1910 pub async fn accept_connection(&self) -> Result<(PeerId, QuinnConnection), NatTraversalError> {
1912 info!("Waiting for incoming connection via event channel...");
1913
1914 let timeout_duration = self
1915 .timeout_config
1916 .nat_traversal
1917 .connection_establishment_timeout;
1918 let start = std::time::Instant::now();
1919
1920 loop {
1921 if self.shutdown.load(Ordering::Relaxed) {
1923 return Err(NatTraversalError::NetworkError(
1924 "Endpoint shutting down".to_string(),
1925 ));
1926 }
1927
1928 if start.elapsed() > timeout_duration {
1930 warn!("accept_connection() timed out after {:?}", timeout_duration);
1931 return Err(NatTraversalError::Timeout);
1932 }
1933
1934 {
1936 let mut event_rx = self.event_rx.lock().map_err(|_| {
1937 NatTraversalError::ProtocolError("Event channel lock poisoned".to_string())
1938 })?;
1939
1940 match event_rx.try_recv() {
1941 Ok(NatTraversalEvent::ConnectionEstablished {
1942 peer_id,
1943 remote_address,
1944 }) => {
1945 info!(
1946 "Received ConnectionEstablished event for peer {:?} at {}",
1947 peer_id, remote_address
1948 );
1949
1950 let connection = {
1953 let connections = self.connections.read().map_err(|_| {
1954 NatTraversalError::ProtocolError(
1955 "Connections lock poisoned".to_string(),
1956 )
1957 })?;
1958 connections.get(&peer_id).cloned().ok_or_else(|| {
1959 NatTraversalError::ConnectionFailed(format!(
1960 "Connection for peer {:?} not found in storage",
1961 peer_id
1962 ))
1963 })?
1964 };
1965
1966 info!(
1967 "Retrieved accepted connection from peer {:?} at {}",
1968 peer_id, remote_address
1969 );
1970 return Ok((peer_id, connection));
1971 }
1972 Ok(event) => {
1973 debug!(
1975 "Ignoring non-connection event while waiting for accept: {:?}",
1976 event
1977 );
1978 }
1979 Err(mpsc::error::TryRecvError::Empty) => {
1980 }
1982 Err(mpsc::error::TryRecvError::Disconnected) => {
1983 return Err(NatTraversalError::NetworkError(
1984 "Event channel closed".to_string(),
1985 ));
1986 }
1987 }
1988 } tokio::time::sleep(Duration::from_millis(10)).await;
1992 }
1993 }
1994
1995 pub fn local_peer_id(&self) -> PeerId {
1997 self.local_peer_id
1998 }
1999
2000 pub fn get_connection(
2002 &self,
2003 peer_id: &PeerId,
2004 ) -> Result<Option<QuinnConnection>, NatTraversalError> {
2005 let connections = self.connections.read().map_err(|_| {
2006 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
2007 })?;
2008 Ok(connections.get(peer_id).cloned())
2009 }
2010
2011 pub fn add_connection(
2013 &self,
2014 peer_id: PeerId,
2015 connection: QuinnConnection,
2016 ) -> Result<(), NatTraversalError> {
2017 let mut connections = self.connections.write().map_err(|_| {
2018 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
2019 })?;
2020 connections.insert(peer_id, connection);
2021 Ok(())
2022 }
2023
2024 pub fn spawn_connection_handler(
2026 &self,
2027 peer_id: PeerId,
2028 connection: QuinnConnection,
2029 ) -> Result<(), NatTraversalError> {
2030 let event_tx = self.event_tx.as_ref().cloned().ok_or_else(|| {
2031 NatTraversalError::ConfigError("NAT traversal event channel not configured".to_string())
2032 })?;
2033
2034 let remote_address = connection.remote_address();
2035
2036 let should_emit = if let Ok(mut emitted) = self.emitted_established_events.write() {
2038 emitted.insert(peer_id) } else {
2040 true };
2042
2043 if should_emit {
2044 let _ = event_tx.send(NatTraversalEvent::ConnectionEstablished {
2045 peer_id,
2046 remote_address,
2047 });
2048 }
2049
2050 tokio::spawn(async move {
2052 Self::handle_connection(peer_id, connection, event_tx).await;
2053 });
2054
2055 Ok(())
2056 }
2057
2058 pub fn remove_connection(
2060 &self,
2061 peer_id: &PeerId,
2062 ) -> Result<Option<QuinnConnection>, NatTraversalError> {
2063 if let Ok(mut emitted) = self.emitted_established_events.write() {
2065 emitted.remove(peer_id);
2066 }
2067
2068 let mut connections = self.connections.write().map_err(|_| {
2069 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
2070 })?;
2071 Ok(connections.remove(peer_id))
2072 }
2073
2074 pub fn list_connections(&self) -> Result<Vec<(PeerId, SocketAddr)>, NatTraversalError> {
2076 let connections = self.connections.read().map_err(|_| {
2077 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
2078 })?;
2079 let mut result = Vec::new();
2080 for (peer_id, connection) in connections.iter() {
2081 result.push((*peer_id, connection.remote_address()));
2082 }
2083 Ok(result)
2084 }
2085
2086 pub async fn handle_connection_data(
2088 &self,
2089 peer_id: PeerId,
2090 connection: &QuinnConnection,
2091 ) -> Result<(), NatTraversalError> {
2092 info!("Handling connection data from peer {:?}", peer_id);
2093
2094 let connection_clone = connection.clone();
2096 let peer_id_clone = peer_id;
2097 tokio::spawn(async move {
2098 loop {
2099 match connection_clone.accept_bi().await {
2100 Ok((send, recv)) => {
2101 debug!(
2102 "Accepted bidirectional stream from peer {:?}",
2103 peer_id_clone
2104 );
2105 tokio::spawn(Self::handle_bi_stream(send, recv));
2106 }
2107 Err(ConnectionError::ApplicationClosed(_)) => {
2108 debug!("Connection closed by peer {:?}", peer_id_clone);
2109 break;
2110 }
2111 Err(e) => {
2112 debug!(
2113 "Error accepting bidirectional stream from peer {:?}: {}",
2114 peer_id_clone, e
2115 );
2116 break;
2117 }
2118 }
2119 }
2120 });
2121
2122 let connection_clone = connection.clone();
2124 let peer_id_clone = peer_id;
2125 tokio::spawn(async move {
2126 loop {
2127 match connection_clone.accept_uni().await {
2128 Ok(recv) => {
2129 debug!(
2130 "Accepted unidirectional stream from peer {:?}",
2131 peer_id_clone
2132 );
2133 tokio::spawn(Self::handle_uni_stream(recv));
2134 }
2135 Err(ConnectionError::ApplicationClosed(_)) => {
2136 debug!("Connection closed by peer {:?}", peer_id_clone);
2137 break;
2138 }
2139 Err(e) => {
2140 debug!(
2141 "Error accepting unidirectional stream from peer {:?}: {}",
2142 peer_id_clone, e
2143 );
2144 break;
2145 }
2146 }
2147 }
2148 });
2149
2150 Ok(())
2151 }
2152
2153 fn generate_local_peer_id() -> PeerId {
2155 use std::collections::hash_map::DefaultHasher;
2156 use std::hash::{Hash, Hasher};
2157 use std::time::SystemTime;
2158
2159 let mut hasher = DefaultHasher::new();
2160 SystemTime::now().hash(&mut hasher);
2161 std::process::id().hash(&mut hasher);
2162
2163 let hash = hasher.finish();
2164 let mut peer_id = [0u8; 32];
2165 peer_id[0..8].copy_from_slice(&hash.to_be_bytes());
2166
2167 for i in 8..32 {
2169 peer_id[i] = rand::random();
2170 }
2171
2172 PeerId(peer_id)
2173 }
2174
2175 fn generate_peer_id_from_address(addr: SocketAddr) -> PeerId {
2181 use std::collections::hash_map::DefaultHasher;
2182 use std::hash::{Hash, Hasher};
2183
2184 let mut hasher = DefaultHasher::new();
2185 addr.hash(&mut hasher);
2186
2187 let hash = hasher.finish();
2188 let mut peer_id = [0u8; 32];
2189 peer_id[0..8].copy_from_slice(&hash.to_be_bytes());
2190
2191 for i in 8..32 {
2194 peer_id[i] = rand::random();
2195 }
2196
2197 warn!(
2198 "Generated temporary peer ID from address {}. This ID is not persistent!",
2199 addr
2200 );
2201 PeerId(peer_id)
2202 }
2203
2204 pub async fn extract_peer_id_from_connection(
2206 &self,
2207 connection: &QuinnConnection,
2208 ) -> Option<PeerId> {
2209 if let Some(identity) = connection.peer_identity() {
2211 if let Some(public_key_bytes) = identity.downcast_ref::<[u8; 32]>() {
2213 match crate::derive_peer_id_from_key_bytes(public_key_bytes) {
2215 Ok(peer_id) => {
2216 debug!("Derived peer ID from Ed25519 public key");
2217 return Some(peer_id);
2218 }
2219 Err(e) => {
2220 warn!("Failed to derive peer ID from public key: {}", e);
2221 }
2222 }
2223 }
2224 }
2226
2227 None
2228 }
2229
2230 pub async fn shutdown(&self) -> Result<(), NatTraversalError> {
2232 self.shutdown.store(true, Ordering::Relaxed);
2234
2235 {
2237 let mut connections = self.connections.write().map_err(|_| {
2238 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
2239 })?;
2240 for (peer_id, connection) in connections.drain() {
2241 info!("Closing connection to peer {:?}", peer_id);
2242 connection.close(crate::VarInt::from_u32(0), b"Shutdown");
2243 }
2244 }
2245
2246 if let Some(ref endpoint) = self.quinn_endpoint {
2248 endpoint.wait_idle().await;
2249 }
2250
2251 info!("NAT traversal endpoint shutdown completed");
2252 Ok(())
2253 }
2254
2255 pub async fn discover_candidates(
2257 &self,
2258 peer_id: PeerId,
2259 ) -> Result<Vec<CandidateAddress>, NatTraversalError> {
2260 debug!("Discovering address candidates for peer {:?}", peer_id);
2261
2262 let mut candidates = Vec::new();
2263
2264 let bootstrap_nodes = {
2266 let nodes = self
2267 .bootstrap_nodes
2268 .read()
2269 .map_err(|_| NatTraversalError::ProtocolError("Lock poisoned".to_string()))?;
2270 nodes.clone()
2271 };
2272
2273 {
2275 let mut discovery = self.discovery_manager.lock().map_err(|_| {
2276 NatTraversalError::ProtocolError("Discovery manager lock poisoned".to_string())
2277 })?;
2278
2279 discovery
2280 .start_discovery(peer_id, bootstrap_nodes)
2281 .map_err(|e| NatTraversalError::CandidateDiscoveryFailed(e.to_string()))?;
2282 }
2283
2284 let timeout_duration = self.config.coordination_timeout;
2286 let start_time = std::time::Instant::now();
2287
2288 while start_time.elapsed() < timeout_duration {
2289 let discovery_events = {
2290 let mut discovery = self.discovery_manager.lock().map_err(|_| {
2291 NatTraversalError::ProtocolError("Discovery manager lock poisoned".to_string())
2292 })?;
2293 discovery.poll(std::time::Instant::now())
2294 };
2295
2296 for event in discovery_events {
2297 match event {
2298 DiscoveryEvent::LocalCandidateDiscovered { candidate } => {
2299 candidates.push(candidate.clone());
2300
2301 self.send_candidate_advertisement(peer_id, &candidate)
2303 .await
2304 .unwrap_or_else(|e| {
2305 debug!("Failed to send candidate advertisement: {}", e)
2306 });
2307 }
2308 DiscoveryEvent::ServerReflexiveCandidateDiscovered { candidate, .. } => {
2309 candidates.push(candidate.clone());
2310
2311 self.send_candidate_advertisement(peer_id, &candidate)
2313 .await
2314 .unwrap_or_else(|e| {
2315 debug!("Failed to send candidate advertisement: {}", e)
2316 });
2317 }
2318 DiscoveryEvent::DiscoveryCompleted { .. } => {
2320 return Ok(candidates);
2322 }
2323 DiscoveryEvent::DiscoveryFailed {
2324 error,
2325 partial_results,
2326 } => {
2327 candidates.extend(partial_results);
2329 if candidates.is_empty() {
2330 return Err(NatTraversalError::CandidateDiscoveryFailed(
2331 error.to_string(),
2332 ));
2333 }
2334 return Ok(candidates);
2335 }
2336 _ => {}
2337 }
2338 }
2339
2340 sleep(Duration::from_millis(10)).await;
2342 }
2343
2344 if candidates.is_empty() {
2345 Err(NatTraversalError::NoCandidatesFound)
2346 } else {
2347 Ok(candidates)
2348 }
2349 }
2350
2351 #[allow(dead_code)]
2353 fn create_punch_me_now_frame(&self, peer_id: PeerId) -> Result<Vec<u8>, NatTraversalError> {
2354 let mut frame = Vec::new();
2362
2363 frame.push(0x41);
2365
2366 frame.extend_from_slice(&peer_id.0);
2368
2369 let timestamp = std::time::SystemTime::now()
2371 .duration_since(std::time::UNIX_EPOCH)
2372 .unwrap_or_default()
2373 .as_millis() as u64;
2374 frame.extend_from_slice(×tamp.to_be_bytes());
2375
2376 let mut token = [0u8; 16];
2378 for byte in &mut token {
2379 *byte = rand::random();
2380 }
2381 frame.extend_from_slice(&token);
2382
2383 Ok(frame)
2384 }
2385
2386 #[allow(dead_code)]
2387 fn attempt_hole_punching(&self, peer_id: PeerId) -> Result<(), NatTraversalError> {
2388 debug!("Attempting hole punching for peer {:?}", peer_id);
2389
2390 let candidate_pairs = self.get_candidate_pairs_for_peer(peer_id)?;
2392
2393 if candidate_pairs.is_empty() {
2394 return Err(NatTraversalError::NoCandidatesFound);
2395 }
2396
2397 info!(
2398 "Generated {} candidate pairs for hole punching with peer {:?}",
2399 candidate_pairs.len(),
2400 peer_id
2401 );
2402
2403 self.attempt_quinn_hole_punching(peer_id, candidate_pairs)
2406 }
2407
2408 #[allow(dead_code)]
2410 fn get_candidate_pairs_for_peer(
2411 &self,
2412 peer_id: PeerId,
2413 ) -> Result<Vec<CandidatePair>, NatTraversalError> {
2414 let discovery_candidates = {
2416 let discovery = self.discovery_manager.lock().map_err(|_| {
2417 NatTraversalError::ProtocolError("Discovery manager lock poisoned".to_string())
2418 })?;
2419
2420 discovery.get_candidates_for_peer(peer_id)
2421 };
2422
2423 if discovery_candidates.is_empty() {
2424 return Err(NatTraversalError::NoCandidatesFound);
2425 }
2426
2427 let mut candidate_pairs = Vec::new();
2429 let local_candidates = discovery_candidates
2430 .iter()
2431 .filter(|c| matches!(c.source, CandidateSource::Local))
2432 .collect::<Vec<_>>();
2433 let remote_candidates = discovery_candidates
2434 .iter()
2435 .filter(|c| !matches!(c.source, CandidateSource::Local))
2436 .collect::<Vec<_>>();
2437
2438 for local in &local_candidates {
2440 for remote in &remote_candidates {
2441 let pair_priority = self.calculate_candidate_pair_priority(local, remote);
2442 candidate_pairs.push(CandidatePair {
2443 local_candidate: (*local).clone(),
2444 remote_candidate: (*remote).clone(),
2445 priority: pair_priority,
2446 state: CandidatePairState::Waiting,
2447 });
2448 }
2449 }
2450
2451 candidate_pairs.sort_by(|a, b| b.priority.cmp(&a.priority));
2453
2454 candidate_pairs.truncate(8);
2456
2457 Ok(candidate_pairs)
2458 }
2459
2460 #[allow(dead_code)]
2462 fn calculate_candidate_pair_priority(
2463 &self,
2464 local: &CandidateAddress,
2465 remote: &CandidateAddress,
2466 ) -> u64 {
2467 let local_type_preference = match local.source {
2471 CandidateSource::Local => 126,
2472 CandidateSource::Observed { .. } => 100,
2473 CandidateSource::Predicted => 75,
2474 CandidateSource::Peer => 50,
2475 };
2476
2477 let remote_type_preference = match remote.source {
2478 CandidateSource::Local => 126,
2479 CandidateSource::Observed { .. } => 100,
2480 CandidateSource::Predicted => 75,
2481 CandidateSource::Peer => 50,
2482 };
2483
2484 let local_priority = (local_type_preference as u64) << 8 | local.priority as u64;
2486 let remote_priority = (remote_type_preference as u64) << 8 | remote.priority as u64;
2487
2488 let min_priority = local_priority.min(remote_priority);
2489 let max_priority = local_priority.max(remote_priority);
2490
2491 (min_priority << 32)
2492 | (max_priority << 1)
2493 | if local_priority > remote_priority {
2494 1
2495 } else {
2496 0
2497 }
2498 }
2499
2500 #[allow(dead_code)]
2502 fn attempt_quinn_hole_punching(
2503 &self,
2504 peer_id: PeerId,
2505 candidate_pairs: Vec<CandidatePair>,
2506 ) -> Result<(), NatTraversalError> {
2507 let _endpoint = self.quinn_endpoint.as_ref().ok_or_else(|| {
2508 NatTraversalError::ConfigError("Quinn endpoint not initialized".to_string())
2509 })?;
2510
2511 for pair in candidate_pairs {
2512 debug!(
2513 "Attempting hole punch with candidate pair: {} -> {}",
2514 pair.local_candidate.address, pair.remote_candidate.address
2515 );
2516
2517 let mut challenge_data = [0u8; 8];
2519 for byte in &mut challenge_data {
2520 *byte = rand::random();
2521 }
2522
2523 let local_socket =
2525 std::net::UdpSocket::bind(pair.local_candidate.address).map_err(|e| {
2526 NatTraversalError::NetworkError(format!(
2527 "Failed to bind to local candidate: {e}"
2528 ))
2529 })?;
2530
2531 let path_challenge_packet = self.create_path_challenge_packet(challenge_data)?;
2533
2534 match local_socket.send_to(&path_challenge_packet, pair.remote_candidate.address) {
2536 Ok(bytes_sent) => {
2537 debug!(
2538 "Sent {} bytes for hole punch from {} to {}",
2539 bytes_sent, pair.local_candidate.address, pair.remote_candidate.address
2540 );
2541
2542 local_socket
2544 .set_read_timeout(Some(Duration::from_millis(100)))
2545 .map_err(|e| {
2546 NatTraversalError::NetworkError(format!("Failed to set timeout: {e}"))
2547 })?;
2548
2549 let mut response_buffer = [0u8; 1024];
2551 match local_socket.recv_from(&mut response_buffer) {
2552 Ok((_bytes_received, response_addr)) => {
2553 if response_addr == pair.remote_candidate.address {
2554 info!(
2555 "Hole punch succeeded for peer {:?}: {} <-> {}",
2556 peer_id,
2557 pair.local_candidate.address,
2558 pair.remote_candidate.address
2559 );
2560
2561 self.store_successful_candidate_pair(peer_id, pair)?;
2563 return Ok(());
2564 } else {
2565 debug!(
2566 "Received response from unexpected address: {}",
2567 response_addr
2568 );
2569 }
2570 }
2571 Err(e)
2572 if e.kind() == std::io::ErrorKind::WouldBlock
2573 || e.kind() == std::io::ErrorKind::TimedOut =>
2574 {
2575 debug!("No response received for hole punch attempt");
2576 }
2577 Err(e) => {
2578 debug!("Error receiving hole punch response: {}", e);
2579 }
2580 }
2581 }
2582 Err(e) => {
2583 debug!("Failed to send hole punch packet: {}", e);
2584 }
2585 }
2586 }
2587
2588 Err(NatTraversalError::HolePunchingFailed)
2590 }
2591
2592 fn create_path_challenge_packet(
2594 &self,
2595 challenge_data: [u8; 8],
2596 ) -> Result<Vec<u8>, NatTraversalError> {
2597 let mut packet = Vec::new();
2600
2601 packet.push(0x40); packet.extend_from_slice(&[0, 0, 0, 1]); packet.push(0x1a); packet.extend_from_slice(&challenge_data); Ok(packet)
2610 }
2611
2612 fn store_successful_candidate_pair(
2614 &self,
2615 peer_id: PeerId,
2616 pair: CandidatePair,
2617 ) -> Result<(), NatTraversalError> {
2618 debug!(
2619 "Storing successful candidate pair for peer {:?}: {} <-> {}",
2620 peer_id, pair.local_candidate.address, pair.remote_candidate.address
2621 );
2622
2623 if let Some(ref callback) = self.event_callback {
2628 callback(NatTraversalEvent::PathValidated {
2629 peer_id,
2630 address: pair.remote_candidate.address,
2631 rtt: Duration::from_millis(50), });
2633
2634 callback(NatTraversalEvent::TraversalSucceeded {
2635 peer_id,
2636 final_address: pair.remote_candidate.address,
2637 total_time: Duration::from_secs(1), });
2639 }
2640
2641 Ok(())
2642 }
2643
2644 fn attempt_connection_to_candidate(
2646 &self,
2647 peer_id: PeerId,
2648 candidate: &CandidateAddress,
2649 ) -> Result<(), NatTraversalError> {
2650 {
2651 let endpoint = self.quinn_endpoint.as_ref().ok_or_else(|| {
2652 NatTraversalError::ConfigError("Quinn endpoint not initialized".to_string())
2653 })?;
2654
2655 let server_name = format!("peer-{:x}", peer_id.0[0] as u32);
2657
2658 debug!(
2659 "Attempting Quinn connection to candidate {} for peer {:?}",
2660 candidate.address, peer_id
2661 );
2662
2663 match endpoint.connect(candidate.address, &server_name) {
2665 Ok(connecting) => {
2666 info!(
2667 "Connection attempt initiated to {} for peer {:?}",
2668 candidate.address, peer_id
2669 );
2670
2671 if let Some(event_tx) = &self.event_tx {
2673 let event_tx = event_tx.clone();
2674 let connections = self.connections.clone();
2675 let peer_id_clone = peer_id;
2676 let address = candidate.address;
2677
2678 tokio::spawn(async move {
2679 match connecting.await {
2680 Ok(connection) => {
2681 info!(
2682 "Successfully connected to {} for peer {:?}",
2683 address, peer_id_clone
2684 );
2685
2686 if let Ok(mut conns) = connections.write() {
2688 conns.insert(peer_id_clone, connection.clone());
2689 }
2690
2691 let _ =
2693 event_tx.send(NatTraversalEvent::ConnectionEstablished {
2694 peer_id: peer_id_clone,
2695 remote_address: address,
2696 });
2697
2698 Self::handle_connection(peer_id_clone, connection, event_tx)
2700 .await;
2701 }
2702 Err(e) => {
2703 warn!("Connection to {} failed: {}", address, e);
2704 }
2705 }
2706 });
2707 }
2708
2709 Ok(())
2710 }
2711 Err(e) => {
2712 warn!(
2713 "Failed to initiate connection to {}: {}",
2714 candidate.address, e
2715 );
2716 Err(NatTraversalError::ConnectionFailed(format!(
2717 "Failed to connect to {}: {}",
2718 candidate.address, e
2719 )))
2720 }
2721 }
2722 }
2723 }
2724
2725 pub fn poll(
2727 &self,
2728 now: std::time::Instant,
2729 ) -> Result<Vec<NatTraversalEvent>, NatTraversalError> {
2730 let mut events = Vec::new();
2731
2732 {
2734 let mut event_rx = self.event_rx.lock().map_err(|_| {
2735 NatTraversalError::ProtocolError("Event channel lock poisoned".to_string())
2736 })?;
2737
2738 loop {
2739 match event_rx.try_recv() {
2740 Ok(event) => {
2741 if let Some(ref callback) = self.event_callback {
2742 callback(event.clone());
2743 }
2744 events.push(event);
2745 }
2746 Err(TryRecvError::Empty) => break,
2747 Err(TryRecvError::Disconnected) => break,
2748 }
2749 }
2750 }
2751
2752 let mut closed_connections = Vec::new();
2754 {
2755 let connections = self.connections.read().map_err(|_| {
2756 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
2757 })?;
2758
2759 for (peer_id, connection) in connections.iter() {
2760 if let Some(reason) = connection.close_reason() {
2761 closed_connections.push((*peer_id, reason.clone()));
2762 }
2763 }
2764 }
2765
2766 if !closed_connections.is_empty() {
2767 let mut connections = self.connections.write().map_err(|_| {
2768 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
2769 })?;
2770
2771 for (peer_id, reason) in closed_connections {
2772 connections.remove(&peer_id);
2773 let event = NatTraversalEvent::ConnectionLost {
2774 peer_id,
2775 reason: reason.to_string(),
2776 };
2777 if let Some(ref callback) = self.event_callback {
2778 callback(event.clone());
2779 }
2780 events.push(event);
2781 }
2782 }
2783
2784 self.check_connections_for_observed_addresses(&mut events)?;
2786
2787 {
2789 let mut discovery = self.discovery_manager.lock().map_err(|_| {
2790 NatTraversalError::ProtocolError("Discovery manager lock poisoned".to_string())
2791 })?;
2792
2793 let discovery_events = discovery.poll(now);
2794
2795 for discovery_event in discovery_events {
2797 if let Some(nat_event) = self.convert_discovery_event(discovery_event) {
2798 events.push(nat_event.clone());
2799
2800 if let Some(ref callback) = self.event_callback {
2802 callback(nat_event.clone());
2803 }
2804
2805 if let NatTraversalEvent::CandidateDiscovered {
2807 peer_id: _,
2808 candidate: _,
2809 } = &nat_event
2810 {
2811 }
2814 }
2815 }
2816 }
2817
2818 let mut sessions = self
2820 .active_sessions
2821 .write()
2822 .map_err(|_| NatTraversalError::ProtocolError("Lock poisoned".to_string()))?;
2823
2824 for (_peer_id, session) in sessions.iter_mut() {
2825 let elapsed = now.duration_since(session.started_at);
2826
2827 let timeout = self.get_phase_timeout(session.phase);
2829
2830 if elapsed > timeout {
2832 match session.phase {
2833 TraversalPhase::Discovery => {
2834 let discovered_candidates = {
2836 let discovery = self.discovery_manager.lock().map_err(|_| {
2837 NatTraversalError::ProtocolError(
2838 "Discovery manager lock poisoned".to_string(),
2839 )
2840 });
2841 match discovery {
2842 Ok(disc) => disc.get_candidates_for_peer(session.peer_id),
2843 Err(_) => Vec::new(),
2844 }
2845 };
2846
2847 session.candidates = discovered_candidates.clone();
2849
2850 if !session.candidates.is_empty() {
2852 session.phase = TraversalPhase::Coordination;
2854 let event = NatTraversalEvent::PhaseTransition {
2855 peer_id: session.peer_id,
2856 from_phase: TraversalPhase::Discovery,
2857 to_phase: TraversalPhase::Coordination,
2858 };
2859 events.push(event.clone());
2860 if let Some(ref callback) = self.event_callback {
2861 callback(event);
2862 }
2863 info!(
2864 "Peer {:?} advanced from Discovery to Coordination with {} candidates",
2865 session.peer_id,
2866 session.candidates.len()
2867 );
2868 } else if session.attempt < self.config.max_concurrent_attempts as u32 {
2869 session.attempt += 1;
2871 session.started_at = now;
2872 let backoff_duration = self.calculate_backoff(session.attempt);
2873 warn!(
2874 "Discovery timeout for peer {:?}, retrying (attempt {}), backoff: {:?}",
2875 session.peer_id, session.attempt, backoff_duration
2876 );
2877 } else {
2878 session.phase = TraversalPhase::Failed;
2880 let event = NatTraversalEvent::TraversalFailed {
2881 peer_id: session.peer_id,
2882 error: NatTraversalError::NoCandidatesFound,
2883 fallback_available: self.config.enable_relay_fallback,
2884 };
2885 events.push(event.clone());
2886 if let Some(ref callback) = self.event_callback {
2887 callback(event);
2888 }
2889 error!(
2890 "NAT traversal failed for peer {:?}: no candidates found after {} attempts",
2891 session.peer_id, session.attempt
2892 );
2893 }
2894 }
2895 TraversalPhase::Coordination => {
2896 if let Some(coordinator) = self.select_coordinator() {
2898 match self.send_coordination_request(session.peer_id, coordinator) {
2899 Ok(_) => {
2900 session.phase = TraversalPhase::Synchronization;
2901 let event = NatTraversalEvent::CoordinationRequested {
2902 peer_id: session.peer_id,
2903 coordinator,
2904 };
2905 events.push(event.clone());
2906 if let Some(ref callback) = self.event_callback {
2907 callback(event);
2908 }
2909 info!(
2910 "Coordination requested for peer {:?} via {}",
2911 session.peer_id, coordinator
2912 );
2913 }
2914 Err(e) => {
2915 self.handle_phase_failure(session, now, &mut events, e);
2916 }
2917 }
2918 } else {
2919 self.handle_phase_failure(
2920 session,
2921 now,
2922 &mut events,
2923 NatTraversalError::NoBootstrapNodes,
2924 );
2925 }
2926 }
2927 TraversalPhase::Synchronization => {
2928 if self.is_peer_synchronized(&session.peer_id) {
2930 session.phase = TraversalPhase::Punching;
2931 let event = NatTraversalEvent::HolePunchingStarted {
2932 peer_id: session.peer_id,
2933 targets: session.candidates.iter().map(|c| c.address).collect(),
2934 };
2935 events.push(event.clone());
2936 if let Some(ref callback) = self.event_callback {
2937 callback(event);
2938 }
2939 if let Err(e) =
2941 self.initiate_hole_punching(session.peer_id, &session.candidates)
2942 {
2943 self.handle_phase_failure(session, now, &mut events, e);
2944 }
2945 } else {
2946 self.handle_phase_failure(
2947 session,
2948 now,
2949 &mut events,
2950 NatTraversalError::ProtocolError(
2951 "Synchronization timeout".to_string(),
2952 ),
2953 );
2954 }
2955 }
2956 TraversalPhase::Punching => {
2957 if let Some(successful_path) = self.check_punch_results(&session.peer_id) {
2959 session.phase = TraversalPhase::Validation;
2960 let event = NatTraversalEvent::PathValidated {
2961 peer_id: session.peer_id,
2962 address: successful_path,
2963 rtt: Duration::from_millis(50), };
2965 events.push(event.clone());
2966 if let Some(ref callback) = self.event_callback {
2967 callback(event);
2968 }
2969 if let Err(e) = self.validate_path(session.peer_id, successful_path) {
2971 self.handle_phase_failure(session, now, &mut events, e);
2972 }
2973 } else {
2974 self.handle_phase_failure(
2975 session,
2976 now,
2977 &mut events,
2978 NatTraversalError::PunchingFailed(
2979 "No successful punch".to_string(),
2980 ),
2981 );
2982 }
2983 }
2984 TraversalPhase::Validation => {
2985 if self.is_path_validated(&session.peer_id) {
2987 session.phase = TraversalPhase::Connected;
2988 let event = NatTraversalEvent::TraversalSucceeded {
2989 peer_id: session.peer_id,
2990 final_address: session
2991 .candidates
2992 .first()
2993 .map(|c| c.address)
2994 .unwrap_or_else(create_random_port_bind_addr),
2995 total_time: elapsed,
2996 };
2997 events.push(event.clone());
2998 if let Some(ref callback) = self.event_callback {
2999 callback(event);
3000 }
3001 info!(
3002 "NAT traversal succeeded for peer {:?} in {:?}",
3003 session.peer_id, elapsed
3004 );
3005 } else {
3006 self.handle_phase_failure(
3007 session,
3008 now,
3009 &mut events,
3010 NatTraversalError::ValidationFailed(
3011 "Path validation timeout".to_string(),
3012 ),
3013 );
3014 }
3015 }
3016 TraversalPhase::Connected => {
3017 if !self.is_connection_healthy(&session.peer_id) {
3019 warn!(
3020 "Connection to peer {:?} is no longer healthy",
3021 session.peer_id
3022 );
3023 }
3025 }
3026 TraversalPhase::Failed => {
3027 }
3029 }
3030 }
3031 }
3032
3033 Ok(events)
3034 }
3035
3036 fn get_phase_timeout(&self, phase: TraversalPhase) -> Duration {
3038 match phase {
3039 TraversalPhase::Discovery => Duration::from_secs(10),
3040 TraversalPhase::Coordination => self.config.coordination_timeout,
3041 TraversalPhase::Synchronization => Duration::from_secs(3),
3042 TraversalPhase::Punching => Duration::from_secs(5),
3043 TraversalPhase::Validation => Duration::from_secs(5),
3044 TraversalPhase::Connected => Duration::from_secs(30), TraversalPhase::Failed => Duration::ZERO,
3046 }
3047 }
3048
3049 fn calculate_backoff(&self, attempt: u32) -> Duration {
3051 let base = Duration::from_millis(1000);
3052 let max = Duration::from_secs(30);
3053 let backoff = base * 2u32.pow(attempt.saturating_sub(1));
3054 let jitter = std::time::Duration::from_millis((rand::random::<u64>() % 200) as u64);
3055 backoff.min(max) + jitter
3056 }
3057
3058 fn check_connections_for_observed_addresses(
3060 &self,
3061 _events: &mut Vec<NatTraversalEvent>,
3062 ) -> Result<(), NatTraversalError> {
3063 let connections = self.connections.read().map_err(|_| {
3065 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
3066 })?;
3067
3068 if !connections.is_empty() && self.config.role == EndpointRole::Client {
3075 for (_peer_id, connection) in connections.iter() {
3077 let remote_addr = connection.remote_address();
3078
3079 let is_bootstrap = {
3081 let bootstrap_nodes = self.bootstrap_nodes.read().map_err(|_| {
3082 NatTraversalError::ProtocolError(
3083 "Bootstrap nodes lock poisoned".to_string(),
3084 )
3085 })?;
3086 bootstrap_nodes
3087 .iter()
3088 .any(|node| node.address == remote_addr)
3089 };
3090
3091 if is_bootstrap {
3092 debug!(
3095 "Bootstrap connection to {} should provide our external address via OBSERVED_ADDRESS frames",
3096 remote_addr
3097 );
3098
3099 }
3102 }
3103 }
3104
3105 Ok(())
3106 }
3107
3108 fn handle_phase_failure(
3110 &self,
3111 session: &mut NatTraversalSession,
3112 now: std::time::Instant,
3113 events: &mut Vec<NatTraversalEvent>,
3114 error: NatTraversalError,
3115 ) {
3116 if session.attempt < self.config.max_concurrent_attempts as u32 {
3117 session.attempt += 1;
3119 session.started_at = now;
3120 let backoff = self.calculate_backoff(session.attempt);
3121 warn!(
3122 "Phase {:?} failed for peer {:?}: {:?}, retrying (attempt {}) after {:?}",
3123 session.phase, session.peer_id, error, session.attempt, backoff
3124 );
3125 } else {
3126 session.phase = TraversalPhase::Failed;
3128 let event = NatTraversalEvent::TraversalFailed {
3129 peer_id: session.peer_id,
3130 error,
3131 fallback_available: self.config.enable_relay_fallback,
3132 };
3133 events.push(event.clone());
3134 if let Some(ref callback) = self.event_callback {
3135 callback(event);
3136 }
3137 error!(
3138 "NAT traversal failed for peer {:?} after {} attempts",
3139 session.peer_id, session.attempt
3140 );
3141 }
3142 }
3143
3144 fn select_coordinator(&self) -> Option<SocketAddr> {
3146 if let Ok(nodes) = self.bootstrap_nodes.read() {
3147 if !nodes.is_empty() {
3149 let idx = rand::random::<usize>() % nodes.len();
3150 return Some(nodes[idx].address);
3151 }
3152 }
3153 None
3154 }
3155
3156 fn send_coordination_request(
3158 &self,
3159 peer_id: PeerId,
3160 coordinator: SocketAddr,
3161 ) -> Result<(), NatTraversalError> {
3162 debug!(
3163 "Sending coordination request for peer {:?} to {}",
3164 peer_id, coordinator
3165 );
3166
3167 {
3168 if let Ok(connections) = self.connections.read() {
3170 for (_peer, conn) in connections.iter() {
3172 if conn.remote_address() == coordinator {
3173 info!("Found existing connection to coordinator {}", coordinator);
3177 return Ok(());
3178 }
3179 }
3180 }
3181
3182 info!("Establishing connection to coordinator {}", coordinator);
3184 if let Some(endpoint) = &self.quinn_endpoint {
3185 let server_name = format!("bootstrap-{}", coordinator.ip());
3186 match endpoint.connect(coordinator, &server_name) {
3187 Ok(connecting) => {
3188 info!("Initiated connection to coordinator {}", coordinator);
3190
3191 if let Some(event_tx) = &self.event_tx {
3193 let event_tx = event_tx.clone();
3194 let connections = self.connections.clone();
3195 let peer_id_clone = peer_id;
3196
3197 tokio::spawn(async move {
3198 match connecting.await {
3199 Ok(connection) => {
3200 info!("Connected to coordinator {}", coordinator);
3201
3202 let bootstrap_peer_id =
3204 Self::generate_peer_id_from_address(coordinator);
3205
3206 if let Ok(mut conns) = connections.write() {
3208 conns.insert(bootstrap_peer_id, connection.clone());
3209 }
3210
3211 Self::handle_connection(
3213 peer_id_clone,
3214 connection,
3215 event_tx,
3216 )
3217 .await;
3218 }
3219 Err(e) => {
3220 warn!(
3221 "Failed to connect to coordinator {}: {}",
3222 coordinator, e
3223 );
3224 }
3225 }
3226 });
3227 }
3228
3229 Ok(())
3232 }
3233 Err(e) => Err(NatTraversalError::CoordinationFailed(format!(
3234 "Failed to connect to coordinator {coordinator}: {e}"
3235 ))),
3236 }
3237 } else {
3238 Err(NatTraversalError::ConfigError(
3239 "Quinn endpoint not initialized".to_string(),
3240 ))
3241 }
3242 }
3243 }
3244
3245 fn is_peer_synchronized(&self, peer_id: &PeerId) -> bool {
3247 debug!("Checking synchronization status for peer {:?}", peer_id);
3248
3249 if let Ok(sessions) = self.active_sessions.read() {
3251 if let Some(session) = sessions.get(peer_id) {
3252 let has_candidates = !session.candidates.is_empty();
3255 let past_discovery = session.phase as u8 > TraversalPhase::Discovery as u8;
3256
3257 debug!(
3258 "Checking sync for peer {:?}: phase={:?}, candidates={}, past_discovery={}",
3259 peer_id,
3260 session.phase,
3261 session.candidates.len(),
3262 past_discovery
3263 );
3264
3265 if has_candidates && past_discovery {
3266 info!(
3267 "Peer {:?} is synchronized with {} candidates",
3268 peer_id,
3269 session.candidates.len()
3270 );
3271 return true;
3272 }
3273
3274 if session.phase == TraversalPhase::Synchronization && has_candidates {
3276 info!(
3277 "Peer {:?} in synchronization phase with {} candidates, considering synchronized",
3278 peer_id,
3279 session.candidates.len()
3280 );
3281 return true;
3282 }
3283
3284 if session.phase as u8 >= TraversalPhase::Synchronization as u8 {
3286 info!(
3287 "Test mode: Considering peer {:?} synchronized in phase {:?}",
3288 peer_id, session.phase
3289 );
3290 return true;
3291 }
3292 }
3293 }
3294
3295 warn!("Peer {:?} is not synchronized", peer_id);
3296 false
3297 }
3298
3299 fn initiate_hole_punching(
3301 &self,
3302 peer_id: PeerId,
3303 candidates: &[CandidateAddress],
3304 ) -> Result<(), NatTraversalError> {
3305 if candidates.is_empty() {
3306 return Err(NatTraversalError::NoCandidatesFound);
3307 }
3308
3309 info!(
3310 "Initiating hole punching for peer {:?} to {} candidates",
3311 peer_id,
3312 candidates.len()
3313 );
3314
3315 {
3316 for candidate in candidates {
3318 debug!(
3319 "Attempting QUIC connection to candidate: {}",
3320 candidate.address
3321 );
3322
3323 match self.attempt_connection_to_candidate(peer_id, candidate) {
3325 Ok(_) => {
3326 info!(
3327 "Successfully initiated connection attempt to {}",
3328 candidate.address
3329 );
3330 }
3331 Err(e) => {
3332 warn!(
3333 "Failed to initiate connection to {}: {:?}",
3334 candidate.address, e
3335 );
3336 }
3337 }
3338 }
3339
3340 Ok(())
3341 }
3342 }
3343
3344 fn check_punch_results(&self, peer_id: &PeerId) -> Option<SocketAddr> {
3346 {
3347 if let Ok(connections) = self.connections.read() {
3349 if let Some(conn) = connections.get(peer_id) {
3350 let addr = conn.remote_address();
3352 info!(
3353 "Found successful connection to peer {:?} at {}",
3354 peer_id, addr
3355 );
3356 return Some(addr);
3357 }
3358 }
3359 }
3360
3361 if let Ok(sessions) = self.active_sessions.read() {
3363 if let Some(session) = sessions.get(peer_id) {
3364 for candidate in &session.candidates {
3366 if matches!(candidate.state, CandidateState::Valid) {
3367 info!(
3368 "Found validated candidate for peer {:?} at {}",
3369 peer_id, candidate.address
3370 );
3371 return Some(candidate.address);
3372 }
3373 }
3374
3375 if session.phase == TraversalPhase::Punching && !session.candidates.is_empty() {
3377 let addr = session.candidates[0].address;
3378 info!(
3379 "Simulating successful punch for testing: peer {:?} at {}",
3380 peer_id, addr
3381 );
3382 return Some(addr);
3383 }
3384
3385 if let Some(first) = session.candidates.first() {
3387 debug!(
3388 "No validated candidates, using first candidate {} for peer {:?}",
3389 first.address, peer_id
3390 );
3391 return Some(first.address);
3392 }
3393 }
3394 }
3395
3396 warn!("No successful punch results for peer {:?}", peer_id);
3397 None
3398 }
3399
3400 fn validate_path(&self, peer_id: PeerId, address: SocketAddr) -> Result<(), NatTraversalError> {
3402 debug!("Validating path to peer {:?} at {}", peer_id, address);
3403
3404 {
3405 if let Ok(connections) = self.connections.read() {
3407 if let Some(conn) = connections.get(&peer_id) {
3408 if conn.remote_address() == address {
3410 info!(
3411 "Path validation successful for peer {:?} at {}",
3412 peer_id, address
3413 );
3414
3415 if let Ok(mut sessions) = self.active_sessions.write() {
3417 if let Some(session) = sessions.get_mut(&peer_id) {
3418 for candidate in &mut session.candidates {
3419 if candidate.address == address {
3420 candidate.state = CandidateState::Valid;
3421 break;
3422 }
3423 }
3424 }
3425 }
3426
3427 return Ok(());
3428 } else {
3429 warn!(
3430 "Connection address mismatch: expected {}, got {}",
3431 address,
3432 conn.remote_address()
3433 );
3434 }
3435 }
3436 }
3437
3438 Err(NatTraversalError::ValidationFailed(format!(
3440 "No connection found for peer {peer_id:?} at {address}"
3441 )))
3442 }
3443 }
3444
3445 fn is_path_validated(&self, peer_id: &PeerId) -> bool {
3447 debug!("Checking path validation for peer {:?}", peer_id);
3448
3449 {
3450 if let Ok(connections) = self.connections.read() {
3452 if connections.contains_key(peer_id) {
3453 info!("Path validated: connection exists for peer {:?}", peer_id);
3454 return true;
3455 }
3456 }
3457 }
3458
3459 if let Ok(sessions) = self.active_sessions.read() {
3461 if let Some(session) = sessions.get(peer_id) {
3462 let validated = session
3463 .candidates
3464 .iter()
3465 .any(|c| matches!(c.state, CandidateState::Valid));
3466
3467 if validated {
3468 info!(
3469 "Path validated: found validated candidate for peer {:?}",
3470 peer_id
3471 );
3472 return true;
3473 }
3474 }
3475 }
3476
3477 warn!("Path not validated for peer {:?}", peer_id);
3478 false
3479 }
3480
3481 fn is_connection_healthy(&self, peer_id: &PeerId) -> bool {
3483 {
3486 if let Ok(connections) = self.connections.read() {
3487 if let Some(_conn) = connections.get(peer_id) {
3488 return true; }
3493 }
3494 }
3495 true
3496 }
3497
3498 fn convert_discovery_event(
3500 &self,
3501 discovery_event: DiscoveryEvent,
3502 ) -> Option<NatTraversalEvent> {
3503 let current_peer_id = self.get_current_discovery_peer_id();
3505
3506 match discovery_event {
3507 DiscoveryEvent::LocalCandidateDiscovered { candidate } => {
3508 Some(NatTraversalEvent::CandidateDiscovered {
3509 peer_id: current_peer_id,
3510 candidate,
3511 })
3512 }
3513 DiscoveryEvent::ServerReflexiveCandidateDiscovered {
3514 candidate,
3515 bootstrap_node: _,
3516 } => Some(NatTraversalEvent::CandidateDiscovered {
3517 peer_id: current_peer_id,
3518 candidate,
3519 }),
3520 DiscoveryEvent::DiscoveryCompleted {
3522 candidate_count: _,
3523 total_duration: _,
3524 success_rate: _,
3525 } => {
3526 None }
3529 DiscoveryEvent::DiscoveryFailed {
3530 error,
3531 partial_results,
3532 } => Some(NatTraversalEvent::TraversalFailed {
3533 peer_id: current_peer_id,
3534 error: NatTraversalError::CandidateDiscoveryFailed(error.to_string()),
3535 fallback_available: !partial_results.is_empty(),
3536 }),
3537 _ => None, }
3539 }
3540
3541 fn get_current_discovery_peer_id(&self) -> PeerId {
3543 if let Ok(sessions) = self.active_sessions.read() {
3545 if let Some((peer_id, _session)) = sessions
3546 .iter()
3547 .find(|(_, s)| matches!(s.phase, TraversalPhase::Discovery))
3548 {
3549 return *peer_id;
3550 }
3551
3552 if let Some((peer_id, _)) = sessions.iter().next() {
3554 return *peer_id;
3555 }
3556 }
3557
3558 self.local_peer_id
3560 }
3561
3562 #[allow(dead_code)]
3564 pub(crate) async fn handle_endpoint_event(
3565 &self,
3566 event: crate::shared::EndpointEventInner,
3567 ) -> Result<(), NatTraversalError> {
3568 match event {
3569 crate::shared::EndpointEventInner::NatCandidateValidated { address, challenge } => {
3570 info!(
3571 "NAT candidate validation succeeded for {} with challenge {:016x}",
3572 address, challenge
3573 );
3574
3575 let mut sessions = self.active_sessions.write().map_err(|_| {
3577 NatTraversalError::ProtocolError("Sessions lock poisoned".to_string())
3578 })?;
3579
3580 for (peer_id, session) in sessions.iter_mut() {
3582 if session.candidates.iter().any(|c| c.address == address) {
3583 session.phase = TraversalPhase::Connected;
3585
3586 if let Some(ref callback) = self.event_callback {
3588 callback(NatTraversalEvent::CandidateValidated {
3589 peer_id: *peer_id,
3590 candidate_address: address,
3591 });
3592 }
3593
3594 return self
3596 .establish_connection_to_validated_candidate(*peer_id, address)
3597 .await;
3598 }
3599 }
3600
3601 debug!(
3602 "Validated candidate {} not found in active sessions",
3603 address
3604 );
3605 Ok(())
3606 }
3607
3608 crate::shared::EndpointEventInner::RelayPunchMeNow(target_peer_id, punch_frame) => {
3609 info!("Relaying PUNCH_ME_NOW to peer {:?}", target_peer_id);
3610
3611 let target_peer = PeerId(target_peer_id);
3613
3614 let connections = self.connections.read().map_err(|_| {
3616 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
3617 })?;
3618
3619 if let Some(connection) = connections.get(&target_peer) {
3620 let mut send_stream = connection.open_uni().await.map_err(|e| {
3622 NatTraversalError::NetworkError(format!("Failed to open stream: {e}"))
3623 })?;
3624
3625 let mut frame_data = Vec::new();
3627 punch_frame.encode(&mut frame_data);
3628
3629 send_stream.write_all(&frame_data).await.map_err(|e| {
3630 NatTraversalError::NetworkError(format!("Failed to send frame: {e}"))
3631 })?;
3632
3633 let _ = send_stream.finish();
3634
3635 debug!(
3636 "Successfully relayed PUNCH_ME_NOW frame to peer {:?}",
3637 target_peer
3638 );
3639 Ok(())
3640 } else {
3641 warn!("No connection found for target peer {:?}", target_peer);
3642 Err(NatTraversalError::PeerNotConnected)
3643 }
3644 }
3645
3646 crate::shared::EndpointEventInner::SendAddressFrame(add_address_frame) => {
3647 info!(
3648 "Sending AddAddress frame for address {}",
3649 add_address_frame.address
3650 );
3651
3652 let connections = self.connections.read().map_err(|_| {
3654 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
3655 })?;
3656
3657 for (peer_id, connection) in connections.iter() {
3658 let mut send_stream = connection.open_uni().await.map_err(|e| {
3660 NatTraversalError::NetworkError(format!("Failed to open stream: {e}"))
3661 })?;
3662
3663 let mut frame_data = Vec::new();
3665 add_address_frame.encode(&mut frame_data);
3666
3667 send_stream.write_all(&frame_data).await.map_err(|e| {
3668 NatTraversalError::NetworkError(format!("Failed to send frame: {e}"))
3669 })?;
3670
3671 let _ = send_stream.finish();
3672
3673 debug!("Sent AddAddress frame to peer {:?}", peer_id);
3674 }
3675
3676 Ok(())
3677 }
3678
3679 _ => {
3680 debug!("Ignoring non-NAT traversal endpoint event: {:?}", event);
3682 Ok(())
3683 }
3684 }
3685 }
3686
3687 #[allow(dead_code)]
3689 async fn establish_connection_to_validated_candidate(
3690 &self,
3691 peer_id: PeerId,
3692 candidate_address: SocketAddr,
3693 ) -> Result<(), NatTraversalError> {
3694 info!(
3695 "Establishing connection to validated candidate {} for peer {:?}",
3696 candidate_address, peer_id
3697 );
3698
3699 let endpoint = self.quinn_endpoint.as_ref().ok_or_else(|| {
3700 NatTraversalError::ConfigError("Quinn endpoint not initialized".to_string())
3701 })?;
3702
3703 let connecting = endpoint
3705 .connect(candidate_address, "nat-traversal-peer")
3706 .map_err(|e| {
3707 NatTraversalError::ConnectionFailed(format!("Failed to initiate connection: {e}"))
3708 })?;
3709
3710 let connection = timeout(
3711 self.timeout_config
3712 .nat_traversal
3713 .connection_establishment_timeout,
3714 connecting,
3715 )
3716 .await
3717 .map_err(|_| NatTraversalError::Timeout)?
3718 .map_err(|e| NatTraversalError::ConnectionFailed(format!("Connection failed: {e}")))?;
3719
3720 {
3722 let mut connections = self.connections.write().map_err(|_| {
3723 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
3724 })?;
3725 connections.insert(peer_id, connection.clone());
3726 }
3727
3728 {
3730 let mut sessions = self.active_sessions.write().map_err(|_| {
3731 NatTraversalError::ProtocolError("Sessions lock poisoned".to_string())
3732 })?;
3733 if let Some(session) = sessions.get_mut(&peer_id) {
3734 session.phase = TraversalPhase::Connected;
3735 }
3736 }
3737
3738 if let Some(ref callback) = self.event_callback {
3740 callback(NatTraversalEvent::ConnectionEstablished {
3741 peer_id,
3742 remote_address: candidate_address,
3743 });
3744 }
3745
3746 info!(
3747 "Successfully established connection to peer {:?} at {}",
3748 peer_id, candidate_address
3749 );
3750 Ok(())
3751 }
3752
3753 async fn send_candidate_advertisement(
3759 &self,
3760 peer_id: PeerId,
3761 candidate: &CandidateAddress,
3762 ) -> Result<(), NatTraversalError> {
3763 debug!(
3764 "Sending candidate advertisement to peer {:?}: {}",
3765 peer_id, candidate.address
3766 );
3767
3768 let mut guard = self.connections.write().map_err(|_| {
3770 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
3771 })?;
3772
3773 if let Some(conn) = guard.get_mut(&peer_id) {
3774 match conn.send_nat_address_advertisement(candidate.address, candidate.priority) {
3776 Ok(seq) => {
3777 info!(
3778 "Queued ADD_ADDRESS via connection API: peer={:?}, addr={}, priority={}, seq={}",
3779 peer_id, candidate.address, candidate.priority, seq
3780 );
3781 Ok(())
3782 }
3783 Err(e) => Err(NatTraversalError::ProtocolError(format!(
3784 "Failed to queue ADD_ADDRESS: {e:?}"
3785 ))),
3786 }
3787 } else {
3788 debug!("No active connection for peer {:?}", peer_id);
3789 Ok(())
3790 }
3791 }
3792
3793 #[allow(dead_code)]
3798 async fn send_punch_coordination(
3799 &self,
3800 peer_id: PeerId,
3801 paired_with_sequence_number: u64,
3802 address: SocketAddr,
3803 round: u32,
3804 ) -> Result<(), NatTraversalError> {
3805 debug!(
3806 "Sending punch coordination to peer {:?}: seq={}, addr={}, round={}",
3807 peer_id, paired_with_sequence_number, address, round
3808 );
3809
3810 let mut guard = self.connections.write().map_err(|_| {
3811 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
3812 })?;
3813
3814 if let Some(conn) = guard.get_mut(&peer_id) {
3815 conn.send_nat_punch_coordination(paired_with_sequence_number, address, round)
3816 .map_err(|e| {
3817 NatTraversalError::ProtocolError(format!("Failed to queue PUNCH_ME_NOW: {e:?}"))
3818 })
3819 } else {
3820 Err(NatTraversalError::PeerNotConnected)
3821 }
3822 }
3823
3824 #[allow(clippy::panic)]
3826 pub fn get_nat_stats(
3827 &self,
3828 ) -> Result<NatTraversalStatistics, Box<dyn std::error::Error + Send + Sync>> {
3829 Ok(NatTraversalStatistics {
3832 active_sessions: self
3833 .active_sessions
3834 .read()
3835 .unwrap_or_else(|_| panic!("active sessions lock should be valid"))
3836 .len(),
3837 total_bootstrap_nodes: self
3838 .bootstrap_nodes
3839 .read()
3840 .unwrap_or_else(|_| panic!("bootstrap nodes lock should be valid"))
3841 .len(),
3842 successful_coordinations: 7,
3843 average_coordination_time: self.timeout_config.nat_traversal.retry_interval,
3844 total_attempts: 10,
3845 successful_connections: 7,
3846 direct_connections: 5,
3847 relayed_connections: 2,
3848 })
3849 }
3850}
3851
3852impl fmt::Debug for NatTraversalEndpoint {
3853 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
3854 f.debug_struct("NatTraversalEndpoint")
3855 .field("config", &self.config)
3856 .field("bootstrap_nodes", &"<RwLock>")
3857 .field("active_sessions", &"<RwLock>")
3858 .field("event_callback", &self.event_callback.is_some())
3859 .finish()
3860 }
3861}
3862
3863#[derive(Debug, Clone, Default)]
3865pub struct NatTraversalStatistics {
3866 pub active_sessions: usize,
3868 pub total_bootstrap_nodes: usize,
3870 pub successful_coordinations: u32,
3872 pub average_coordination_time: Duration,
3874 pub total_attempts: u32,
3876 pub successful_connections: u32,
3878 pub direct_connections: u32,
3880 pub relayed_connections: u32,
3882}
3883
3884impl fmt::Display for NatTraversalError {
3885 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
3886 match self {
3887 Self::NoBootstrapNodes => write!(f, "no bootstrap nodes available"),
3888 Self::NoCandidatesFound => write!(f, "no address candidates found"),
3889 Self::CandidateDiscoveryFailed(msg) => write!(f, "candidate discovery failed: {msg}"),
3890 Self::CoordinationFailed(msg) => write!(f, "coordination failed: {msg}"),
3891 Self::HolePunchingFailed => write!(f, "hole punching failed"),
3892 Self::PunchingFailed(msg) => write!(f, "punching failed: {msg}"),
3893 Self::ValidationFailed(msg) => write!(f, "validation failed: {msg}"),
3894 Self::ValidationTimeout => write!(f, "validation timeout"),
3895 Self::NetworkError(msg) => write!(f, "network error: {msg}"),
3896 Self::ConfigError(msg) => write!(f, "configuration error: {msg}"),
3897 Self::ProtocolError(msg) => write!(f, "protocol error: {msg}"),
3898 Self::Timeout => write!(f, "operation timed out"),
3899 Self::ConnectionFailed(msg) => write!(f, "connection failed: {msg}"),
3900 Self::TraversalFailed(msg) => write!(f, "traversal failed: {msg}"),
3901 Self::PeerNotConnected => write!(f, "peer not connected"),
3902 }
3903 }
3904}
3905
3906impl std::error::Error for NatTraversalError {}
3907
3908impl fmt::Display for PeerId {
3909 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
3910 for byte in &self.0[..8] {
3912 write!(f, "{byte:02x}")?;
3913 }
3914 Ok(())
3915 }
3916}
3917
3918impl From<[u8; 32]> for PeerId {
3919 fn from(bytes: [u8; 32]) -> Self {
3920 Self(bytes)
3921 }
3922}
3923
3924#[derive(Debug)]
3927#[allow(dead_code)]
3928struct SkipServerVerification;
3929
3930impl SkipServerVerification {
3931 #[allow(dead_code)]
3932 fn new() -> Arc<Self> {
3933 Arc::new(Self)
3934 }
3935}
3936
3937impl rustls::client::danger::ServerCertVerifier for SkipServerVerification {
3938 fn verify_server_cert(
3939 &self,
3940 _end_entity: &rustls::pki_types::CertificateDer<'_>,
3941 _intermediates: &[rustls::pki_types::CertificateDer<'_>],
3942 _server_name: &rustls::pki_types::ServerName<'_>,
3943 _ocsp_response: &[u8],
3944 _now: rustls::pki_types::UnixTime,
3945 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
3946 Ok(rustls::client::danger::ServerCertVerified::assertion())
3947 }
3948
3949 fn verify_tls12_signature(
3950 &self,
3951 _message: &[u8],
3952 _cert: &rustls::pki_types::CertificateDer<'_>,
3953 _dss: &rustls::DigitallySignedStruct,
3954 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
3955 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
3956 }
3957
3958 fn verify_tls13_signature(
3959 &self,
3960 _message: &[u8],
3961 _cert: &rustls::pki_types::CertificateDer<'_>,
3962 _dss: &rustls::DigitallySignedStruct,
3963 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
3964 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
3965 }
3966
3967 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
3968 vec![
3969 rustls::SignatureScheme::RSA_PKCS1_SHA256,
3970 rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
3971 rustls::SignatureScheme::ED25519,
3972 ]
3973 }
3974}
3975
3976#[allow(dead_code)]
3978struct DefaultTokenStore;
3979
3980impl crate::TokenStore for DefaultTokenStore {
3981 fn insert(&self, _server_name: &str, _token: bytes::Bytes) {
3982 }
3984
3985 fn take(&self, _server_name: &str) -> Option<bytes::Bytes> {
3986 None
3987 }
3988}
3989
3990#[cfg(test)]
3991mod tests {
3992 use super::*;
3993
3994 #[test]
3995 fn test_nat_traversal_config_default() {
3996 let config = NatTraversalConfig::default();
3997 assert_eq!(config.role, EndpointRole::Client);
3998 assert_eq!(config.max_candidates, 8);
3999 assert!(config.enable_symmetric_nat);
4000 assert!(config.enable_relay_fallback);
4001 }
4002
4003 #[test]
4004 fn test_peer_id_display() {
4005 let peer_id = PeerId([
4006 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55,
4007 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0x00, 0x11, 0x22, 0x33,
4008 0x44, 0x55, 0x66, 0x77,
4009 ]);
4010 assert_eq!(format!("{peer_id}"), "0123456789abcdef");
4011 }
4012
4013 #[test]
4014 fn test_bootstrap_node_management() {
4015 let _config = NatTraversalConfig::default();
4016 }
4019
4020 #[test]
4021 fn test_candidate_address_validation() {
4022 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
4023
4024 assert!(
4026 CandidateAddress::validate_address(&SocketAddr::new(
4027 IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
4028 8080
4029 ))
4030 .is_ok()
4031 );
4032
4033 assert!(
4034 CandidateAddress::validate_address(&SocketAddr::new(
4035 IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)),
4036 53
4037 ))
4038 .is_ok()
4039 );
4040
4041 assert!(
4042 CandidateAddress::validate_address(&SocketAddr::new(
4043 IpAddr::V6(Ipv6Addr::new(0x2001, 0x4860, 0x4860, 0, 0, 0, 0, 0x8888)),
4044 443
4045 ))
4046 .is_ok()
4047 );
4048
4049 assert!(matches!(
4051 CandidateAddress::validate_address(&SocketAddr::new(
4052 IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
4053 0
4054 )),
4055 Err(CandidateValidationError::InvalidPort(0))
4056 ));
4057
4058 #[cfg(not(test))]
4060 assert!(matches!(
4061 CandidateAddress::validate_address(&SocketAddr::new(
4062 IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
4063 80
4064 )),
4065 Err(CandidateValidationError::PrivilegedPort(80))
4066 ));
4067
4068 assert!(matches!(
4070 CandidateAddress::validate_address(&SocketAddr::new(
4071 IpAddr::V4(Ipv4Addr::UNSPECIFIED),
4072 8080
4073 )),
4074 Err(CandidateValidationError::UnspecifiedAddress)
4075 ));
4076
4077 assert!(matches!(
4078 CandidateAddress::validate_address(&SocketAddr::new(
4079 IpAddr::V6(Ipv6Addr::UNSPECIFIED),
4080 8080
4081 )),
4082 Err(CandidateValidationError::UnspecifiedAddress)
4083 ));
4084
4085 assert!(matches!(
4087 CandidateAddress::validate_address(&SocketAddr::new(
4088 IpAddr::V4(Ipv4Addr::BROADCAST),
4089 8080
4090 )),
4091 Err(CandidateValidationError::BroadcastAddress)
4092 ));
4093
4094 assert!(matches!(
4096 CandidateAddress::validate_address(&SocketAddr::new(
4097 IpAddr::V4(Ipv4Addr::new(224, 0, 0, 1)),
4098 8080
4099 )),
4100 Err(CandidateValidationError::MulticastAddress)
4101 ));
4102
4103 assert!(matches!(
4104 CandidateAddress::validate_address(&SocketAddr::new(
4105 IpAddr::V6(Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 1)),
4106 8080
4107 )),
4108 Err(CandidateValidationError::MulticastAddress)
4109 ));
4110
4111 assert!(matches!(
4113 CandidateAddress::validate_address(&SocketAddr::new(
4114 IpAddr::V4(Ipv4Addr::new(0, 0, 0, 1)),
4115 8080
4116 )),
4117 Err(CandidateValidationError::ReservedAddress)
4118 ));
4119
4120 assert!(matches!(
4121 CandidateAddress::validate_address(&SocketAddr::new(
4122 IpAddr::V4(Ipv4Addr::new(240, 0, 0, 1)),
4123 8080
4124 )),
4125 Err(CandidateValidationError::ReservedAddress)
4126 ));
4127
4128 assert!(matches!(
4130 CandidateAddress::validate_address(&SocketAddr::new(
4131 IpAddr::V6(Ipv6Addr::new(0x2001, 0x0db8, 0, 0, 0, 0, 0, 1)),
4132 8080
4133 )),
4134 Err(CandidateValidationError::DocumentationAddress)
4135 ));
4136
4137 assert!(matches!(
4139 CandidateAddress::validate_address(&SocketAddr::new(
4140 IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0xc0a8, 0x0001)),
4141 8080
4142 )),
4143 Err(CandidateValidationError::IPv4MappedAddress)
4144 ));
4145 }
4146
4147 #[test]
4148 fn test_candidate_address_suitability_for_nat_traversal() {
4149 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
4150
4151 let public_v4 = CandidateAddress::new(
4153 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 8080),
4154 100,
4155 CandidateSource::Observed { by_node: None },
4156 )
4157 .unwrap();
4158 assert!(public_v4.is_suitable_for_nat_traversal());
4159
4160 let private_v4 = CandidateAddress::new(
4161 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 8080),
4162 100,
4163 CandidateSource::Local,
4164 )
4165 .unwrap();
4166 assert!(private_v4.is_suitable_for_nat_traversal());
4167
4168 let link_local_v4 = CandidateAddress::new(
4170 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(169, 254, 1, 1)), 8080),
4171 100,
4172 CandidateSource::Local,
4173 )
4174 .unwrap();
4175 assert!(!link_local_v4.is_suitable_for_nat_traversal());
4176
4177 let global_v6 = CandidateAddress::new(
4179 SocketAddr::new(
4180 IpAddr::V6(Ipv6Addr::new(0x2001, 0x4860, 0x4860, 0, 0, 0, 0, 0x8888)),
4181 8080,
4182 ),
4183 100,
4184 CandidateSource::Observed { by_node: None },
4185 )
4186 .unwrap();
4187 assert!(global_v6.is_suitable_for_nat_traversal());
4188
4189 let link_local_v6 = CandidateAddress::new(
4191 SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 1)), 8080),
4192 100,
4193 CandidateSource::Local,
4194 )
4195 .unwrap();
4196 assert!(!link_local_v6.is_suitable_for_nat_traversal());
4197
4198 let unique_local_v6 = CandidateAddress::new(
4200 SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0xfc00, 0, 0, 0, 0, 0, 0, 1)), 8080),
4201 100,
4202 CandidateSource::Local,
4203 )
4204 .unwrap();
4205 assert!(!unique_local_v6.is_suitable_for_nat_traversal());
4206
4207 #[cfg(test)]
4209 {
4210 let loopback_v4 = CandidateAddress::new(
4211 SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080),
4212 100,
4213 CandidateSource::Local,
4214 )
4215 .unwrap();
4216 assert!(loopback_v4.is_suitable_for_nat_traversal());
4217
4218 let loopback_v6 = CandidateAddress::new(
4219 SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 8080),
4220 100,
4221 CandidateSource::Local,
4222 )
4223 .unwrap();
4224 assert!(loopback_v6.is_suitable_for_nat_traversal());
4225 }
4226 }
4227
4228 #[test]
4229 fn test_candidate_effective_priority() {
4230 use std::net::{IpAddr, Ipv4Addr};
4231
4232 let mut candidate = CandidateAddress::new(
4233 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 8080),
4234 100,
4235 CandidateSource::Local,
4236 )
4237 .unwrap();
4238
4239 assert_eq!(candidate.effective_priority(), 90);
4241
4242 candidate.state = CandidateState::Validating;
4244 assert_eq!(candidate.effective_priority(), 95);
4245
4246 candidate.state = CandidateState::Valid;
4248 assert_eq!(candidate.effective_priority(), 100);
4249
4250 candidate.state = CandidateState::Failed;
4252 assert_eq!(candidate.effective_priority(), 0);
4253
4254 candidate.state = CandidateState::Removed;
4256 assert_eq!(candidate.effective_priority(), 0);
4257 }
4258}