1#![allow(missing_docs)]
8
9use std::{collections::HashMap, fmt, net::SocketAddr, sync::Arc, time::Duration};
16
17#[allow(clippy::panic)]
36fn create_random_port_bind_addr() -> SocketAddr {
37 "0.0.0.0:0"
38 .parse()
39 .unwrap_or_else(|_| panic!("Random port bind address format is always valid"))
40}
41
42use tracing::{debug, error, info, warn};
43
44use std::sync::atomic::{AtomicBool, Ordering};
45
46use tokio::{
47 net::UdpSocket,
48 sync::{mpsc, mpsc::error::TryRecvError},
49 time::{sleep, timeout},
50};
51
52use crate::high_level::default_runtime;
53
54use crate::{
55 VarInt,
56 candidate_discovery::{CandidateDiscoveryManager, DiscoveryConfig, DiscoveryEvent},
57 connection::nat_traversal::{CandidateSource, CandidateState, NatTraversalRole},
58};
59
60use crate::{
61 ClientConfig, ConnectionError, EndpointConfig, ServerConfig, TransportConfig,
62 high_level::{Connection as QuinnConnection, Endpoint as QuinnEndpoint},
63};
64
65#[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))]
66use crate::{crypto::rustls::QuicClientConfig, crypto::rustls::QuicServerConfig};
67
68use crate::config::validation::{ConfigValidator, ValidationResult};
69
70#[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))]
71use crate::crypto::raw_public_keys::RawPublicKeyConfigBuilder;
72
73pub struct NatTraversalEndpoint {
75 quinn_endpoint: Option<QuinnEndpoint>,
77 config: NatTraversalConfig,
81 bootstrap_nodes: Arc<std::sync::RwLock<Vec<BootstrapNode>>>,
83 active_sessions: Arc<std::sync::RwLock<HashMap<PeerId, NatTraversalSession>>>,
85 discovery_manager: Arc<std::sync::Mutex<CandidateDiscoveryManager>>,
87 event_callback: Option<Box<dyn Fn(NatTraversalEvent) + Send + Sync>>,
89 shutdown: Arc<AtomicBool>,
91 event_tx: Option<mpsc::UnboundedSender<NatTraversalEvent>>,
93 event_rx: std::sync::Mutex<mpsc::UnboundedReceiver<NatTraversalEvent>>,
95 connections: Arc<std::sync::RwLock<HashMap<PeerId, QuinnConnection>>>,
97 local_peer_id: PeerId,
99 timeout_config: crate::config::nat_timeouts::TimeoutConfig,
101}
102
103#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
135pub struct NatTraversalConfig {
136 pub role: EndpointRole,
138 pub bootstrap_nodes: Vec<SocketAddr>,
140 pub max_candidates: usize,
142 pub coordination_timeout: Duration,
144 pub enable_symmetric_nat: bool,
146 pub enable_relay_fallback: bool,
148 pub max_concurrent_attempts: usize,
150 pub bind_addr: Option<SocketAddr>,
167 pub prefer_rfc_nat_traversal: bool,
170 pub timeouts: crate::config::nat_timeouts::TimeoutConfig,
172}
173
174#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
176pub enum EndpointRole {
177 Client,
179 Server {
181 can_coordinate: bool,
183 },
184 Bootstrap,
186}
187
188impl EndpointRole {
189 pub fn name(&self) -> &'static str {
191 match self {
192 Self::Client => "client",
193 Self::Server { .. } => "server",
194 Self::Bootstrap => "bootstrap",
195 }
196 }
197}
198
199#[derive(
201 Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, serde::Serialize, serde::Deserialize,
202)]
203pub struct PeerId(pub [u8; 32]);
204
205#[derive(Debug, Clone)]
207pub struct BootstrapNode {
208 pub address: SocketAddr,
210 pub last_seen: std::time::Instant,
212 pub can_coordinate: bool,
214 pub rtt: Option<Duration>,
216 pub coordination_count: u32,
218}
219
220impl BootstrapNode {
221 pub fn new(address: SocketAddr) -> Self {
223 Self {
224 address,
225 last_seen: std::time::Instant::now(),
226 can_coordinate: true,
227 rtt: None,
228 coordination_count: 0,
229 }
230 }
231}
232
233#[derive(Debug, Clone)]
235pub struct CandidatePair {
236 pub local_candidate: CandidateAddress,
238 pub remote_candidate: CandidateAddress,
240 pub priority: u64,
242 pub state: CandidatePairState,
244}
245
246#[derive(Debug, Clone, Copy, PartialEq, Eq)]
248pub enum CandidatePairState {
249 Waiting,
251 InProgress,
253 Succeeded,
255 Failed,
257 Cancelled,
259}
260
261#[derive(Debug)]
263struct NatTraversalSession {
264 peer_id: PeerId,
266 #[allow(dead_code)]
268 coordinator: SocketAddr,
269 attempt: u32,
271 started_at: std::time::Instant,
273 phase: TraversalPhase,
275 candidates: Vec<CandidateAddress>,
277 session_state: SessionState,
279}
280
281#[derive(Debug, Clone)]
283pub struct SessionState {
284 pub state: ConnectionState,
286 pub last_transition: std::time::Instant,
288 pub connection: Option<QuinnConnection>,
290 pub active_attempts: Vec<(SocketAddr, std::time::Instant)>,
292 pub metrics: ConnectionMetrics,
294}
295
296#[derive(Debug, Clone, Copy, PartialEq, Eq)]
298pub enum ConnectionState {
299 Idle,
301 Connecting,
303 Connected,
305 Migrating,
307 Closed,
309}
310
311#[derive(Debug, Clone, Default)]
313pub struct ConnectionMetrics {
314 pub rtt: Option<Duration>,
316 pub loss_rate: f64,
318 pub bytes_sent: u64,
320 pub bytes_received: u64,
322 pub last_activity: Option<std::time::Instant>,
324}
325
326#[derive(Debug, Clone)]
328pub struct SessionStateUpdate {
329 pub peer_id: PeerId,
331 pub old_state: ConnectionState,
333 pub new_state: ConnectionState,
335 pub reason: StateChangeReason,
337}
338
339#[derive(Debug, Clone, Copy, PartialEq, Eq)]
341pub enum StateChangeReason {
342 Timeout,
344 ConnectionEstablished,
346 ConnectionClosed,
348 MigrationComplete,
350 MigrationFailed,
352 NetworkError,
354 UserClosed,
356}
357
358#[derive(Debug, Clone, Copy, PartialEq, Eq)]
360pub enum TraversalPhase {
361 Discovery,
363 Coordination,
365 Synchronization,
367 Punching,
369 Validation,
371 Connected,
373 Failed,
375}
376
377#[derive(Debug, Clone, Copy)]
379enum SessionUpdate {
380 Timeout,
382 Disconnected,
384 UpdateMetrics,
386 InvalidState,
388 Retry,
390 MigrationTimeout,
392 Remove,
394}
395
396#[derive(Debug, Clone)]
398pub struct CandidateAddress {
399 pub address: SocketAddr,
401 pub priority: u32,
403 pub source: CandidateSource,
405 pub state: CandidateState,
407}
408
409impl CandidateAddress {
410 pub fn new(
412 address: SocketAddr,
413 priority: u32,
414 source: CandidateSource,
415 ) -> Result<Self, CandidateValidationError> {
416 Self::validate_address(&address)?;
417 Ok(Self {
418 address,
419 priority,
420 source,
421 state: CandidateState::New,
422 })
423 }
424
425 pub fn validate_address(addr: &SocketAddr) -> Result<(), CandidateValidationError> {
427 if addr.port() == 0 {
429 return Err(CandidateValidationError::InvalidPort(0));
430 }
431
432 #[cfg(not(test))]
434 if addr.port() < 1024 {
435 return Err(CandidateValidationError::PrivilegedPort(addr.port()));
436 }
437
438 match addr.ip() {
439 std::net::IpAddr::V4(ipv4) => {
440 if ipv4.is_unspecified() {
442 return Err(CandidateValidationError::UnspecifiedAddress);
443 }
444 if ipv4.is_broadcast() {
445 return Err(CandidateValidationError::BroadcastAddress);
446 }
447 if ipv4.is_multicast() {
448 return Err(CandidateValidationError::MulticastAddress);
449 }
450 if ipv4.octets()[0] == 0 {
452 return Err(CandidateValidationError::ReservedAddress);
453 }
454 if ipv4.octets()[0] >= 240 {
456 return Err(CandidateValidationError::ReservedAddress);
457 }
458 }
459 std::net::IpAddr::V6(ipv6) => {
460 if ipv6.is_unspecified() {
462 return Err(CandidateValidationError::UnspecifiedAddress);
463 }
464 if ipv6.is_multicast() {
465 return Err(CandidateValidationError::MulticastAddress);
466 }
467 let segments = ipv6.segments();
469 if segments[0] == 0x2001 && segments[1] == 0x0db8 {
470 return Err(CandidateValidationError::DocumentationAddress);
471 }
472 if ipv6.to_ipv4_mapped().is_some() {
474 return Err(CandidateValidationError::IPv4MappedAddress);
475 }
476 }
477 }
478
479 Ok(())
480 }
481
482 pub fn is_suitable_for_nat_traversal(&self) -> bool {
484 match self.address.ip() {
485 std::net::IpAddr::V4(ipv4) => {
486 #[cfg(test)]
491 if ipv4.is_loopback() {
492 return true;
493 }
494 !ipv4.is_loopback()
495 && !ipv4.is_link_local()
496 && !ipv4.is_multicast()
497 && !ipv4.is_broadcast()
498 }
499 std::net::IpAddr::V6(ipv6) => {
500 #[cfg(test)]
506 if ipv6.is_loopback() {
507 return true;
508 }
509 let segments = ipv6.segments();
510 let is_link_local = (segments[0] & 0xffc0) == 0xfe80;
511 let is_unique_local = (segments[0] & 0xfe00) == 0xfc00;
512
513 !ipv6.is_loopback() && !is_link_local && !is_unique_local && !ipv6.is_multicast()
514 }
515 }
516 }
517
518 pub fn effective_priority(&self) -> u32 {
520 match self.state {
521 CandidateState::Valid => self.priority,
522 CandidateState::New => self.priority.saturating_sub(10),
523 CandidateState::Validating => self.priority.saturating_sub(5),
524 CandidateState::Failed => 0,
525 CandidateState::Removed => 0,
526 }
527 }
528}
529
530#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
532pub enum CandidateValidationError {
533 #[error("invalid port number: {0}")]
535 InvalidPort(u16),
536 #[error("privileged port not allowed: {0}")]
538 PrivilegedPort(u16),
539 #[error("unspecified address not allowed")]
541 UnspecifiedAddress,
542 #[error("broadcast address not allowed")]
544 BroadcastAddress,
545 #[error("multicast address not allowed")]
547 MulticastAddress,
548 #[error("reserved address not allowed")]
550 ReservedAddress,
551 #[error("documentation address not allowed")]
553 DocumentationAddress,
554 #[error("IPv4-mapped IPv6 address not allowed")]
556 IPv4MappedAddress,
557}
558
559#[derive(Debug, Clone)]
561pub enum NatTraversalEvent {
562 CandidateDiscovered {
564 peer_id: PeerId,
566 candidate: CandidateAddress,
568 },
569 CoordinationRequested {
571 peer_id: PeerId,
573 coordinator: SocketAddr,
575 },
576 CoordinationSynchronized {
578 peer_id: PeerId,
580 round_id: VarInt,
582 },
583 HolePunchingStarted {
585 peer_id: PeerId,
587 targets: Vec<SocketAddr>,
589 },
590 PathValidated {
592 peer_id: PeerId,
594 address: SocketAddr,
596 rtt: Duration,
598 },
599 CandidateValidated {
601 peer_id: PeerId,
603 candidate_address: SocketAddr,
605 },
606 TraversalSucceeded {
608 peer_id: PeerId,
610 final_address: SocketAddr,
612 total_time: Duration,
614 },
615 ConnectionEstablished {
617 peer_id: PeerId,
618 remote_address: SocketAddr,
620 },
621 TraversalFailed {
623 peer_id: PeerId,
625 error: NatTraversalError,
627 fallback_available: bool,
629 },
630 ConnectionLost {
632 peer_id: PeerId,
634 reason: String,
636 },
637 PhaseTransition {
639 peer_id: PeerId,
641 from_phase: TraversalPhase,
643 to_phase: TraversalPhase,
645 },
646 SessionStateChanged {
648 peer_id: PeerId,
650 new_state: ConnectionState,
652 },
653}
654
655#[derive(Debug, Clone)]
657pub enum NatTraversalError {
658 NoBootstrapNodes,
660 NoCandidatesFound,
662 CandidateDiscoveryFailed(String),
664 CoordinationFailed(String),
666 HolePunchingFailed,
668 PunchingFailed(String),
670 ValidationFailed(String),
672 ValidationTimeout,
674 NetworkError(String),
676 ConfigError(String),
678 ProtocolError(String),
680 Timeout,
682 ConnectionFailed(String),
684 TraversalFailed(String),
686 PeerNotConnected,
688}
689
690impl Default for NatTraversalConfig {
691 fn default() -> Self {
692 Self {
693 role: EndpointRole::Client,
694 bootstrap_nodes: Vec::new(),
695 max_candidates: 8,
696 coordination_timeout: Duration::from_secs(10),
697 enable_symmetric_nat: true,
698 enable_relay_fallback: true,
699 max_concurrent_attempts: 3,
700 bind_addr: None,
701 prefer_rfc_nat_traversal: true, timeouts: crate::config::nat_timeouts::TimeoutConfig::default(),
703 }
704 }
705}
706
707impl ConfigValidator for NatTraversalConfig {
708 fn validate(&self) -> ValidationResult<()> {
709 use crate::config::validation::*;
710
711 match self.role {
713 EndpointRole::Client => {
714 if self.bootstrap_nodes.is_empty() {
715 return Err(ConfigValidationError::InvalidRole(
716 "Client endpoints require at least one bootstrap node".to_string(),
717 ));
718 }
719 }
720 EndpointRole::Server { can_coordinate } => {
721 if can_coordinate && self.bootstrap_nodes.is_empty() {
722 return Err(ConfigValidationError::InvalidRole(
723 "Server endpoints with coordination capability require bootstrap nodes"
724 .to_string(),
725 ));
726 }
727 }
728 EndpointRole::Bootstrap => {
729 }
731 }
732
733 if !self.bootstrap_nodes.is_empty() {
735 validate_bootstrap_nodes(&self.bootstrap_nodes)?;
736 }
737
738 validate_range(self.max_candidates, 1, 256, "max_candidates")?;
740
741 validate_duration(
743 self.coordination_timeout,
744 Duration::from_millis(100),
745 Duration::from_secs(300),
746 "coordination_timeout",
747 )?;
748
749 validate_range(
751 self.max_concurrent_attempts,
752 1,
753 16,
754 "max_concurrent_attempts",
755 )?;
756
757 if self.max_concurrent_attempts > self.max_candidates {
759 return Err(ConfigValidationError::IncompatibleConfiguration(
760 "max_concurrent_attempts cannot exceed max_candidates".to_string(),
761 ));
762 }
763
764 if self.role == EndpointRole::Bootstrap && self.enable_relay_fallback {
765 return Err(ConfigValidationError::IncompatibleConfiguration(
766 "Bootstrap nodes should not enable relay fallback".to_string(),
767 ));
768 }
769
770 Ok(())
771 }
772}
773
774impl NatTraversalEndpoint {
775 pub async fn new(
777 config: NatTraversalConfig,
778 event_callback: Option<Box<dyn Fn(NatTraversalEvent) + Send + Sync>>,
779 ) -> Result<Self, NatTraversalError> {
780 Self::new_impl(config, event_callback).await
781 }
782
783 async fn new_impl(
785 config: NatTraversalConfig,
786 event_callback: Option<Box<dyn Fn(NatTraversalEvent) + Send + Sync>>,
787 ) -> Result<Self, NatTraversalError> {
788 Self::new_common(config, event_callback).await
789 }
790
791 async fn new_common(
793 config: NatTraversalConfig,
794 event_callback: Option<Box<dyn Fn(NatTraversalEvent) + Send + Sync>>,
795 ) -> Result<Self, NatTraversalError> {
796 Self::new_shared_logic(config, event_callback).await
798 }
799
800 async fn new_shared_logic(
802 config: NatTraversalConfig,
803 event_callback: Option<Box<dyn Fn(NatTraversalEvent) + Send + Sync>>,
804 ) -> Result<Self, NatTraversalError> {
805 {
808 config
809 .validate()
810 .map_err(|e| NatTraversalError::ConfigError(e.to_string()))?;
811 }
812
813 let bootstrap_nodes = Arc::new(std::sync::RwLock::new(
817 config
818 .bootstrap_nodes
819 .iter()
820 .map(|&address| BootstrapNode {
821 address,
822 last_seen: std::time::Instant::now(),
823 can_coordinate: true, rtt: None,
825 coordination_count: 0,
826 })
827 .collect(),
828 ));
829
830 let discovery_config = DiscoveryConfig {
832 total_timeout: config.coordination_timeout,
833 max_candidates: config.max_candidates,
834 enable_symmetric_prediction: config.enable_symmetric_nat,
835 bound_address: config.bind_addr, ..DiscoveryConfig::default()
837 };
838
839 let nat_traversal_role = match config.role {
840 EndpointRole::Client => NatTraversalRole::Client,
841 EndpointRole::Server { can_coordinate } => NatTraversalRole::Server {
842 can_relay: can_coordinate,
843 },
844 EndpointRole::Bootstrap => NatTraversalRole::Bootstrap,
845 };
846
847 let discovery_manager = Arc::new(std::sync::Mutex::new(CandidateDiscoveryManager::new(
848 discovery_config,
849 )));
850
851 let (quinn_endpoint, event_tx, event_rx, local_addr) =
854 Self::create_quinn_endpoint(&config, nat_traversal_role).await?;
855
856 {
858 let mut discovery = discovery_manager.lock().map_err(|_| {
859 NatTraversalError::ProtocolError("Discovery manager lock poisoned".to_string())
860 })?;
861 discovery.set_bound_address(local_addr);
862 info!(
863 "Updated discovery manager with bound address: {}",
864 local_addr
865 );
866 }
867
868 let endpoint = Self {
869 quinn_endpoint: Some(quinn_endpoint.clone()),
870 config: config.clone(),
871 bootstrap_nodes,
872 active_sessions: Arc::new(std::sync::RwLock::new(HashMap::new())),
873 discovery_manager,
874 event_callback,
875 shutdown: Arc::new(AtomicBool::new(false)),
876 event_tx: Some(event_tx.clone()),
877 event_rx: std::sync::Mutex::new(event_rx),
878 connections: Arc::new(std::sync::RwLock::new(HashMap::new())),
879 local_peer_id: Self::generate_local_peer_id(),
880 timeout_config: config.timeouts.clone(),
881 };
882
883 if matches!(
885 config.role,
886 EndpointRole::Bootstrap | EndpointRole::Server { .. }
887 ) {
888 let endpoint_clone = quinn_endpoint.clone();
889 let shutdown_clone = endpoint.shutdown.clone();
890 let event_tx_clone = event_tx.clone();
891 let connections_clone = endpoint.connections.clone();
892
893 tokio::spawn(async move {
894 Self::accept_connections(
895 endpoint_clone,
896 shutdown_clone,
897 event_tx_clone,
898 connections_clone,
899 )
900 .await;
901 });
902
903 info!("Started accepting connections for {:?} role", config.role);
904 }
905
906 let discovery_manager_clone = endpoint.discovery_manager.clone();
908 let shutdown_clone = endpoint.shutdown.clone();
909 let event_tx_clone = event_tx;
910
911 tokio::spawn(async move {
912 Self::poll_discovery(discovery_manager_clone, shutdown_clone, event_tx_clone).await;
913 });
914
915 info!("Started discovery polling task");
916
917 {
919 let mut discovery = endpoint.discovery_manager.lock().map_err(|_| {
920 NatTraversalError::ProtocolError("Discovery manager lock poisoned".to_string())
921 })?;
922
923 let local_peer_id = endpoint.local_peer_id;
925 let bootstrap_nodes = {
926 let nodes = endpoint.bootstrap_nodes.read().map_err(|_| {
927 NatTraversalError::ProtocolError("Bootstrap nodes lock poisoned".to_string())
928 })?;
929 nodes.clone()
930 };
931
932 discovery
933 .start_discovery(local_peer_id, bootstrap_nodes)
934 .map_err(|e| NatTraversalError::CandidateDiscoveryFailed(e.to_string()))?;
935
936 info!(
937 "Started local candidate discovery for peer {:?}",
938 local_peer_id
939 );
940 }
941
942 Ok(endpoint)
943 }
944
945 pub fn get_quinn_endpoint(&self) -> Option<&crate::high_level::Endpoint> {
947 self.quinn_endpoint.as_ref()
948 }
949
950 pub fn get_event_callback(&self) -> Option<&Box<dyn Fn(NatTraversalEvent) + Send + Sync>> {
952 self.event_callback.as_ref()
953 }
954
955 pub fn initiate_nat_traversal(
957 &self,
958 peer_id: PeerId,
959 coordinator: SocketAddr,
960 ) -> Result<(), NatTraversalError> {
961 info!(
962 "Starting NAT traversal to peer {:?} via coordinator {}",
963 peer_id, coordinator
964 );
965
966 let session = NatTraversalSession {
968 peer_id,
969 coordinator,
970 attempt: 1,
971 started_at: std::time::Instant::now(),
972 phase: TraversalPhase::Discovery,
973 candidates: Vec::new(),
974 session_state: SessionState {
975 state: ConnectionState::Connecting,
976 last_transition: std::time::Instant::now(),
977
978 connection: None,
979 active_attempts: Vec::new(),
980 metrics: ConnectionMetrics::default(),
981 },
982 };
983
984 {
986 let mut sessions = self
987 .active_sessions
988 .write()
989 .map_err(|_| NatTraversalError::ProtocolError("Lock poisoned".to_string()))?;
990 sessions.insert(peer_id, session);
991 }
992
993 let bootstrap_nodes_vec = {
995 let bootstrap_nodes = self
996 .bootstrap_nodes
997 .read()
998 .map_err(|_| NatTraversalError::ProtocolError("Lock poisoned".to_string()))?;
999 bootstrap_nodes.clone()
1000 };
1001
1002 {
1003 let mut discovery = self.discovery_manager.lock().map_err(|_| {
1004 NatTraversalError::ProtocolError("Discovery manager lock poisoned".to_string())
1005 })?;
1006
1007 discovery
1008 .start_discovery(peer_id, bootstrap_nodes_vec)
1009 .map_err(|e| NatTraversalError::CandidateDiscoveryFailed(e.to_string()))?;
1010 }
1011
1012 if let Some(ref callback) = self.event_callback {
1014 callback(NatTraversalEvent::CoordinationRequested {
1015 peer_id,
1016 coordinator,
1017 });
1018 }
1019
1020 Ok(())
1022 }
1023
1024 pub fn poll_sessions(&self) -> Result<Vec<SessionStateUpdate>, NatTraversalError> {
1026 let mut updates = Vec::new();
1027 let now = std::time::Instant::now();
1028
1029 let mut sessions = self
1030 .active_sessions
1031 .write()
1032 .map_err(|_| NatTraversalError::ProtocolError("Sessions lock poisoned".to_string()))?;
1033
1034 for (peer_id, session) in sessions.iter_mut() {
1035 let mut state_changed = false;
1036
1037 match session.session_state.state {
1038 ConnectionState::Connecting => {
1039 let elapsed = now.duration_since(session.session_state.last_transition);
1041 if elapsed
1042 > self
1043 .timeout_config
1044 .nat_traversal
1045 .connection_establishment_timeout
1046 {
1047 session.session_state.state = ConnectionState::Closed;
1048 session.session_state.last_transition = now;
1049 state_changed = true;
1050
1051 updates.push(SessionStateUpdate {
1052 peer_id: *peer_id,
1053 old_state: ConnectionState::Connecting,
1054 new_state: ConnectionState::Closed,
1055 reason: StateChangeReason::Timeout,
1056 });
1057 }
1058
1059 if let Some(ref _connection) = session.session_state.connection {
1062 session.session_state.state = ConnectionState::Connected;
1063 session.session_state.last_transition = now;
1064 state_changed = true;
1065
1066 updates.push(SessionStateUpdate {
1067 peer_id: *peer_id,
1068 old_state: ConnectionState::Connecting,
1069 new_state: ConnectionState::Connected,
1070 reason: StateChangeReason::ConnectionEstablished,
1071 });
1072 }
1073 }
1074 ConnectionState::Connected => {
1075 {
1078 }
1081
1082 session.session_state.metrics.last_activity = Some(now);
1084 }
1085 ConnectionState::Migrating => {
1086 let elapsed = now.duration_since(session.session_state.last_transition);
1088 if elapsed > Duration::from_secs(10) {
1089 if session.session_state.connection.is_some() {
1092 session.session_state.state = ConnectionState::Connected;
1093 state_changed = true;
1094
1095 updates.push(SessionStateUpdate {
1096 peer_id: *peer_id,
1097 old_state: ConnectionState::Migrating,
1098 new_state: ConnectionState::Connected,
1099 reason: StateChangeReason::MigrationComplete,
1100 });
1101 } else {
1102 session.session_state.state = ConnectionState::Closed;
1103 state_changed = true;
1104
1105 updates.push(SessionStateUpdate {
1106 peer_id: *peer_id,
1107 old_state: ConnectionState::Migrating,
1108 new_state: ConnectionState::Closed,
1109 reason: StateChangeReason::MigrationFailed,
1110 });
1111 }
1112
1113 session.session_state.last_transition = now;
1114 }
1115 }
1116 _ => {}
1117 }
1118
1119 if state_changed {
1121 if let Some(ref callback) = self.event_callback {
1122 callback(NatTraversalEvent::SessionStateChanged {
1123 peer_id: *peer_id,
1124 new_state: session.session_state.state,
1125 });
1126 }
1127 }
1128 }
1129
1130 Ok(updates)
1131 }
1132
1133 pub fn start_session_polling(&self, interval: Duration) -> tokio::task::JoinHandle<()> {
1135 let sessions = self.active_sessions.clone();
1136 let shutdown = self.shutdown.clone();
1137 let timeout_config = self.timeout_config.clone();
1138
1139 tokio::spawn(async move {
1140 let mut ticker = tokio::time::interval(interval);
1141
1142 loop {
1143 ticker.tick().await;
1144
1145 if shutdown.load(Ordering::Relaxed) {
1146 break;
1147 }
1148
1149 let sessions_to_update = {
1151 match sessions.read() {
1152 Ok(sessions_guard) => {
1153 sessions_guard
1154 .iter()
1155 .filter_map(|(peer_id, session)| {
1156 let now = std::time::Instant::now();
1157 let elapsed =
1158 now.duration_since(session.session_state.last_transition);
1159
1160 match session.session_state.state {
1161 ConnectionState::Connecting => {
1162 if elapsed
1164 > timeout_config
1165 .nat_traversal
1166 .connection_establishment_timeout
1167 {
1168 Some((*peer_id, SessionUpdate::Timeout))
1169 } else {
1170 None
1171 }
1172 }
1173 ConnectionState::Connected => {
1174 if let Some(ref conn) = session.session_state.connection
1176 {
1177 if conn.close_reason().is_some() {
1178 Some((*peer_id, SessionUpdate::Disconnected))
1179 } else {
1180 Some((*peer_id, SessionUpdate::UpdateMetrics))
1182 }
1183 } else {
1184 Some((*peer_id, SessionUpdate::InvalidState))
1185 }
1186 }
1187 ConnectionState::Idle => {
1188 if elapsed
1190 > timeout_config
1191 .discovery
1192 .server_reflexive_cache_ttl
1193 {
1194 Some((*peer_id, SessionUpdate::Retry))
1195 } else {
1196 None
1197 }
1198 }
1199 ConnectionState::Migrating => {
1200 if elapsed > timeout_config.nat_traversal.probe_timeout
1202 {
1203 Some((*peer_id, SessionUpdate::MigrationTimeout))
1204 } else {
1205 None
1206 }
1207 }
1208 ConnectionState::Closed => {
1209 if elapsed
1211 > timeout_config.discovery.interface_cache_ttl
1212 {
1213 Some((*peer_id, SessionUpdate::Remove))
1214 } else {
1215 None
1216 }
1217 }
1218 }
1219 })
1220 .collect::<Vec<_>>()
1221 }
1222 _ => {
1223 vec![]
1224 }
1225 }
1226 };
1227
1228 if !sessions_to_update.is_empty() {
1230 if let Ok(mut sessions_guard) = sessions.write() {
1231 for (peer_id, update) in sessions_to_update {
1232 match update {
1233 SessionUpdate::Timeout => {
1234 if let Some(session) = sessions_guard.get_mut(&peer_id) {
1235 session.session_state.state = ConnectionState::Closed;
1236 session.session_state.last_transition =
1237 std::time::Instant::now();
1238 tracing::warn!("Connection to {:?} timed out", peer_id);
1239 }
1240 }
1241 SessionUpdate::Disconnected => {
1242 if let Some(session) = sessions_guard.get_mut(&peer_id) {
1243 session.session_state.state = ConnectionState::Closed;
1244 session.session_state.last_transition =
1245 std::time::Instant::now();
1246 session.session_state.connection = None;
1247 tracing::info!("Connection to {:?} closed", peer_id);
1248 }
1249 }
1250 SessionUpdate::UpdateMetrics => {
1251 if let Some(session) = sessions_guard.get_mut(&peer_id) {
1252 if let Some(ref conn) = session.session_state.connection {
1253 let stats = conn.stats();
1255 session.session_state.metrics.rtt =
1256 Some(stats.path.rtt);
1257 session.session_state.metrics.loss_rate =
1258 stats.path.lost_packets as f64
1259 / stats.path.sent_packets.max(1) as f64;
1260 }
1261 }
1262 }
1263 SessionUpdate::InvalidState => {
1264 if let Some(session) = sessions_guard.get_mut(&peer_id) {
1265 session.session_state.state = ConnectionState::Closed;
1266 session.session_state.last_transition =
1267 std::time::Instant::now();
1268 tracing::error!("Session {:?} in invalid state", peer_id);
1269 }
1270 }
1271 SessionUpdate::Retry => {
1272 if let Some(session) = sessions_guard.get_mut(&peer_id) {
1273 session.session_state.state = ConnectionState::Connecting;
1274 session.session_state.last_transition =
1275 std::time::Instant::now();
1276 session.attempt += 1;
1277 tracing::info!(
1278 "Retrying connection to {:?} (attempt {})",
1279 peer_id,
1280 session.attempt
1281 );
1282 }
1283 }
1284 SessionUpdate::MigrationTimeout => {
1285 if let Some(session) = sessions_guard.get_mut(&peer_id) {
1286 session.session_state.state = ConnectionState::Closed;
1287 session.session_state.last_transition =
1288 std::time::Instant::now();
1289 tracing::warn!("Migration timeout for {:?}", peer_id);
1290 }
1291 }
1292 SessionUpdate::Remove => {
1293 sessions_guard.remove(&peer_id);
1294 tracing::debug!("Removed old session for {:?}", peer_id);
1295 }
1296 }
1297 }
1298 }
1299 }
1300 }
1301 })
1302 }
1303
1304 pub fn get_statistics(&self) -> Result<NatTraversalStatistics, NatTraversalError> {
1308 let sessions = self
1309 .active_sessions
1310 .read()
1311 .map_err(|_| NatTraversalError::ProtocolError("Lock poisoned".to_string()))?;
1312 let bootstrap_nodes = self
1313 .bootstrap_nodes
1314 .read()
1315 .map_err(|_| NatTraversalError::ProtocolError("Lock poisoned".to_string()))?;
1316
1317 let avg_coordination_time = {
1319 let rtts: Vec<Duration> = bootstrap_nodes.iter().filter_map(|b| b.rtt).collect();
1320
1321 if rtts.is_empty() {
1322 Duration::from_millis(500) } else {
1324 let total_millis: u64 = rtts.iter().map(|d| d.as_millis() as u64).sum();
1325 Duration::from_millis(total_millis / rtts.len() as u64 * 2) }
1327 };
1328
1329 Ok(NatTraversalStatistics {
1330 active_sessions: sessions.len(),
1331 total_bootstrap_nodes: bootstrap_nodes.len(),
1332 successful_coordinations: bootstrap_nodes.iter().map(|b| b.coordination_count).sum(),
1333 average_coordination_time: avg_coordination_time,
1334 total_attempts: 0,
1335 successful_connections: 0,
1336 direct_connections: 0,
1337 relayed_connections: 0,
1338 })
1339 }
1340
1341 pub fn add_bootstrap_node(&self, address: SocketAddr) -> Result<(), NatTraversalError> {
1343 let mut bootstrap_nodes = self
1344 .bootstrap_nodes
1345 .write()
1346 .map_err(|_| NatTraversalError::ProtocolError("Lock poisoned".to_string()))?;
1347
1348 if !bootstrap_nodes.iter().any(|b| b.address == address) {
1350 bootstrap_nodes.push(BootstrapNode {
1351 address,
1352 last_seen: std::time::Instant::now(),
1353 can_coordinate: true,
1354 rtt: None,
1355 coordination_count: 0,
1356 });
1357 info!("Added bootstrap node: {}", address);
1358 }
1359 Ok(())
1360 }
1361
1362 pub fn remove_bootstrap_node(&self, address: SocketAddr) -> Result<(), NatTraversalError> {
1364 let mut bootstrap_nodes = self
1365 .bootstrap_nodes
1366 .write()
1367 .map_err(|_| NatTraversalError::ProtocolError("Lock poisoned".to_string()))?;
1368 bootstrap_nodes.retain(|b| b.address != address);
1369 info!("Removed bootstrap node: {}", address);
1370 Ok(())
1371 }
1372
1373 async fn create_quinn_endpoint(
1377 config: &NatTraversalConfig,
1378 _nat_role: NatTraversalRole,
1379 ) -> Result<
1380 (
1381 QuinnEndpoint,
1382 mpsc::UnboundedSender<NatTraversalEvent>,
1383 mpsc::UnboundedReceiver<NatTraversalEvent>,
1384 SocketAddr,
1385 ),
1386 NatTraversalError,
1387 > {
1388 use std::sync::Arc;
1389
1390 let server_config = match config.role {
1392 EndpointRole::Bootstrap | EndpointRole::Server { .. } => {
1393 info!(
1394 "Creating server config for role: {:?} using Raw Public Keys (RFC 7250)",
1395 config.role
1396 );
1397
1398 let (server_key, _public_key) =
1400 crate::crypto::raw_public_keys::key_utils::generate_ed25519_keypair();
1401
1402 let rpk_config = RawPublicKeyConfigBuilder::new()
1404 .with_server_key(server_key)
1405 .allow_any_key() .build_rfc7250_server_config()
1407 .map_err(|e| {
1408 NatTraversalError::ConfigError(format!("RPK server config failed: {e}"))
1409 })?;
1410
1411 let server_crypto = QuicServerConfig::try_from(rpk_config.inner().as_ref().clone())
1412 .map_err(|e| NatTraversalError::ConfigError(e.to_string()))?;
1413
1414 let mut server_config = ServerConfig::with_crypto(Arc::new(server_crypto));
1415
1416 let mut transport_config = TransportConfig::default();
1418 transport_config
1419 .keep_alive_interval(Some(config.timeouts.nat_traversal.retry_interval));
1420 transport_config.max_idle_timeout(Some(crate::VarInt::from_u32(30000).into()));
1421
1422 let nat_config = match config.role {
1427 EndpointRole::Client => {
1428 crate::transport_parameters::NatTraversalConfig::ClientSupport
1429 }
1430 EndpointRole::Bootstrap | EndpointRole::Server { .. } => {
1431 crate::transport_parameters::NatTraversalConfig::ServerSupport {
1432 concurrency_limit: VarInt::from_u32(
1433 config.max_concurrent_attempts as u32,
1434 ),
1435 }
1436 }
1437 };
1438 transport_config.nat_traversal_config(Some(nat_config));
1439
1440 server_config.transport_config(Arc::new(transport_config));
1441
1442 Some(server_config)
1443 }
1444 _ => None,
1445 };
1446
1447 let client_config = {
1449 info!("Creating client config using Raw Public Keys (RFC 7250)");
1450
1451 let rpk_config = RawPublicKeyConfigBuilder::new()
1453 .allow_any_key() .build_rfc7250_client_config()
1455 .map_err(|e| {
1456 NatTraversalError::ConfigError(format!("RPK client config failed: {e}"))
1457 })?;
1458
1459 let client_crypto = QuicClientConfig::try_from(rpk_config.inner().as_ref().clone())
1460 .map_err(|e| NatTraversalError::ConfigError(e.to_string()))?;
1461
1462 let mut client_config = ClientConfig::new(Arc::new(client_crypto));
1463
1464 let mut transport_config = TransportConfig::default();
1466 transport_config.keep_alive_interval(Some(Duration::from_secs(5)));
1467 transport_config.max_idle_timeout(Some(crate::VarInt::from_u32(30000).into()));
1468
1469 let nat_config = match config.role {
1474 EndpointRole::Client => {
1475 crate::transport_parameters::NatTraversalConfig::ClientSupport
1476 }
1477 EndpointRole::Bootstrap | EndpointRole::Server { .. } => {
1478 crate::transport_parameters::NatTraversalConfig::ServerSupport {
1479 concurrency_limit: VarInt::from_u32(config.max_concurrent_attempts as u32),
1480 }
1481 }
1482 };
1483 transport_config.nat_traversal_config(Some(nat_config));
1484
1485 client_config.transport_config(Arc::new(transport_config));
1486
1487 client_config
1488 };
1489
1490 let bind_addr = config
1492 .bind_addr
1493 .unwrap_or_else(create_random_port_bind_addr);
1494 let socket = UdpSocket::bind(bind_addr).await.map_err(|e| {
1495 NatTraversalError::NetworkError(format!("Failed to bind UDP socket: {e}"))
1496 })?;
1497
1498 info!("Binding endpoint to {}", bind_addr);
1499
1500 let std_socket = socket.into_std().map_err(|e| {
1502 NatTraversalError::NetworkError(format!("Failed to convert socket: {e}"))
1503 })?;
1504
1505 let runtime = default_runtime().ok_or_else(|| {
1507 NatTraversalError::ConfigError("No compatible async runtime found".to_string())
1508 })?;
1509
1510 let mut endpoint = QuinnEndpoint::new(
1511 EndpointConfig::default(),
1512 server_config,
1513 std_socket,
1514 runtime,
1515 )
1516 .map_err(|e| {
1517 NatTraversalError::ConfigError(format!("Failed to create Quinn endpoint: {e}"))
1518 })?;
1519
1520 endpoint.set_default_client_config(client_config);
1522
1523 let local_addr = endpoint.local_addr().map_err(|e| {
1525 NatTraversalError::NetworkError(format!("Failed to get local address: {e}"))
1526 })?;
1527
1528 info!("Endpoint bound to actual address: {}", local_addr);
1529
1530 let (event_tx, event_rx) = mpsc::unbounded_channel();
1532
1533 Ok((endpoint, event_tx, event_rx, local_addr))
1534 }
1535
1536 #[allow(clippy::panic)]
1538 pub async fn start_listening(&self, bind_addr: SocketAddr) -> Result<(), NatTraversalError> {
1539 let endpoint = self.quinn_endpoint.as_ref().ok_or_else(|| {
1540 NatTraversalError::ConfigError("Quinn endpoint not initialized".to_string())
1541 })?;
1542
1543 let _socket = UdpSocket::bind(bind_addr).await.map_err(|e| {
1545 NatTraversalError::NetworkError(format!("Failed to bind to {bind_addr}: {e}"))
1546 })?;
1547
1548 info!("Started listening on {}", bind_addr);
1549
1550 let endpoint_clone = endpoint.clone();
1552 let shutdown_clone = self.shutdown.clone();
1553 let event_tx = self
1554 .event_tx
1555 .as_ref()
1556 .unwrap_or_else(|| panic!("event transmitter should be initialized"))
1557 .clone();
1558 let connections_clone = self.connections.clone();
1559
1560 tokio::spawn(async move {
1561 Self::accept_connections(endpoint_clone, shutdown_clone, event_tx, connections_clone)
1562 .await;
1563 });
1564
1565 Ok(())
1566 }
1567
1568 async fn accept_connections(
1570 endpoint: QuinnEndpoint,
1571 shutdown: Arc<AtomicBool>,
1572 event_tx: mpsc::UnboundedSender<NatTraversalEvent>,
1573 connections: Arc<std::sync::RwLock<HashMap<PeerId, QuinnConnection>>>,
1574 ) {
1575 while !shutdown.load(Ordering::Relaxed) {
1576 match endpoint.accept().await {
1577 Some(connecting) => {
1578 let event_tx = event_tx.clone();
1579 let connections = connections.clone();
1580 tokio::spawn(async move {
1581 match connecting.await {
1582 Ok(connection) => {
1583 info!("Accepted connection from {}", connection.remote_address());
1584
1585 let peer_id = Self::generate_peer_id_from_address(
1587 connection.remote_address(),
1588 );
1589
1590 if let Ok(mut conns) = connections.write() {
1592 conns.insert(peer_id, connection.clone());
1593 }
1594
1595 let _ = event_tx.send(NatTraversalEvent::ConnectionEstablished {
1596 peer_id,
1597 remote_address: connection.remote_address(),
1598 });
1599
1600 Self::handle_connection(peer_id, connection, event_tx).await;
1602 }
1603 Err(e) => {
1604 debug!("Connection failed: {}", e);
1605 }
1606 }
1607 });
1608 }
1609 None => {
1610 break;
1612 }
1613 }
1614 }
1615 }
1616
1617 async fn poll_discovery(
1619 discovery_manager: Arc<std::sync::Mutex<CandidateDiscoveryManager>>,
1620 shutdown: Arc<AtomicBool>,
1621 _event_tx: mpsc::UnboundedSender<NatTraversalEvent>,
1622 ) {
1623 use tokio::time::{Duration, interval};
1624
1625 let mut poll_interval = interval(Duration::from_millis(100));
1626
1627 while !shutdown.load(Ordering::Relaxed) {
1628 poll_interval.tick().await;
1629
1630 let events = match discovery_manager.lock() {
1632 Ok(mut discovery) => discovery.poll(std::time::Instant::now()),
1633 Err(e) => {
1634 error!("Failed to lock discovery manager: {}", e);
1635 continue;
1636 }
1637 };
1638
1639 for event in events {
1641 match event {
1642 DiscoveryEvent::DiscoveryStarted {
1643 peer_id,
1644 bootstrap_count,
1645 } => {
1646 debug!(
1647 "Discovery started for peer {:?} with {} bootstrap nodes",
1648 peer_id, bootstrap_count
1649 );
1650 }
1651 DiscoveryEvent::LocalScanningStarted => {
1652 debug!("Local interface scanning started");
1653 }
1654 DiscoveryEvent::LocalCandidateDiscovered { candidate } => {
1655 debug!("Discovered local candidate: {}", candidate.address);
1656 }
1659 DiscoveryEvent::LocalScanningCompleted {
1660 candidate_count,
1661 duration,
1662 } => {
1663 debug!(
1664 "Local interface scanning completed: {} candidates in {:?}",
1665 candidate_count, duration
1666 );
1667 }
1668 DiscoveryEvent::ServerReflexiveDiscoveryStarted { bootstrap_count } => {
1669 debug!(
1670 "Server reflexive discovery started with {} bootstrap nodes",
1671 bootstrap_count
1672 );
1673 }
1674 DiscoveryEvent::ServerReflexiveCandidateDiscovered {
1675 candidate,
1676 bootstrap_node,
1677 } => {
1678 debug!(
1679 "Discovered server-reflexive candidate {} via bootstrap {}",
1680 candidate.address, bootstrap_node
1681 );
1682 }
1684 DiscoveryEvent::BootstrapQueryFailed {
1685 bootstrap_node,
1686 error,
1687 } => {
1688 debug!("Bootstrap query failed for {}: {}", bootstrap_node, error);
1689 }
1690 DiscoveryEvent::PortAllocationDetected {
1692 port,
1693 source_address,
1694 bootstrap_node,
1695 timestamp,
1696 } => {
1697 debug!(
1698 "Port allocation detected: port {} from {} via bootstrap {:?} at {:?}",
1699 port, source_address, bootstrap_node, timestamp
1700 );
1701 }
1702 DiscoveryEvent::DiscoveryCompleted {
1703 candidate_count,
1704 total_duration,
1705 success_rate,
1706 } => {
1707 info!(
1708 "Discovery completed with {} candidates in {:?} (success rate: {:.2}%)",
1709 candidate_count,
1710 total_duration,
1711 success_rate * 100.0
1712 );
1713 }
1716 DiscoveryEvent::DiscoveryFailed {
1717 error,
1718 partial_results,
1719 } => {
1720 warn!(
1721 "Discovery failed: {} (found {} partial candidates)",
1722 error,
1723 partial_results.len()
1724 );
1725
1726 }
1731 DiscoveryEvent::PathValidationRequested {
1732 candidate_id,
1733 candidate_address,
1734 challenge_token,
1735 } => {
1736 debug!(
1737 "PATH_CHALLENGE requested for candidate {} at {} with token {:08x}",
1738 candidate_id.0, candidate_address, challenge_token
1739 );
1740 }
1743 DiscoveryEvent::PathValidationResponse {
1744 candidate_id,
1745 candidate_address,
1746 challenge_token: _,
1747 rtt,
1748 } => {
1749 debug!(
1750 "PATH_RESPONSE received for candidate {} at {} with RTT {:?}",
1751 candidate_id.0, candidate_address, rtt
1752 );
1753 }
1755 }
1756 }
1757 }
1758
1759 info!("Discovery polling task shutting down");
1760 }
1761
1762 async fn handle_connection(
1764 peer_id: PeerId,
1765 connection: QuinnConnection,
1766 event_tx: mpsc::UnboundedSender<NatTraversalEvent>,
1767 ) {
1768 let remote_address = connection.remote_address();
1769 let closed = connection.closed();
1770 tokio::pin!(closed);
1771
1772 debug!(
1773 "Handling connection from peer {:?} at {}",
1774 peer_id, remote_address
1775 );
1776
1777 closed.await;
1781
1782 let reason = connection
1783 .close_reason()
1784 .map(|reason| format!("Connection closed: {reason}"))
1785 .unwrap_or_else(|| "Connection closed".to_string());
1786 let _ = event_tx.send(NatTraversalEvent::ConnectionLost { peer_id, reason });
1787 }
1788
1789 async fn handle_bi_stream(
1791 _send: crate::high_level::SendStream,
1792 _recv: crate::high_level::RecvStream,
1793 ) {
1794 }
1823
1824 async fn handle_uni_stream(mut recv: crate::high_level::RecvStream) {
1826 let mut buffer = vec![0u8; 1024];
1827
1828 loop {
1829 match recv.read(&mut buffer).await {
1830 Ok(Some(size)) => {
1831 debug!("Received {} bytes on unidirectional stream", size);
1832 }
1834 Ok(None) => {
1835 debug!("Unidirectional stream closed by peer");
1836 break;
1837 }
1838 Err(e) => {
1839 debug!("Error reading from unidirectional stream: {}", e);
1840 break;
1841 }
1842 }
1843 }
1844 }
1845
1846 pub async fn connect_to_peer(
1848 &self,
1849 peer_id: PeerId,
1850 server_name: &str,
1851 remote_addr: SocketAddr,
1852 ) -> Result<QuinnConnection, NatTraversalError> {
1853 let endpoint = self.quinn_endpoint.as_ref().ok_or_else(|| {
1854 NatTraversalError::ConfigError("Quinn endpoint not initialized".to_string())
1855 })?;
1856
1857 info!("Connecting to peer {:?} at {}", peer_id, remote_addr);
1858
1859 let connecting = endpoint.connect(remote_addr, server_name).map_err(|e| {
1861 NatTraversalError::ConnectionFailed(format!("Failed to initiate connection: {e}"))
1862 })?;
1863
1864 let connection = timeout(
1865 self.timeout_config
1866 .nat_traversal
1867 .connection_establishment_timeout,
1868 connecting,
1869 )
1870 .await
1871 .map_err(|_| NatTraversalError::Timeout)?
1872 .map_err(|e| NatTraversalError::ConnectionFailed(format!("Connection failed: {e}")))?;
1873
1874 info!(
1875 "Successfully connected to peer {:?} at {}",
1876 peer_id, remote_addr
1877 );
1878
1879 if let Some(ref event_tx) = self.event_tx {
1881 let _ = event_tx.send(NatTraversalEvent::ConnectionEstablished {
1882 peer_id,
1883 remote_address: remote_addr,
1884 });
1885 }
1886
1887 Ok(connection)
1888 }
1889
1890 pub async fn accept_connection(&self) -> Result<(PeerId, QuinnConnection), NatTraversalError> {
1892 info!("Waiting for incoming connection via event channel...");
1893
1894 let timeout_duration = self
1895 .timeout_config
1896 .nat_traversal
1897 .connection_establishment_timeout;
1898 let start = std::time::Instant::now();
1899
1900 loop {
1901 if self.shutdown.load(Ordering::Relaxed) {
1903 return Err(NatTraversalError::NetworkError(
1904 "Endpoint shutting down".to_string(),
1905 ));
1906 }
1907
1908 if start.elapsed() > timeout_duration {
1910 warn!("accept_connection() timed out after {:?}", timeout_duration);
1911 return Err(NatTraversalError::Timeout);
1912 }
1913
1914 {
1916 let mut event_rx = self.event_rx.lock().map_err(|_| {
1917 NatTraversalError::ProtocolError("Event channel lock poisoned".to_string())
1918 })?;
1919
1920 match event_rx.try_recv() {
1921 Ok(NatTraversalEvent::ConnectionEstablished {
1922 peer_id,
1923 remote_address,
1924 }) => {
1925 info!(
1926 "Received ConnectionEstablished event for peer {:?} at {}",
1927 peer_id, remote_address
1928 );
1929
1930 let connection = {
1933 let connections = self.connections.read().map_err(|_| {
1934 NatTraversalError::ProtocolError(
1935 "Connections lock poisoned".to_string(),
1936 )
1937 })?;
1938 connections.get(&peer_id).cloned().ok_or_else(|| {
1939 NatTraversalError::ConnectionFailed(format!(
1940 "Connection for peer {:?} not found in storage",
1941 peer_id
1942 ))
1943 })?
1944 };
1945
1946 info!(
1947 "Retrieved accepted connection from peer {:?} at {}",
1948 peer_id, remote_address
1949 );
1950 return Ok((peer_id, connection));
1951 }
1952 Ok(event) => {
1953 debug!(
1955 "Ignoring non-connection event while waiting for accept: {:?}",
1956 event
1957 );
1958 }
1959 Err(mpsc::error::TryRecvError::Empty) => {
1960 }
1962 Err(mpsc::error::TryRecvError::Disconnected) => {
1963 return Err(NatTraversalError::NetworkError(
1964 "Event channel closed".to_string(),
1965 ));
1966 }
1967 }
1968 } tokio::time::sleep(Duration::from_millis(10)).await;
1972 }
1973 }
1974
1975 pub fn local_peer_id(&self) -> PeerId {
1977 self.local_peer_id
1978 }
1979
1980 pub fn get_connection(
1982 &self,
1983 peer_id: &PeerId,
1984 ) -> Result<Option<QuinnConnection>, NatTraversalError> {
1985 let connections = self.connections.read().map_err(|_| {
1986 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
1987 })?;
1988 Ok(connections.get(peer_id).cloned())
1989 }
1990
1991 pub fn add_connection(
1993 &self,
1994 peer_id: PeerId,
1995 connection: QuinnConnection,
1996 ) -> Result<(), NatTraversalError> {
1997 let mut connections = self.connections.write().map_err(|_| {
1998 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
1999 })?;
2000 connections.insert(peer_id, connection);
2001 Ok(())
2002 }
2003
2004 pub fn spawn_connection_handler(
2006 &self,
2007 peer_id: PeerId,
2008 connection: QuinnConnection,
2009 ) -> Result<(), NatTraversalError> {
2010 let event_tx = self.event_tx.as_ref().cloned().ok_or_else(|| {
2011 NatTraversalError::ConfigError("NAT traversal event channel not configured".to_string())
2012 })?;
2013
2014 let remote_address = connection.remote_address();
2015
2016 let _ = event_tx.send(NatTraversalEvent::ConnectionEstablished {
2018 peer_id,
2019 remote_address,
2020 });
2021
2022 tokio::spawn(async move {
2024 Self::handle_connection(peer_id, connection, event_tx).await;
2025 });
2026
2027 Ok(())
2028 }
2029
2030 pub fn remove_connection(
2032 &self,
2033 peer_id: &PeerId,
2034 ) -> Result<Option<QuinnConnection>, NatTraversalError> {
2035 let mut connections = self.connections.write().map_err(|_| {
2036 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
2037 })?;
2038 Ok(connections.remove(peer_id))
2039 }
2040
2041 pub fn list_connections(&self) -> Result<Vec<(PeerId, SocketAddr)>, NatTraversalError> {
2043 let connections = self.connections.read().map_err(|_| {
2044 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
2045 })?;
2046 let mut result = Vec::new();
2047 for (peer_id, connection) in connections.iter() {
2048 result.push((*peer_id, connection.remote_address()));
2049 }
2050 Ok(result)
2051 }
2052
2053 pub async fn handle_connection_data(
2055 &self,
2056 peer_id: PeerId,
2057 connection: &QuinnConnection,
2058 ) -> Result<(), NatTraversalError> {
2059 info!("Handling connection data from peer {:?}", peer_id);
2060
2061 let connection_clone = connection.clone();
2063 let peer_id_clone = peer_id;
2064 tokio::spawn(async move {
2065 loop {
2066 match connection_clone.accept_bi().await {
2067 Ok((send, recv)) => {
2068 debug!(
2069 "Accepted bidirectional stream from peer {:?}",
2070 peer_id_clone
2071 );
2072 tokio::spawn(Self::handle_bi_stream(send, recv));
2073 }
2074 Err(ConnectionError::ApplicationClosed(_)) => {
2075 debug!("Connection closed by peer {:?}", peer_id_clone);
2076 break;
2077 }
2078 Err(e) => {
2079 debug!(
2080 "Error accepting bidirectional stream from peer {:?}: {}",
2081 peer_id_clone, e
2082 );
2083 break;
2084 }
2085 }
2086 }
2087 });
2088
2089 let connection_clone = connection.clone();
2091 let peer_id_clone = peer_id;
2092 tokio::spawn(async move {
2093 loop {
2094 match connection_clone.accept_uni().await {
2095 Ok(recv) => {
2096 debug!(
2097 "Accepted unidirectional stream from peer {:?}",
2098 peer_id_clone
2099 );
2100 tokio::spawn(Self::handle_uni_stream(recv));
2101 }
2102 Err(ConnectionError::ApplicationClosed(_)) => {
2103 debug!("Connection closed by peer {:?}", peer_id_clone);
2104 break;
2105 }
2106 Err(e) => {
2107 debug!(
2108 "Error accepting unidirectional stream from peer {:?}: {}",
2109 peer_id_clone, e
2110 );
2111 break;
2112 }
2113 }
2114 }
2115 });
2116
2117 Ok(())
2118 }
2119
2120 fn generate_local_peer_id() -> PeerId {
2122 use std::collections::hash_map::DefaultHasher;
2123 use std::hash::{Hash, Hasher};
2124 use std::time::SystemTime;
2125
2126 let mut hasher = DefaultHasher::new();
2127 SystemTime::now().hash(&mut hasher);
2128 std::process::id().hash(&mut hasher);
2129
2130 let hash = hasher.finish();
2131 let mut peer_id = [0u8; 32];
2132 peer_id[0..8].copy_from_slice(&hash.to_be_bytes());
2133
2134 for i in 8..32 {
2136 peer_id[i] = rand::random();
2137 }
2138
2139 PeerId(peer_id)
2140 }
2141
2142 fn generate_peer_id_from_address(addr: SocketAddr) -> PeerId {
2148 use std::collections::hash_map::DefaultHasher;
2149 use std::hash::{Hash, Hasher};
2150
2151 let mut hasher = DefaultHasher::new();
2152 addr.hash(&mut hasher);
2153
2154 let hash = hasher.finish();
2155 let mut peer_id = [0u8; 32];
2156 peer_id[0..8].copy_from_slice(&hash.to_be_bytes());
2157
2158 for i in 8..32 {
2161 peer_id[i] = rand::random();
2162 }
2163
2164 warn!(
2165 "Generated temporary peer ID from address {}. This ID is not persistent!",
2166 addr
2167 );
2168 PeerId(peer_id)
2169 }
2170
2171 pub async fn extract_peer_id_from_connection(
2173 &self,
2174 connection: &QuinnConnection,
2175 ) -> Option<PeerId> {
2176 if let Some(identity) = connection.peer_identity() {
2178 if let Some(public_key_bytes) = identity.downcast_ref::<[u8; 32]>() {
2180 match crate::derive_peer_id_from_key_bytes(public_key_bytes) {
2182 Ok(peer_id) => {
2183 debug!("Derived peer ID from Ed25519 public key");
2184 return Some(peer_id);
2185 }
2186 Err(e) => {
2187 warn!("Failed to derive peer ID from public key: {}", e);
2188 }
2189 }
2190 }
2191 }
2193
2194 None
2195 }
2196
2197 pub async fn shutdown(&self) -> Result<(), NatTraversalError> {
2199 self.shutdown.store(true, Ordering::Relaxed);
2201
2202 {
2204 let mut connections = self.connections.write().map_err(|_| {
2205 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
2206 })?;
2207 for (peer_id, connection) in connections.drain() {
2208 info!("Closing connection to peer {:?}", peer_id);
2209 connection.close(crate::VarInt::from_u32(0), b"Shutdown");
2210 }
2211 }
2212
2213 if let Some(ref endpoint) = self.quinn_endpoint {
2215 endpoint.wait_idle().await;
2216 }
2217
2218 info!("NAT traversal endpoint shutdown completed");
2219 Ok(())
2220 }
2221
2222 pub async fn discover_candidates(
2224 &self,
2225 peer_id: PeerId,
2226 ) -> Result<Vec<CandidateAddress>, NatTraversalError> {
2227 debug!("Discovering address candidates for peer {:?}", peer_id);
2228
2229 let mut candidates = Vec::new();
2230
2231 let bootstrap_nodes = {
2233 let nodes = self
2234 .bootstrap_nodes
2235 .read()
2236 .map_err(|_| NatTraversalError::ProtocolError("Lock poisoned".to_string()))?;
2237 nodes.clone()
2238 };
2239
2240 {
2242 let mut discovery = self.discovery_manager.lock().map_err(|_| {
2243 NatTraversalError::ProtocolError("Discovery manager lock poisoned".to_string())
2244 })?;
2245
2246 discovery
2247 .start_discovery(peer_id, bootstrap_nodes)
2248 .map_err(|e| NatTraversalError::CandidateDiscoveryFailed(e.to_string()))?;
2249 }
2250
2251 let timeout_duration = self.config.coordination_timeout;
2253 let start_time = std::time::Instant::now();
2254
2255 while start_time.elapsed() < timeout_duration {
2256 let discovery_events = {
2257 let mut discovery = self.discovery_manager.lock().map_err(|_| {
2258 NatTraversalError::ProtocolError("Discovery manager lock poisoned".to_string())
2259 })?;
2260 discovery.poll(std::time::Instant::now())
2261 };
2262
2263 for event in discovery_events {
2264 match event {
2265 DiscoveryEvent::LocalCandidateDiscovered { candidate } => {
2266 candidates.push(candidate.clone());
2267
2268 self.send_candidate_advertisement(peer_id, &candidate)
2270 .await
2271 .unwrap_or_else(|e| {
2272 debug!("Failed to send candidate advertisement: {}", e)
2273 });
2274 }
2275 DiscoveryEvent::ServerReflexiveCandidateDiscovered { candidate, .. } => {
2276 candidates.push(candidate.clone());
2277
2278 self.send_candidate_advertisement(peer_id, &candidate)
2280 .await
2281 .unwrap_or_else(|e| {
2282 debug!("Failed to send candidate advertisement: {}", e)
2283 });
2284 }
2285 DiscoveryEvent::DiscoveryCompleted { .. } => {
2287 return Ok(candidates);
2289 }
2290 DiscoveryEvent::DiscoveryFailed {
2291 error,
2292 partial_results,
2293 } => {
2294 candidates.extend(partial_results);
2296 if candidates.is_empty() {
2297 return Err(NatTraversalError::CandidateDiscoveryFailed(
2298 error.to_string(),
2299 ));
2300 }
2301 return Ok(candidates);
2302 }
2303 _ => {}
2304 }
2305 }
2306
2307 sleep(Duration::from_millis(10)).await;
2309 }
2310
2311 if candidates.is_empty() {
2312 Err(NatTraversalError::NoCandidatesFound)
2313 } else {
2314 Ok(candidates)
2315 }
2316 }
2317
2318 #[allow(dead_code)]
2320 fn create_punch_me_now_frame(&self, peer_id: PeerId) -> Result<Vec<u8>, NatTraversalError> {
2321 let mut frame = Vec::new();
2329
2330 frame.push(0x41);
2332
2333 frame.extend_from_slice(&peer_id.0);
2335
2336 let timestamp = std::time::SystemTime::now()
2338 .duration_since(std::time::UNIX_EPOCH)
2339 .unwrap_or_default()
2340 .as_millis() as u64;
2341 frame.extend_from_slice(×tamp.to_be_bytes());
2342
2343 let mut token = [0u8; 16];
2345 for byte in &mut token {
2346 *byte = rand::random();
2347 }
2348 frame.extend_from_slice(&token);
2349
2350 Ok(frame)
2351 }
2352
2353 #[allow(dead_code)]
2354 fn attempt_hole_punching(&self, peer_id: PeerId) -> Result<(), NatTraversalError> {
2355 debug!("Attempting hole punching for peer {:?}", peer_id);
2356
2357 let candidate_pairs = self.get_candidate_pairs_for_peer(peer_id)?;
2359
2360 if candidate_pairs.is_empty() {
2361 return Err(NatTraversalError::NoCandidatesFound);
2362 }
2363
2364 info!(
2365 "Generated {} candidate pairs for hole punching with peer {:?}",
2366 candidate_pairs.len(),
2367 peer_id
2368 );
2369
2370 self.attempt_quinn_hole_punching(peer_id, candidate_pairs)
2373 }
2374
2375 #[allow(dead_code)]
2377 fn get_candidate_pairs_for_peer(
2378 &self,
2379 peer_id: PeerId,
2380 ) -> Result<Vec<CandidatePair>, NatTraversalError> {
2381 let discovery_candidates = {
2383 let discovery = self.discovery_manager.lock().map_err(|_| {
2384 NatTraversalError::ProtocolError("Discovery manager lock poisoned".to_string())
2385 })?;
2386
2387 discovery.get_candidates_for_peer(peer_id)
2388 };
2389
2390 if discovery_candidates.is_empty() {
2391 return Err(NatTraversalError::NoCandidatesFound);
2392 }
2393
2394 let mut candidate_pairs = Vec::new();
2396 let local_candidates = discovery_candidates
2397 .iter()
2398 .filter(|c| matches!(c.source, CandidateSource::Local))
2399 .collect::<Vec<_>>();
2400 let remote_candidates = discovery_candidates
2401 .iter()
2402 .filter(|c| !matches!(c.source, CandidateSource::Local))
2403 .collect::<Vec<_>>();
2404
2405 for local in &local_candidates {
2407 for remote in &remote_candidates {
2408 let pair_priority = self.calculate_candidate_pair_priority(local, remote);
2409 candidate_pairs.push(CandidatePair {
2410 local_candidate: (*local).clone(),
2411 remote_candidate: (*remote).clone(),
2412 priority: pair_priority,
2413 state: CandidatePairState::Waiting,
2414 });
2415 }
2416 }
2417
2418 candidate_pairs.sort_by(|a, b| b.priority.cmp(&a.priority));
2420
2421 candidate_pairs.truncate(8);
2423
2424 Ok(candidate_pairs)
2425 }
2426
2427 #[allow(dead_code)]
2429 fn calculate_candidate_pair_priority(
2430 &self,
2431 local: &CandidateAddress,
2432 remote: &CandidateAddress,
2433 ) -> u64 {
2434 let local_type_preference = match local.source {
2438 CandidateSource::Local => 126,
2439 CandidateSource::Observed { .. } => 100,
2440 CandidateSource::Predicted => 75,
2441 CandidateSource::Peer => 50,
2442 };
2443
2444 let remote_type_preference = match remote.source {
2445 CandidateSource::Local => 126,
2446 CandidateSource::Observed { .. } => 100,
2447 CandidateSource::Predicted => 75,
2448 CandidateSource::Peer => 50,
2449 };
2450
2451 let local_priority = (local_type_preference as u64) << 8 | local.priority as u64;
2453 let remote_priority = (remote_type_preference as u64) << 8 | remote.priority as u64;
2454
2455 let min_priority = local_priority.min(remote_priority);
2456 let max_priority = local_priority.max(remote_priority);
2457
2458 (min_priority << 32)
2459 | (max_priority << 1)
2460 | if local_priority > remote_priority {
2461 1
2462 } else {
2463 0
2464 }
2465 }
2466
2467 #[allow(dead_code)]
2469 fn attempt_quinn_hole_punching(
2470 &self,
2471 peer_id: PeerId,
2472 candidate_pairs: Vec<CandidatePair>,
2473 ) -> Result<(), NatTraversalError> {
2474 let _endpoint = self.quinn_endpoint.as_ref().ok_or_else(|| {
2475 NatTraversalError::ConfigError("Quinn endpoint not initialized".to_string())
2476 })?;
2477
2478 for pair in candidate_pairs {
2479 debug!(
2480 "Attempting hole punch with candidate pair: {} -> {}",
2481 pair.local_candidate.address, pair.remote_candidate.address
2482 );
2483
2484 let mut challenge_data = [0u8; 8];
2486 for byte in &mut challenge_data {
2487 *byte = rand::random();
2488 }
2489
2490 let local_socket =
2492 std::net::UdpSocket::bind(pair.local_candidate.address).map_err(|e| {
2493 NatTraversalError::NetworkError(format!(
2494 "Failed to bind to local candidate: {e}"
2495 ))
2496 })?;
2497
2498 let path_challenge_packet = self.create_path_challenge_packet(challenge_data)?;
2500
2501 match local_socket.send_to(&path_challenge_packet, pair.remote_candidate.address) {
2503 Ok(bytes_sent) => {
2504 debug!(
2505 "Sent {} bytes for hole punch from {} to {}",
2506 bytes_sent, pair.local_candidate.address, pair.remote_candidate.address
2507 );
2508
2509 local_socket
2511 .set_read_timeout(Some(Duration::from_millis(100)))
2512 .map_err(|e| {
2513 NatTraversalError::NetworkError(format!("Failed to set timeout: {e}"))
2514 })?;
2515
2516 let mut response_buffer = [0u8; 1024];
2518 match local_socket.recv_from(&mut response_buffer) {
2519 Ok((_bytes_received, response_addr)) => {
2520 if response_addr == pair.remote_candidate.address {
2521 info!(
2522 "Hole punch succeeded for peer {:?}: {} <-> {}",
2523 peer_id,
2524 pair.local_candidate.address,
2525 pair.remote_candidate.address
2526 );
2527
2528 self.store_successful_candidate_pair(peer_id, pair)?;
2530 return Ok(());
2531 } else {
2532 debug!(
2533 "Received response from unexpected address: {}",
2534 response_addr
2535 );
2536 }
2537 }
2538 Err(e)
2539 if e.kind() == std::io::ErrorKind::WouldBlock
2540 || e.kind() == std::io::ErrorKind::TimedOut =>
2541 {
2542 debug!("No response received for hole punch attempt");
2543 }
2544 Err(e) => {
2545 debug!("Error receiving hole punch response: {}", e);
2546 }
2547 }
2548 }
2549 Err(e) => {
2550 debug!("Failed to send hole punch packet: {}", e);
2551 }
2552 }
2553 }
2554
2555 Err(NatTraversalError::HolePunchingFailed)
2557 }
2558
2559 fn create_path_challenge_packet(
2561 &self,
2562 challenge_data: [u8; 8],
2563 ) -> Result<Vec<u8>, NatTraversalError> {
2564 let mut packet = Vec::new();
2567
2568 packet.push(0x40); packet.extend_from_slice(&[0, 0, 0, 1]); packet.push(0x1a); packet.extend_from_slice(&challenge_data); Ok(packet)
2577 }
2578
2579 fn store_successful_candidate_pair(
2581 &self,
2582 peer_id: PeerId,
2583 pair: CandidatePair,
2584 ) -> Result<(), NatTraversalError> {
2585 debug!(
2586 "Storing successful candidate pair for peer {:?}: {} <-> {}",
2587 peer_id, pair.local_candidate.address, pair.remote_candidate.address
2588 );
2589
2590 if let Some(ref callback) = self.event_callback {
2595 callback(NatTraversalEvent::PathValidated {
2596 peer_id,
2597 address: pair.remote_candidate.address,
2598 rtt: Duration::from_millis(50), });
2600
2601 callback(NatTraversalEvent::TraversalSucceeded {
2602 peer_id,
2603 final_address: pair.remote_candidate.address,
2604 total_time: Duration::from_secs(1), });
2606 }
2607
2608 Ok(())
2609 }
2610
2611 fn attempt_connection_to_candidate(
2613 &self,
2614 peer_id: PeerId,
2615 candidate: &CandidateAddress,
2616 ) -> Result<(), NatTraversalError> {
2617 {
2618 let endpoint = self.quinn_endpoint.as_ref().ok_or_else(|| {
2619 NatTraversalError::ConfigError("Quinn endpoint not initialized".to_string())
2620 })?;
2621
2622 let server_name = format!("peer-{:x}", peer_id.0[0] as u32);
2624
2625 debug!(
2626 "Attempting Quinn connection to candidate {} for peer {:?}",
2627 candidate.address, peer_id
2628 );
2629
2630 match endpoint.connect(candidate.address, &server_name) {
2632 Ok(connecting) => {
2633 info!(
2634 "Connection attempt initiated to {} for peer {:?}",
2635 candidate.address, peer_id
2636 );
2637
2638 if let Some(event_tx) = &self.event_tx {
2640 let event_tx = event_tx.clone();
2641 let connections = self.connections.clone();
2642 let peer_id_clone = peer_id;
2643 let address = candidate.address;
2644
2645 tokio::spawn(async move {
2646 match connecting.await {
2647 Ok(connection) => {
2648 info!(
2649 "Successfully connected to {} for peer {:?}",
2650 address, peer_id_clone
2651 );
2652
2653 if let Ok(mut conns) = connections.write() {
2655 conns.insert(peer_id_clone, connection.clone());
2656 }
2657
2658 let _ =
2660 event_tx.send(NatTraversalEvent::ConnectionEstablished {
2661 peer_id: peer_id_clone,
2662 remote_address: address,
2663 });
2664
2665 Self::handle_connection(peer_id_clone, connection, event_tx)
2667 .await;
2668 }
2669 Err(e) => {
2670 warn!("Connection to {} failed: {}", address, e);
2671 }
2672 }
2673 });
2674 }
2675
2676 Ok(())
2677 }
2678 Err(e) => {
2679 warn!(
2680 "Failed to initiate connection to {}: {}",
2681 candidate.address, e
2682 );
2683 Err(NatTraversalError::ConnectionFailed(format!(
2684 "Failed to connect to {}: {}",
2685 candidate.address, e
2686 )))
2687 }
2688 }
2689 }
2690 }
2691
2692 pub fn poll(
2694 &self,
2695 now: std::time::Instant,
2696 ) -> Result<Vec<NatTraversalEvent>, NatTraversalError> {
2697 let mut events = Vec::new();
2698
2699 {
2701 let mut event_rx = self.event_rx.lock().map_err(|_| {
2702 NatTraversalError::ProtocolError("Event channel lock poisoned".to_string())
2703 })?;
2704
2705 loop {
2706 match event_rx.try_recv() {
2707 Ok(event) => {
2708 if let Some(ref callback) = self.event_callback {
2709 callback(event.clone());
2710 }
2711 events.push(event);
2712 }
2713 Err(TryRecvError::Empty) => break,
2714 Err(TryRecvError::Disconnected) => break,
2715 }
2716 }
2717 }
2718
2719 let mut closed_connections = Vec::new();
2721 {
2722 let connections = self.connections.read().map_err(|_| {
2723 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
2724 })?;
2725
2726 for (peer_id, connection) in connections.iter() {
2727 if let Some(reason) = connection.close_reason() {
2728 closed_connections.push((*peer_id, reason.clone()));
2729 }
2730 }
2731 }
2732
2733 if !closed_connections.is_empty() {
2734 let mut connections = self.connections.write().map_err(|_| {
2735 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
2736 })?;
2737
2738 for (peer_id, reason) in closed_connections {
2739 connections.remove(&peer_id);
2740 let event = NatTraversalEvent::ConnectionLost {
2741 peer_id,
2742 reason: reason.to_string(),
2743 };
2744 if let Some(ref callback) = self.event_callback {
2745 callback(event.clone());
2746 }
2747 events.push(event);
2748 }
2749 }
2750
2751 self.check_connections_for_observed_addresses(&mut events)?;
2753
2754 {
2756 let mut discovery = self.discovery_manager.lock().map_err(|_| {
2757 NatTraversalError::ProtocolError("Discovery manager lock poisoned".to_string())
2758 })?;
2759
2760 let discovery_events = discovery.poll(now);
2761
2762 for discovery_event in discovery_events {
2764 if let Some(nat_event) = self.convert_discovery_event(discovery_event) {
2765 events.push(nat_event.clone());
2766
2767 if let Some(ref callback) = self.event_callback {
2769 callback(nat_event.clone());
2770 }
2771
2772 if let NatTraversalEvent::CandidateDiscovered {
2774 peer_id: _,
2775 candidate: _,
2776 } = &nat_event
2777 {
2778 }
2781 }
2782 }
2783 }
2784
2785 let mut sessions = self
2787 .active_sessions
2788 .write()
2789 .map_err(|_| NatTraversalError::ProtocolError("Lock poisoned".to_string()))?;
2790
2791 for (_peer_id, session) in sessions.iter_mut() {
2792 let elapsed = now.duration_since(session.started_at);
2793
2794 let timeout = self.get_phase_timeout(session.phase);
2796
2797 if elapsed > timeout {
2799 match session.phase {
2800 TraversalPhase::Discovery => {
2801 let discovered_candidates = {
2803 let discovery = self.discovery_manager.lock().map_err(|_| {
2804 NatTraversalError::ProtocolError(
2805 "Discovery manager lock poisoned".to_string(),
2806 )
2807 });
2808 match discovery {
2809 Ok(disc) => disc.get_candidates_for_peer(session.peer_id),
2810 Err(_) => Vec::new(),
2811 }
2812 };
2813
2814 session.candidates = discovered_candidates.clone();
2816
2817 if !session.candidates.is_empty() {
2819 session.phase = TraversalPhase::Coordination;
2821 let event = NatTraversalEvent::PhaseTransition {
2822 peer_id: session.peer_id,
2823 from_phase: TraversalPhase::Discovery,
2824 to_phase: TraversalPhase::Coordination,
2825 };
2826 events.push(event.clone());
2827 if let Some(ref callback) = self.event_callback {
2828 callback(event);
2829 }
2830 info!(
2831 "Peer {:?} advanced from Discovery to Coordination with {} candidates",
2832 session.peer_id,
2833 session.candidates.len()
2834 );
2835 } else if session.attempt < self.config.max_concurrent_attempts as u32 {
2836 session.attempt += 1;
2838 session.started_at = now;
2839 let backoff_duration = self.calculate_backoff(session.attempt);
2840 warn!(
2841 "Discovery timeout for peer {:?}, retrying (attempt {}), backoff: {:?}",
2842 session.peer_id, session.attempt, backoff_duration
2843 );
2844 } else {
2845 session.phase = TraversalPhase::Failed;
2847 let event = NatTraversalEvent::TraversalFailed {
2848 peer_id: session.peer_id,
2849 error: NatTraversalError::NoCandidatesFound,
2850 fallback_available: self.config.enable_relay_fallback,
2851 };
2852 events.push(event.clone());
2853 if let Some(ref callback) = self.event_callback {
2854 callback(event);
2855 }
2856 error!(
2857 "NAT traversal failed for peer {:?}: no candidates found after {} attempts",
2858 session.peer_id, session.attempt
2859 );
2860 }
2861 }
2862 TraversalPhase::Coordination => {
2863 if let Some(coordinator) = self.select_coordinator() {
2865 match self.send_coordination_request(session.peer_id, coordinator) {
2866 Ok(_) => {
2867 session.phase = TraversalPhase::Synchronization;
2868 let event = NatTraversalEvent::CoordinationRequested {
2869 peer_id: session.peer_id,
2870 coordinator,
2871 };
2872 events.push(event.clone());
2873 if let Some(ref callback) = self.event_callback {
2874 callback(event);
2875 }
2876 info!(
2877 "Coordination requested for peer {:?} via {}",
2878 session.peer_id, coordinator
2879 );
2880 }
2881 Err(e) => {
2882 self.handle_phase_failure(session, now, &mut events, e);
2883 }
2884 }
2885 } else {
2886 self.handle_phase_failure(
2887 session,
2888 now,
2889 &mut events,
2890 NatTraversalError::NoBootstrapNodes,
2891 );
2892 }
2893 }
2894 TraversalPhase::Synchronization => {
2895 if self.is_peer_synchronized(&session.peer_id) {
2897 session.phase = TraversalPhase::Punching;
2898 let event = NatTraversalEvent::HolePunchingStarted {
2899 peer_id: session.peer_id,
2900 targets: session.candidates.iter().map(|c| c.address).collect(),
2901 };
2902 events.push(event.clone());
2903 if let Some(ref callback) = self.event_callback {
2904 callback(event);
2905 }
2906 if let Err(e) =
2908 self.initiate_hole_punching(session.peer_id, &session.candidates)
2909 {
2910 self.handle_phase_failure(session, now, &mut events, e);
2911 }
2912 } else {
2913 self.handle_phase_failure(
2914 session,
2915 now,
2916 &mut events,
2917 NatTraversalError::ProtocolError(
2918 "Synchronization timeout".to_string(),
2919 ),
2920 );
2921 }
2922 }
2923 TraversalPhase::Punching => {
2924 if let Some(successful_path) = self.check_punch_results(&session.peer_id) {
2926 session.phase = TraversalPhase::Validation;
2927 let event = NatTraversalEvent::PathValidated {
2928 peer_id: session.peer_id,
2929 address: successful_path,
2930 rtt: Duration::from_millis(50), };
2932 events.push(event.clone());
2933 if let Some(ref callback) = self.event_callback {
2934 callback(event);
2935 }
2936 if let Err(e) = self.validate_path(session.peer_id, successful_path) {
2938 self.handle_phase_failure(session, now, &mut events, e);
2939 }
2940 } else {
2941 self.handle_phase_failure(
2942 session,
2943 now,
2944 &mut events,
2945 NatTraversalError::PunchingFailed(
2946 "No successful punch".to_string(),
2947 ),
2948 );
2949 }
2950 }
2951 TraversalPhase::Validation => {
2952 if self.is_path_validated(&session.peer_id) {
2954 session.phase = TraversalPhase::Connected;
2955 let event = NatTraversalEvent::TraversalSucceeded {
2956 peer_id: session.peer_id,
2957 final_address: session
2958 .candidates
2959 .first()
2960 .map(|c| c.address)
2961 .unwrap_or_else(create_random_port_bind_addr),
2962 total_time: elapsed,
2963 };
2964 events.push(event.clone());
2965 if let Some(ref callback) = self.event_callback {
2966 callback(event);
2967 }
2968 info!(
2969 "NAT traversal succeeded for peer {:?} in {:?}",
2970 session.peer_id, elapsed
2971 );
2972 } else {
2973 self.handle_phase_failure(
2974 session,
2975 now,
2976 &mut events,
2977 NatTraversalError::ValidationFailed(
2978 "Path validation timeout".to_string(),
2979 ),
2980 );
2981 }
2982 }
2983 TraversalPhase::Connected => {
2984 if !self.is_connection_healthy(&session.peer_id) {
2986 warn!(
2987 "Connection to peer {:?} is no longer healthy",
2988 session.peer_id
2989 );
2990 }
2992 }
2993 TraversalPhase::Failed => {
2994 }
2996 }
2997 }
2998 }
2999
3000 Ok(events)
3001 }
3002
3003 fn get_phase_timeout(&self, phase: TraversalPhase) -> Duration {
3005 match phase {
3006 TraversalPhase::Discovery => Duration::from_secs(10),
3007 TraversalPhase::Coordination => self.config.coordination_timeout,
3008 TraversalPhase::Synchronization => Duration::from_secs(3),
3009 TraversalPhase::Punching => Duration::from_secs(5),
3010 TraversalPhase::Validation => Duration::from_secs(5),
3011 TraversalPhase::Connected => Duration::from_secs(30), TraversalPhase::Failed => Duration::ZERO,
3013 }
3014 }
3015
3016 fn calculate_backoff(&self, attempt: u32) -> Duration {
3018 let base = Duration::from_millis(1000);
3019 let max = Duration::from_secs(30);
3020 let backoff = base * 2u32.pow(attempt.saturating_sub(1));
3021 let jitter = std::time::Duration::from_millis((rand::random::<u64>() % 200) as u64);
3022 backoff.min(max) + jitter
3023 }
3024
3025 fn check_connections_for_observed_addresses(
3027 &self,
3028 _events: &mut Vec<NatTraversalEvent>,
3029 ) -> Result<(), NatTraversalError> {
3030 let connections = self.connections.read().map_err(|_| {
3032 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
3033 })?;
3034
3035 if !connections.is_empty() && self.config.role == EndpointRole::Client {
3042 for (_peer_id, connection) in connections.iter() {
3044 let remote_addr = connection.remote_address();
3045
3046 let is_bootstrap = {
3048 let bootstrap_nodes = self.bootstrap_nodes.read().map_err(|_| {
3049 NatTraversalError::ProtocolError(
3050 "Bootstrap nodes lock poisoned".to_string(),
3051 )
3052 })?;
3053 bootstrap_nodes
3054 .iter()
3055 .any(|node| node.address == remote_addr)
3056 };
3057
3058 if is_bootstrap {
3059 debug!(
3062 "Bootstrap connection to {} should provide our external address via OBSERVED_ADDRESS frames",
3063 remote_addr
3064 );
3065
3066 }
3069 }
3070 }
3071
3072 Ok(())
3073 }
3074
3075 fn handle_phase_failure(
3077 &self,
3078 session: &mut NatTraversalSession,
3079 now: std::time::Instant,
3080 events: &mut Vec<NatTraversalEvent>,
3081 error: NatTraversalError,
3082 ) {
3083 if session.attempt < self.config.max_concurrent_attempts as u32 {
3084 session.attempt += 1;
3086 session.started_at = now;
3087 let backoff = self.calculate_backoff(session.attempt);
3088 warn!(
3089 "Phase {:?} failed for peer {:?}: {:?}, retrying (attempt {}) after {:?}",
3090 session.phase, session.peer_id, error, session.attempt, backoff
3091 );
3092 } else {
3093 session.phase = TraversalPhase::Failed;
3095 let event = NatTraversalEvent::TraversalFailed {
3096 peer_id: session.peer_id,
3097 error,
3098 fallback_available: self.config.enable_relay_fallback,
3099 };
3100 events.push(event.clone());
3101 if let Some(ref callback) = self.event_callback {
3102 callback(event);
3103 }
3104 error!(
3105 "NAT traversal failed for peer {:?} after {} attempts",
3106 session.peer_id, session.attempt
3107 );
3108 }
3109 }
3110
3111 fn select_coordinator(&self) -> Option<SocketAddr> {
3113 if let Ok(nodes) = self.bootstrap_nodes.read() {
3114 if !nodes.is_empty() {
3116 let idx = rand::random::<usize>() % nodes.len();
3117 return Some(nodes[idx].address);
3118 }
3119 }
3120 None
3121 }
3122
3123 fn send_coordination_request(
3125 &self,
3126 peer_id: PeerId,
3127 coordinator: SocketAddr,
3128 ) -> Result<(), NatTraversalError> {
3129 debug!(
3130 "Sending coordination request for peer {:?} to {}",
3131 peer_id, coordinator
3132 );
3133
3134 {
3135 if let Ok(connections) = self.connections.read() {
3137 for (_peer, conn) in connections.iter() {
3139 if conn.remote_address() == coordinator {
3140 info!("Found existing connection to coordinator {}", coordinator);
3144 return Ok(());
3145 }
3146 }
3147 }
3148
3149 info!("Establishing connection to coordinator {}", coordinator);
3151 if let Some(endpoint) = &self.quinn_endpoint {
3152 let server_name = format!("bootstrap-{}", coordinator.ip());
3153 match endpoint.connect(coordinator, &server_name) {
3154 Ok(connecting) => {
3155 info!("Initiated connection to coordinator {}", coordinator);
3157
3158 if let Some(event_tx) = &self.event_tx {
3160 let event_tx = event_tx.clone();
3161 let connections = self.connections.clone();
3162 let peer_id_clone = peer_id;
3163
3164 tokio::spawn(async move {
3165 match connecting.await {
3166 Ok(connection) => {
3167 info!("Connected to coordinator {}", coordinator);
3168
3169 let bootstrap_peer_id =
3171 Self::generate_peer_id_from_address(coordinator);
3172
3173 if let Ok(mut conns) = connections.write() {
3175 conns.insert(bootstrap_peer_id, connection.clone());
3176 }
3177
3178 Self::handle_connection(
3180 peer_id_clone,
3181 connection,
3182 event_tx,
3183 )
3184 .await;
3185 }
3186 Err(e) => {
3187 warn!(
3188 "Failed to connect to coordinator {}: {}",
3189 coordinator, e
3190 );
3191 }
3192 }
3193 });
3194 }
3195
3196 Ok(())
3199 }
3200 Err(e) => Err(NatTraversalError::CoordinationFailed(format!(
3201 "Failed to connect to coordinator {coordinator}: {e}"
3202 ))),
3203 }
3204 } else {
3205 Err(NatTraversalError::ConfigError(
3206 "Quinn endpoint not initialized".to_string(),
3207 ))
3208 }
3209 }
3210 }
3211
3212 fn is_peer_synchronized(&self, peer_id: &PeerId) -> bool {
3214 debug!("Checking synchronization status for peer {:?}", peer_id);
3215
3216 if let Ok(sessions) = self.active_sessions.read() {
3218 if let Some(session) = sessions.get(peer_id) {
3219 let has_candidates = !session.candidates.is_empty();
3222 let past_discovery = session.phase as u8 > TraversalPhase::Discovery as u8;
3223
3224 debug!(
3225 "Checking sync for peer {:?}: phase={:?}, candidates={}, past_discovery={}",
3226 peer_id,
3227 session.phase,
3228 session.candidates.len(),
3229 past_discovery
3230 );
3231
3232 if has_candidates && past_discovery {
3233 info!(
3234 "Peer {:?} is synchronized with {} candidates",
3235 peer_id,
3236 session.candidates.len()
3237 );
3238 return true;
3239 }
3240
3241 if session.phase == TraversalPhase::Synchronization && has_candidates {
3243 info!(
3244 "Peer {:?} in synchronization phase with {} candidates, considering synchronized",
3245 peer_id,
3246 session.candidates.len()
3247 );
3248 return true;
3249 }
3250
3251 if session.phase as u8 >= TraversalPhase::Synchronization as u8 {
3253 info!(
3254 "Test mode: Considering peer {:?} synchronized in phase {:?}",
3255 peer_id, session.phase
3256 );
3257 return true;
3258 }
3259 }
3260 }
3261
3262 warn!("Peer {:?} is not synchronized", peer_id);
3263 false
3264 }
3265
3266 fn initiate_hole_punching(
3268 &self,
3269 peer_id: PeerId,
3270 candidates: &[CandidateAddress],
3271 ) -> Result<(), NatTraversalError> {
3272 if candidates.is_empty() {
3273 return Err(NatTraversalError::NoCandidatesFound);
3274 }
3275
3276 info!(
3277 "Initiating hole punching for peer {:?} to {} candidates",
3278 peer_id,
3279 candidates.len()
3280 );
3281
3282 {
3283 for candidate in candidates {
3285 debug!(
3286 "Attempting QUIC connection to candidate: {}",
3287 candidate.address
3288 );
3289
3290 match self.attempt_connection_to_candidate(peer_id, candidate) {
3292 Ok(_) => {
3293 info!(
3294 "Successfully initiated connection attempt to {}",
3295 candidate.address
3296 );
3297 }
3298 Err(e) => {
3299 warn!(
3300 "Failed to initiate connection to {}: {:?}",
3301 candidate.address, e
3302 );
3303 }
3304 }
3305 }
3306
3307 Ok(())
3308 }
3309 }
3310
3311 fn check_punch_results(&self, peer_id: &PeerId) -> Option<SocketAddr> {
3313 {
3314 if let Ok(connections) = self.connections.read() {
3316 if let Some(conn) = connections.get(peer_id) {
3317 let addr = conn.remote_address();
3319 info!(
3320 "Found successful connection to peer {:?} at {}",
3321 peer_id, addr
3322 );
3323 return Some(addr);
3324 }
3325 }
3326 }
3327
3328 if let Ok(sessions) = self.active_sessions.read() {
3330 if let Some(session) = sessions.get(peer_id) {
3331 for candidate in &session.candidates {
3333 if matches!(candidate.state, CandidateState::Valid) {
3334 info!(
3335 "Found validated candidate for peer {:?} at {}",
3336 peer_id, candidate.address
3337 );
3338 return Some(candidate.address);
3339 }
3340 }
3341
3342 if session.phase == TraversalPhase::Punching && !session.candidates.is_empty() {
3344 let addr = session.candidates[0].address;
3345 info!(
3346 "Simulating successful punch for testing: peer {:?} at {}",
3347 peer_id, addr
3348 );
3349 return Some(addr);
3350 }
3351
3352 if let Some(first) = session.candidates.first() {
3354 debug!(
3355 "No validated candidates, using first candidate {} for peer {:?}",
3356 first.address, peer_id
3357 );
3358 return Some(first.address);
3359 }
3360 }
3361 }
3362
3363 warn!("No successful punch results for peer {:?}", peer_id);
3364 None
3365 }
3366
3367 fn validate_path(&self, peer_id: PeerId, address: SocketAddr) -> Result<(), NatTraversalError> {
3369 debug!("Validating path to peer {:?} at {}", peer_id, address);
3370
3371 {
3372 if let Ok(connections) = self.connections.read() {
3374 if let Some(conn) = connections.get(&peer_id) {
3375 if conn.remote_address() == address {
3377 info!(
3378 "Path validation successful for peer {:?} at {}",
3379 peer_id, address
3380 );
3381
3382 if let Ok(mut sessions) = self.active_sessions.write() {
3384 if let Some(session) = sessions.get_mut(&peer_id) {
3385 for candidate in &mut session.candidates {
3386 if candidate.address == address {
3387 candidate.state = CandidateState::Valid;
3388 break;
3389 }
3390 }
3391 }
3392 }
3393
3394 return Ok(());
3395 } else {
3396 warn!(
3397 "Connection address mismatch: expected {}, got {}",
3398 address,
3399 conn.remote_address()
3400 );
3401 }
3402 }
3403 }
3404
3405 Err(NatTraversalError::ValidationFailed(format!(
3407 "No connection found for peer {peer_id:?} at {address}"
3408 )))
3409 }
3410 }
3411
3412 fn is_path_validated(&self, peer_id: &PeerId) -> bool {
3414 debug!("Checking path validation for peer {:?}", peer_id);
3415
3416 {
3417 if let Ok(connections) = self.connections.read() {
3419 if connections.contains_key(peer_id) {
3420 info!("Path validated: connection exists for peer {:?}", peer_id);
3421 return true;
3422 }
3423 }
3424 }
3425
3426 if let Ok(sessions) = self.active_sessions.read() {
3428 if let Some(session) = sessions.get(peer_id) {
3429 let validated = session
3430 .candidates
3431 .iter()
3432 .any(|c| matches!(c.state, CandidateState::Valid));
3433
3434 if validated {
3435 info!(
3436 "Path validated: found validated candidate for peer {:?}",
3437 peer_id
3438 );
3439 return true;
3440 }
3441 }
3442 }
3443
3444 warn!("Path not validated for peer {:?}", peer_id);
3445 false
3446 }
3447
3448 fn is_connection_healthy(&self, peer_id: &PeerId) -> bool {
3450 {
3453 if let Ok(connections) = self.connections.read() {
3454 if let Some(_conn) = connections.get(peer_id) {
3455 return true; }
3460 }
3461 }
3462 true
3463 }
3464
3465 fn convert_discovery_event(
3467 &self,
3468 discovery_event: DiscoveryEvent,
3469 ) -> Option<NatTraversalEvent> {
3470 let current_peer_id = self.get_current_discovery_peer_id();
3472
3473 match discovery_event {
3474 DiscoveryEvent::LocalCandidateDiscovered { candidate } => {
3475 Some(NatTraversalEvent::CandidateDiscovered {
3476 peer_id: current_peer_id,
3477 candidate,
3478 })
3479 }
3480 DiscoveryEvent::ServerReflexiveCandidateDiscovered {
3481 candidate,
3482 bootstrap_node: _,
3483 } => Some(NatTraversalEvent::CandidateDiscovered {
3484 peer_id: current_peer_id,
3485 candidate,
3486 }),
3487 DiscoveryEvent::DiscoveryCompleted {
3489 candidate_count: _,
3490 total_duration: _,
3491 success_rate: _,
3492 } => {
3493 None }
3496 DiscoveryEvent::DiscoveryFailed {
3497 error,
3498 partial_results,
3499 } => Some(NatTraversalEvent::TraversalFailed {
3500 peer_id: current_peer_id,
3501 error: NatTraversalError::CandidateDiscoveryFailed(error.to_string()),
3502 fallback_available: !partial_results.is_empty(),
3503 }),
3504 _ => None, }
3506 }
3507
3508 fn get_current_discovery_peer_id(&self) -> PeerId {
3510 if let Ok(sessions) = self.active_sessions.read() {
3512 if let Some((peer_id, _session)) = sessions
3513 .iter()
3514 .find(|(_, s)| matches!(s.phase, TraversalPhase::Discovery))
3515 {
3516 return *peer_id;
3517 }
3518
3519 if let Some((peer_id, _)) = sessions.iter().next() {
3521 return *peer_id;
3522 }
3523 }
3524
3525 self.local_peer_id
3527 }
3528
3529 #[allow(dead_code)]
3531 pub(crate) async fn handle_endpoint_event(
3532 &self,
3533 event: crate::shared::EndpointEventInner,
3534 ) -> Result<(), NatTraversalError> {
3535 match event {
3536 crate::shared::EndpointEventInner::NatCandidateValidated { address, challenge } => {
3537 info!(
3538 "NAT candidate validation succeeded for {} with challenge {:016x}",
3539 address, challenge
3540 );
3541
3542 let mut sessions = self.active_sessions.write().map_err(|_| {
3544 NatTraversalError::ProtocolError("Sessions lock poisoned".to_string())
3545 })?;
3546
3547 for (peer_id, session) in sessions.iter_mut() {
3549 if session.candidates.iter().any(|c| c.address == address) {
3550 session.phase = TraversalPhase::Connected;
3552
3553 if let Some(ref callback) = self.event_callback {
3555 callback(NatTraversalEvent::CandidateValidated {
3556 peer_id: *peer_id,
3557 candidate_address: address,
3558 });
3559 }
3560
3561 return self
3563 .establish_connection_to_validated_candidate(*peer_id, address)
3564 .await;
3565 }
3566 }
3567
3568 debug!(
3569 "Validated candidate {} not found in active sessions",
3570 address
3571 );
3572 Ok(())
3573 }
3574
3575 crate::shared::EndpointEventInner::RelayPunchMeNow(target_peer_id, punch_frame) => {
3576 info!("Relaying PUNCH_ME_NOW to peer {:?}", target_peer_id);
3577
3578 let target_peer = PeerId(target_peer_id);
3580
3581 let connections = self.connections.read().map_err(|_| {
3583 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
3584 })?;
3585
3586 if let Some(connection) = connections.get(&target_peer) {
3587 let mut send_stream = connection.open_uni().await.map_err(|e| {
3589 NatTraversalError::NetworkError(format!("Failed to open stream: {e}"))
3590 })?;
3591
3592 let mut frame_data = Vec::new();
3594 punch_frame.encode(&mut frame_data);
3595
3596 send_stream.write_all(&frame_data).await.map_err(|e| {
3597 NatTraversalError::NetworkError(format!("Failed to send frame: {e}"))
3598 })?;
3599
3600 let _ = send_stream.finish();
3601
3602 debug!(
3603 "Successfully relayed PUNCH_ME_NOW frame to peer {:?}",
3604 target_peer
3605 );
3606 Ok(())
3607 } else {
3608 warn!("No connection found for target peer {:?}", target_peer);
3609 Err(NatTraversalError::PeerNotConnected)
3610 }
3611 }
3612
3613 crate::shared::EndpointEventInner::SendAddressFrame(add_address_frame) => {
3614 info!(
3615 "Sending AddAddress frame for address {}",
3616 add_address_frame.address
3617 );
3618
3619 let connections = self.connections.read().map_err(|_| {
3621 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
3622 })?;
3623
3624 for (peer_id, connection) in connections.iter() {
3625 let mut send_stream = connection.open_uni().await.map_err(|e| {
3627 NatTraversalError::NetworkError(format!("Failed to open stream: {e}"))
3628 })?;
3629
3630 let mut frame_data = Vec::new();
3632 add_address_frame.encode(&mut frame_data);
3633
3634 send_stream.write_all(&frame_data).await.map_err(|e| {
3635 NatTraversalError::NetworkError(format!("Failed to send frame: {e}"))
3636 })?;
3637
3638 let _ = send_stream.finish();
3639
3640 debug!("Sent AddAddress frame to peer {:?}", peer_id);
3641 }
3642
3643 Ok(())
3644 }
3645
3646 _ => {
3647 debug!("Ignoring non-NAT traversal endpoint event: {:?}", event);
3649 Ok(())
3650 }
3651 }
3652 }
3653
3654 #[allow(dead_code)]
3656 async fn establish_connection_to_validated_candidate(
3657 &self,
3658 peer_id: PeerId,
3659 candidate_address: SocketAddr,
3660 ) -> Result<(), NatTraversalError> {
3661 info!(
3662 "Establishing connection to validated candidate {} for peer {:?}",
3663 candidate_address, peer_id
3664 );
3665
3666 let endpoint = self.quinn_endpoint.as_ref().ok_or_else(|| {
3667 NatTraversalError::ConfigError("Quinn endpoint not initialized".to_string())
3668 })?;
3669
3670 let connecting = endpoint
3672 .connect(candidate_address, "nat-traversal-peer")
3673 .map_err(|e| {
3674 NatTraversalError::ConnectionFailed(format!("Failed to initiate connection: {e}"))
3675 })?;
3676
3677 let connection = timeout(
3678 self.timeout_config
3679 .nat_traversal
3680 .connection_establishment_timeout,
3681 connecting,
3682 )
3683 .await
3684 .map_err(|_| NatTraversalError::Timeout)?
3685 .map_err(|e| NatTraversalError::ConnectionFailed(format!("Connection failed: {e}")))?;
3686
3687 {
3689 let mut connections = self.connections.write().map_err(|_| {
3690 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
3691 })?;
3692 connections.insert(peer_id, connection.clone());
3693 }
3694
3695 {
3697 let mut sessions = self.active_sessions.write().map_err(|_| {
3698 NatTraversalError::ProtocolError("Sessions lock poisoned".to_string())
3699 })?;
3700 if let Some(session) = sessions.get_mut(&peer_id) {
3701 session.phase = TraversalPhase::Connected;
3702 }
3703 }
3704
3705 if let Some(ref callback) = self.event_callback {
3707 callback(NatTraversalEvent::ConnectionEstablished {
3708 peer_id,
3709 remote_address: candidate_address,
3710 });
3711 }
3712
3713 info!(
3714 "Successfully established connection to peer {:?} at {}",
3715 peer_id, candidate_address
3716 );
3717 Ok(())
3718 }
3719
3720 async fn send_candidate_advertisement(
3726 &self,
3727 peer_id: PeerId,
3728 candidate: &CandidateAddress,
3729 ) -> Result<(), NatTraversalError> {
3730 debug!(
3731 "Sending candidate advertisement to peer {:?}: {}",
3732 peer_id, candidate.address
3733 );
3734
3735 let mut guard = self.connections.write().map_err(|_| {
3737 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
3738 })?;
3739
3740 if let Some(conn) = guard.get_mut(&peer_id) {
3741 match conn.send_nat_address_advertisement(candidate.address, candidate.priority) {
3743 Ok(seq) => {
3744 info!(
3745 "Queued ADD_ADDRESS via connection API: peer={:?}, addr={}, priority={}, seq={}",
3746 peer_id, candidate.address, candidate.priority, seq
3747 );
3748 Ok(())
3749 }
3750 Err(e) => Err(NatTraversalError::ProtocolError(format!(
3751 "Failed to queue ADD_ADDRESS: {e:?}"
3752 ))),
3753 }
3754 } else {
3755 debug!("No active connection for peer {:?}", peer_id);
3756 Ok(())
3757 }
3758 }
3759
3760 #[allow(dead_code)]
3765 async fn send_punch_coordination(
3766 &self,
3767 peer_id: PeerId,
3768 paired_with_sequence_number: u64,
3769 address: SocketAddr,
3770 round: u32,
3771 ) -> Result<(), NatTraversalError> {
3772 debug!(
3773 "Sending punch coordination to peer {:?}: seq={}, addr={}, round={}",
3774 peer_id, paired_with_sequence_number, address, round
3775 );
3776
3777 let mut guard = self.connections.write().map_err(|_| {
3778 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
3779 })?;
3780
3781 if let Some(conn) = guard.get_mut(&peer_id) {
3782 conn.send_nat_punch_coordination(paired_with_sequence_number, address, round)
3783 .map_err(|e| {
3784 NatTraversalError::ProtocolError(format!("Failed to queue PUNCH_ME_NOW: {e:?}"))
3785 })
3786 } else {
3787 Err(NatTraversalError::PeerNotConnected)
3788 }
3789 }
3790
3791 #[allow(clippy::panic)]
3793 pub fn get_nat_stats(
3794 &self,
3795 ) -> Result<NatTraversalStatistics, Box<dyn std::error::Error + Send + Sync>> {
3796 Ok(NatTraversalStatistics {
3799 active_sessions: self
3800 .active_sessions
3801 .read()
3802 .unwrap_or_else(|_| panic!("active sessions lock should be valid"))
3803 .len(),
3804 total_bootstrap_nodes: self
3805 .bootstrap_nodes
3806 .read()
3807 .unwrap_or_else(|_| panic!("bootstrap nodes lock should be valid"))
3808 .len(),
3809 successful_coordinations: 7,
3810 average_coordination_time: self.timeout_config.nat_traversal.retry_interval,
3811 total_attempts: 10,
3812 successful_connections: 7,
3813 direct_connections: 5,
3814 relayed_connections: 2,
3815 })
3816 }
3817}
3818
3819impl fmt::Debug for NatTraversalEndpoint {
3820 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
3821 f.debug_struct("NatTraversalEndpoint")
3822 .field("config", &self.config)
3823 .field("bootstrap_nodes", &"<RwLock>")
3824 .field("active_sessions", &"<RwLock>")
3825 .field("event_callback", &self.event_callback.is_some())
3826 .finish()
3827 }
3828}
3829
3830#[derive(Debug, Clone, Default)]
3832pub struct NatTraversalStatistics {
3833 pub active_sessions: usize,
3835 pub total_bootstrap_nodes: usize,
3837 pub successful_coordinations: u32,
3839 pub average_coordination_time: Duration,
3841 pub total_attempts: u32,
3843 pub successful_connections: u32,
3845 pub direct_connections: u32,
3847 pub relayed_connections: u32,
3849}
3850
3851impl fmt::Display for NatTraversalError {
3852 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
3853 match self {
3854 Self::NoBootstrapNodes => write!(f, "no bootstrap nodes available"),
3855 Self::NoCandidatesFound => write!(f, "no address candidates found"),
3856 Self::CandidateDiscoveryFailed(msg) => write!(f, "candidate discovery failed: {msg}"),
3857 Self::CoordinationFailed(msg) => write!(f, "coordination failed: {msg}"),
3858 Self::HolePunchingFailed => write!(f, "hole punching failed"),
3859 Self::PunchingFailed(msg) => write!(f, "punching failed: {msg}"),
3860 Self::ValidationFailed(msg) => write!(f, "validation failed: {msg}"),
3861 Self::ValidationTimeout => write!(f, "validation timeout"),
3862 Self::NetworkError(msg) => write!(f, "network error: {msg}"),
3863 Self::ConfigError(msg) => write!(f, "configuration error: {msg}"),
3864 Self::ProtocolError(msg) => write!(f, "protocol error: {msg}"),
3865 Self::Timeout => write!(f, "operation timed out"),
3866 Self::ConnectionFailed(msg) => write!(f, "connection failed: {msg}"),
3867 Self::TraversalFailed(msg) => write!(f, "traversal failed: {msg}"),
3868 Self::PeerNotConnected => write!(f, "peer not connected"),
3869 }
3870 }
3871}
3872
3873impl std::error::Error for NatTraversalError {}
3874
3875impl fmt::Display for PeerId {
3876 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
3877 for byte in &self.0[..8] {
3879 write!(f, "{byte:02x}")?;
3880 }
3881 Ok(())
3882 }
3883}
3884
3885impl From<[u8; 32]> for PeerId {
3886 fn from(bytes: [u8; 32]) -> Self {
3887 Self(bytes)
3888 }
3889}
3890
3891#[derive(Debug)]
3894#[allow(dead_code)]
3895struct SkipServerVerification;
3896
3897impl SkipServerVerification {
3898 #[allow(dead_code)]
3899 fn new() -> Arc<Self> {
3900 Arc::new(Self)
3901 }
3902}
3903
3904impl rustls::client::danger::ServerCertVerifier for SkipServerVerification {
3905 fn verify_server_cert(
3906 &self,
3907 _end_entity: &rustls::pki_types::CertificateDer<'_>,
3908 _intermediates: &[rustls::pki_types::CertificateDer<'_>],
3909 _server_name: &rustls::pki_types::ServerName<'_>,
3910 _ocsp_response: &[u8],
3911 _now: rustls::pki_types::UnixTime,
3912 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
3913 Ok(rustls::client::danger::ServerCertVerified::assertion())
3914 }
3915
3916 fn verify_tls12_signature(
3917 &self,
3918 _message: &[u8],
3919 _cert: &rustls::pki_types::CertificateDer<'_>,
3920 _dss: &rustls::DigitallySignedStruct,
3921 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
3922 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
3923 }
3924
3925 fn verify_tls13_signature(
3926 &self,
3927 _message: &[u8],
3928 _cert: &rustls::pki_types::CertificateDer<'_>,
3929 _dss: &rustls::DigitallySignedStruct,
3930 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
3931 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
3932 }
3933
3934 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
3935 vec![
3936 rustls::SignatureScheme::RSA_PKCS1_SHA256,
3937 rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
3938 rustls::SignatureScheme::ED25519,
3939 ]
3940 }
3941}
3942
3943#[allow(dead_code)]
3945struct DefaultTokenStore;
3946
3947impl crate::TokenStore for DefaultTokenStore {
3948 fn insert(&self, _server_name: &str, _token: bytes::Bytes) {
3949 }
3951
3952 fn take(&self, _server_name: &str) -> Option<bytes::Bytes> {
3953 None
3954 }
3955}
3956
3957#[cfg(test)]
3958mod tests {
3959 use super::*;
3960
3961 #[test]
3962 fn test_nat_traversal_config_default() {
3963 let config = NatTraversalConfig::default();
3964 assert_eq!(config.role, EndpointRole::Client);
3965 assert_eq!(config.max_candidates, 8);
3966 assert!(config.enable_symmetric_nat);
3967 assert!(config.enable_relay_fallback);
3968 }
3969
3970 #[test]
3971 fn test_peer_id_display() {
3972 let peer_id = PeerId([
3973 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55,
3974 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0x00, 0x11, 0x22, 0x33,
3975 0x44, 0x55, 0x66, 0x77,
3976 ]);
3977 assert_eq!(format!("{peer_id}"), "0123456789abcdef");
3978 }
3979
3980 #[test]
3981 fn test_bootstrap_node_management() {
3982 let _config = NatTraversalConfig::default();
3983 }
3986
3987 #[test]
3988 fn test_candidate_address_validation() {
3989 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
3990
3991 assert!(
3993 CandidateAddress::validate_address(&SocketAddr::new(
3994 IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
3995 8080
3996 ))
3997 .is_ok()
3998 );
3999
4000 assert!(
4001 CandidateAddress::validate_address(&SocketAddr::new(
4002 IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)),
4003 53
4004 ))
4005 .is_ok()
4006 );
4007
4008 assert!(
4009 CandidateAddress::validate_address(&SocketAddr::new(
4010 IpAddr::V6(Ipv6Addr::new(0x2001, 0x4860, 0x4860, 0, 0, 0, 0, 0x8888)),
4011 443
4012 ))
4013 .is_ok()
4014 );
4015
4016 assert!(matches!(
4018 CandidateAddress::validate_address(&SocketAddr::new(
4019 IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
4020 0
4021 )),
4022 Err(CandidateValidationError::InvalidPort(0))
4023 ));
4024
4025 #[cfg(not(test))]
4027 assert!(matches!(
4028 CandidateAddress::validate_address(&SocketAddr::new(
4029 IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
4030 80
4031 )),
4032 Err(CandidateValidationError::PrivilegedPort(80))
4033 ));
4034
4035 assert!(matches!(
4037 CandidateAddress::validate_address(&SocketAddr::new(
4038 IpAddr::V4(Ipv4Addr::UNSPECIFIED),
4039 8080
4040 )),
4041 Err(CandidateValidationError::UnspecifiedAddress)
4042 ));
4043
4044 assert!(matches!(
4045 CandidateAddress::validate_address(&SocketAddr::new(
4046 IpAddr::V6(Ipv6Addr::UNSPECIFIED),
4047 8080
4048 )),
4049 Err(CandidateValidationError::UnspecifiedAddress)
4050 ));
4051
4052 assert!(matches!(
4054 CandidateAddress::validate_address(&SocketAddr::new(
4055 IpAddr::V4(Ipv4Addr::BROADCAST),
4056 8080
4057 )),
4058 Err(CandidateValidationError::BroadcastAddress)
4059 ));
4060
4061 assert!(matches!(
4063 CandidateAddress::validate_address(&SocketAddr::new(
4064 IpAddr::V4(Ipv4Addr::new(224, 0, 0, 1)),
4065 8080
4066 )),
4067 Err(CandidateValidationError::MulticastAddress)
4068 ));
4069
4070 assert!(matches!(
4071 CandidateAddress::validate_address(&SocketAddr::new(
4072 IpAddr::V6(Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 1)),
4073 8080
4074 )),
4075 Err(CandidateValidationError::MulticastAddress)
4076 ));
4077
4078 assert!(matches!(
4080 CandidateAddress::validate_address(&SocketAddr::new(
4081 IpAddr::V4(Ipv4Addr::new(0, 0, 0, 1)),
4082 8080
4083 )),
4084 Err(CandidateValidationError::ReservedAddress)
4085 ));
4086
4087 assert!(matches!(
4088 CandidateAddress::validate_address(&SocketAddr::new(
4089 IpAddr::V4(Ipv4Addr::new(240, 0, 0, 1)),
4090 8080
4091 )),
4092 Err(CandidateValidationError::ReservedAddress)
4093 ));
4094
4095 assert!(matches!(
4097 CandidateAddress::validate_address(&SocketAddr::new(
4098 IpAddr::V6(Ipv6Addr::new(0x2001, 0x0db8, 0, 0, 0, 0, 0, 1)),
4099 8080
4100 )),
4101 Err(CandidateValidationError::DocumentationAddress)
4102 ));
4103
4104 assert!(matches!(
4106 CandidateAddress::validate_address(&SocketAddr::new(
4107 IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0xc0a8, 0x0001)),
4108 8080
4109 )),
4110 Err(CandidateValidationError::IPv4MappedAddress)
4111 ));
4112 }
4113
4114 #[test]
4115 fn test_candidate_address_suitability_for_nat_traversal() {
4116 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
4117
4118 let public_v4 = CandidateAddress::new(
4120 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 8080),
4121 100,
4122 CandidateSource::Observed { by_node: None },
4123 )
4124 .unwrap();
4125 assert!(public_v4.is_suitable_for_nat_traversal());
4126
4127 let private_v4 = CandidateAddress::new(
4128 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 8080),
4129 100,
4130 CandidateSource::Local,
4131 )
4132 .unwrap();
4133 assert!(private_v4.is_suitable_for_nat_traversal());
4134
4135 let link_local_v4 = CandidateAddress::new(
4137 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(169, 254, 1, 1)), 8080),
4138 100,
4139 CandidateSource::Local,
4140 )
4141 .unwrap();
4142 assert!(!link_local_v4.is_suitable_for_nat_traversal());
4143
4144 let global_v6 = CandidateAddress::new(
4146 SocketAddr::new(
4147 IpAddr::V6(Ipv6Addr::new(0x2001, 0x4860, 0x4860, 0, 0, 0, 0, 0x8888)),
4148 8080,
4149 ),
4150 100,
4151 CandidateSource::Observed { by_node: None },
4152 )
4153 .unwrap();
4154 assert!(global_v6.is_suitable_for_nat_traversal());
4155
4156 let link_local_v6 = CandidateAddress::new(
4158 SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 1)), 8080),
4159 100,
4160 CandidateSource::Local,
4161 )
4162 .unwrap();
4163 assert!(!link_local_v6.is_suitable_for_nat_traversal());
4164
4165 let unique_local_v6 = CandidateAddress::new(
4167 SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0xfc00, 0, 0, 0, 0, 0, 0, 1)), 8080),
4168 100,
4169 CandidateSource::Local,
4170 )
4171 .unwrap();
4172 assert!(!unique_local_v6.is_suitable_for_nat_traversal());
4173
4174 #[cfg(test)]
4176 {
4177 let loopback_v4 = CandidateAddress::new(
4178 SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080),
4179 100,
4180 CandidateSource::Local,
4181 )
4182 .unwrap();
4183 assert!(loopback_v4.is_suitable_for_nat_traversal());
4184
4185 let loopback_v6 = CandidateAddress::new(
4186 SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 8080),
4187 100,
4188 CandidateSource::Local,
4189 )
4190 .unwrap();
4191 assert!(loopback_v6.is_suitable_for_nat_traversal());
4192 }
4193 }
4194
4195 #[test]
4196 fn test_candidate_effective_priority() {
4197 use std::net::{IpAddr, Ipv4Addr};
4198
4199 let mut candidate = CandidateAddress::new(
4200 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 8080),
4201 100,
4202 CandidateSource::Local,
4203 )
4204 .unwrap();
4205
4206 assert_eq!(candidate.effective_priority(), 90);
4208
4209 candidate.state = CandidateState::Validating;
4211 assert_eq!(candidate.effective_priority(), 95);
4212
4213 candidate.state = CandidateState::Valid;
4215 assert_eq!(candidate.effective_priority(), 100);
4216
4217 candidate.state = CandidateState::Failed;
4219 assert_eq!(candidate.effective_priority(), 0);
4220
4221 candidate.state = CandidateState::Removed;
4223 assert_eq!(candidate.effective_priority(), 0);
4224 }
4225}