1use std::{collections::HashMap, fmt, net::SocketAddr, sync::Arc, time::Duration};
8
9fn create_random_port_bind_addr() -> SocketAddr {
28 "0.0.0.0:0"
29 .parse()
30 .expect("Random port bind address format is always valid")
31}
32
33use tracing::{debug, error, info, warn};
34
35use std::sync::atomic::{AtomicBool, Ordering};
36
37use tokio::{
38 net::UdpSocket,
39 sync::mpsc,
40 time::{sleep, timeout},
41};
42
43use crate::high_level::default_runtime;
44
45use crate::{
46 VarInt,
47 candidate_discovery::{CandidateDiscoveryManager, DiscoveryConfig, DiscoveryEvent},
48 connection::nat_traversal::{CandidateSource, CandidateState, NatTraversalRole},
49};
50
51use crate::{
52 ClientConfig, ConnectionError, EndpointConfig, ServerConfig, TransportConfig,
53 high_level::{Connection as QuinnConnection, Endpoint as QuinnEndpoint},
54};
55
56#[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))]
57use crate::{crypto::rustls::QuicClientConfig, crypto::rustls::QuicServerConfig};
58
59use crate::config::validation::{ConfigValidator, ValidationResult};
60
61#[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))]
62use crate::crypto::certificate_manager::{CertificateConfig, CertificateManager};
63
64pub struct NatTraversalEndpoint {
66 quinn_endpoint: Option<QuinnEndpoint>,
68 config: NatTraversalConfig,
72 bootstrap_nodes: Arc<std::sync::RwLock<Vec<BootstrapNode>>>,
74 active_sessions: Arc<std::sync::RwLock<HashMap<PeerId, NatTraversalSession>>>,
76 discovery_manager: Arc<std::sync::Mutex<CandidateDiscoveryManager>>,
78 event_callback: Option<Box<dyn Fn(NatTraversalEvent) + Send + Sync>>,
80 shutdown: Arc<AtomicBool>,
82 event_tx: Option<mpsc::UnboundedSender<NatTraversalEvent>>,
84 connections: Arc<std::sync::RwLock<HashMap<PeerId, QuinnConnection>>>,
86 local_peer_id: PeerId,
88 timeout_config: crate::config::nat_timeouts::TimeoutConfig,
90}
91
92#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
124pub struct NatTraversalConfig {
125 pub role: EndpointRole,
127 pub bootstrap_nodes: Vec<SocketAddr>,
129 pub max_candidates: usize,
131 pub coordination_timeout: Duration,
133 pub enable_symmetric_nat: bool,
135 pub enable_relay_fallback: bool,
137 pub max_concurrent_attempts: usize,
139 pub bind_addr: Option<SocketAddr>,
156 pub prefer_rfc_nat_traversal: bool,
159 pub timeouts: crate::config::nat_timeouts::TimeoutConfig,
161}
162
163#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
165pub enum EndpointRole {
166 Client,
168 Server {
170 can_coordinate: bool,
172 },
173 Bootstrap,
175}
176
177impl EndpointRole {
178 pub fn name(&self) -> &'static str {
180 match self {
181 Self::Client => "client",
182 Self::Server { .. } => "server",
183 Self::Bootstrap => "bootstrap",
184 }
185 }
186}
187
188#[derive(
190 Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, serde::Serialize, serde::Deserialize,
191)]
192pub struct PeerId(pub [u8; 32]);
193
194#[derive(Debug, Clone)]
196pub struct BootstrapNode {
197 pub address: SocketAddr,
199 pub last_seen: std::time::Instant,
201 pub can_coordinate: bool,
203 pub rtt: Option<Duration>,
205 pub coordination_count: u32,
207}
208
209impl BootstrapNode {
210 pub fn new(address: SocketAddr) -> Self {
212 Self {
213 address,
214 last_seen: std::time::Instant::now(),
215 can_coordinate: true,
216 rtt: None,
217 coordination_count: 0,
218 }
219 }
220}
221
222#[derive(Debug, Clone)]
224pub struct CandidatePair {
225 pub local_candidate: CandidateAddress,
227 pub remote_candidate: CandidateAddress,
229 pub priority: u64,
231 pub state: CandidatePairState,
233}
234
235#[derive(Debug, Clone, Copy, PartialEq, Eq)]
237pub enum CandidatePairState {
238 Waiting,
240 InProgress,
242 Succeeded,
244 Failed,
246 Cancelled,
248}
249
250#[derive(Debug)]
252struct NatTraversalSession {
253 peer_id: PeerId,
255 coordinator: SocketAddr,
257 attempt: u32,
259 started_at: std::time::Instant,
261 phase: TraversalPhase,
263 candidates: Vec<CandidateAddress>,
265 session_state: SessionState,
267}
268
269#[derive(Debug, Clone)]
271pub struct SessionState {
272 pub state: ConnectionState,
274 pub last_transition: std::time::Instant,
276 pub connection: Option<QuinnConnection>,
278 pub active_attempts: Vec<(SocketAddr, std::time::Instant)>,
280 pub metrics: ConnectionMetrics,
282}
283
284#[derive(Debug, Clone, Copy, PartialEq, Eq)]
286pub enum ConnectionState {
287 Idle,
289 Connecting,
291 Connected,
293 Migrating,
295 Closed,
297}
298
299#[derive(Debug, Clone, Default)]
301pub struct ConnectionMetrics {
302 pub rtt: Option<Duration>,
304 pub loss_rate: f64,
306 pub bytes_sent: u64,
308 pub bytes_received: u64,
310 pub last_activity: Option<std::time::Instant>,
312}
313
314#[derive(Debug, Clone)]
316pub struct SessionStateUpdate {
317 pub peer_id: PeerId,
319 pub old_state: ConnectionState,
321 pub new_state: ConnectionState,
323 pub reason: StateChangeReason,
325}
326
327#[derive(Debug, Clone, Copy, PartialEq, Eq)]
329pub enum StateChangeReason {
330 Timeout,
332 ConnectionEstablished,
334 ConnectionClosed,
336 MigrationComplete,
338 MigrationFailed,
340 NetworkError,
342 UserClosed,
344}
345
346#[derive(Debug, Clone, Copy, PartialEq, Eq)]
348pub enum TraversalPhase {
349 Discovery,
351 Coordination,
353 Synchronization,
355 Punching,
357 Validation,
359 Connected,
361 Failed,
363}
364
365#[derive(Debug, Clone, Copy)]
367enum SessionUpdate {
368 Timeout,
370 Disconnected,
372 UpdateMetrics,
374 InvalidState,
376 Retry,
378 MigrationTimeout,
380 Remove,
382}
383
384#[derive(Debug, Clone)]
386pub struct CandidateAddress {
387 pub address: SocketAddr,
389 pub priority: u32,
391 pub source: CandidateSource,
393 pub state: CandidateState,
395}
396
397impl CandidateAddress {
398 pub fn new(
400 address: SocketAddr,
401 priority: u32,
402 source: CandidateSource,
403 ) -> Result<Self, CandidateValidationError> {
404 Self::validate_address(&address)?;
405 Ok(Self {
406 address,
407 priority,
408 source,
409 state: CandidateState::New,
410 })
411 }
412
413 pub fn validate_address(addr: &SocketAddr) -> Result<(), CandidateValidationError> {
415 if addr.port() == 0 {
417 return Err(CandidateValidationError::InvalidPort(0));
418 }
419
420 #[cfg(not(test))]
422 if addr.port() < 1024 {
423 return Err(CandidateValidationError::PrivilegedPort(addr.port()));
424 }
425
426 match addr.ip() {
427 std::net::IpAddr::V4(ipv4) => {
428 if ipv4.is_unspecified() {
430 return Err(CandidateValidationError::UnspecifiedAddress);
431 }
432 if ipv4.is_broadcast() {
433 return Err(CandidateValidationError::BroadcastAddress);
434 }
435 if ipv4.is_multicast() {
436 return Err(CandidateValidationError::MulticastAddress);
437 }
438 if ipv4.octets()[0] == 0 {
440 return Err(CandidateValidationError::ReservedAddress);
441 }
442 if ipv4.octets()[0] >= 240 {
444 return Err(CandidateValidationError::ReservedAddress);
445 }
446 }
447 std::net::IpAddr::V6(ipv6) => {
448 if ipv6.is_unspecified() {
450 return Err(CandidateValidationError::UnspecifiedAddress);
451 }
452 if ipv6.is_multicast() {
453 return Err(CandidateValidationError::MulticastAddress);
454 }
455 let segments = ipv6.segments();
457 if segments[0] == 0x2001 && segments[1] == 0x0db8 {
458 return Err(CandidateValidationError::DocumentationAddress);
459 }
460 if ipv6.to_ipv4_mapped().is_some() {
462 return Err(CandidateValidationError::IPv4MappedAddress);
463 }
464 }
465 }
466
467 Ok(())
468 }
469
470 pub fn is_suitable_for_nat_traversal(&self) -> bool {
472 match self.address.ip() {
473 std::net::IpAddr::V4(ipv4) => {
474 #[cfg(test)]
479 if ipv4.is_loopback() {
480 return true;
481 }
482 !ipv4.is_loopback()
483 && !ipv4.is_link_local()
484 && !ipv4.is_multicast()
485 && !ipv4.is_broadcast()
486 }
487 std::net::IpAddr::V6(ipv6) => {
488 #[cfg(test)]
494 if ipv6.is_loopback() {
495 return true;
496 }
497 let segments = ipv6.segments();
498 let is_link_local = (segments[0] & 0xffc0) == 0xfe80;
499 let is_unique_local = (segments[0] & 0xfe00) == 0xfc00;
500
501 !ipv6.is_loopback() && !is_link_local && !is_unique_local && !ipv6.is_multicast()
502 }
503 }
504 }
505
506 pub fn effective_priority(&self) -> u32 {
508 match self.state {
509 CandidateState::Valid => self.priority,
510 CandidateState::New => self.priority.saturating_sub(10),
511 CandidateState::Validating => self.priority.saturating_sub(5),
512 CandidateState::Failed => 0,
513 CandidateState::Removed => 0,
514 }
515 }
516}
517
518#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
520pub enum CandidateValidationError {
521 #[error("invalid port number: {0}")]
523 InvalidPort(u16),
524 #[error("privileged port not allowed: {0}")]
526 PrivilegedPort(u16),
527 #[error("unspecified address not allowed")]
529 UnspecifiedAddress,
530 #[error("broadcast address not allowed")]
532 BroadcastAddress,
533 #[error("multicast address not allowed")]
535 MulticastAddress,
536 #[error("reserved address not allowed")]
538 ReservedAddress,
539 #[error("documentation address not allowed")]
541 DocumentationAddress,
542 #[error("IPv4-mapped IPv6 address not allowed")]
544 IPv4MappedAddress,
545}
546
547#[derive(Debug, Clone)]
549pub enum NatTraversalEvent {
550 CandidateDiscovered {
552 peer_id: PeerId,
553 candidate: CandidateAddress,
554 },
555 CoordinationRequested {
557 peer_id: PeerId,
558 coordinator: SocketAddr,
559 },
560 CoordinationSynchronized { peer_id: PeerId, round_id: VarInt },
562 HolePunchingStarted {
564 peer_id: PeerId,
565 targets: Vec<SocketAddr>,
566 },
567 PathValidated {
569 peer_id: PeerId,
570 address: SocketAddr,
571 rtt: Duration,
572 },
573 CandidateValidated {
575 peer_id: PeerId,
576 candidate_address: SocketAddr,
577 },
578 TraversalSucceeded {
580 peer_id: PeerId,
581 final_address: SocketAddr,
582 total_time: Duration,
583 },
584 ConnectionEstablished {
586 peer_id: PeerId,
587 remote_address: SocketAddr,
589 },
590 TraversalFailed {
592 peer_id: PeerId,
594 error: NatTraversalError,
596 fallback_available: bool,
598 },
599 ConnectionLost { peer_id: PeerId, reason: String },
601 PhaseTransition {
603 peer_id: PeerId,
604 from_phase: TraversalPhase,
605 to_phase: TraversalPhase,
606 },
607 SessionStateChanged {
609 peer_id: PeerId,
610 new_state: ConnectionState,
611 },
612}
613
614#[derive(Debug, Clone)]
616pub enum NatTraversalError {
617 NoBootstrapNodes,
619 NoCandidatesFound,
621 CandidateDiscoveryFailed(String),
623 CoordinationFailed(String),
625 HolePunchingFailed,
627 PunchingFailed(String),
629 ValidationFailed(String),
631 ValidationTimeout,
633 NetworkError(String),
635 ConfigError(String),
637 ProtocolError(String),
639 Timeout,
641 ConnectionFailed(String),
643 TraversalFailed(String),
645 PeerNotConnected,
647}
648
649impl Default for NatTraversalConfig {
650 fn default() -> Self {
651 Self {
652 role: EndpointRole::Client,
653 bootstrap_nodes: Vec::new(),
654 max_candidates: 8,
655 coordination_timeout: Duration::from_secs(10),
656 enable_symmetric_nat: true,
657 enable_relay_fallback: true,
658 max_concurrent_attempts: 3,
659 bind_addr: None,
660 prefer_rfc_nat_traversal: true, timeouts: crate::config::nat_timeouts::TimeoutConfig::default(),
662 }
663 }
664}
665
666impl ConfigValidator for NatTraversalConfig {
667 fn validate(&self) -> ValidationResult<()> {
668 use crate::config::validation::*;
669
670 match self.role {
672 EndpointRole::Client => {
673 if self.bootstrap_nodes.is_empty() {
674 return Err(ConfigValidationError::InvalidRole(
675 "Client endpoints require at least one bootstrap node".to_string(),
676 ));
677 }
678 }
679 EndpointRole::Server { can_coordinate } => {
680 if can_coordinate && self.bootstrap_nodes.is_empty() {
681 return Err(ConfigValidationError::InvalidRole(
682 "Server endpoints with coordination capability require bootstrap nodes"
683 .to_string(),
684 ));
685 }
686 }
687 EndpointRole::Bootstrap => {
688 }
690 }
691
692 if !self.bootstrap_nodes.is_empty() {
694 validate_bootstrap_nodes(&self.bootstrap_nodes)?;
695 }
696
697 validate_range(self.max_candidates, 1, 256, "max_candidates")?;
699
700 validate_duration(
702 self.coordination_timeout,
703 Duration::from_millis(100),
704 Duration::from_secs(300),
705 "coordination_timeout",
706 )?;
707
708 validate_range(
710 self.max_concurrent_attempts,
711 1,
712 16,
713 "max_concurrent_attempts",
714 )?;
715
716 if self.max_concurrent_attempts > self.max_candidates {
718 return Err(ConfigValidationError::IncompatibleConfiguration(
719 "max_concurrent_attempts cannot exceed max_candidates".to_string(),
720 ));
721 }
722
723 if self.role == EndpointRole::Bootstrap && self.enable_relay_fallback {
724 return Err(ConfigValidationError::IncompatibleConfiguration(
725 "Bootstrap nodes should not enable relay fallback".to_string(),
726 ));
727 }
728
729 Ok(())
730 }
731}
732
733impl NatTraversalEndpoint {
734 pub async fn new(
736 config: NatTraversalConfig,
737 event_callback: Option<Box<dyn Fn(NatTraversalEvent) + Send + Sync>>,
738 ) -> Result<Self, NatTraversalError> {
739 Self::new_impl(config, event_callback).await
740 }
741
742 async fn new_impl(
744 config: NatTraversalConfig,
745 event_callback: Option<Box<dyn Fn(NatTraversalEvent) + Send + Sync>>,
746 ) -> Result<Self, NatTraversalError> {
747 Self::new_common(config, event_callback).await
748 }
749
750 async fn new_common(
752 config: NatTraversalConfig,
753 event_callback: Option<Box<dyn Fn(NatTraversalEvent) + Send + Sync>>,
754 ) -> Result<Self, NatTraversalError> {
755 Self::new_shared_logic(config, event_callback).await
757 }
758
759 async fn new_shared_logic(
761 config: NatTraversalConfig,
762 event_callback: Option<Box<dyn Fn(NatTraversalEvent) + Send + Sync>>,
763 ) -> Result<Self, NatTraversalError> {
764 {
767 config
768 .validate()
769 .map_err(|e| NatTraversalError::ConfigError(e.to_string()))?;
770 }
771
772 let bootstrap_nodes = Arc::new(std::sync::RwLock::new(
776 config
777 .bootstrap_nodes
778 .iter()
779 .map(|&address| BootstrapNode {
780 address,
781 last_seen: std::time::Instant::now(),
782 can_coordinate: true, rtt: None,
784 coordination_count: 0,
785 })
786 .collect(),
787 ));
788
789 let discovery_config = DiscoveryConfig {
791 total_timeout: config.coordination_timeout,
792 max_candidates: config.max_candidates,
793 enable_symmetric_prediction: config.enable_symmetric_nat,
794 bound_address: config.bind_addr, ..DiscoveryConfig::default()
796 };
797
798 let nat_traversal_role = match config.role {
799 EndpointRole::Client => NatTraversalRole::Client,
800 EndpointRole::Server { can_coordinate } => NatTraversalRole::Server {
801 can_relay: can_coordinate,
802 },
803 EndpointRole::Bootstrap => NatTraversalRole::Bootstrap,
804 };
805
806 let discovery_manager = Arc::new(std::sync::Mutex::new(CandidateDiscoveryManager::new(
807 discovery_config,
808 )));
809
810 let (quinn_endpoint, event_tx, local_addr) =
813 Self::create_quinn_endpoint(&config, nat_traversal_role).await?;
814
815 {
817 let mut discovery = discovery_manager.lock().map_err(|_| {
818 NatTraversalError::ProtocolError("Discovery manager lock poisoned".to_string())
819 })?;
820 discovery.set_bound_address(local_addr);
821 info!(
822 "Updated discovery manager with bound address: {}",
823 local_addr
824 );
825 }
826
827 let endpoint = Self {
828 quinn_endpoint: Some(quinn_endpoint.clone()),
829 config: config.clone(),
830 bootstrap_nodes,
831 active_sessions: Arc::new(std::sync::RwLock::new(HashMap::new())),
832 discovery_manager,
833 event_callback,
834 shutdown: Arc::new(AtomicBool::new(false)),
835 event_tx: Some(event_tx.clone()),
836 connections: Arc::new(std::sync::RwLock::new(HashMap::new())),
837 local_peer_id: Self::generate_local_peer_id(),
838 timeout_config: config.timeouts.clone(),
839 };
840
841 if matches!(
843 config.role,
844 EndpointRole::Bootstrap | EndpointRole::Server { .. }
845 ) {
846 let endpoint_clone = quinn_endpoint.clone();
847 let shutdown_clone = endpoint.shutdown.clone();
848 let event_tx_clone = event_tx.clone();
849 let connections_clone = endpoint.connections.clone();
850
851 tokio::spawn(async move {
852 Self::accept_connections(
853 endpoint_clone,
854 shutdown_clone,
855 event_tx_clone,
856 connections_clone,
857 )
858 .await;
859 });
860
861 info!("Started accepting connections for {:?} role", config.role);
862 }
863
864 let discovery_manager_clone = endpoint.discovery_manager.clone();
866 let shutdown_clone = endpoint.shutdown.clone();
867 let event_tx_clone = event_tx;
868
869 tokio::spawn(async move {
870 Self::poll_discovery(discovery_manager_clone, shutdown_clone, event_tx_clone).await;
871 });
872
873 info!("Started discovery polling task");
874
875 {
877 let mut discovery = endpoint.discovery_manager.lock().map_err(|_| {
878 NatTraversalError::ProtocolError("Discovery manager lock poisoned".to_string())
879 })?;
880
881 let local_peer_id = endpoint.local_peer_id;
883 let bootstrap_nodes = {
884 let nodes = endpoint.bootstrap_nodes.read().map_err(|_| {
885 NatTraversalError::ProtocolError("Bootstrap nodes lock poisoned".to_string())
886 })?;
887 nodes.clone()
888 };
889
890 discovery
891 .start_discovery(local_peer_id, bootstrap_nodes)
892 .map_err(|e| NatTraversalError::CandidateDiscoveryFailed(e.to_string()))?;
893
894 info!(
895 "Started local candidate discovery for peer {:?}",
896 local_peer_id
897 );
898 }
899
900 Ok(endpoint)
901 }
902
903 pub fn get_quinn_endpoint(&self) -> Option<&crate::high_level::Endpoint> {
905 self.quinn_endpoint.as_ref()
906 }
907
908 pub fn get_event_callback(&self) -> Option<&Box<dyn Fn(NatTraversalEvent) + Send + Sync>> {
910 self.event_callback.as_ref()
911 }
912
913 pub fn initiate_nat_traversal(
915 &self,
916 peer_id: PeerId,
917 coordinator: SocketAddr,
918 ) -> Result<(), NatTraversalError> {
919 info!(
920 "Starting NAT traversal to peer {:?} via coordinator {}",
921 peer_id, coordinator
922 );
923
924 let session = NatTraversalSession {
926 peer_id,
927 coordinator,
928 attempt: 1,
929 started_at: std::time::Instant::now(),
930 phase: TraversalPhase::Discovery,
931 candidates: Vec::new(),
932 session_state: SessionState {
933 state: ConnectionState::Connecting,
934 last_transition: std::time::Instant::now(),
935
936 connection: None,
937 active_attempts: Vec::new(),
938 metrics: ConnectionMetrics::default(),
939 },
940 };
941
942 {
944 let mut sessions = self
945 .active_sessions
946 .write()
947 .map_err(|_| NatTraversalError::ProtocolError("Lock poisoned".to_string()))?;
948 sessions.insert(peer_id, session);
949 }
950
951 let bootstrap_nodes_vec = {
953 let bootstrap_nodes = self
954 .bootstrap_nodes
955 .read()
956 .map_err(|_| NatTraversalError::ProtocolError("Lock poisoned".to_string()))?;
957 bootstrap_nodes.clone()
958 };
959
960 {
961 let mut discovery = self.discovery_manager.lock().map_err(|_| {
962 NatTraversalError::ProtocolError("Discovery manager lock poisoned".to_string())
963 })?;
964
965 discovery
966 .start_discovery(peer_id, bootstrap_nodes_vec)
967 .map_err(|e| NatTraversalError::CandidateDiscoveryFailed(e.to_string()))?;
968 }
969
970 if let Some(ref callback) = self.event_callback {
972 callback(NatTraversalEvent::CoordinationRequested {
973 peer_id,
974 coordinator,
975 });
976 }
977
978 Ok(())
980 }
981
982 pub fn poll_sessions(&self) -> Result<Vec<SessionStateUpdate>, NatTraversalError> {
984 let mut updates = Vec::new();
985 let now = std::time::Instant::now();
986
987 let mut sessions = self
988 .active_sessions
989 .write()
990 .map_err(|_| NatTraversalError::ProtocolError("Sessions lock poisoned".to_string()))?;
991
992 for (peer_id, session) in sessions.iter_mut() {
993 let mut state_changed = false;
994
995 match session.session_state.state {
996 ConnectionState::Connecting => {
997 let elapsed = now.duration_since(session.session_state.last_transition);
999 if elapsed
1000 > self
1001 .timeout_config
1002 .nat_traversal
1003 .connection_establishment_timeout
1004 {
1005 session.session_state.state = ConnectionState::Closed;
1006 session.session_state.last_transition = now;
1007 state_changed = true;
1008
1009 updates.push(SessionStateUpdate {
1010 peer_id: *peer_id,
1011 old_state: ConnectionState::Connecting,
1012 new_state: ConnectionState::Closed,
1013 reason: StateChangeReason::Timeout,
1014 });
1015 }
1016
1017 if let Some(ref _connection) = session.session_state.connection {
1020 session.session_state.state = ConnectionState::Connected;
1021 session.session_state.last_transition = now;
1022 state_changed = true;
1023
1024 updates.push(SessionStateUpdate {
1025 peer_id: *peer_id,
1026 old_state: ConnectionState::Connecting,
1027 new_state: ConnectionState::Connected,
1028 reason: StateChangeReason::ConnectionEstablished,
1029 });
1030 }
1031 }
1032 ConnectionState::Connected => {
1033 {
1036 }
1039
1040 session.session_state.metrics.last_activity = Some(now);
1042 }
1043 ConnectionState::Migrating => {
1044 let elapsed = now.duration_since(session.session_state.last_transition);
1046 if elapsed > Duration::from_secs(10) {
1047 if session.session_state.connection.is_some() {
1050 session.session_state.state = ConnectionState::Connected;
1051 state_changed = true;
1052
1053 updates.push(SessionStateUpdate {
1054 peer_id: *peer_id,
1055 old_state: ConnectionState::Migrating,
1056 new_state: ConnectionState::Connected,
1057 reason: StateChangeReason::MigrationComplete,
1058 });
1059 } else {
1060 session.session_state.state = ConnectionState::Closed;
1061 state_changed = true;
1062
1063 updates.push(SessionStateUpdate {
1064 peer_id: *peer_id,
1065 old_state: ConnectionState::Migrating,
1066 new_state: ConnectionState::Closed,
1067 reason: StateChangeReason::MigrationFailed,
1068 });
1069 }
1070
1071 session.session_state.last_transition = now;
1072 }
1073 }
1074 _ => {}
1075 }
1076
1077 if state_changed {
1079 if let Some(ref callback) = self.event_callback {
1080 callback(NatTraversalEvent::SessionStateChanged {
1081 peer_id: *peer_id,
1082 new_state: session.session_state.state,
1083 });
1084 }
1085 }
1086 }
1087
1088 Ok(updates)
1089 }
1090
1091 pub fn start_session_polling(&self, interval: Duration) -> tokio::task::JoinHandle<()> {
1093 let sessions = self.active_sessions.clone();
1094 let shutdown = self.shutdown.clone();
1095 let timeout_config = self.timeout_config.clone();
1096
1097 tokio::spawn(async move {
1098 let mut ticker = tokio::time::interval(interval);
1099
1100 loop {
1101 ticker.tick().await;
1102
1103 if shutdown.load(Ordering::Relaxed) {
1104 break;
1105 }
1106
1107 let sessions_to_update = {
1109 match sessions.read() {
1110 Ok(sessions_guard) => {
1111 sessions_guard
1112 .iter()
1113 .filter_map(|(peer_id, session)| {
1114 let now = std::time::Instant::now();
1115 let elapsed =
1116 now.duration_since(session.session_state.last_transition);
1117
1118 match session.session_state.state {
1119 ConnectionState::Connecting => {
1120 if elapsed
1122 > timeout_config
1123 .nat_traversal
1124 .connection_establishment_timeout
1125 {
1126 Some((*peer_id, SessionUpdate::Timeout))
1127 } else {
1128 None
1129 }
1130 }
1131 ConnectionState::Connected => {
1132 if let Some(ref conn) = session.session_state.connection
1134 {
1135 if conn.close_reason().is_some() {
1136 Some((*peer_id, SessionUpdate::Disconnected))
1137 } else {
1138 Some((*peer_id, SessionUpdate::UpdateMetrics))
1140 }
1141 } else {
1142 Some((*peer_id, SessionUpdate::InvalidState))
1143 }
1144 }
1145 ConnectionState::Idle => {
1146 if elapsed
1148 > timeout_config
1149 .discovery
1150 .server_reflexive_cache_ttl
1151 {
1152 Some((*peer_id, SessionUpdate::Retry))
1153 } else {
1154 None
1155 }
1156 }
1157 ConnectionState::Migrating => {
1158 if elapsed > timeout_config.nat_traversal.probe_timeout
1160 {
1161 Some((*peer_id, SessionUpdate::MigrationTimeout))
1162 } else {
1163 None
1164 }
1165 }
1166 ConnectionState::Closed => {
1167 if elapsed
1169 > timeout_config.discovery.interface_cache_ttl
1170 {
1171 Some((*peer_id, SessionUpdate::Remove))
1172 } else {
1173 None
1174 }
1175 }
1176 }
1177 })
1178 .collect::<Vec<_>>()
1179 }
1180 _ => {
1181 vec![]
1182 }
1183 }
1184 };
1185
1186 if !sessions_to_update.is_empty() {
1188 if let Ok(mut sessions_guard) = sessions.write() {
1189 for (peer_id, update) in sessions_to_update {
1190 match update {
1191 SessionUpdate::Timeout => {
1192 if let Some(session) = sessions_guard.get_mut(&peer_id) {
1193 session.session_state.state = ConnectionState::Closed;
1194 session.session_state.last_transition =
1195 std::time::Instant::now();
1196 tracing::warn!("Connection to {:?} timed out", peer_id);
1197 }
1198 }
1199 SessionUpdate::Disconnected => {
1200 if let Some(session) = sessions_guard.get_mut(&peer_id) {
1201 session.session_state.state = ConnectionState::Closed;
1202 session.session_state.last_transition =
1203 std::time::Instant::now();
1204 session.session_state.connection = None;
1205 tracing::info!("Connection to {:?} closed", peer_id);
1206 }
1207 }
1208 SessionUpdate::UpdateMetrics => {
1209 if let Some(session) = sessions_guard.get_mut(&peer_id) {
1210 if let Some(ref conn) = session.session_state.connection {
1211 let stats = conn.stats();
1213 session.session_state.metrics.rtt =
1214 Some(stats.path.rtt);
1215 session.session_state.metrics.loss_rate =
1216 stats.path.lost_packets as f64
1217 / stats.path.sent_packets.max(1) as f64;
1218 }
1219 }
1220 }
1221 SessionUpdate::InvalidState => {
1222 if let Some(session) = sessions_guard.get_mut(&peer_id) {
1223 session.session_state.state = ConnectionState::Closed;
1224 session.session_state.last_transition =
1225 std::time::Instant::now();
1226 tracing::error!("Session {:?} in invalid state", peer_id);
1227 }
1228 }
1229 SessionUpdate::Retry => {
1230 if let Some(session) = sessions_guard.get_mut(&peer_id) {
1231 session.session_state.state = ConnectionState::Connecting;
1232 session.session_state.last_transition =
1233 std::time::Instant::now();
1234 session.attempt += 1;
1235 tracing::info!(
1236 "Retrying connection to {:?} (attempt {})",
1237 peer_id,
1238 session.attempt
1239 );
1240 }
1241 }
1242 SessionUpdate::MigrationTimeout => {
1243 if let Some(session) = sessions_guard.get_mut(&peer_id) {
1244 session.session_state.state = ConnectionState::Closed;
1245 session.session_state.last_transition =
1246 std::time::Instant::now();
1247 tracing::warn!("Migration timeout for {:?}", peer_id);
1248 }
1249 }
1250 SessionUpdate::Remove => {
1251 sessions_guard.remove(&peer_id);
1252 tracing::debug!("Removed old session for {:?}", peer_id);
1253 }
1254 }
1255 }
1256 }
1257 }
1258 }
1259 })
1260 }
1261
1262 pub fn inject_observed_address(
1265 &self,
1266 observed_address: SocketAddr,
1267 _from_peer: PeerId,
1268 ) -> Result<(), NatTraversalError> {
1269 info!("Injecting observed address {}", observed_address);
1270
1271 let mut discovery = self.discovery_manager.lock().map_err(|_| {
1273 NatTraversalError::ProtocolError("Discovery manager lock poisoned".to_string())
1274 })?;
1275
1276 let our_peer_id = self.local_peer_id;
1278
1279 match discovery.accept_quic_discovered_address(our_peer_id, observed_address) {
1281 Ok(()) => {
1282 info!(
1283 "Successfully accepted observed address: {}",
1284 observed_address
1285 );
1286
1287 if let Some(ref event_tx) = self.event_tx {
1289 let _ = event_tx.send(NatTraversalEvent::CandidateValidated {
1290 peer_id: our_peer_id,
1291 candidate_address: observed_address,
1292 });
1293 }
1294
1295 Ok(())
1296 }
1297 Err(e) => {
1298 warn!(
1299 "Failed to accept observed address {}: {}",
1300 observed_address, e
1301 );
1302 Err(NatTraversalError::CandidateDiscoveryFailed(e.to_string()))
1303 }
1304 }
1305 }
1306
1307 pub fn get_statistics(&self) -> Result<NatTraversalStatistics, NatTraversalError> {
1309 let sessions = self
1310 .active_sessions
1311 .read()
1312 .map_err(|_| NatTraversalError::ProtocolError("Lock poisoned".to_string()))?;
1313 let bootstrap_nodes = self
1314 .bootstrap_nodes
1315 .read()
1316 .map_err(|_| NatTraversalError::ProtocolError("Lock poisoned".to_string()))?;
1317
1318 let avg_coordination_time = {
1320 let rtts: Vec<Duration> = bootstrap_nodes.iter().filter_map(|b| b.rtt).collect();
1321
1322 if rtts.is_empty() {
1323 Duration::from_millis(500) } else {
1325 let total_millis: u64 = rtts.iter().map(|d| d.as_millis() as u64).sum();
1326 Duration::from_millis(total_millis / rtts.len() as u64 * 2) }
1328 };
1329
1330 Ok(NatTraversalStatistics {
1331 active_sessions: sessions.len(),
1332 total_bootstrap_nodes: bootstrap_nodes.len(),
1333 successful_coordinations: bootstrap_nodes.iter().map(|b| b.coordination_count).sum(),
1334 average_coordination_time: avg_coordination_time,
1335 total_attempts: 0,
1336 successful_connections: 0,
1337 direct_connections: 0,
1338 relayed_connections: 0,
1339 })
1340 }
1341
1342 pub fn add_bootstrap_node(&self, address: SocketAddr) -> Result<(), NatTraversalError> {
1344 let mut bootstrap_nodes = self
1345 .bootstrap_nodes
1346 .write()
1347 .map_err(|_| NatTraversalError::ProtocolError("Lock poisoned".to_string()))?;
1348
1349 if !bootstrap_nodes.iter().any(|b| b.address == address) {
1351 bootstrap_nodes.push(BootstrapNode {
1352 address,
1353 last_seen: std::time::Instant::now(),
1354 can_coordinate: true,
1355 rtt: None,
1356 coordination_count: 0,
1357 });
1358 info!("Added bootstrap node: {}", address);
1359 }
1360 Ok(())
1361 }
1362
1363 pub fn remove_bootstrap_node(&self, address: SocketAddr) -> Result<(), NatTraversalError> {
1365 let mut bootstrap_nodes = self
1366 .bootstrap_nodes
1367 .write()
1368 .map_err(|_| NatTraversalError::ProtocolError("Lock poisoned".to_string()))?;
1369 bootstrap_nodes.retain(|b| b.address != address);
1370 info!("Removed bootstrap node: {}", address);
1371 Ok(())
1372 }
1373
1374 async fn create_quinn_endpoint(
1378 config: &NatTraversalConfig,
1379 _nat_role: NatTraversalRole,
1380 ) -> Result<
1381 (
1382 QuinnEndpoint,
1383 mpsc::UnboundedSender<NatTraversalEvent>,
1384 SocketAddr,
1385 ),
1386 NatTraversalError,
1387 > {
1388 use std::sync::Arc;
1389
1390 let server_config = match config.role {
1392 EndpointRole::Bootstrap | EndpointRole::Server { .. } => {
1393 let cert_config = CertificateConfig {
1395 common_name: format!("ant-quic-{}", config.role.name()),
1396 subject_alt_names: vec!["localhost".to_string(), "ant-quic-node".to_string()],
1397 self_signed: true, ..CertificateConfig::default()
1399 };
1400
1401 let cert_manager = CertificateManager::new(cert_config).map_err(|e| {
1402 NatTraversalError::ConfigError(format!(
1403 "Certificate manager creation failed: {e}"
1404 ))
1405 })?;
1406
1407 let cert_bundle = cert_manager.generate_certificate().map_err(|e| {
1408 NatTraversalError::ConfigError(format!("Certificate generation failed: {e}"))
1409 })?;
1410
1411 let rustls_config =
1412 cert_manager
1413 .create_server_config(&cert_bundle)
1414 .map_err(|e| {
1415 NatTraversalError::ConfigError(format!(
1416 "Server config creation failed: {e}"
1417 ))
1418 })?;
1419
1420 let server_crypto = QuicServerConfig::try_from(rustls_config.as_ref().clone())
1421 .map_err(|e| NatTraversalError::ConfigError(e.to_string()))?;
1422
1423 let mut server_config = ServerConfig::with_crypto(Arc::new(server_crypto));
1424
1425 let mut transport_config = TransportConfig::default();
1427 transport_config
1428 .keep_alive_interval(Some(config.timeouts.nat_traversal.retry_interval));
1429 transport_config.max_idle_timeout(Some(crate::VarInt::from_u32(30000).into()));
1430
1431 let nat_config = match config.role {
1436 EndpointRole::Client => {
1437 crate::transport_parameters::NatTraversalConfig::ClientSupport
1438 }
1439 EndpointRole::Bootstrap | EndpointRole::Server { .. } => {
1440 crate::transport_parameters::NatTraversalConfig::ServerSupport {
1441 concurrency_limit: VarInt::from_u32(
1442 config.max_concurrent_attempts as u32,
1443 ),
1444 }
1445 }
1446 };
1447 transport_config.nat_traversal_config(Some(nat_config));
1448
1449 server_config.transport_config(Arc::new(transport_config));
1450
1451 Some(server_config)
1452 }
1453 _ => None,
1454 };
1455
1456 let client_config = {
1458 let cert_config = CertificateConfig {
1459 common_name: format!("ant-quic-{}", config.role.name()),
1460 subject_alt_names: vec!["localhost".to_string(), "ant-quic-node".to_string()],
1461 self_signed: true,
1462 ..CertificateConfig::default()
1463 };
1464
1465 let cert_manager = CertificateManager::new(cert_config).map_err(|e| {
1466 NatTraversalError::ConfigError(format!("Certificate manager creation failed: {e}"))
1467 })?;
1468
1469 let _cert_bundle = cert_manager.generate_certificate().map_err(|e| {
1470 NatTraversalError::ConfigError(format!("Certificate generation failed: {e}"))
1471 })?;
1472
1473 let rustls_config = cert_manager.create_client_config().map_err(|e| {
1474 NatTraversalError::ConfigError(format!("Client config creation failed: {e}"))
1475 })?;
1476
1477 let client_crypto = QuicClientConfig::try_from(rustls_config.as_ref().clone())
1478 .map_err(|e| NatTraversalError::ConfigError(e.to_string()))?;
1479
1480 let mut client_config = ClientConfig::new(Arc::new(client_crypto));
1481
1482 let mut transport_config = TransportConfig::default();
1484 transport_config.keep_alive_interval(Some(Duration::from_secs(5)));
1485 transport_config.max_idle_timeout(Some(crate::VarInt::from_u32(30000).into()));
1486
1487 let nat_config = match config.role {
1492 EndpointRole::Client => {
1493 crate::transport_parameters::NatTraversalConfig::ClientSupport
1494 }
1495 EndpointRole::Bootstrap | EndpointRole::Server { .. } => {
1496 crate::transport_parameters::NatTraversalConfig::ServerSupport {
1497 concurrency_limit: VarInt::from_u32(config.max_concurrent_attempts as u32),
1498 }
1499 }
1500 };
1501 transport_config.nat_traversal_config(Some(nat_config));
1502
1503 client_config.transport_config(Arc::new(transport_config));
1504
1505 client_config
1506 };
1507
1508 let bind_addr = config
1510 .bind_addr
1511 .unwrap_or_else(create_random_port_bind_addr);
1512 let socket = UdpSocket::bind(bind_addr).await.map_err(|e| {
1513 NatTraversalError::NetworkError(format!("Failed to bind UDP socket: {e}"))
1514 })?;
1515
1516 info!("Binding endpoint to {}", bind_addr);
1517
1518 let std_socket = socket.into_std().map_err(|e| {
1520 NatTraversalError::NetworkError(format!("Failed to convert socket: {e}"))
1521 })?;
1522
1523 let runtime = default_runtime().ok_or_else(|| {
1525 NatTraversalError::ConfigError("No compatible async runtime found".to_string())
1526 })?;
1527
1528 let mut endpoint = QuinnEndpoint::new(
1529 EndpointConfig::default(),
1530 server_config,
1531 std_socket,
1532 runtime,
1533 )
1534 .map_err(|e| {
1535 NatTraversalError::ConfigError(format!("Failed to create Quinn endpoint: {e}"))
1536 })?;
1537
1538 endpoint.set_default_client_config(client_config);
1540
1541 let local_addr = endpoint.local_addr().map_err(|e| {
1543 NatTraversalError::NetworkError(format!("Failed to get local address: {e}"))
1544 })?;
1545
1546 info!("Endpoint bound to actual address: {}", local_addr);
1547
1548 let (event_tx, _event_rx) = mpsc::unbounded_channel();
1550
1551 Ok((endpoint, event_tx, local_addr))
1552 }
1553
1554 pub async fn start_listening(&self, bind_addr: SocketAddr) -> Result<(), NatTraversalError> {
1556 let endpoint = self.quinn_endpoint.as_ref().ok_or_else(|| {
1557 NatTraversalError::ConfigError("Quinn endpoint not initialized".to_string())
1558 })?;
1559
1560 let _socket = UdpSocket::bind(bind_addr).await.map_err(|e| {
1562 NatTraversalError::NetworkError(format!("Failed to bind to {bind_addr}: {e}"))
1563 })?;
1564
1565 info!("Started listening on {}", bind_addr);
1566
1567 let endpoint_clone = endpoint.clone();
1569 let shutdown_clone = self.shutdown.clone();
1570 let event_tx = self.event_tx.as_ref().unwrap().clone();
1571 let connections_clone = self.connections.clone();
1572
1573 tokio::spawn(async move {
1574 Self::accept_connections(endpoint_clone, shutdown_clone, event_tx, connections_clone)
1575 .await;
1576 });
1577
1578 Ok(())
1579 }
1580
1581 async fn accept_connections(
1583 endpoint: QuinnEndpoint,
1584 shutdown: Arc<AtomicBool>,
1585 event_tx: mpsc::UnboundedSender<NatTraversalEvent>,
1586 connections: Arc<std::sync::RwLock<HashMap<PeerId, QuinnConnection>>>,
1587 ) {
1588 while !shutdown.load(Ordering::Relaxed) {
1589 match endpoint.accept().await {
1590 Some(connecting) => {
1591 let event_tx = event_tx.clone();
1592 let connections = connections.clone();
1593 tokio::spawn(async move {
1594 match connecting.await {
1595 Ok(connection) => {
1596 info!("Accepted connection from {}", connection.remote_address());
1597
1598 let peer_id = Self::generate_peer_id_from_address(
1600 connection.remote_address(),
1601 );
1602
1603 if let Ok(mut conns) = connections.write() {
1605 conns.insert(peer_id, connection.clone());
1606 }
1607
1608 let _ = event_tx.send(NatTraversalEvent::ConnectionEstablished {
1609 peer_id,
1610 remote_address: connection.remote_address(),
1611 });
1612
1613 Self::handle_connection(connection, event_tx).await;
1615 }
1616 Err(e) => {
1617 debug!("Connection failed: {}", e);
1618 }
1619 }
1620 });
1621 }
1622 None => {
1623 break;
1625 }
1626 }
1627 }
1628 }
1629
1630 async fn poll_discovery(
1632 discovery_manager: Arc<std::sync::Mutex<CandidateDiscoveryManager>>,
1633 shutdown: Arc<AtomicBool>,
1634 _event_tx: mpsc::UnboundedSender<NatTraversalEvent>,
1635 ) {
1636 use tokio::time::{Duration, interval};
1637
1638 let mut poll_interval = interval(Duration::from_millis(100));
1639
1640 while !shutdown.load(Ordering::Relaxed) {
1641 poll_interval.tick().await;
1642
1643 let events = match discovery_manager.lock() {
1645 Ok(mut discovery) => discovery.poll(std::time::Instant::now()),
1646 Err(e) => {
1647 error!("Failed to lock discovery manager: {}", e);
1648 continue;
1649 }
1650 };
1651
1652 for event in events {
1654 match event {
1655 DiscoveryEvent::DiscoveryStarted {
1656 peer_id,
1657 bootstrap_count,
1658 } => {
1659 debug!(
1660 "Discovery started for peer {:?} with {} bootstrap nodes",
1661 peer_id, bootstrap_count
1662 );
1663 }
1664 DiscoveryEvent::LocalScanningStarted => {
1665 debug!("Local interface scanning started");
1666 }
1667 DiscoveryEvent::LocalCandidateDiscovered { candidate } => {
1668 debug!("Discovered local candidate: {}", candidate.address);
1669 }
1672 DiscoveryEvent::LocalScanningCompleted {
1673 candidate_count,
1674 duration,
1675 } => {
1676 debug!(
1677 "Local interface scanning completed: {} candidates in {:?}",
1678 candidate_count, duration
1679 );
1680 }
1681 DiscoveryEvent::ServerReflexiveDiscoveryStarted { bootstrap_count } => {
1682 debug!(
1683 "Server reflexive discovery started with {} bootstrap nodes",
1684 bootstrap_count
1685 );
1686 }
1687 DiscoveryEvent::ServerReflexiveCandidateDiscovered {
1688 candidate,
1689 bootstrap_node,
1690 } => {
1691 debug!(
1692 "Discovered server-reflexive candidate {} via bootstrap {}",
1693 candidate.address, bootstrap_node
1694 );
1695 }
1697 DiscoveryEvent::BootstrapQueryFailed {
1698 bootstrap_node,
1699 error,
1700 } => {
1701 debug!("Bootstrap query failed for {}: {}", bootstrap_node, error);
1702 }
1703 DiscoveryEvent::SymmetricPredictionStarted { base_address } => {
1704 debug!(
1705 "Symmetric NAT prediction started from base address {}",
1706 base_address
1707 );
1708 }
1709 DiscoveryEvent::PredictedCandidateGenerated {
1710 candidate,
1711 confidence,
1712 } => {
1713 debug!(
1714 "Predicted symmetric NAT candidate {} with confidence {}",
1715 candidate.address, confidence
1716 );
1717 }
1719 DiscoveryEvent::PortAllocationDetected {
1720 port,
1721 source_address,
1722 bootstrap_node,
1723 timestamp,
1724 } => {
1725 debug!(
1726 "Port allocation detected: port {} from {} via bootstrap {:?} at {:?}",
1727 port, source_address, bootstrap_node, timestamp
1728 );
1729 }
1730 DiscoveryEvent::DiscoveryCompleted {
1731 candidate_count,
1732 total_duration,
1733 success_rate,
1734 } => {
1735 info!(
1736 "Discovery completed with {} candidates in {:?} (success rate: {:.2}%)",
1737 candidate_count,
1738 total_duration,
1739 success_rate * 100.0
1740 );
1741 }
1744 DiscoveryEvent::DiscoveryFailed {
1745 error,
1746 partial_results,
1747 } => {
1748 warn!(
1749 "Discovery failed: {} (found {} partial candidates)",
1750 error,
1751 partial_results.len()
1752 );
1753
1754 }
1759 DiscoveryEvent::PathValidationRequested {
1760 candidate_id,
1761 candidate_address,
1762 challenge_token,
1763 } => {
1764 debug!(
1765 "PATH_CHALLENGE requested for candidate {} at {} with token {:08x}",
1766 candidate_id.0, candidate_address, challenge_token
1767 );
1768 }
1771 DiscoveryEvent::PathValidationResponse {
1772 candidate_id,
1773 candidate_address,
1774 challenge_token: _,
1775 rtt,
1776 } => {
1777 debug!(
1778 "PATH_RESPONSE received for candidate {} at {} with RTT {:?}",
1779 candidate_id.0, candidate_address, rtt
1780 );
1781 }
1783 }
1784 }
1785 }
1786
1787 info!("Discovery polling task shutting down");
1788 }
1789
1790 async fn handle_connection(
1792 connection: QuinnConnection,
1793 event_tx: mpsc::UnboundedSender<NatTraversalEvent>,
1794 ) {
1795 let peer_id = Self::generate_peer_id_from_address(connection.remote_address());
1796 let remote_address = connection.remote_address();
1797
1798 debug!(
1799 "Handling connection from peer {:?} at {}",
1800 peer_id, remote_address
1801 );
1802
1803 loop {
1805 tokio::select! {
1806 stream = connection.accept_bi() => {
1807 match stream {
1808 Ok((send, recv)) => {
1809 tokio::spawn(async move {
1810 Self::handle_bi_stream(send, recv).await;
1811 });
1812 }
1813 Err(e) => {
1814 debug!("Error accepting bidirectional stream: {}", e);
1815 let _ = event_tx.send(NatTraversalEvent::ConnectionLost {
1816 peer_id,
1817 reason: format!("Stream error: {e}"),
1818 });
1819 break;
1820 }
1821 }
1822 }
1823 stream = connection.accept_uni() => {
1824 match stream {
1825 Ok(recv) => {
1826 tokio::spawn(async move {
1827 Self::handle_uni_stream(recv).await;
1828 });
1829 }
1830 Err(e) => {
1831 debug!("Error accepting unidirectional stream: {}", e);
1832 let _ = event_tx.send(NatTraversalEvent::ConnectionLost {
1833 peer_id,
1834 reason: format!("Stream error: {e}"),
1835 });
1836 break;
1837 }
1838 }
1839 }
1840 }
1841 }
1842 }
1843
1844 async fn handle_bi_stream(
1846 _send: crate::high_level::SendStream,
1847 _recv: crate::high_level::RecvStream,
1848 ) {
1849 }
1878
1879 async fn handle_uni_stream(mut recv: crate::high_level::RecvStream) {
1881 let mut buffer = vec![0u8; 1024];
1882
1883 loop {
1884 match recv.read(&mut buffer).await {
1885 Ok(Some(size)) => {
1886 debug!("Received {} bytes on unidirectional stream", size);
1887 }
1889 Ok(None) => {
1890 debug!("Unidirectional stream closed by peer");
1891 break;
1892 }
1893 Err(e) => {
1894 debug!("Error reading from unidirectional stream: {}", e);
1895 break;
1896 }
1897 }
1898 }
1899 }
1900
1901 pub async fn connect_to_peer(
1903 &self,
1904 peer_id: PeerId,
1905 server_name: &str,
1906 remote_addr: SocketAddr,
1907 ) -> Result<QuinnConnection, NatTraversalError> {
1908 let endpoint = self.quinn_endpoint.as_ref().ok_or_else(|| {
1909 NatTraversalError::ConfigError("Quinn endpoint not initialized".to_string())
1910 })?;
1911
1912 info!("Connecting to peer {:?} at {}", peer_id, remote_addr);
1913
1914 let connecting = endpoint.connect(remote_addr, server_name).map_err(|e| {
1916 NatTraversalError::ConnectionFailed(format!("Failed to initiate connection: {e}"))
1917 })?;
1918
1919 let connection = timeout(
1920 self.timeout_config
1921 .nat_traversal
1922 .connection_establishment_timeout,
1923 connecting,
1924 )
1925 .await
1926 .map_err(|_| NatTraversalError::Timeout)?
1927 .map_err(|e| NatTraversalError::ConnectionFailed(format!("Connection failed: {e}")))?;
1928
1929 info!(
1930 "Successfully connected to peer {:?} at {}",
1931 peer_id, remote_addr
1932 );
1933
1934 if let Some(ref event_tx) = self.event_tx {
1936 let _ = event_tx.send(NatTraversalEvent::ConnectionEstablished {
1937 peer_id,
1938 remote_address: remote_addr,
1939 });
1940 }
1941
1942 Ok(connection)
1943 }
1944
1945 pub async fn accept_connection(&self) -> Result<(PeerId, QuinnConnection), NatTraversalError> {
1947 let endpoint = self.quinn_endpoint.as_ref().ok_or_else(|| {
1948 NatTraversalError::ConfigError("Quinn endpoint not initialized".to_string())
1949 })?;
1950
1951 let incoming = endpoint
1953 .accept()
1954 .await
1955 .ok_or_else(|| NatTraversalError::NetworkError("Endpoint closed".to_string()))?;
1956
1957 let remote_addr = incoming.remote_address();
1958 info!("Accepting connection from {}", remote_addr);
1959
1960 let connection = incoming.await.map_err(|e| {
1962 NatTraversalError::ConnectionFailed(format!("Failed to accept connection: {e}"))
1963 })?;
1964
1965 let peer_id = self
1967 .extract_peer_id_from_connection(&connection)
1968 .await
1969 .unwrap_or_else(|| Self::generate_peer_id_from_address(remote_addr));
1970
1971 {
1973 let mut connections = self.connections.write().map_err(|_| {
1974 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
1975 })?;
1976 connections.insert(peer_id, connection.clone());
1977 }
1978
1979 info!(
1980 "Connection accepted from peer {:?} at {}",
1981 peer_id, remote_addr
1982 );
1983
1984 if let Some(ref event_tx) = self.event_tx {
1986 let _ = event_tx.send(NatTraversalEvent::ConnectionEstablished {
1987 peer_id,
1988 remote_address: remote_addr,
1989 });
1990 }
1991
1992 Ok((peer_id, connection))
1993 }
1994
1995 pub fn local_peer_id(&self) -> PeerId {
1997 self.local_peer_id
1998 }
1999
2000 pub fn get_connection(
2002 &self,
2003 peer_id: &PeerId,
2004 ) -> Result<Option<QuinnConnection>, NatTraversalError> {
2005 let connections = self.connections.read().map_err(|_| {
2006 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
2007 })?;
2008 Ok(connections.get(peer_id).cloned())
2009 }
2010
2011 pub fn remove_connection(
2013 &self,
2014 peer_id: &PeerId,
2015 ) -> Result<Option<QuinnConnection>, NatTraversalError> {
2016 let mut connections = self.connections.write().map_err(|_| {
2017 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
2018 })?;
2019 Ok(connections.remove(peer_id))
2020 }
2021
2022 pub fn list_connections(&self) -> Result<Vec<(PeerId, SocketAddr)>, NatTraversalError> {
2024 let connections = self.connections.read().map_err(|_| {
2025 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
2026 })?;
2027 let mut result = Vec::new();
2028 for (peer_id, connection) in connections.iter() {
2029 result.push((*peer_id, connection.remote_address()));
2030 }
2031 Ok(result)
2032 }
2033
2034 pub async fn handle_connection_data(
2036 &self,
2037 peer_id: PeerId,
2038 connection: &QuinnConnection,
2039 ) -> Result<(), NatTraversalError> {
2040 info!("Handling connection data from peer {:?}", peer_id);
2041
2042 let connection_clone = connection.clone();
2044 let peer_id_clone = peer_id;
2045 tokio::spawn(async move {
2046 loop {
2047 match connection_clone.accept_bi().await {
2048 Ok((send, recv)) => {
2049 debug!(
2050 "Accepted bidirectional stream from peer {:?}",
2051 peer_id_clone
2052 );
2053 tokio::spawn(Self::handle_bi_stream(send, recv));
2054 }
2055 Err(ConnectionError::ApplicationClosed(_)) => {
2056 debug!("Connection closed by peer {:?}", peer_id_clone);
2057 break;
2058 }
2059 Err(e) => {
2060 debug!(
2061 "Error accepting bidirectional stream from peer {:?}: {}",
2062 peer_id_clone, e
2063 );
2064 break;
2065 }
2066 }
2067 }
2068 });
2069
2070 let connection_clone = connection.clone();
2072 let peer_id_clone = peer_id;
2073 tokio::spawn(async move {
2074 loop {
2075 match connection_clone.accept_uni().await {
2076 Ok(recv) => {
2077 debug!(
2078 "Accepted unidirectional stream from peer {:?}",
2079 peer_id_clone
2080 );
2081 tokio::spawn(Self::handle_uni_stream(recv));
2082 }
2083 Err(ConnectionError::ApplicationClosed(_)) => {
2084 debug!("Connection closed by peer {:?}", peer_id_clone);
2085 break;
2086 }
2087 Err(e) => {
2088 debug!(
2089 "Error accepting unidirectional stream from peer {:?}: {}",
2090 peer_id_clone, e
2091 );
2092 break;
2093 }
2094 }
2095 }
2096 });
2097
2098 Ok(())
2099 }
2100
2101 fn generate_local_peer_id() -> PeerId {
2103 use std::collections::hash_map::DefaultHasher;
2104 use std::hash::{Hash, Hasher};
2105 use std::time::SystemTime;
2106
2107 let mut hasher = DefaultHasher::new();
2108 SystemTime::now().hash(&mut hasher);
2109 std::process::id().hash(&mut hasher);
2110
2111 let hash = hasher.finish();
2112 let mut peer_id = [0u8; 32];
2113 peer_id[0..8].copy_from_slice(&hash.to_be_bytes());
2114
2115 for i in 8..32 {
2117 peer_id[i] = rand::random();
2118 }
2119
2120 PeerId(peer_id)
2121 }
2122
2123 fn generate_peer_id_from_address(addr: SocketAddr) -> PeerId {
2129 use std::collections::hash_map::DefaultHasher;
2130 use std::hash::{Hash, Hasher};
2131
2132 let mut hasher = DefaultHasher::new();
2133 addr.hash(&mut hasher);
2134
2135 let hash = hasher.finish();
2136 let mut peer_id = [0u8; 32];
2137 peer_id[0..8].copy_from_slice(&hash.to_be_bytes());
2138
2139 for i in 8..32 {
2142 peer_id[i] = rand::random();
2143 }
2144
2145 warn!(
2146 "Generated temporary peer ID from address {}. This ID is not persistent!",
2147 addr
2148 );
2149 PeerId(peer_id)
2150 }
2151
2152 async fn extract_peer_id_from_connection(
2154 &self,
2155 connection: &QuinnConnection,
2156 ) -> Option<PeerId> {
2157 if let Some(identity) = connection.peer_identity() {
2159 if let Some(public_key_bytes) = identity.downcast_ref::<[u8; 32]>() {
2161 match crate::derive_peer_id_from_key_bytes(public_key_bytes) {
2163 Ok(peer_id) => {
2164 debug!("Derived peer ID from Ed25519 public key");
2165 return Some(peer_id);
2166 }
2167 Err(e) => {
2168 warn!("Failed to derive peer ID from public key: {}", e);
2169 }
2170 }
2171 }
2172 }
2174
2175 None
2176 }
2177
2178 pub async fn shutdown(&self) -> Result<(), NatTraversalError> {
2180 self.shutdown.store(true, Ordering::Relaxed);
2182
2183 {
2185 let mut connections = self.connections.write().map_err(|_| {
2186 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
2187 })?;
2188 for (peer_id, connection) in connections.drain() {
2189 info!("Closing connection to peer {:?}", peer_id);
2190 connection.close(crate::VarInt::from_u32(0), b"Shutdown");
2191 }
2192 }
2193
2194 if let Some(ref endpoint) = self.quinn_endpoint {
2196 endpoint.wait_idle().await;
2197 }
2198
2199 info!("NAT traversal endpoint shutdown completed");
2200 Ok(())
2201 }
2202
2203 pub async fn discover_candidates(
2205 &self,
2206 peer_id: PeerId,
2207 ) -> Result<Vec<CandidateAddress>, NatTraversalError> {
2208 debug!("Discovering address candidates for peer {:?}", peer_id);
2209
2210 let mut candidates = Vec::new();
2211
2212 let bootstrap_nodes = {
2214 let nodes = self
2215 .bootstrap_nodes
2216 .read()
2217 .map_err(|_| NatTraversalError::ProtocolError("Lock poisoned".to_string()))?;
2218 nodes.clone()
2219 };
2220
2221 {
2223 let mut discovery = self.discovery_manager.lock().map_err(|_| {
2224 NatTraversalError::ProtocolError("Discovery manager lock poisoned".to_string())
2225 })?;
2226
2227 discovery
2228 .start_discovery(peer_id, bootstrap_nodes)
2229 .map_err(|e| NatTraversalError::CandidateDiscoveryFailed(e.to_string()))?;
2230 }
2231
2232 let timeout_duration = self.config.coordination_timeout;
2234 let start_time = std::time::Instant::now();
2235
2236 while start_time.elapsed() < timeout_duration {
2237 let discovery_events = {
2238 let mut discovery = self.discovery_manager.lock().map_err(|_| {
2239 NatTraversalError::ProtocolError("Discovery manager lock poisoned".to_string())
2240 })?;
2241 discovery.poll(std::time::Instant::now())
2242 };
2243
2244 for event in discovery_events {
2245 match event {
2246 DiscoveryEvent::LocalCandidateDiscovered { candidate } => {
2247 candidates.push(candidate.clone());
2248
2249 self.send_candidate_advertisement(peer_id, &candidate)
2251 .await
2252 .unwrap_or_else(|e| {
2253 debug!("Failed to send candidate advertisement: {}", e)
2254 });
2255 }
2256 DiscoveryEvent::ServerReflexiveCandidateDiscovered { candidate, .. } => {
2257 candidates.push(candidate.clone());
2258
2259 self.send_candidate_advertisement(peer_id, &candidate)
2261 .await
2262 .unwrap_or_else(|e| {
2263 debug!("Failed to send candidate advertisement: {}", e)
2264 });
2265 }
2266 DiscoveryEvent::PredictedCandidateGenerated { candidate, .. } => {
2267 candidates.push(candidate.clone());
2268
2269 self.send_candidate_advertisement(peer_id, &candidate)
2271 .await
2272 .unwrap_or_else(|e| {
2273 debug!("Failed to send candidate advertisement: {}", e)
2274 });
2275 }
2276 DiscoveryEvent::DiscoveryCompleted { .. } => {
2277 return Ok(candidates);
2279 }
2280 DiscoveryEvent::DiscoveryFailed {
2281 error,
2282 partial_results,
2283 } => {
2284 candidates.extend(partial_results);
2286 if candidates.is_empty() {
2287 return Err(NatTraversalError::CandidateDiscoveryFailed(
2288 error.to_string(),
2289 ));
2290 }
2291 return Ok(candidates);
2292 }
2293 _ => {}
2294 }
2295 }
2296
2297 sleep(Duration::from_millis(10)).await;
2299 }
2300
2301 if candidates.is_empty() {
2302 Err(NatTraversalError::NoCandidatesFound)
2303 } else {
2304 Ok(candidates)
2305 }
2306 }
2307
2308 fn create_punch_me_now_frame(&self, peer_id: PeerId) -> Result<Vec<u8>, NatTraversalError> {
2310 let mut frame = Vec::new();
2318
2319 frame.push(0x41);
2321
2322 frame.extend_from_slice(&peer_id.0);
2324
2325 let timestamp = std::time::SystemTime::now()
2327 .duration_since(std::time::UNIX_EPOCH)
2328 .unwrap_or_default()
2329 .as_millis() as u64;
2330 frame.extend_from_slice(×tamp.to_be_bytes());
2331
2332 let mut token = [0u8; 16];
2334 for byte in &mut token {
2335 *byte = rand::random();
2336 }
2337 frame.extend_from_slice(&token);
2338
2339 Ok(frame)
2340 }
2341
2342 fn attempt_hole_punching(&self, peer_id: PeerId) -> Result<(), NatTraversalError> {
2343 debug!("Attempting hole punching for peer {:?}", peer_id);
2344
2345 let candidate_pairs = self.get_candidate_pairs_for_peer(peer_id)?;
2347
2348 if candidate_pairs.is_empty() {
2349 return Err(NatTraversalError::NoCandidatesFound);
2350 }
2351
2352 info!(
2353 "Generated {} candidate pairs for hole punching with peer {:?}",
2354 candidate_pairs.len(),
2355 peer_id
2356 );
2357
2358 self.attempt_quinn_hole_punching(peer_id, candidate_pairs)
2361 }
2362
2363 fn get_candidate_pairs_for_peer(
2365 &self,
2366 peer_id: PeerId,
2367 ) -> Result<Vec<CandidatePair>, NatTraversalError> {
2368 let discovery_candidates = {
2370 let discovery = self.discovery_manager.lock().map_err(|_| {
2371 NatTraversalError::ProtocolError("Discovery manager lock poisoned".to_string())
2372 })?;
2373
2374 discovery.get_candidates_for_peer(peer_id)
2375 };
2376
2377 if discovery_candidates.is_empty() {
2378 return Err(NatTraversalError::NoCandidatesFound);
2379 }
2380
2381 let mut candidate_pairs = Vec::new();
2383 let local_candidates = discovery_candidates
2384 .iter()
2385 .filter(|c| matches!(c.source, CandidateSource::Local))
2386 .collect::<Vec<_>>();
2387 let remote_candidates = discovery_candidates
2388 .iter()
2389 .filter(|c| !matches!(c.source, CandidateSource::Local))
2390 .collect::<Vec<_>>();
2391
2392 for local in &local_candidates {
2394 for remote in &remote_candidates {
2395 let pair_priority = self.calculate_candidate_pair_priority(local, remote);
2396 candidate_pairs.push(CandidatePair {
2397 local_candidate: (*local).clone(),
2398 remote_candidate: (*remote).clone(),
2399 priority: pair_priority,
2400 state: CandidatePairState::Waiting,
2401 });
2402 }
2403 }
2404
2405 candidate_pairs.sort_by(|a, b| b.priority.cmp(&a.priority));
2407
2408 candidate_pairs.truncate(8);
2410
2411 Ok(candidate_pairs)
2412 }
2413
2414 fn calculate_candidate_pair_priority(
2416 &self,
2417 local: &CandidateAddress,
2418 remote: &CandidateAddress,
2419 ) -> u64 {
2420 let local_type_preference = match local.source {
2424 CandidateSource::Local => 126,
2425 CandidateSource::Observed { .. } => 100,
2426 CandidateSource::Predicted => 75,
2427 CandidateSource::Peer => 50,
2428 };
2429
2430 let remote_type_preference = match remote.source {
2431 CandidateSource::Local => 126,
2432 CandidateSource::Observed { .. } => 100,
2433 CandidateSource::Predicted => 75,
2434 CandidateSource::Peer => 50,
2435 };
2436
2437 let local_priority = (local_type_preference as u64) << 8 | local.priority as u64;
2439 let remote_priority = (remote_type_preference as u64) << 8 | remote.priority as u64;
2440
2441 let min_priority = local_priority.min(remote_priority);
2442 let max_priority = local_priority.max(remote_priority);
2443
2444 (min_priority << 32)
2445 | (max_priority << 1)
2446 | if local_priority > remote_priority {
2447 1
2448 } else {
2449 0
2450 }
2451 }
2452
2453 fn attempt_quinn_hole_punching(
2455 &self,
2456 peer_id: PeerId,
2457 candidate_pairs: Vec<CandidatePair>,
2458 ) -> Result<(), NatTraversalError> {
2459 let _endpoint = self.quinn_endpoint.as_ref().ok_or_else(|| {
2460 NatTraversalError::ConfigError("Quinn endpoint not initialized".to_string())
2461 })?;
2462
2463 for pair in candidate_pairs {
2464 debug!(
2465 "Attempting hole punch with candidate pair: {} -> {}",
2466 pair.local_candidate.address, pair.remote_candidate.address
2467 );
2468
2469 let mut challenge_data = [0u8; 8];
2471 for byte in &mut challenge_data {
2472 *byte = rand::random();
2473 }
2474
2475 let local_socket =
2477 std::net::UdpSocket::bind(pair.local_candidate.address).map_err(|e| {
2478 NatTraversalError::NetworkError(format!(
2479 "Failed to bind to local candidate: {e}"
2480 ))
2481 })?;
2482
2483 let path_challenge_packet = self.create_path_challenge_packet(challenge_data)?;
2485
2486 match local_socket.send_to(&path_challenge_packet, pair.remote_candidate.address) {
2488 Ok(bytes_sent) => {
2489 debug!(
2490 "Sent {} bytes for hole punch from {} to {}",
2491 bytes_sent, pair.local_candidate.address, pair.remote_candidate.address
2492 );
2493
2494 local_socket
2496 .set_read_timeout(Some(Duration::from_millis(100)))
2497 .map_err(|e| {
2498 NatTraversalError::NetworkError(format!("Failed to set timeout: {e}"))
2499 })?;
2500
2501 let mut response_buffer = [0u8; 1024];
2503 match local_socket.recv_from(&mut response_buffer) {
2504 Ok((_bytes_received, response_addr)) => {
2505 if response_addr == pair.remote_candidate.address {
2506 info!(
2507 "Hole punch succeeded for peer {:?}: {} <-> {}",
2508 peer_id,
2509 pair.local_candidate.address,
2510 pair.remote_candidate.address
2511 );
2512
2513 self.store_successful_candidate_pair(peer_id, pair)?;
2515 return Ok(());
2516 } else {
2517 debug!(
2518 "Received response from unexpected address: {}",
2519 response_addr
2520 );
2521 }
2522 }
2523 Err(e)
2524 if e.kind() == std::io::ErrorKind::WouldBlock
2525 || e.kind() == std::io::ErrorKind::TimedOut =>
2526 {
2527 debug!("No response received for hole punch attempt");
2528 }
2529 Err(e) => {
2530 debug!("Error receiving hole punch response: {}", e);
2531 }
2532 }
2533 }
2534 Err(e) => {
2535 debug!("Failed to send hole punch packet: {}", e);
2536 }
2537 }
2538 }
2539
2540 Err(NatTraversalError::HolePunchingFailed)
2542 }
2543
2544 fn create_path_challenge_packet(
2546 &self,
2547 challenge_data: [u8; 8],
2548 ) -> Result<Vec<u8>, NatTraversalError> {
2549 let mut packet = Vec::new();
2552
2553 packet.push(0x40); packet.extend_from_slice(&[0, 0, 0, 1]); packet.push(0x1a); packet.extend_from_slice(&challenge_data); Ok(packet)
2562 }
2563
2564 fn store_successful_candidate_pair(
2566 &self,
2567 peer_id: PeerId,
2568 pair: CandidatePair,
2569 ) -> Result<(), NatTraversalError> {
2570 debug!(
2571 "Storing successful candidate pair for peer {:?}: {} <-> {}",
2572 peer_id, pair.local_candidate.address, pair.remote_candidate.address
2573 );
2574
2575 if let Some(ref callback) = self.event_callback {
2580 callback(NatTraversalEvent::PathValidated {
2581 peer_id,
2582 address: pair.remote_candidate.address,
2583 rtt: Duration::from_millis(50), });
2585
2586 callback(NatTraversalEvent::TraversalSucceeded {
2587 peer_id,
2588 final_address: pair.remote_candidate.address,
2589 total_time: Duration::from_secs(1), });
2591 }
2592
2593 Ok(())
2594 }
2595
2596 fn attempt_connection_to_candidate(
2598 &self,
2599 peer_id: PeerId,
2600 candidate: &CandidateAddress,
2601 ) -> Result<(), NatTraversalError> {
2602 {
2603 let endpoint = self.quinn_endpoint.as_ref().ok_or_else(|| {
2604 NatTraversalError::ConfigError("Quinn endpoint not initialized".to_string())
2605 })?;
2606
2607 let server_name = format!("peer-{:x}", peer_id.0[0] as u32);
2609
2610 debug!(
2611 "Attempting Quinn connection to candidate {} for peer {:?}",
2612 candidate.address, peer_id
2613 );
2614
2615 match endpoint.connect(candidate.address, &server_name) {
2617 Ok(connecting) => {
2618 info!(
2619 "Connection attempt initiated to {} for peer {:?}",
2620 candidate.address, peer_id
2621 );
2622
2623 if let Some(event_tx) = &self.event_tx {
2625 let event_tx = event_tx.clone();
2626 let connections = self.connections.clone();
2627 let peer_id_clone = peer_id;
2628 let address = candidate.address;
2629
2630 tokio::spawn(async move {
2631 match connecting.await {
2632 Ok(connection) => {
2633 info!(
2634 "Successfully connected to {} for peer {:?}",
2635 address, peer_id_clone
2636 );
2637
2638 if let Ok(mut conns) = connections.write() {
2640 conns.insert(peer_id_clone, connection.clone());
2641 }
2642
2643 let _ =
2645 event_tx.send(NatTraversalEvent::ConnectionEstablished {
2646 peer_id: peer_id_clone,
2647 remote_address: address,
2648 });
2649
2650 Self::handle_connection(connection, event_tx).await;
2652 }
2653 Err(e) => {
2654 warn!("Connection to {} failed: {}", address, e);
2655 }
2656 }
2657 });
2658 }
2659
2660 Ok(())
2661 }
2662 Err(e) => {
2663 warn!(
2664 "Failed to initiate connection to {}: {}",
2665 candidate.address, e
2666 );
2667 Err(NatTraversalError::ConnectionFailed(format!(
2668 "Failed to connect to {}: {}",
2669 candidate.address, e
2670 )))
2671 }
2672 }
2673 }
2674 }
2675
2676 pub fn poll(
2678 &self,
2679 now: std::time::Instant,
2680 ) -> Result<Vec<NatTraversalEvent>, NatTraversalError> {
2681 let mut events = Vec::new();
2682
2683 self.check_connections_for_observed_addresses(&mut events)?;
2685
2686 {
2688 let mut discovery = self.discovery_manager.lock().map_err(|_| {
2689 NatTraversalError::ProtocolError("Discovery manager lock poisoned".to_string())
2690 })?;
2691
2692 let discovery_events = discovery.poll(now);
2693
2694 for discovery_event in discovery_events {
2696 if let Some(nat_event) = self.convert_discovery_event(discovery_event) {
2697 events.push(nat_event.clone());
2698
2699 if let Some(ref callback) = self.event_callback {
2701 callback(nat_event.clone());
2702 }
2703
2704 if let NatTraversalEvent::CandidateDiscovered {
2706 peer_id: _,
2707 candidate: _,
2708 } = &nat_event
2709 {
2710 }
2713 }
2714 }
2715 }
2716
2717 let mut sessions = self
2719 .active_sessions
2720 .write()
2721 .map_err(|_| NatTraversalError::ProtocolError("Lock poisoned".to_string()))?;
2722
2723 for (_peer_id, session) in sessions.iter_mut() {
2724 let elapsed = now.duration_since(session.started_at);
2725
2726 let timeout = self.get_phase_timeout(session.phase);
2728
2729 if elapsed > timeout {
2731 match session.phase {
2732 TraversalPhase::Discovery => {
2733 let discovered_candidates = {
2735 let discovery = self.discovery_manager.lock().map_err(|_| {
2736 NatTraversalError::ProtocolError(
2737 "Discovery manager lock poisoned".to_string(),
2738 )
2739 });
2740 match discovery {
2741 Ok(disc) => disc.get_candidates_for_peer(session.peer_id),
2742 Err(_) => Vec::new(),
2743 }
2744 };
2745
2746 session.candidates = discovered_candidates.clone();
2748
2749 if !session.candidates.is_empty() {
2751 session.phase = TraversalPhase::Coordination;
2753 let event = NatTraversalEvent::PhaseTransition {
2754 peer_id: session.peer_id,
2755 from_phase: TraversalPhase::Discovery,
2756 to_phase: TraversalPhase::Coordination,
2757 };
2758 events.push(event.clone());
2759 if let Some(ref callback) = self.event_callback {
2760 callback(event);
2761 }
2762 info!(
2763 "Peer {:?} advanced from Discovery to Coordination with {} candidates",
2764 session.peer_id,
2765 session.candidates.len()
2766 );
2767 } else if session.attempt < self.config.max_concurrent_attempts as u32 {
2768 session.attempt += 1;
2770 session.started_at = now;
2771 let backoff_duration = self.calculate_backoff(session.attempt);
2772 warn!(
2773 "Discovery timeout for peer {:?}, retrying (attempt {}), backoff: {:?}",
2774 session.peer_id, session.attempt, backoff_duration
2775 );
2776 } else {
2777 session.phase = TraversalPhase::Failed;
2779 let event = NatTraversalEvent::TraversalFailed {
2780 peer_id: session.peer_id,
2781 error: NatTraversalError::NoCandidatesFound,
2782 fallback_available: self.config.enable_relay_fallback,
2783 };
2784 events.push(event.clone());
2785 if let Some(ref callback) = self.event_callback {
2786 callback(event);
2787 }
2788 error!(
2789 "NAT traversal failed for peer {:?}: no candidates found after {} attempts",
2790 session.peer_id, session.attempt
2791 );
2792 }
2793 }
2794 TraversalPhase::Coordination => {
2795 if let Some(coordinator) = self.select_coordinator() {
2797 match self.send_coordination_request(session.peer_id, coordinator) {
2798 Ok(_) => {
2799 session.phase = TraversalPhase::Synchronization;
2800 let event = NatTraversalEvent::CoordinationRequested {
2801 peer_id: session.peer_id,
2802 coordinator,
2803 };
2804 events.push(event.clone());
2805 if let Some(ref callback) = self.event_callback {
2806 callback(event);
2807 }
2808 info!(
2809 "Coordination requested for peer {:?} via {}",
2810 session.peer_id, coordinator
2811 );
2812 }
2813 Err(e) => {
2814 self.handle_phase_failure(session, now, &mut events, e);
2815 }
2816 }
2817 } else {
2818 self.handle_phase_failure(
2819 session,
2820 now,
2821 &mut events,
2822 NatTraversalError::NoBootstrapNodes,
2823 );
2824 }
2825 }
2826 TraversalPhase::Synchronization => {
2827 if self.is_peer_synchronized(&session.peer_id) {
2829 session.phase = TraversalPhase::Punching;
2830 let event = NatTraversalEvent::HolePunchingStarted {
2831 peer_id: session.peer_id,
2832 targets: session.candidates.iter().map(|c| c.address).collect(),
2833 };
2834 events.push(event.clone());
2835 if let Some(ref callback) = self.event_callback {
2836 callback(event);
2837 }
2838 if let Err(e) =
2840 self.initiate_hole_punching(session.peer_id, &session.candidates)
2841 {
2842 self.handle_phase_failure(session, now, &mut events, e);
2843 }
2844 } else {
2845 self.handle_phase_failure(
2846 session,
2847 now,
2848 &mut events,
2849 NatTraversalError::ProtocolError(
2850 "Synchronization timeout".to_string(),
2851 ),
2852 );
2853 }
2854 }
2855 TraversalPhase::Punching => {
2856 if let Some(successful_path) = self.check_punch_results(&session.peer_id) {
2858 session.phase = TraversalPhase::Validation;
2859 let event = NatTraversalEvent::PathValidated {
2860 peer_id: session.peer_id,
2861 address: successful_path,
2862 rtt: Duration::from_millis(50), };
2864 events.push(event.clone());
2865 if let Some(ref callback) = self.event_callback {
2866 callback(event);
2867 }
2868 if let Err(e) = self.validate_path(session.peer_id, successful_path) {
2870 self.handle_phase_failure(session, now, &mut events, e);
2871 }
2872 } else {
2873 self.handle_phase_failure(
2874 session,
2875 now,
2876 &mut events,
2877 NatTraversalError::PunchingFailed(
2878 "No successful punch".to_string(),
2879 ),
2880 );
2881 }
2882 }
2883 TraversalPhase::Validation => {
2884 if self.is_path_validated(&session.peer_id) {
2886 session.phase = TraversalPhase::Connected;
2887 let event = NatTraversalEvent::TraversalSucceeded {
2888 peer_id: session.peer_id,
2889 final_address: session
2890 .candidates
2891 .first()
2892 .map(|c| c.address)
2893 .unwrap_or_else(create_random_port_bind_addr),
2894 total_time: elapsed,
2895 };
2896 events.push(event.clone());
2897 if let Some(ref callback) = self.event_callback {
2898 callback(event);
2899 }
2900 info!(
2901 "NAT traversal succeeded for peer {:?} in {:?}",
2902 session.peer_id, elapsed
2903 );
2904 } else {
2905 self.handle_phase_failure(
2906 session,
2907 now,
2908 &mut events,
2909 NatTraversalError::ValidationFailed(
2910 "Path validation timeout".to_string(),
2911 ),
2912 );
2913 }
2914 }
2915 TraversalPhase::Connected => {
2916 if !self.is_connection_healthy(&session.peer_id) {
2918 warn!(
2919 "Connection to peer {:?} is no longer healthy",
2920 session.peer_id
2921 );
2922 }
2924 }
2925 TraversalPhase::Failed => {
2926 }
2928 }
2929 }
2930 }
2931
2932 Ok(events)
2933 }
2934
2935 fn get_phase_timeout(&self, phase: TraversalPhase) -> Duration {
2937 match phase {
2938 TraversalPhase::Discovery => Duration::from_secs(10),
2939 TraversalPhase::Coordination => self.config.coordination_timeout,
2940 TraversalPhase::Synchronization => Duration::from_secs(3),
2941 TraversalPhase::Punching => Duration::from_secs(5),
2942 TraversalPhase::Validation => Duration::from_secs(5),
2943 TraversalPhase::Connected => Duration::from_secs(30), TraversalPhase::Failed => Duration::ZERO,
2945 }
2946 }
2947
2948 fn calculate_backoff(&self, attempt: u32) -> Duration {
2950 let base = Duration::from_millis(1000);
2951 let max = Duration::from_secs(30);
2952 let backoff = base * 2u32.pow(attempt.saturating_sub(1));
2953 let jitter = std::time::Duration::from_millis((rand::random::<u64>() % 200) as u64);
2954 backoff.min(max) + jitter
2955 }
2956
2957 fn check_connections_for_observed_addresses(
2959 &self,
2960 _events: &mut Vec<NatTraversalEvent>,
2961 ) -> Result<(), NatTraversalError> {
2962 let connections = self.connections.read().map_err(|_| {
2964 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
2965 })?;
2966
2967 if !connections.is_empty() && self.config.role == EndpointRole::Client {
2974 for (_peer_id, connection) in connections.iter() {
2976 let remote_addr = connection.remote_address();
2977
2978 let is_bootstrap = {
2980 let bootstrap_nodes = self.bootstrap_nodes.read().map_err(|_| {
2981 NatTraversalError::ProtocolError(
2982 "Bootstrap nodes lock poisoned".to_string(),
2983 )
2984 })?;
2985 bootstrap_nodes
2986 .iter()
2987 .any(|node| node.address == remote_addr)
2988 };
2989
2990 if is_bootstrap {
2991 debug!(
2994 "Bootstrap connection to {} should provide our external address via OBSERVED_ADDRESS frames",
2995 remote_addr
2996 );
2997
2998 }
3001 }
3002 }
3003
3004 Ok(())
3005 }
3006
3007 fn handle_phase_failure(
3009 &self,
3010 session: &mut NatTraversalSession,
3011 now: std::time::Instant,
3012 events: &mut Vec<NatTraversalEvent>,
3013 error: NatTraversalError,
3014 ) {
3015 if session.attempt < self.config.max_concurrent_attempts as u32 {
3016 session.attempt += 1;
3018 session.started_at = now;
3019 let backoff = self.calculate_backoff(session.attempt);
3020 warn!(
3021 "Phase {:?} failed for peer {:?}: {:?}, retrying (attempt {}) after {:?}",
3022 session.phase, session.peer_id, error, session.attempt, backoff
3023 );
3024 } else {
3025 session.phase = TraversalPhase::Failed;
3027 let event = NatTraversalEvent::TraversalFailed {
3028 peer_id: session.peer_id,
3029 error,
3030 fallback_available: self.config.enable_relay_fallback,
3031 };
3032 events.push(event.clone());
3033 if let Some(ref callback) = self.event_callback {
3034 callback(event);
3035 }
3036 error!(
3037 "NAT traversal failed for peer {:?} after {} attempts",
3038 session.peer_id, session.attempt
3039 );
3040 }
3041 }
3042
3043 fn select_coordinator(&self) -> Option<SocketAddr> {
3045 if let Ok(nodes) = self.bootstrap_nodes.read() {
3046 if !nodes.is_empty() {
3048 let idx = rand::random::<usize>() % nodes.len();
3049 return Some(nodes[idx].address);
3050 }
3051 }
3052 None
3053 }
3054
3055 fn send_coordination_request(
3057 &self,
3058 peer_id: PeerId,
3059 coordinator: SocketAddr,
3060 ) -> Result<(), NatTraversalError> {
3061 debug!(
3062 "Sending coordination request for peer {:?} to {}",
3063 peer_id, coordinator
3064 );
3065
3066 {
3067 if let Ok(connections) = self.connections.read() {
3069 for (_peer, conn) in connections.iter() {
3071 if conn.remote_address() == coordinator {
3072 info!("Found existing connection to coordinator {}", coordinator);
3076 return Ok(());
3077 }
3078 }
3079 }
3080
3081 info!("Establishing connection to coordinator {}", coordinator);
3083 if let Some(endpoint) = &self.quinn_endpoint {
3084 let server_name = format!("bootstrap-{}", coordinator.ip());
3085 match endpoint.connect(coordinator, &server_name) {
3086 Ok(connecting) => {
3087 info!("Initiated connection to coordinator {}", coordinator);
3089
3090 if let Some(event_tx) = &self.event_tx {
3092 let event_tx = event_tx.clone();
3093 let connections = self.connections.clone();
3094
3095 tokio::spawn(async move {
3096 match connecting.await {
3097 Ok(connection) => {
3098 info!("Connected to coordinator {}", coordinator);
3099
3100 let bootstrap_peer_id =
3102 Self::generate_peer_id_from_address(coordinator);
3103
3104 if let Ok(mut conns) = connections.write() {
3106 conns.insert(bootstrap_peer_id, connection.clone());
3107 }
3108
3109 Self::handle_connection(connection, event_tx).await;
3111 }
3112 Err(e) => {
3113 warn!(
3114 "Failed to connect to coordinator {}: {}",
3115 coordinator, e
3116 );
3117 }
3118 }
3119 });
3120 }
3121
3122 Ok(())
3125 }
3126 Err(e) => Err(NatTraversalError::CoordinationFailed(format!(
3127 "Failed to connect to coordinator {coordinator}: {e}"
3128 ))),
3129 }
3130 } else {
3131 Err(NatTraversalError::ConfigError(
3132 "Quinn endpoint not initialized".to_string(),
3133 ))
3134 }
3135 }
3136 }
3137
3138 fn is_peer_synchronized(&self, peer_id: &PeerId) -> bool {
3140 debug!("Checking synchronization status for peer {:?}", peer_id);
3141
3142 if let Ok(sessions) = self.active_sessions.read() {
3144 if let Some(session) = sessions.get(peer_id) {
3145 let has_candidates = !session.candidates.is_empty();
3148 let past_discovery = session.phase as u8 > TraversalPhase::Discovery as u8;
3149
3150 debug!(
3151 "Checking sync for peer {:?}: phase={:?}, candidates={}, past_discovery={}",
3152 peer_id,
3153 session.phase,
3154 session.candidates.len(),
3155 past_discovery
3156 );
3157
3158 if has_candidates && past_discovery {
3159 info!(
3160 "Peer {:?} is synchronized with {} candidates",
3161 peer_id,
3162 session.candidates.len()
3163 );
3164 return true;
3165 }
3166
3167 if session.phase == TraversalPhase::Synchronization && has_candidates {
3169 info!(
3170 "Peer {:?} in synchronization phase with {} candidates, considering synchronized",
3171 peer_id,
3172 session.candidates.len()
3173 );
3174 return true;
3175 }
3176
3177 if session.phase as u8 >= TraversalPhase::Synchronization as u8 {
3179 info!(
3180 "Test mode: Considering peer {:?} synchronized in phase {:?}",
3181 peer_id, session.phase
3182 );
3183 return true;
3184 }
3185 }
3186 }
3187
3188 warn!("Peer {:?} is not synchronized", peer_id);
3189 false
3190 }
3191
3192 fn initiate_hole_punching(
3194 &self,
3195 peer_id: PeerId,
3196 candidates: &[CandidateAddress],
3197 ) -> Result<(), NatTraversalError> {
3198 if candidates.is_empty() {
3199 return Err(NatTraversalError::NoCandidatesFound);
3200 }
3201
3202 info!(
3203 "Initiating hole punching for peer {:?} to {} candidates",
3204 peer_id,
3205 candidates.len()
3206 );
3207
3208 {
3209 for candidate in candidates {
3211 debug!(
3212 "Attempting QUIC connection to candidate: {}",
3213 candidate.address
3214 );
3215
3216 match self.attempt_connection_to_candidate(peer_id, candidate) {
3218 Ok(_) => {
3219 info!(
3220 "Successfully initiated connection attempt to {}",
3221 candidate.address
3222 );
3223 }
3224 Err(e) => {
3225 warn!(
3226 "Failed to initiate connection to {}: {:?}",
3227 candidate.address, e
3228 );
3229 }
3230 }
3231 }
3232
3233 Ok(())
3234 }
3235 }
3236
3237 fn check_punch_results(&self, peer_id: &PeerId) -> Option<SocketAddr> {
3239 {
3240 if let Ok(connections) = self.connections.read() {
3242 if let Some(conn) = connections.get(peer_id) {
3243 let addr = conn.remote_address();
3245 info!(
3246 "Found successful connection to peer {:?} at {}",
3247 peer_id, addr
3248 );
3249 return Some(addr);
3250 }
3251 }
3252 }
3253
3254 if let Ok(sessions) = self.active_sessions.read() {
3256 if let Some(session) = sessions.get(peer_id) {
3257 for candidate in &session.candidates {
3259 if matches!(candidate.state, CandidateState::Valid) {
3260 info!(
3261 "Found validated candidate for peer {:?} at {}",
3262 peer_id, candidate.address
3263 );
3264 return Some(candidate.address);
3265 }
3266 }
3267
3268 if session.phase == TraversalPhase::Punching && !session.candidates.is_empty() {
3270 let addr = session.candidates[0].address;
3271 info!(
3272 "Simulating successful punch for testing: peer {:?} at {}",
3273 peer_id, addr
3274 );
3275 return Some(addr);
3276 }
3277
3278 if let Some(first) = session.candidates.first() {
3280 debug!(
3281 "No validated candidates, using first candidate {} for peer {:?}",
3282 first.address, peer_id
3283 );
3284 return Some(first.address);
3285 }
3286 }
3287 }
3288
3289 warn!("No successful punch results for peer {:?}", peer_id);
3290 None
3291 }
3292
3293 fn validate_path(&self, peer_id: PeerId, address: SocketAddr) -> Result<(), NatTraversalError> {
3295 debug!("Validating path to peer {:?} at {}", peer_id, address);
3296
3297 {
3298 if let Ok(connections) = self.connections.read() {
3300 if let Some(conn) = connections.get(&peer_id) {
3301 if conn.remote_address() == address {
3303 info!(
3304 "Path validation successful for peer {:?} at {}",
3305 peer_id, address
3306 );
3307
3308 if let Ok(mut sessions) = self.active_sessions.write() {
3310 if let Some(session) = sessions.get_mut(&peer_id) {
3311 for candidate in &mut session.candidates {
3312 if candidate.address == address {
3313 candidate.state = CandidateState::Valid;
3314 break;
3315 }
3316 }
3317 }
3318 }
3319
3320 return Ok(());
3321 } else {
3322 warn!(
3323 "Connection address mismatch: expected {}, got {}",
3324 address,
3325 conn.remote_address()
3326 );
3327 }
3328 }
3329 }
3330
3331 Err(NatTraversalError::ValidationFailed(format!(
3333 "No connection found for peer {peer_id:?} at {address}"
3334 )))
3335 }
3336 }
3337
3338 fn is_path_validated(&self, peer_id: &PeerId) -> bool {
3340 debug!("Checking path validation for peer {:?}", peer_id);
3341
3342 {
3343 if let Ok(connections) = self.connections.read() {
3345 if connections.contains_key(peer_id) {
3346 info!("Path validated: connection exists for peer {:?}", peer_id);
3347 return true;
3348 }
3349 }
3350 }
3351
3352 if let Ok(sessions) = self.active_sessions.read() {
3354 if let Some(session) = sessions.get(peer_id) {
3355 let validated = session
3356 .candidates
3357 .iter()
3358 .any(|c| matches!(c.state, CandidateState::Valid));
3359
3360 if validated {
3361 info!(
3362 "Path validated: found validated candidate for peer {:?}",
3363 peer_id
3364 );
3365 return true;
3366 }
3367 }
3368 }
3369
3370 warn!("Path not validated for peer {:?}", peer_id);
3371 false
3372 }
3373
3374 fn is_connection_healthy(&self, peer_id: &PeerId) -> bool {
3376 {
3379 if let Ok(connections) = self.connections.read() {
3380 if let Some(_conn) = connections.get(peer_id) {
3381 return true; }
3386 }
3387 }
3388 true
3389 }
3390
3391 fn convert_discovery_event(
3393 &self,
3394 discovery_event: DiscoveryEvent,
3395 ) -> Option<NatTraversalEvent> {
3396 let current_peer_id = self.get_current_discovery_peer_id();
3398
3399 match discovery_event {
3400 DiscoveryEvent::LocalCandidateDiscovered { candidate } => {
3401 Some(NatTraversalEvent::CandidateDiscovered {
3402 peer_id: current_peer_id,
3403 candidate,
3404 })
3405 }
3406 DiscoveryEvent::ServerReflexiveCandidateDiscovered {
3407 candidate,
3408 bootstrap_node: _,
3409 } => Some(NatTraversalEvent::CandidateDiscovered {
3410 peer_id: current_peer_id,
3411 candidate,
3412 }),
3413 DiscoveryEvent::PredictedCandidateGenerated {
3414 candidate,
3415 confidence: _,
3416 } => Some(NatTraversalEvent::CandidateDiscovered {
3417 peer_id: current_peer_id,
3418 candidate,
3419 }),
3420 DiscoveryEvent::DiscoveryCompleted {
3421 candidate_count: _,
3422 total_duration: _,
3423 success_rate: _,
3424 } => {
3425 None }
3428 DiscoveryEvent::DiscoveryFailed {
3429 error,
3430 partial_results,
3431 } => Some(NatTraversalEvent::TraversalFailed {
3432 peer_id: current_peer_id,
3433 error: NatTraversalError::CandidateDiscoveryFailed(error.to_string()),
3434 fallback_available: !partial_results.is_empty(),
3435 }),
3436 _ => None, }
3438 }
3439
3440 fn get_current_discovery_peer_id(&self) -> PeerId {
3442 if let Ok(sessions) = self.active_sessions.read() {
3444 if let Some((peer_id, _session)) = sessions
3445 .iter()
3446 .find(|(_, s)| matches!(s.phase, TraversalPhase::Discovery))
3447 {
3448 return *peer_id;
3449 }
3450
3451 if let Some((peer_id, _)) = sessions.iter().next() {
3453 return *peer_id;
3454 }
3455 }
3456
3457 self.local_peer_id
3459 }
3460
3461 pub(crate) async fn handle_endpoint_event(
3463 &self,
3464 event: crate::shared::EndpointEventInner,
3465 ) -> Result<(), NatTraversalError> {
3466 match event {
3467 crate::shared::EndpointEventInner::NatCandidateValidated { address, challenge } => {
3468 info!(
3469 "NAT candidate validation succeeded for {} with challenge {:016x}",
3470 address, challenge
3471 );
3472
3473 let mut sessions = self.active_sessions.write().map_err(|_| {
3475 NatTraversalError::ProtocolError("Sessions lock poisoned".to_string())
3476 })?;
3477
3478 for (peer_id, session) in sessions.iter_mut() {
3480 if session.candidates.iter().any(|c| c.address == address) {
3481 session.phase = TraversalPhase::Connected;
3483
3484 if let Some(ref callback) = self.event_callback {
3486 callback(NatTraversalEvent::CandidateValidated {
3487 peer_id: *peer_id,
3488 candidate_address: address,
3489 });
3490 }
3491
3492 return self
3494 .establish_connection_to_validated_candidate(*peer_id, address)
3495 .await;
3496 }
3497 }
3498
3499 debug!(
3500 "Validated candidate {} not found in active sessions",
3501 address
3502 );
3503 Ok(())
3504 }
3505
3506 crate::shared::EndpointEventInner::RelayPunchMeNow(target_peer_id, punch_frame) => {
3507 info!("Relaying PUNCH_ME_NOW to peer {:?}", target_peer_id);
3508
3509 let target_peer = PeerId(target_peer_id);
3511
3512 let connections = self.connections.read().map_err(|_| {
3514 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
3515 })?;
3516
3517 if let Some(connection) = connections.get(&target_peer) {
3518 let mut send_stream = connection.open_uni().await.map_err(|e| {
3520 NatTraversalError::NetworkError(format!("Failed to open stream: {e}"))
3521 })?;
3522
3523 let mut frame_data = Vec::new();
3525 punch_frame.encode(&mut frame_data);
3526
3527 send_stream.write_all(&frame_data).await.map_err(|e| {
3528 NatTraversalError::NetworkError(format!("Failed to send frame: {e}"))
3529 })?;
3530
3531 send_stream.finish();
3532
3533 debug!(
3534 "Successfully relayed PUNCH_ME_NOW frame to peer {:?}",
3535 target_peer
3536 );
3537 Ok(())
3538 } else {
3539 warn!("No connection found for target peer {:?}", target_peer);
3540 Err(NatTraversalError::PeerNotConnected)
3541 }
3542 }
3543
3544 crate::shared::EndpointEventInner::SendAddressFrame(add_address_frame) => {
3545 info!(
3546 "Sending AddAddress frame for address {}",
3547 add_address_frame.address
3548 );
3549
3550 let connections = self.connections.read().map_err(|_| {
3552 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
3553 })?;
3554
3555 for (peer_id, connection) in connections.iter() {
3556 let mut send_stream = connection.open_uni().await.map_err(|e| {
3558 NatTraversalError::NetworkError(format!("Failed to open stream: {e}"))
3559 })?;
3560
3561 let mut frame_data = Vec::new();
3563 add_address_frame.encode(&mut frame_data);
3564
3565 send_stream.write_all(&frame_data).await.map_err(|e| {
3566 NatTraversalError::NetworkError(format!("Failed to send frame: {e}"))
3567 })?;
3568
3569 send_stream.finish();
3570
3571 debug!("Sent AddAddress frame to peer {:?}", peer_id);
3572 }
3573
3574 Ok(())
3575 }
3576
3577 _ => {
3578 debug!("Ignoring non-NAT traversal endpoint event: {:?}", event);
3580 Ok(())
3581 }
3582 }
3583 }
3584
3585 async fn establish_connection_to_validated_candidate(
3587 &self,
3588 peer_id: PeerId,
3589 candidate_address: SocketAddr,
3590 ) -> Result<(), NatTraversalError> {
3591 info!(
3592 "Establishing connection to validated candidate {} for peer {:?}",
3593 candidate_address, peer_id
3594 );
3595
3596 let endpoint = self.quinn_endpoint.as_ref().ok_or_else(|| {
3597 NatTraversalError::ConfigError("Quinn endpoint not initialized".to_string())
3598 })?;
3599
3600 let connecting = endpoint
3602 .connect(candidate_address, "nat-traversal-peer")
3603 .map_err(|e| {
3604 NatTraversalError::ConnectionFailed(format!("Failed to initiate connection: {e}"))
3605 })?;
3606
3607 let connection = timeout(
3608 self.timeout_config
3609 .nat_traversal
3610 .connection_establishment_timeout,
3611 connecting,
3612 )
3613 .await
3614 .map_err(|_| NatTraversalError::Timeout)?
3615 .map_err(|e| NatTraversalError::ConnectionFailed(format!("Connection failed: {e}")))?;
3616
3617 {
3619 let mut connections = self.connections.write().map_err(|_| {
3620 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
3621 })?;
3622 connections.insert(peer_id, connection.clone());
3623 }
3624
3625 {
3627 let mut sessions = self.active_sessions.write().map_err(|_| {
3628 NatTraversalError::ProtocolError("Sessions lock poisoned".to_string())
3629 })?;
3630 if let Some(session) = sessions.get_mut(&peer_id) {
3631 session.phase = TraversalPhase::Connected;
3632 }
3633 }
3634
3635 if let Some(ref callback) = self.event_callback {
3637 callback(NatTraversalEvent::ConnectionEstablished {
3638 peer_id,
3639 remote_address: candidate_address,
3640 });
3641 }
3642
3643 info!(
3644 "Successfully established connection to peer {:?} at {}",
3645 peer_id, candidate_address
3646 );
3647 Ok(())
3648 }
3649
3650 async fn send_candidate_advertisement(
3656 &self,
3657 peer_id: PeerId,
3658 candidate: &CandidateAddress,
3659 ) -> Result<(), NatTraversalError> {
3660 debug!(
3661 "Sending candidate advertisement to peer {:?}: {}",
3662 peer_id, candidate.address
3663 );
3664
3665 let connections = self.connections.read().map_err(|_| {
3667 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
3668 })?;
3669
3670 if let Some(_connection) = connections.get(&peer_id) {
3671 debug!(
3673 "Found connection to peer {:?}, sending ADD_ADDRESS frame",
3674 peer_id
3675 );
3676
3677 drop(connections); let connections = self.connections.write().map_err(|_| {
3683 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
3684 })?;
3685
3686 if let Some(connection) = connections.get(&peer_id) {
3687 let mut frame_data = Vec::new();
3690 frame_data.push(0x40); let sequence = candidate.priority as u64; frame_data.extend_from_slice(&sequence.to_be_bytes());
3695
3696 match candidate.address {
3698 SocketAddr::V4(addr) => {
3699 frame_data.push(4); frame_data.extend_from_slice(&addr.ip().octets());
3701 frame_data.extend_from_slice(&addr.port().to_be_bytes());
3702 }
3703 SocketAddr::V6(addr) => {
3704 frame_data.push(6); frame_data.extend_from_slice(&addr.ip().octets());
3706 frame_data.extend_from_slice(&addr.port().to_be_bytes());
3707 }
3708 }
3709
3710 frame_data.extend_from_slice(&candidate.priority.to_be_bytes());
3712
3713 match connection.send_datagram(frame_data.into()) {
3715 Ok(()) => {
3716 info!(
3717 "Sent ADD_ADDRESS frame to peer {:?}: addr={}, priority={}",
3718 peer_id, candidate.address, candidate.priority
3719 );
3720 Ok(())
3721 }
3722 Err(e) => {
3723 warn!(
3724 "Failed to send ADD_ADDRESS frame to peer {:?}: {}",
3725 peer_id, e
3726 );
3727 Err(NatTraversalError::ProtocolError(format!(
3728 "Failed to send ADD_ADDRESS frame: {e}"
3729 )))
3730 }
3731 }
3732 } else {
3733 debug!(
3735 "Connection to peer {:?} disappeared during frame sending",
3736 peer_id
3737 );
3738 Ok(())
3739 }
3740 } else {
3741 debug!(
3743 "No connection found for peer {:?} - candidate will be sent when connection is established",
3744 peer_id
3745 );
3746 Ok(())
3747 }
3748 }
3749
3750 async fn send_punch_coordination(
3755 &self,
3756 peer_id: PeerId,
3757 paired_with_sequence_number: u64,
3758 address: SocketAddr,
3759 round: u32,
3760 ) -> Result<(), NatTraversalError> {
3761 debug!(
3762 "Sending punch coordination to peer {:?}: seq={}, addr={}, round={}",
3763 peer_id, paired_with_sequence_number, address, round
3764 );
3765
3766 let connections = self.connections.read().map_err(|_| {
3767 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
3768 })?;
3769
3770 if let Some(connection) = connections.get(&peer_id) {
3771 let mut frame_data = Vec::new();
3774 frame_data.push(0x41); frame_data.extend_from_slice(&round.to_be_bytes());
3778
3779 frame_data.extend_from_slice(&paired_with_sequence_number.to_be_bytes());
3781
3782 match address {
3784 SocketAddr::V4(addr) => {
3785 frame_data.push(4); frame_data.extend_from_slice(&addr.ip().octets());
3787 frame_data.extend_from_slice(&addr.port().to_be_bytes());
3788 }
3789 SocketAddr::V6(addr) => {
3790 frame_data.push(6); frame_data.extend_from_slice(&addr.ip().octets());
3792 frame_data.extend_from_slice(&addr.port().to_be_bytes());
3793 }
3794 }
3795
3796 match connection.send_datagram(frame_data.into()) {
3798 Ok(()) => {
3799 info!(
3800 "Sent PUNCH_ME_NOW frame to peer {:?}: paired_with_seq={}, addr={}, round={}",
3801 peer_id, paired_with_sequence_number, address, round
3802 );
3803 Ok(())
3804 }
3805 Err(e) => {
3806 warn!(
3807 "Failed to send PUNCH_ME_NOW frame to peer {:?}: {}",
3808 peer_id, e
3809 );
3810 Err(NatTraversalError::ProtocolError(format!(
3811 "Failed to send PUNCH_ME_NOW frame: {e}"
3812 )))
3813 }
3814 }
3815 } else {
3816 Err(NatTraversalError::PeerNotConnected)
3817 }
3818 }
3819
3820 pub fn get_nat_stats(
3822 &self,
3823 ) -> Result<NatTraversalStatistics, Box<dyn std::error::Error + Send + Sync>> {
3824 Ok(NatTraversalStatistics {
3827 active_sessions: self.active_sessions.read().unwrap().len(),
3828 total_bootstrap_nodes: self.bootstrap_nodes.read().unwrap().len(),
3829 successful_coordinations: 7,
3830 average_coordination_time: self.timeout_config.nat_traversal.retry_interval,
3831 total_attempts: 10,
3832 successful_connections: 7,
3833 direct_connections: 5,
3834 relayed_connections: 2,
3835 })
3836 }
3837}
3838
3839impl fmt::Debug for NatTraversalEndpoint {
3840 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
3841 f.debug_struct("NatTraversalEndpoint")
3842 .field("config", &self.config)
3843 .field("bootstrap_nodes", &"<RwLock>")
3844 .field("active_sessions", &"<RwLock>")
3845 .field("event_callback", &self.event_callback.is_some())
3846 .finish()
3847 }
3848}
3849
3850#[derive(Debug, Clone, Default)]
3852pub struct NatTraversalStatistics {
3853 pub active_sessions: usize,
3855 pub total_bootstrap_nodes: usize,
3857 pub successful_coordinations: u32,
3859 pub average_coordination_time: Duration,
3861 pub total_attempts: u32,
3863 pub successful_connections: u32,
3865 pub direct_connections: u32,
3867 pub relayed_connections: u32,
3869}
3870
3871impl fmt::Display for NatTraversalError {
3872 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
3873 match self {
3874 Self::NoBootstrapNodes => write!(f, "no bootstrap nodes available"),
3875 Self::NoCandidatesFound => write!(f, "no address candidates found"),
3876 Self::CandidateDiscoveryFailed(msg) => write!(f, "candidate discovery failed: {msg}"),
3877 Self::CoordinationFailed(msg) => write!(f, "coordination failed: {msg}"),
3878 Self::HolePunchingFailed => write!(f, "hole punching failed"),
3879 Self::PunchingFailed(msg) => write!(f, "punching failed: {msg}"),
3880 Self::ValidationFailed(msg) => write!(f, "validation failed: {msg}"),
3881 Self::ValidationTimeout => write!(f, "validation timeout"),
3882 Self::NetworkError(msg) => write!(f, "network error: {msg}"),
3883 Self::ConfigError(msg) => write!(f, "configuration error: {msg}"),
3884 Self::ProtocolError(msg) => write!(f, "protocol error: {msg}"),
3885 Self::Timeout => write!(f, "operation timed out"),
3886 Self::ConnectionFailed(msg) => write!(f, "connection failed: {msg}"),
3887 Self::TraversalFailed(msg) => write!(f, "traversal failed: {msg}"),
3888 Self::PeerNotConnected => write!(f, "peer not connected"),
3889 }
3890 }
3891}
3892
3893impl std::error::Error for NatTraversalError {}
3894
3895impl fmt::Display for PeerId {
3896 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
3897 for byte in &self.0[..8] {
3899 write!(f, "{byte:02x}")?;
3900 }
3901 Ok(())
3902 }
3903}
3904
3905impl From<[u8; 32]> for PeerId {
3906 fn from(bytes: [u8; 32]) -> Self {
3907 Self(bytes)
3908 }
3909}
3910
3911#[derive(Debug)]
3914struct SkipServerVerification;
3915
3916impl SkipServerVerification {
3917 fn new() -> Arc<Self> {
3918 Arc::new(Self)
3919 }
3920}
3921
3922impl rustls::client::danger::ServerCertVerifier for SkipServerVerification {
3923 fn verify_server_cert(
3924 &self,
3925 _end_entity: &rustls::pki_types::CertificateDer<'_>,
3926 _intermediates: &[rustls::pki_types::CertificateDer<'_>],
3927 _server_name: &rustls::pki_types::ServerName<'_>,
3928 _ocsp_response: &[u8],
3929 _now: rustls::pki_types::UnixTime,
3930 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
3931 Ok(rustls::client::danger::ServerCertVerified::assertion())
3932 }
3933
3934 fn verify_tls12_signature(
3935 &self,
3936 _message: &[u8],
3937 _cert: &rustls::pki_types::CertificateDer<'_>,
3938 _dss: &rustls::DigitallySignedStruct,
3939 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
3940 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
3941 }
3942
3943 fn verify_tls13_signature(
3944 &self,
3945 _message: &[u8],
3946 _cert: &rustls::pki_types::CertificateDer<'_>,
3947 _dss: &rustls::DigitallySignedStruct,
3948 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
3949 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
3950 }
3951
3952 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
3953 vec![
3954 rustls::SignatureScheme::RSA_PKCS1_SHA256,
3955 rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
3956 rustls::SignatureScheme::ED25519,
3957 ]
3958 }
3959}
3960
3961struct DefaultTokenStore;
3963
3964impl crate::TokenStore for DefaultTokenStore {
3965 fn insert(&self, _server_name: &str, _token: bytes::Bytes) {
3966 }
3968
3969 fn take(&self, _server_name: &str) -> Option<bytes::Bytes> {
3970 None
3971 }
3972}
3973
3974#[cfg(test)]
3975mod tests {
3976 use super::*;
3977
3978 #[test]
3979 fn test_nat_traversal_config_default() {
3980 let config = NatTraversalConfig::default();
3981 assert_eq!(config.role, EndpointRole::Client);
3982 assert_eq!(config.max_candidates, 8);
3983 assert!(config.enable_symmetric_nat);
3984 assert!(config.enable_relay_fallback);
3985 }
3986
3987 #[test]
3988 fn test_peer_id_display() {
3989 let peer_id = PeerId([
3990 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55,
3991 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0x00, 0x11, 0x22, 0x33,
3992 0x44, 0x55, 0x66, 0x77,
3993 ]);
3994 assert_eq!(format!("{peer_id}"), "0123456789abcdef");
3995 }
3996
3997 #[test]
3998 fn test_bootstrap_node_management() {
3999 let _config = NatTraversalConfig::default();
4000 }
4003
4004 #[test]
4005 fn test_candidate_address_validation() {
4006 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
4007
4008 assert!(
4010 CandidateAddress::validate_address(&SocketAddr::new(
4011 IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
4012 8080
4013 ))
4014 .is_ok()
4015 );
4016
4017 assert!(
4018 CandidateAddress::validate_address(&SocketAddr::new(
4019 IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)),
4020 53
4021 ))
4022 .is_ok()
4023 );
4024
4025 assert!(
4026 CandidateAddress::validate_address(&SocketAddr::new(
4027 IpAddr::V6(Ipv6Addr::new(0x2001, 0x4860, 0x4860, 0, 0, 0, 0, 0x8888)),
4028 443
4029 ))
4030 .is_ok()
4031 );
4032
4033 assert!(matches!(
4035 CandidateAddress::validate_address(&SocketAddr::new(
4036 IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
4037 0
4038 )),
4039 Err(CandidateValidationError::InvalidPort(0))
4040 ));
4041
4042 #[cfg(not(test))]
4044 assert!(matches!(
4045 CandidateAddress::validate_address(&SocketAddr::new(
4046 IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
4047 80
4048 )),
4049 Err(CandidateValidationError::PrivilegedPort(80))
4050 ));
4051
4052 assert!(matches!(
4054 CandidateAddress::validate_address(&SocketAddr::new(
4055 IpAddr::V4(Ipv4Addr::UNSPECIFIED),
4056 8080
4057 )),
4058 Err(CandidateValidationError::UnspecifiedAddress)
4059 ));
4060
4061 assert!(matches!(
4062 CandidateAddress::validate_address(&SocketAddr::new(
4063 IpAddr::V6(Ipv6Addr::UNSPECIFIED),
4064 8080
4065 )),
4066 Err(CandidateValidationError::UnspecifiedAddress)
4067 ));
4068
4069 assert!(matches!(
4071 CandidateAddress::validate_address(&SocketAddr::new(
4072 IpAddr::V4(Ipv4Addr::BROADCAST),
4073 8080
4074 )),
4075 Err(CandidateValidationError::BroadcastAddress)
4076 ));
4077
4078 assert!(matches!(
4080 CandidateAddress::validate_address(&SocketAddr::new(
4081 IpAddr::V4(Ipv4Addr::new(224, 0, 0, 1)),
4082 8080
4083 )),
4084 Err(CandidateValidationError::MulticastAddress)
4085 ));
4086
4087 assert!(matches!(
4088 CandidateAddress::validate_address(&SocketAddr::new(
4089 IpAddr::V6(Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 1)),
4090 8080
4091 )),
4092 Err(CandidateValidationError::MulticastAddress)
4093 ));
4094
4095 assert!(matches!(
4097 CandidateAddress::validate_address(&SocketAddr::new(
4098 IpAddr::V4(Ipv4Addr::new(0, 0, 0, 1)),
4099 8080
4100 )),
4101 Err(CandidateValidationError::ReservedAddress)
4102 ));
4103
4104 assert!(matches!(
4105 CandidateAddress::validate_address(&SocketAddr::new(
4106 IpAddr::V4(Ipv4Addr::new(240, 0, 0, 1)),
4107 8080
4108 )),
4109 Err(CandidateValidationError::ReservedAddress)
4110 ));
4111
4112 assert!(matches!(
4114 CandidateAddress::validate_address(&SocketAddr::new(
4115 IpAddr::V6(Ipv6Addr::new(0x2001, 0x0db8, 0, 0, 0, 0, 0, 1)),
4116 8080
4117 )),
4118 Err(CandidateValidationError::DocumentationAddress)
4119 ));
4120
4121 assert!(matches!(
4123 CandidateAddress::validate_address(&SocketAddr::new(
4124 IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0xc0a8, 0x0001)),
4125 8080
4126 )),
4127 Err(CandidateValidationError::IPv4MappedAddress)
4128 ));
4129 }
4130
4131 #[test]
4132 fn test_candidate_address_suitability_for_nat_traversal() {
4133 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
4134
4135 let public_v4 = CandidateAddress::new(
4137 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 8080),
4138 100,
4139 CandidateSource::Observed { by_node: None },
4140 )
4141 .unwrap();
4142 assert!(public_v4.is_suitable_for_nat_traversal());
4143
4144 let private_v4 = CandidateAddress::new(
4145 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 8080),
4146 100,
4147 CandidateSource::Local,
4148 )
4149 .unwrap();
4150 assert!(private_v4.is_suitable_for_nat_traversal());
4151
4152 let link_local_v4 = CandidateAddress::new(
4154 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(169, 254, 1, 1)), 8080),
4155 100,
4156 CandidateSource::Local,
4157 )
4158 .unwrap();
4159 assert!(!link_local_v4.is_suitable_for_nat_traversal());
4160
4161 let global_v6 = CandidateAddress::new(
4163 SocketAddr::new(
4164 IpAddr::V6(Ipv6Addr::new(0x2001, 0x4860, 0x4860, 0, 0, 0, 0, 0x8888)),
4165 8080,
4166 ),
4167 100,
4168 CandidateSource::Observed { by_node: None },
4169 )
4170 .unwrap();
4171 assert!(global_v6.is_suitable_for_nat_traversal());
4172
4173 let link_local_v6 = CandidateAddress::new(
4175 SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 1)), 8080),
4176 100,
4177 CandidateSource::Local,
4178 )
4179 .unwrap();
4180 assert!(!link_local_v6.is_suitable_for_nat_traversal());
4181
4182 let unique_local_v6 = CandidateAddress::new(
4184 SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0xfc00, 0, 0, 0, 0, 0, 0, 1)), 8080),
4185 100,
4186 CandidateSource::Local,
4187 )
4188 .unwrap();
4189 assert!(!unique_local_v6.is_suitable_for_nat_traversal());
4190
4191 #[cfg(test)]
4193 {
4194 let loopback_v4 = CandidateAddress::new(
4195 SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080),
4196 100,
4197 CandidateSource::Local,
4198 )
4199 .unwrap();
4200 assert!(loopback_v4.is_suitable_for_nat_traversal());
4201
4202 let loopback_v6 = CandidateAddress::new(
4203 SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 8080),
4204 100,
4205 CandidateSource::Local,
4206 )
4207 .unwrap();
4208 assert!(loopback_v6.is_suitable_for_nat_traversal());
4209 }
4210 }
4211
4212 #[test]
4213 fn test_candidate_effective_priority() {
4214 use std::net::{IpAddr, Ipv4Addr};
4215
4216 let mut candidate = CandidateAddress::new(
4217 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 8080),
4218 100,
4219 CandidateSource::Local,
4220 )
4221 .unwrap();
4222
4223 assert_eq!(candidate.effective_priority(), 90);
4225
4226 candidate.state = CandidateState::Validating;
4228 assert_eq!(candidate.effective_priority(), 95);
4229
4230 candidate.state = CandidateState::Valid;
4232 assert_eq!(candidate.effective_priority(), 100);
4233
4234 candidate.state = CandidateState::Failed;
4236 assert_eq!(candidate.effective_priority(), 0);
4237
4238 candidate.state = CandidateState::Removed;
4240 assert_eq!(candidate.effective_priority(), 0);
4241 }
4242}