1#![allow(missing_docs)]
8
9use std::{collections::HashMap, fmt, net::SocketAddr, sync::Arc, time::Duration};
16
17#[allow(clippy::panic)]
36fn create_random_port_bind_addr() -> SocketAddr {
37 "0.0.0.0:0"
38 .parse()
39 .unwrap_or_else(|_| panic!("Random port bind address format is always valid"))
40}
41
42use tracing::{debug, error, info, warn};
43
44use std::sync::atomic::{AtomicBool, Ordering};
45
46use tokio::{
47 net::UdpSocket,
48 sync::mpsc,
49 time::{sleep, timeout},
50};
51
52use crate::high_level::default_runtime;
53
54use crate::{
55 VarInt,
56 candidate_discovery::{CandidateDiscoveryManager, DiscoveryConfig, DiscoveryEvent},
57 connection::nat_traversal::{CandidateSource, CandidateState, NatTraversalRole},
58};
59
60use crate::{
61 ClientConfig, ConnectionError, EndpointConfig, ServerConfig, TransportConfig,
62 high_level::{Connection as QuinnConnection, Endpoint as QuinnEndpoint},
63};
64
65#[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))]
66use crate::{crypto::rustls::QuicClientConfig, crypto::rustls::QuicServerConfig};
67
68use crate::config::validation::{ConfigValidator, ValidationResult};
69
70#[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))]
71use crate::crypto::certificate_manager::{CertificateConfig, CertificateManager};
72
73pub struct NatTraversalEndpoint {
75 quinn_endpoint: Option<QuinnEndpoint>,
77 config: NatTraversalConfig,
81 bootstrap_nodes: Arc<std::sync::RwLock<Vec<BootstrapNode>>>,
83 active_sessions: Arc<std::sync::RwLock<HashMap<PeerId, NatTraversalSession>>>,
85 discovery_manager: Arc<std::sync::Mutex<CandidateDiscoveryManager>>,
87 event_callback: Option<Box<dyn Fn(NatTraversalEvent) + Send + Sync>>,
89 shutdown: Arc<AtomicBool>,
91 event_tx: Option<mpsc::UnboundedSender<NatTraversalEvent>>,
93 connections: Arc<std::sync::RwLock<HashMap<PeerId, QuinnConnection>>>,
95 local_peer_id: PeerId,
97 timeout_config: crate::config::nat_timeouts::TimeoutConfig,
99}
100
101#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
133pub struct NatTraversalConfig {
134 pub role: EndpointRole,
136 pub bootstrap_nodes: Vec<SocketAddr>,
138 pub max_candidates: usize,
140 pub coordination_timeout: Duration,
142 pub enable_symmetric_nat: bool,
144 pub enable_relay_fallback: bool,
146 pub max_concurrent_attempts: usize,
148 pub bind_addr: Option<SocketAddr>,
165 pub prefer_rfc_nat_traversal: bool,
168 pub timeouts: crate::config::nat_timeouts::TimeoutConfig,
170}
171
172#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
174pub enum EndpointRole {
175 Client,
177 Server {
179 can_coordinate: bool,
181 },
182 Bootstrap,
184}
185
186impl EndpointRole {
187 pub fn name(&self) -> &'static str {
189 match self {
190 Self::Client => "client",
191 Self::Server { .. } => "server",
192 Self::Bootstrap => "bootstrap",
193 }
194 }
195}
196
197#[derive(
199 Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, serde::Serialize, serde::Deserialize,
200)]
201pub struct PeerId(pub [u8; 32]);
202
203#[derive(Debug, Clone)]
205pub struct BootstrapNode {
206 pub address: SocketAddr,
208 pub last_seen: std::time::Instant,
210 pub can_coordinate: bool,
212 pub rtt: Option<Duration>,
214 pub coordination_count: u32,
216}
217
218impl BootstrapNode {
219 pub fn new(address: SocketAddr) -> Self {
221 Self {
222 address,
223 last_seen: std::time::Instant::now(),
224 can_coordinate: true,
225 rtt: None,
226 coordination_count: 0,
227 }
228 }
229}
230
231#[derive(Debug, Clone)]
233pub struct CandidatePair {
234 pub local_candidate: CandidateAddress,
236 pub remote_candidate: CandidateAddress,
238 pub priority: u64,
240 pub state: CandidatePairState,
242}
243
244#[derive(Debug, Clone, Copy, PartialEq, Eq)]
246pub enum CandidatePairState {
247 Waiting,
249 InProgress,
251 Succeeded,
253 Failed,
255 Cancelled,
257}
258
259#[derive(Debug)]
261struct NatTraversalSession {
262 peer_id: PeerId,
264 #[allow(dead_code)]
266 coordinator: SocketAddr,
267 attempt: u32,
269 started_at: std::time::Instant,
271 phase: TraversalPhase,
273 candidates: Vec<CandidateAddress>,
275 session_state: SessionState,
277}
278
279#[derive(Debug, Clone)]
281pub struct SessionState {
282 pub state: ConnectionState,
284 pub last_transition: std::time::Instant,
286 pub connection: Option<QuinnConnection>,
288 pub active_attempts: Vec<(SocketAddr, std::time::Instant)>,
290 pub metrics: ConnectionMetrics,
292}
293
294#[derive(Debug, Clone, Copy, PartialEq, Eq)]
296pub enum ConnectionState {
297 Idle,
299 Connecting,
301 Connected,
303 Migrating,
305 Closed,
307}
308
309#[derive(Debug, Clone, Default)]
311pub struct ConnectionMetrics {
312 pub rtt: Option<Duration>,
314 pub loss_rate: f64,
316 pub bytes_sent: u64,
318 pub bytes_received: u64,
320 pub last_activity: Option<std::time::Instant>,
322}
323
324#[derive(Debug, Clone)]
326pub struct SessionStateUpdate {
327 pub peer_id: PeerId,
329 pub old_state: ConnectionState,
331 pub new_state: ConnectionState,
333 pub reason: StateChangeReason,
335}
336
337#[derive(Debug, Clone, Copy, PartialEq, Eq)]
339pub enum StateChangeReason {
340 Timeout,
342 ConnectionEstablished,
344 ConnectionClosed,
346 MigrationComplete,
348 MigrationFailed,
350 NetworkError,
352 UserClosed,
354}
355
356#[derive(Debug, Clone, Copy, PartialEq, Eq)]
358pub enum TraversalPhase {
359 Discovery,
361 Coordination,
363 Synchronization,
365 Punching,
367 Validation,
369 Connected,
371 Failed,
373}
374
375#[derive(Debug, Clone, Copy)]
377enum SessionUpdate {
378 Timeout,
380 Disconnected,
382 UpdateMetrics,
384 InvalidState,
386 Retry,
388 MigrationTimeout,
390 Remove,
392}
393
394#[derive(Debug, Clone)]
396pub struct CandidateAddress {
397 pub address: SocketAddr,
399 pub priority: u32,
401 pub source: CandidateSource,
403 pub state: CandidateState,
405}
406
407impl CandidateAddress {
408 pub fn new(
410 address: SocketAddr,
411 priority: u32,
412 source: CandidateSource,
413 ) -> Result<Self, CandidateValidationError> {
414 Self::validate_address(&address)?;
415 Ok(Self {
416 address,
417 priority,
418 source,
419 state: CandidateState::New,
420 })
421 }
422
423 pub fn validate_address(addr: &SocketAddr) -> Result<(), CandidateValidationError> {
425 if addr.port() == 0 {
427 return Err(CandidateValidationError::InvalidPort(0));
428 }
429
430 #[cfg(not(test))]
432 if addr.port() < 1024 {
433 return Err(CandidateValidationError::PrivilegedPort(addr.port()));
434 }
435
436 match addr.ip() {
437 std::net::IpAddr::V4(ipv4) => {
438 if ipv4.is_unspecified() {
440 return Err(CandidateValidationError::UnspecifiedAddress);
441 }
442 if ipv4.is_broadcast() {
443 return Err(CandidateValidationError::BroadcastAddress);
444 }
445 if ipv4.is_multicast() {
446 return Err(CandidateValidationError::MulticastAddress);
447 }
448 if ipv4.octets()[0] == 0 {
450 return Err(CandidateValidationError::ReservedAddress);
451 }
452 if ipv4.octets()[0] >= 240 {
454 return Err(CandidateValidationError::ReservedAddress);
455 }
456 }
457 std::net::IpAddr::V6(ipv6) => {
458 if ipv6.is_unspecified() {
460 return Err(CandidateValidationError::UnspecifiedAddress);
461 }
462 if ipv6.is_multicast() {
463 return Err(CandidateValidationError::MulticastAddress);
464 }
465 let segments = ipv6.segments();
467 if segments[0] == 0x2001 && segments[1] == 0x0db8 {
468 return Err(CandidateValidationError::DocumentationAddress);
469 }
470 if ipv6.to_ipv4_mapped().is_some() {
472 return Err(CandidateValidationError::IPv4MappedAddress);
473 }
474 }
475 }
476
477 Ok(())
478 }
479
480 pub fn is_suitable_for_nat_traversal(&self) -> bool {
482 match self.address.ip() {
483 std::net::IpAddr::V4(ipv4) => {
484 #[cfg(test)]
489 if ipv4.is_loopback() {
490 return true;
491 }
492 !ipv4.is_loopback()
493 && !ipv4.is_link_local()
494 && !ipv4.is_multicast()
495 && !ipv4.is_broadcast()
496 }
497 std::net::IpAddr::V6(ipv6) => {
498 #[cfg(test)]
504 if ipv6.is_loopback() {
505 return true;
506 }
507 let segments = ipv6.segments();
508 let is_link_local = (segments[0] & 0xffc0) == 0xfe80;
509 let is_unique_local = (segments[0] & 0xfe00) == 0xfc00;
510
511 !ipv6.is_loopback() && !is_link_local && !is_unique_local && !ipv6.is_multicast()
512 }
513 }
514 }
515
516 pub fn effective_priority(&self) -> u32 {
518 match self.state {
519 CandidateState::Valid => self.priority,
520 CandidateState::New => self.priority.saturating_sub(10),
521 CandidateState::Validating => self.priority.saturating_sub(5),
522 CandidateState::Failed => 0,
523 CandidateState::Removed => 0,
524 }
525 }
526}
527
528#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
530pub enum CandidateValidationError {
531 #[error("invalid port number: {0}")]
533 InvalidPort(u16),
534 #[error("privileged port not allowed: {0}")]
536 PrivilegedPort(u16),
537 #[error("unspecified address not allowed")]
539 UnspecifiedAddress,
540 #[error("broadcast address not allowed")]
542 BroadcastAddress,
543 #[error("multicast address not allowed")]
545 MulticastAddress,
546 #[error("reserved address not allowed")]
548 ReservedAddress,
549 #[error("documentation address not allowed")]
551 DocumentationAddress,
552 #[error("IPv4-mapped IPv6 address not allowed")]
554 IPv4MappedAddress,
555}
556
557#[derive(Debug, Clone)]
559pub enum NatTraversalEvent {
560 CandidateDiscovered {
562 peer_id: PeerId,
564 candidate: CandidateAddress,
566 },
567 CoordinationRequested {
569 peer_id: PeerId,
571 coordinator: SocketAddr,
573 },
574 CoordinationSynchronized {
576 peer_id: PeerId,
578 round_id: VarInt,
580 },
581 HolePunchingStarted {
583 peer_id: PeerId,
585 targets: Vec<SocketAddr>,
587 },
588 PathValidated {
590 peer_id: PeerId,
592 address: SocketAddr,
594 rtt: Duration,
596 },
597 CandidateValidated {
599 peer_id: PeerId,
601 candidate_address: SocketAddr,
603 },
604 TraversalSucceeded {
606 peer_id: PeerId,
608 final_address: SocketAddr,
610 total_time: Duration,
612 },
613 ConnectionEstablished {
615 peer_id: PeerId,
616 remote_address: SocketAddr,
618 },
619 TraversalFailed {
621 peer_id: PeerId,
623 error: NatTraversalError,
625 fallback_available: bool,
627 },
628 ConnectionLost {
630 peer_id: PeerId,
632 reason: String,
634 },
635 PhaseTransition {
637 peer_id: PeerId,
639 from_phase: TraversalPhase,
641 to_phase: TraversalPhase,
643 },
644 SessionStateChanged {
646 peer_id: PeerId,
648 new_state: ConnectionState,
650 },
651}
652
653#[derive(Debug, Clone)]
655pub enum NatTraversalError {
656 NoBootstrapNodes,
658 NoCandidatesFound,
660 CandidateDiscoveryFailed(String),
662 CoordinationFailed(String),
664 HolePunchingFailed,
666 PunchingFailed(String),
668 ValidationFailed(String),
670 ValidationTimeout,
672 NetworkError(String),
674 ConfigError(String),
676 ProtocolError(String),
678 Timeout,
680 ConnectionFailed(String),
682 TraversalFailed(String),
684 PeerNotConnected,
686}
687
688impl Default for NatTraversalConfig {
689 fn default() -> Self {
690 Self {
691 role: EndpointRole::Client,
692 bootstrap_nodes: Vec::new(),
693 max_candidates: 8,
694 coordination_timeout: Duration::from_secs(10),
695 enable_symmetric_nat: true,
696 enable_relay_fallback: true,
697 max_concurrent_attempts: 3,
698 bind_addr: None,
699 prefer_rfc_nat_traversal: true, timeouts: crate::config::nat_timeouts::TimeoutConfig::default(),
701 }
702 }
703}
704
705impl ConfigValidator for NatTraversalConfig {
706 fn validate(&self) -> ValidationResult<()> {
707 use crate::config::validation::*;
708
709 match self.role {
711 EndpointRole::Client => {
712 if self.bootstrap_nodes.is_empty() {
713 return Err(ConfigValidationError::InvalidRole(
714 "Client endpoints require at least one bootstrap node".to_string(),
715 ));
716 }
717 }
718 EndpointRole::Server { can_coordinate } => {
719 if can_coordinate && self.bootstrap_nodes.is_empty() {
720 return Err(ConfigValidationError::InvalidRole(
721 "Server endpoints with coordination capability require bootstrap nodes"
722 .to_string(),
723 ));
724 }
725 }
726 EndpointRole::Bootstrap => {
727 }
729 }
730
731 if !self.bootstrap_nodes.is_empty() {
733 validate_bootstrap_nodes(&self.bootstrap_nodes)?;
734 }
735
736 validate_range(self.max_candidates, 1, 256, "max_candidates")?;
738
739 validate_duration(
741 self.coordination_timeout,
742 Duration::from_millis(100),
743 Duration::from_secs(300),
744 "coordination_timeout",
745 )?;
746
747 validate_range(
749 self.max_concurrent_attempts,
750 1,
751 16,
752 "max_concurrent_attempts",
753 )?;
754
755 if self.max_concurrent_attempts > self.max_candidates {
757 return Err(ConfigValidationError::IncompatibleConfiguration(
758 "max_concurrent_attempts cannot exceed max_candidates".to_string(),
759 ));
760 }
761
762 if self.role == EndpointRole::Bootstrap && self.enable_relay_fallback {
763 return Err(ConfigValidationError::IncompatibleConfiguration(
764 "Bootstrap nodes should not enable relay fallback".to_string(),
765 ));
766 }
767
768 Ok(())
769 }
770}
771
772impl NatTraversalEndpoint {
773 pub async fn new(
775 config: NatTraversalConfig,
776 event_callback: Option<Box<dyn Fn(NatTraversalEvent) + Send + Sync>>,
777 ) -> Result<Self, NatTraversalError> {
778 Self::new_impl(config, event_callback).await
779 }
780
781 async fn new_impl(
783 config: NatTraversalConfig,
784 event_callback: Option<Box<dyn Fn(NatTraversalEvent) + Send + Sync>>,
785 ) -> Result<Self, NatTraversalError> {
786 Self::new_common(config, event_callback).await
787 }
788
789 async fn new_common(
791 config: NatTraversalConfig,
792 event_callback: Option<Box<dyn Fn(NatTraversalEvent) + Send + Sync>>,
793 ) -> Result<Self, NatTraversalError> {
794 Self::new_shared_logic(config, event_callback).await
796 }
797
798 async fn new_shared_logic(
800 config: NatTraversalConfig,
801 event_callback: Option<Box<dyn Fn(NatTraversalEvent) + Send + Sync>>,
802 ) -> Result<Self, NatTraversalError> {
803 {
806 config
807 .validate()
808 .map_err(|e| NatTraversalError::ConfigError(e.to_string()))?;
809 }
810
811 let bootstrap_nodes = Arc::new(std::sync::RwLock::new(
815 config
816 .bootstrap_nodes
817 .iter()
818 .map(|&address| BootstrapNode {
819 address,
820 last_seen: std::time::Instant::now(),
821 can_coordinate: true, rtt: None,
823 coordination_count: 0,
824 })
825 .collect(),
826 ));
827
828 let discovery_config = DiscoveryConfig {
830 total_timeout: config.coordination_timeout,
831 max_candidates: config.max_candidates,
832 enable_symmetric_prediction: config.enable_symmetric_nat,
833 bound_address: config.bind_addr, ..DiscoveryConfig::default()
835 };
836
837 let nat_traversal_role = match config.role {
838 EndpointRole::Client => NatTraversalRole::Client,
839 EndpointRole::Server { can_coordinate } => NatTraversalRole::Server {
840 can_relay: can_coordinate,
841 },
842 EndpointRole::Bootstrap => NatTraversalRole::Bootstrap,
843 };
844
845 let discovery_manager = Arc::new(std::sync::Mutex::new(CandidateDiscoveryManager::new(
846 discovery_config,
847 )));
848
849 let (quinn_endpoint, event_tx, local_addr) =
852 Self::create_quinn_endpoint(&config, nat_traversal_role).await?;
853
854 {
856 let mut discovery = discovery_manager.lock().map_err(|_| {
857 NatTraversalError::ProtocolError("Discovery manager lock poisoned".to_string())
858 })?;
859 discovery.set_bound_address(local_addr);
860 info!(
861 "Updated discovery manager with bound address: {}",
862 local_addr
863 );
864 }
865
866 let endpoint = Self {
867 quinn_endpoint: Some(quinn_endpoint.clone()),
868 config: config.clone(),
869 bootstrap_nodes,
870 active_sessions: Arc::new(std::sync::RwLock::new(HashMap::new())),
871 discovery_manager,
872 event_callback,
873 shutdown: Arc::new(AtomicBool::new(false)),
874 event_tx: Some(event_tx.clone()),
875 connections: Arc::new(std::sync::RwLock::new(HashMap::new())),
876 local_peer_id: Self::generate_local_peer_id(),
877 timeout_config: config.timeouts.clone(),
878 };
879
880 if matches!(
882 config.role,
883 EndpointRole::Bootstrap | EndpointRole::Server { .. }
884 ) {
885 let endpoint_clone = quinn_endpoint.clone();
886 let shutdown_clone = endpoint.shutdown.clone();
887 let event_tx_clone = event_tx.clone();
888 let connections_clone = endpoint.connections.clone();
889
890 tokio::spawn(async move {
891 Self::accept_connections(
892 endpoint_clone,
893 shutdown_clone,
894 event_tx_clone,
895 connections_clone,
896 )
897 .await;
898 });
899
900 info!("Started accepting connections for {:?} role", config.role);
901 }
902
903 let discovery_manager_clone = endpoint.discovery_manager.clone();
905 let shutdown_clone = endpoint.shutdown.clone();
906 let event_tx_clone = event_tx;
907
908 tokio::spawn(async move {
909 Self::poll_discovery(discovery_manager_clone, shutdown_clone, event_tx_clone).await;
910 });
911
912 info!("Started discovery polling task");
913
914 {
916 let mut discovery = endpoint.discovery_manager.lock().map_err(|_| {
917 NatTraversalError::ProtocolError("Discovery manager lock poisoned".to_string())
918 })?;
919
920 let local_peer_id = endpoint.local_peer_id;
922 let bootstrap_nodes = {
923 let nodes = endpoint.bootstrap_nodes.read().map_err(|_| {
924 NatTraversalError::ProtocolError("Bootstrap nodes lock poisoned".to_string())
925 })?;
926 nodes.clone()
927 };
928
929 discovery
930 .start_discovery(local_peer_id, bootstrap_nodes)
931 .map_err(|e| NatTraversalError::CandidateDiscoveryFailed(e.to_string()))?;
932
933 info!(
934 "Started local candidate discovery for peer {:?}",
935 local_peer_id
936 );
937 }
938
939 Ok(endpoint)
940 }
941
942 pub fn get_quinn_endpoint(&self) -> Option<&crate::high_level::Endpoint> {
944 self.quinn_endpoint.as_ref()
945 }
946
947 pub fn get_event_callback(&self) -> Option<&Box<dyn Fn(NatTraversalEvent) + Send + Sync>> {
949 self.event_callback.as_ref()
950 }
951
952 pub fn initiate_nat_traversal(
954 &self,
955 peer_id: PeerId,
956 coordinator: SocketAddr,
957 ) -> Result<(), NatTraversalError> {
958 info!(
959 "Starting NAT traversal to peer {:?} via coordinator {}",
960 peer_id, coordinator
961 );
962
963 let session = NatTraversalSession {
965 peer_id,
966 coordinator,
967 attempt: 1,
968 started_at: std::time::Instant::now(),
969 phase: TraversalPhase::Discovery,
970 candidates: Vec::new(),
971 session_state: SessionState {
972 state: ConnectionState::Connecting,
973 last_transition: std::time::Instant::now(),
974
975 connection: None,
976 active_attempts: Vec::new(),
977 metrics: ConnectionMetrics::default(),
978 },
979 };
980
981 {
983 let mut sessions = self
984 .active_sessions
985 .write()
986 .map_err(|_| NatTraversalError::ProtocolError("Lock poisoned".to_string()))?;
987 sessions.insert(peer_id, session);
988 }
989
990 let bootstrap_nodes_vec = {
992 let bootstrap_nodes = self
993 .bootstrap_nodes
994 .read()
995 .map_err(|_| NatTraversalError::ProtocolError("Lock poisoned".to_string()))?;
996 bootstrap_nodes.clone()
997 };
998
999 {
1000 let mut discovery = self.discovery_manager.lock().map_err(|_| {
1001 NatTraversalError::ProtocolError("Discovery manager lock poisoned".to_string())
1002 })?;
1003
1004 discovery
1005 .start_discovery(peer_id, bootstrap_nodes_vec)
1006 .map_err(|e| NatTraversalError::CandidateDiscoveryFailed(e.to_string()))?;
1007 }
1008
1009 if let Some(ref callback) = self.event_callback {
1011 callback(NatTraversalEvent::CoordinationRequested {
1012 peer_id,
1013 coordinator,
1014 });
1015 }
1016
1017 Ok(())
1019 }
1020
1021 pub fn poll_sessions(&self) -> Result<Vec<SessionStateUpdate>, NatTraversalError> {
1023 let mut updates = Vec::new();
1024 let now = std::time::Instant::now();
1025
1026 let mut sessions = self
1027 .active_sessions
1028 .write()
1029 .map_err(|_| NatTraversalError::ProtocolError("Sessions lock poisoned".to_string()))?;
1030
1031 for (peer_id, session) in sessions.iter_mut() {
1032 let mut state_changed = false;
1033
1034 match session.session_state.state {
1035 ConnectionState::Connecting => {
1036 let elapsed = now.duration_since(session.session_state.last_transition);
1038 if elapsed
1039 > self
1040 .timeout_config
1041 .nat_traversal
1042 .connection_establishment_timeout
1043 {
1044 session.session_state.state = ConnectionState::Closed;
1045 session.session_state.last_transition = now;
1046 state_changed = true;
1047
1048 updates.push(SessionStateUpdate {
1049 peer_id: *peer_id,
1050 old_state: ConnectionState::Connecting,
1051 new_state: ConnectionState::Closed,
1052 reason: StateChangeReason::Timeout,
1053 });
1054 }
1055
1056 if let Some(ref _connection) = session.session_state.connection {
1059 session.session_state.state = ConnectionState::Connected;
1060 session.session_state.last_transition = now;
1061 state_changed = true;
1062
1063 updates.push(SessionStateUpdate {
1064 peer_id: *peer_id,
1065 old_state: ConnectionState::Connecting,
1066 new_state: ConnectionState::Connected,
1067 reason: StateChangeReason::ConnectionEstablished,
1068 });
1069 }
1070 }
1071 ConnectionState::Connected => {
1072 {
1075 }
1078
1079 session.session_state.metrics.last_activity = Some(now);
1081 }
1082 ConnectionState::Migrating => {
1083 let elapsed = now.duration_since(session.session_state.last_transition);
1085 if elapsed > Duration::from_secs(10) {
1086 if session.session_state.connection.is_some() {
1089 session.session_state.state = ConnectionState::Connected;
1090 state_changed = true;
1091
1092 updates.push(SessionStateUpdate {
1093 peer_id: *peer_id,
1094 old_state: ConnectionState::Migrating,
1095 new_state: ConnectionState::Connected,
1096 reason: StateChangeReason::MigrationComplete,
1097 });
1098 } else {
1099 session.session_state.state = ConnectionState::Closed;
1100 state_changed = true;
1101
1102 updates.push(SessionStateUpdate {
1103 peer_id: *peer_id,
1104 old_state: ConnectionState::Migrating,
1105 new_state: ConnectionState::Closed,
1106 reason: StateChangeReason::MigrationFailed,
1107 });
1108 }
1109
1110 session.session_state.last_transition = now;
1111 }
1112 }
1113 _ => {}
1114 }
1115
1116 if state_changed {
1118 if let Some(ref callback) = self.event_callback {
1119 callback(NatTraversalEvent::SessionStateChanged {
1120 peer_id: *peer_id,
1121 new_state: session.session_state.state,
1122 });
1123 }
1124 }
1125 }
1126
1127 Ok(updates)
1128 }
1129
1130 pub fn start_session_polling(&self, interval: Duration) -> tokio::task::JoinHandle<()> {
1132 let sessions = self.active_sessions.clone();
1133 let shutdown = self.shutdown.clone();
1134 let timeout_config = self.timeout_config.clone();
1135
1136 tokio::spawn(async move {
1137 let mut ticker = tokio::time::interval(interval);
1138
1139 loop {
1140 ticker.tick().await;
1141
1142 if shutdown.load(Ordering::Relaxed) {
1143 break;
1144 }
1145
1146 let sessions_to_update = {
1148 match sessions.read() {
1149 Ok(sessions_guard) => {
1150 sessions_guard
1151 .iter()
1152 .filter_map(|(peer_id, session)| {
1153 let now = std::time::Instant::now();
1154 let elapsed =
1155 now.duration_since(session.session_state.last_transition);
1156
1157 match session.session_state.state {
1158 ConnectionState::Connecting => {
1159 if elapsed
1161 > timeout_config
1162 .nat_traversal
1163 .connection_establishment_timeout
1164 {
1165 Some((*peer_id, SessionUpdate::Timeout))
1166 } else {
1167 None
1168 }
1169 }
1170 ConnectionState::Connected => {
1171 if let Some(ref conn) = session.session_state.connection
1173 {
1174 if conn.close_reason().is_some() {
1175 Some((*peer_id, SessionUpdate::Disconnected))
1176 } else {
1177 Some((*peer_id, SessionUpdate::UpdateMetrics))
1179 }
1180 } else {
1181 Some((*peer_id, SessionUpdate::InvalidState))
1182 }
1183 }
1184 ConnectionState::Idle => {
1185 if elapsed
1187 > timeout_config
1188 .discovery
1189 .server_reflexive_cache_ttl
1190 {
1191 Some((*peer_id, SessionUpdate::Retry))
1192 } else {
1193 None
1194 }
1195 }
1196 ConnectionState::Migrating => {
1197 if elapsed > timeout_config.nat_traversal.probe_timeout
1199 {
1200 Some((*peer_id, SessionUpdate::MigrationTimeout))
1201 } else {
1202 None
1203 }
1204 }
1205 ConnectionState::Closed => {
1206 if elapsed
1208 > timeout_config.discovery.interface_cache_ttl
1209 {
1210 Some((*peer_id, SessionUpdate::Remove))
1211 } else {
1212 None
1213 }
1214 }
1215 }
1216 })
1217 .collect::<Vec<_>>()
1218 }
1219 _ => {
1220 vec![]
1221 }
1222 }
1223 };
1224
1225 if !sessions_to_update.is_empty() {
1227 if let Ok(mut sessions_guard) = sessions.write() {
1228 for (peer_id, update) in sessions_to_update {
1229 match update {
1230 SessionUpdate::Timeout => {
1231 if let Some(session) = sessions_guard.get_mut(&peer_id) {
1232 session.session_state.state = ConnectionState::Closed;
1233 session.session_state.last_transition =
1234 std::time::Instant::now();
1235 tracing::warn!("Connection to {:?} timed out", peer_id);
1236 }
1237 }
1238 SessionUpdate::Disconnected => {
1239 if let Some(session) = sessions_guard.get_mut(&peer_id) {
1240 session.session_state.state = ConnectionState::Closed;
1241 session.session_state.last_transition =
1242 std::time::Instant::now();
1243 session.session_state.connection = None;
1244 tracing::info!("Connection to {:?} closed", peer_id);
1245 }
1246 }
1247 SessionUpdate::UpdateMetrics => {
1248 if let Some(session) = sessions_guard.get_mut(&peer_id) {
1249 if let Some(ref conn) = session.session_state.connection {
1250 let stats = conn.stats();
1252 session.session_state.metrics.rtt =
1253 Some(stats.path.rtt);
1254 session.session_state.metrics.loss_rate =
1255 stats.path.lost_packets as f64
1256 / stats.path.sent_packets.max(1) as f64;
1257 }
1258 }
1259 }
1260 SessionUpdate::InvalidState => {
1261 if let Some(session) = sessions_guard.get_mut(&peer_id) {
1262 session.session_state.state = ConnectionState::Closed;
1263 session.session_state.last_transition =
1264 std::time::Instant::now();
1265 tracing::error!("Session {:?} in invalid state", peer_id);
1266 }
1267 }
1268 SessionUpdate::Retry => {
1269 if let Some(session) = sessions_guard.get_mut(&peer_id) {
1270 session.session_state.state = ConnectionState::Connecting;
1271 session.session_state.last_transition =
1272 std::time::Instant::now();
1273 session.attempt += 1;
1274 tracing::info!(
1275 "Retrying connection to {:?} (attempt {})",
1276 peer_id,
1277 session.attempt
1278 );
1279 }
1280 }
1281 SessionUpdate::MigrationTimeout => {
1282 if let Some(session) = sessions_guard.get_mut(&peer_id) {
1283 session.session_state.state = ConnectionState::Closed;
1284 session.session_state.last_transition =
1285 std::time::Instant::now();
1286 tracing::warn!("Migration timeout for {:?}", peer_id);
1287 }
1288 }
1289 SessionUpdate::Remove => {
1290 sessions_guard.remove(&peer_id);
1291 tracing::debug!("Removed old session for {:?}", peer_id);
1292 }
1293 }
1294 }
1295 }
1296 }
1297 }
1298 })
1299 }
1300
1301 pub fn get_statistics(&self) -> Result<NatTraversalStatistics, NatTraversalError> {
1305 let sessions = self
1306 .active_sessions
1307 .read()
1308 .map_err(|_| NatTraversalError::ProtocolError("Lock poisoned".to_string()))?;
1309 let bootstrap_nodes = self
1310 .bootstrap_nodes
1311 .read()
1312 .map_err(|_| NatTraversalError::ProtocolError("Lock poisoned".to_string()))?;
1313
1314 let avg_coordination_time = {
1316 let rtts: Vec<Duration> = bootstrap_nodes.iter().filter_map(|b| b.rtt).collect();
1317
1318 if rtts.is_empty() {
1319 Duration::from_millis(500) } else {
1321 let total_millis: u64 = rtts.iter().map(|d| d.as_millis() as u64).sum();
1322 Duration::from_millis(total_millis / rtts.len() as u64 * 2) }
1324 };
1325
1326 Ok(NatTraversalStatistics {
1327 active_sessions: sessions.len(),
1328 total_bootstrap_nodes: bootstrap_nodes.len(),
1329 successful_coordinations: bootstrap_nodes.iter().map(|b| b.coordination_count).sum(),
1330 average_coordination_time: avg_coordination_time,
1331 total_attempts: 0,
1332 successful_connections: 0,
1333 direct_connections: 0,
1334 relayed_connections: 0,
1335 })
1336 }
1337
1338 pub fn add_bootstrap_node(&self, address: SocketAddr) -> Result<(), NatTraversalError> {
1340 let mut bootstrap_nodes = self
1341 .bootstrap_nodes
1342 .write()
1343 .map_err(|_| NatTraversalError::ProtocolError("Lock poisoned".to_string()))?;
1344
1345 if !bootstrap_nodes.iter().any(|b| b.address == address) {
1347 bootstrap_nodes.push(BootstrapNode {
1348 address,
1349 last_seen: std::time::Instant::now(),
1350 can_coordinate: true,
1351 rtt: None,
1352 coordination_count: 0,
1353 });
1354 info!("Added bootstrap node: {}", address);
1355 }
1356 Ok(())
1357 }
1358
1359 pub fn remove_bootstrap_node(&self, address: SocketAddr) -> Result<(), NatTraversalError> {
1361 let mut bootstrap_nodes = self
1362 .bootstrap_nodes
1363 .write()
1364 .map_err(|_| NatTraversalError::ProtocolError("Lock poisoned".to_string()))?;
1365 bootstrap_nodes.retain(|b| b.address != address);
1366 info!("Removed bootstrap node: {}", address);
1367 Ok(())
1368 }
1369
1370 async fn create_quinn_endpoint(
1374 config: &NatTraversalConfig,
1375 _nat_role: NatTraversalRole,
1376 ) -> Result<
1377 (
1378 QuinnEndpoint,
1379 mpsc::UnboundedSender<NatTraversalEvent>,
1380 SocketAddr,
1381 ),
1382 NatTraversalError,
1383 > {
1384 use std::sync::Arc;
1385
1386 let server_config = match config.role {
1388 EndpointRole::Bootstrap | EndpointRole::Server { .. } => {
1389 let cert_config = CertificateConfig {
1391 common_name: format!("ant-quic-{}", config.role.name()),
1392 subject_alt_names: vec!["localhost".to_string(), "ant-quic-node".to_string()],
1393 self_signed: true, ..CertificateConfig::default()
1395 };
1396
1397 let cert_manager = CertificateManager::new(cert_config).map_err(|e| {
1398 NatTraversalError::ConfigError(format!(
1399 "Certificate manager creation failed: {e}"
1400 ))
1401 })?;
1402
1403 let cert_bundle = cert_manager.generate_certificate().map_err(|e| {
1404 NatTraversalError::ConfigError(format!("Certificate generation failed: {e}"))
1405 })?;
1406
1407 let rustls_config =
1408 cert_manager
1409 .create_server_config(&cert_bundle)
1410 .map_err(|e| {
1411 NatTraversalError::ConfigError(format!(
1412 "Server config creation failed: {e}"
1413 ))
1414 })?;
1415
1416 let server_crypto = QuicServerConfig::try_from(rustls_config.as_ref().clone())
1417 .map_err(|e| NatTraversalError::ConfigError(e.to_string()))?;
1418
1419 let mut server_config = ServerConfig::with_crypto(Arc::new(server_crypto));
1420
1421 let mut transport_config = TransportConfig::default();
1423 transport_config
1424 .keep_alive_interval(Some(config.timeouts.nat_traversal.retry_interval));
1425 transport_config.max_idle_timeout(Some(crate::VarInt::from_u32(30000).into()));
1426
1427 let nat_config = match config.role {
1432 EndpointRole::Client => {
1433 crate::transport_parameters::NatTraversalConfig::ClientSupport
1434 }
1435 EndpointRole::Bootstrap | EndpointRole::Server { .. } => {
1436 crate::transport_parameters::NatTraversalConfig::ServerSupport {
1437 concurrency_limit: VarInt::from_u32(
1438 config.max_concurrent_attempts as u32,
1439 ),
1440 }
1441 }
1442 };
1443 transport_config.nat_traversal_config(Some(nat_config));
1444
1445 server_config.transport_config(Arc::new(transport_config));
1446
1447 Some(server_config)
1448 }
1449 _ => None,
1450 };
1451
1452 let client_config = {
1454 let cert_config = CertificateConfig {
1455 common_name: format!("ant-quic-{}", config.role.name()),
1456 subject_alt_names: vec!["localhost".to_string(), "ant-quic-node".to_string()],
1457 self_signed: true,
1458 ..CertificateConfig::default()
1459 };
1460
1461 let cert_manager = CertificateManager::new(cert_config).map_err(|e| {
1462 NatTraversalError::ConfigError(format!("Certificate manager creation failed: {e}"))
1463 })?;
1464
1465 let _cert_bundle = cert_manager.generate_certificate().map_err(|e| {
1466 NatTraversalError::ConfigError(format!("Certificate generation failed: {e}"))
1467 })?;
1468
1469 let rustls_config = cert_manager.create_client_config().map_err(|e| {
1470 NatTraversalError::ConfigError(format!("Client config creation failed: {e}"))
1471 })?;
1472
1473 let client_crypto = QuicClientConfig::try_from(rustls_config.as_ref().clone())
1474 .map_err(|e| NatTraversalError::ConfigError(e.to_string()))?;
1475
1476 let mut client_config = ClientConfig::new(Arc::new(client_crypto));
1477
1478 let mut transport_config = TransportConfig::default();
1480 transport_config.keep_alive_interval(Some(Duration::from_secs(5)));
1481 transport_config.max_idle_timeout(Some(crate::VarInt::from_u32(30000).into()));
1482
1483 let nat_config = match config.role {
1488 EndpointRole::Client => {
1489 crate::transport_parameters::NatTraversalConfig::ClientSupport
1490 }
1491 EndpointRole::Bootstrap | EndpointRole::Server { .. } => {
1492 crate::transport_parameters::NatTraversalConfig::ServerSupport {
1493 concurrency_limit: VarInt::from_u32(config.max_concurrent_attempts as u32),
1494 }
1495 }
1496 };
1497 transport_config.nat_traversal_config(Some(nat_config));
1498
1499 client_config.transport_config(Arc::new(transport_config));
1500
1501 client_config
1502 };
1503
1504 let bind_addr = config
1506 .bind_addr
1507 .unwrap_or_else(create_random_port_bind_addr);
1508 let socket = UdpSocket::bind(bind_addr).await.map_err(|e| {
1509 NatTraversalError::NetworkError(format!("Failed to bind UDP socket: {e}"))
1510 })?;
1511
1512 info!("Binding endpoint to {}", bind_addr);
1513
1514 let std_socket = socket.into_std().map_err(|e| {
1516 NatTraversalError::NetworkError(format!("Failed to convert socket: {e}"))
1517 })?;
1518
1519 let runtime = default_runtime().ok_or_else(|| {
1521 NatTraversalError::ConfigError("No compatible async runtime found".to_string())
1522 })?;
1523
1524 let mut endpoint = QuinnEndpoint::new(
1525 EndpointConfig::default(),
1526 server_config,
1527 std_socket,
1528 runtime,
1529 )
1530 .map_err(|e| {
1531 NatTraversalError::ConfigError(format!("Failed to create Quinn endpoint: {e}"))
1532 })?;
1533
1534 endpoint.set_default_client_config(client_config);
1536
1537 let local_addr = endpoint.local_addr().map_err(|e| {
1539 NatTraversalError::NetworkError(format!("Failed to get local address: {e}"))
1540 })?;
1541
1542 info!("Endpoint bound to actual address: {}", local_addr);
1543
1544 let (event_tx, _event_rx) = mpsc::unbounded_channel();
1546
1547 Ok((endpoint, event_tx, local_addr))
1548 }
1549
1550 #[allow(clippy::panic)]
1552 pub async fn start_listening(&self, bind_addr: SocketAddr) -> Result<(), NatTraversalError> {
1553 let endpoint = self.quinn_endpoint.as_ref().ok_or_else(|| {
1554 NatTraversalError::ConfigError("Quinn endpoint not initialized".to_string())
1555 })?;
1556
1557 let _socket = UdpSocket::bind(bind_addr).await.map_err(|e| {
1559 NatTraversalError::NetworkError(format!("Failed to bind to {bind_addr}: {e}"))
1560 })?;
1561
1562 info!("Started listening on {}", bind_addr);
1563
1564 let endpoint_clone = endpoint.clone();
1566 let shutdown_clone = self.shutdown.clone();
1567 let event_tx = self
1568 .event_tx
1569 .as_ref()
1570 .unwrap_or_else(|| panic!("event transmitter should be initialized"))
1571 .clone();
1572 let connections_clone = self.connections.clone();
1573
1574 tokio::spawn(async move {
1575 Self::accept_connections(endpoint_clone, shutdown_clone, event_tx, connections_clone)
1576 .await;
1577 });
1578
1579 Ok(())
1580 }
1581
1582 async fn accept_connections(
1584 endpoint: QuinnEndpoint,
1585 shutdown: Arc<AtomicBool>,
1586 event_tx: mpsc::UnboundedSender<NatTraversalEvent>,
1587 connections: Arc<std::sync::RwLock<HashMap<PeerId, QuinnConnection>>>,
1588 ) {
1589 while !shutdown.load(Ordering::Relaxed) {
1590 match endpoint.accept().await {
1591 Some(connecting) => {
1592 let event_tx = event_tx.clone();
1593 let connections = connections.clone();
1594 tokio::spawn(async move {
1595 match connecting.await {
1596 Ok(connection) => {
1597 info!("Accepted connection from {}", connection.remote_address());
1598
1599 let peer_id = Self::generate_peer_id_from_address(
1601 connection.remote_address(),
1602 );
1603
1604 if let Ok(mut conns) = connections.write() {
1606 conns.insert(peer_id, connection.clone());
1607 }
1608
1609 let _ = event_tx.send(NatTraversalEvent::ConnectionEstablished {
1610 peer_id,
1611 remote_address: connection.remote_address(),
1612 });
1613
1614 Self::handle_connection(connection, event_tx).await;
1616 }
1617 Err(e) => {
1618 debug!("Connection failed: {}", e);
1619 }
1620 }
1621 });
1622 }
1623 None => {
1624 break;
1626 }
1627 }
1628 }
1629 }
1630
1631 async fn poll_discovery(
1633 discovery_manager: Arc<std::sync::Mutex<CandidateDiscoveryManager>>,
1634 shutdown: Arc<AtomicBool>,
1635 _event_tx: mpsc::UnboundedSender<NatTraversalEvent>,
1636 ) {
1637 use tokio::time::{Duration, interval};
1638
1639 let mut poll_interval = interval(Duration::from_millis(100));
1640
1641 while !shutdown.load(Ordering::Relaxed) {
1642 poll_interval.tick().await;
1643
1644 let events = match discovery_manager.lock() {
1646 Ok(mut discovery) => discovery.poll(std::time::Instant::now()),
1647 Err(e) => {
1648 error!("Failed to lock discovery manager: {}", e);
1649 continue;
1650 }
1651 };
1652
1653 for event in events {
1655 match event {
1656 DiscoveryEvent::DiscoveryStarted {
1657 peer_id,
1658 bootstrap_count,
1659 } => {
1660 debug!(
1661 "Discovery started for peer {:?} with {} bootstrap nodes",
1662 peer_id, bootstrap_count
1663 );
1664 }
1665 DiscoveryEvent::LocalScanningStarted => {
1666 debug!("Local interface scanning started");
1667 }
1668 DiscoveryEvent::LocalCandidateDiscovered { candidate } => {
1669 debug!("Discovered local candidate: {}", candidate.address);
1670 }
1673 DiscoveryEvent::LocalScanningCompleted {
1674 candidate_count,
1675 duration,
1676 } => {
1677 debug!(
1678 "Local interface scanning completed: {} candidates in {:?}",
1679 candidate_count, duration
1680 );
1681 }
1682 DiscoveryEvent::ServerReflexiveDiscoveryStarted { bootstrap_count } => {
1683 debug!(
1684 "Server reflexive discovery started with {} bootstrap nodes",
1685 bootstrap_count
1686 );
1687 }
1688 DiscoveryEvent::ServerReflexiveCandidateDiscovered {
1689 candidate,
1690 bootstrap_node,
1691 } => {
1692 debug!(
1693 "Discovered server-reflexive candidate {} via bootstrap {}",
1694 candidate.address, bootstrap_node
1695 );
1696 }
1698 DiscoveryEvent::BootstrapQueryFailed {
1699 bootstrap_node,
1700 error,
1701 } => {
1702 debug!("Bootstrap query failed for {}: {}", bootstrap_node, error);
1703 }
1704 DiscoveryEvent::PortAllocationDetected {
1706 port,
1707 source_address,
1708 bootstrap_node,
1709 timestamp,
1710 } => {
1711 debug!(
1712 "Port allocation detected: port {} from {} via bootstrap {:?} at {:?}",
1713 port, source_address, bootstrap_node, timestamp
1714 );
1715 }
1716 DiscoveryEvent::DiscoveryCompleted {
1717 candidate_count,
1718 total_duration,
1719 success_rate,
1720 } => {
1721 info!(
1722 "Discovery completed with {} candidates in {:?} (success rate: {:.2}%)",
1723 candidate_count,
1724 total_duration,
1725 success_rate * 100.0
1726 );
1727 }
1730 DiscoveryEvent::DiscoveryFailed {
1731 error,
1732 partial_results,
1733 } => {
1734 warn!(
1735 "Discovery failed: {} (found {} partial candidates)",
1736 error,
1737 partial_results.len()
1738 );
1739
1740 }
1745 DiscoveryEvent::PathValidationRequested {
1746 candidate_id,
1747 candidate_address,
1748 challenge_token,
1749 } => {
1750 debug!(
1751 "PATH_CHALLENGE requested for candidate {} at {} with token {:08x}",
1752 candidate_id.0, candidate_address, challenge_token
1753 );
1754 }
1757 DiscoveryEvent::PathValidationResponse {
1758 candidate_id,
1759 candidate_address,
1760 challenge_token: _,
1761 rtt,
1762 } => {
1763 debug!(
1764 "PATH_RESPONSE received for candidate {} at {} with RTT {:?}",
1765 candidate_id.0, candidate_address, rtt
1766 );
1767 }
1769 }
1770 }
1771 }
1772
1773 info!("Discovery polling task shutting down");
1774 }
1775
1776 async fn handle_connection(
1778 connection: QuinnConnection,
1779 event_tx: mpsc::UnboundedSender<NatTraversalEvent>,
1780 ) {
1781 let peer_id = Self::generate_peer_id_from_address(connection.remote_address());
1782 let remote_address = connection.remote_address();
1783
1784 debug!(
1785 "Handling connection from peer {:?} at {}",
1786 peer_id, remote_address
1787 );
1788
1789 loop {
1791 tokio::select! {
1792 stream = connection.accept_bi() => {
1793 match stream {
1794 Ok((send, recv)) => {
1795 tokio::spawn(async move {
1796 Self::handle_bi_stream(send, recv).await;
1797 });
1798 }
1799 Err(e) => {
1800 debug!("Error accepting bidirectional stream: {}", e);
1801 let _ = event_tx.send(NatTraversalEvent::ConnectionLost {
1802 peer_id,
1803 reason: format!("Stream error: {e}"),
1804 });
1805 break;
1806 }
1807 }
1808 }
1809 stream = connection.accept_uni() => {
1810 match stream {
1811 Ok(recv) => {
1812 tokio::spawn(async move {
1813 Self::handle_uni_stream(recv).await;
1814 });
1815 }
1816 Err(e) => {
1817 debug!("Error accepting unidirectional stream: {}", e);
1818 let _ = event_tx.send(NatTraversalEvent::ConnectionLost {
1819 peer_id,
1820 reason: format!("Stream error: {e}"),
1821 });
1822 break;
1823 }
1824 }
1825 }
1826 }
1827 }
1828 }
1829
1830 async fn handle_bi_stream(
1832 _send: crate::high_level::SendStream,
1833 _recv: crate::high_level::RecvStream,
1834 ) {
1835 }
1864
1865 async fn handle_uni_stream(mut recv: crate::high_level::RecvStream) {
1867 let mut buffer = vec![0u8; 1024];
1868
1869 loop {
1870 match recv.read(&mut buffer).await {
1871 Ok(Some(size)) => {
1872 debug!("Received {} bytes on unidirectional stream", size);
1873 }
1875 Ok(None) => {
1876 debug!("Unidirectional stream closed by peer");
1877 break;
1878 }
1879 Err(e) => {
1880 debug!("Error reading from unidirectional stream: {}", e);
1881 break;
1882 }
1883 }
1884 }
1885 }
1886
1887 pub async fn connect_to_peer(
1889 &self,
1890 peer_id: PeerId,
1891 server_name: &str,
1892 remote_addr: SocketAddr,
1893 ) -> Result<QuinnConnection, NatTraversalError> {
1894 let (stored_peer_id, connection) = self
1895 .connect_to_address(remote_addr, server_name, Some(peer_id))
1896 .await?;
1897
1898 if stored_peer_id != peer_id {
1899 warn!(
1900 expected = %hex::encode(peer_id.0),
1901 stored = %hex::encode(stored_peer_id.0),
1902 "Stored peer ID differs from expected during connect_to_peer"
1903 );
1904 }
1905
1906 Ok(connection)
1907 }
1908
1909 pub async fn connect_to_address(
1911 &self,
1912 remote_addr: SocketAddr,
1913 server_name: &str,
1914 expected_peer_id: Option<PeerId>,
1915 ) -> Result<(PeerId, QuinnConnection), NatTraversalError> {
1916 let endpoint = self.quinn_endpoint.as_ref().ok_or_else(|| {
1917 NatTraversalError::ConfigError("Quinn endpoint not initialized".to_string())
1918 })?;
1919
1920 info!("Connecting to remote {}", remote_addr);
1921
1922 let connecting = endpoint.connect(remote_addr, server_name).map_err(|e| {
1924 NatTraversalError::ConnectionFailed(format!("Failed to initiate connection: {e}"))
1925 })?;
1926
1927 let connection = timeout(
1928 self.timeout_config
1929 .nat_traversal
1930 .connection_establishment_timeout,
1931 connecting,
1932 )
1933 .await
1934 .map_err(|_| NatTraversalError::Timeout)?
1935 .map_err(|e| NatTraversalError::ConnectionFailed(format!("Connection failed: {e}")))?;
1936
1937 let actual_peer_id = self.extract_peer_id_from_connection(&connection).await;
1938
1939 let peer_id_to_store = if let Some(expected) = expected_peer_id {
1940 if let Some(actual) = actual_peer_id {
1941 if actual != expected {
1942 warn!(
1943 expected = %hex::encode(expected.0),
1944 actual = %hex::encode(actual.0),
1945 "Extracted peer ID differs from expected"
1946 );
1947 }
1948 }
1949 expected
1950 } else if let Some(actual) = actual_peer_id {
1951 actual
1952 } else {
1953 Self::generate_peer_id_from_address(remote_addr)
1954 };
1955
1956 {
1957 let mut connections = self.connections.write().map_err(|_| {
1958 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
1959 })?;
1960
1961 connections.insert(peer_id_to_store, connection.clone());
1962
1963 if let Some(actual) = actual_peer_id {
1964 if actual != peer_id_to_store {
1965 connections.insert(actual, connection.clone());
1966 }
1967 }
1968 }
1969
1970 let event_peer_id = if let Some(expected) = expected_peer_id {
1971 expected
1972 } else if let Some(actual) = actual_peer_id {
1973 actual
1974 } else {
1975 peer_id_to_store
1976 };
1977
1978 let actual_peer_hex = actual_peer_id.map(|id| hex::encode(id.0));
1979 let actual_hex_display = actual_peer_hex.as_deref().unwrap_or("unknown-peer-id");
1980
1981 info!(
1982 peer = %hex::encode(peer_id_to_store.0),
1983 actual = actual_hex_display,
1984 "Successfully connected to remote {}",
1985 remote_addr
1986 );
1987
1988 if let Some(ref event_tx) = self.event_tx {
1989 let _ = event_tx.send(NatTraversalEvent::ConnectionEstablished {
1990 peer_id: event_peer_id,
1991 remote_address: remote_addr,
1992 });
1993 }
1994
1995 Ok((peer_id_to_store, connection))
1996 }
1997
1998 pub async fn accept_connection(&self) -> Result<(PeerId, QuinnConnection), NatTraversalError> {
2000 let endpoint = self.quinn_endpoint.as_ref().ok_or_else(|| {
2001 NatTraversalError::ConfigError("Quinn endpoint not initialized".to_string())
2002 })?;
2003
2004 let incoming = endpoint
2006 .accept()
2007 .await
2008 .ok_or_else(|| NatTraversalError::NetworkError("Endpoint closed".to_string()))?;
2009
2010 let remote_addr = incoming.remote_address();
2011 info!("Accepting connection from {}", remote_addr);
2012
2013 let connection = incoming.await.map_err(|e| {
2015 NatTraversalError::ConnectionFailed(format!("Failed to accept connection: {e}"))
2016 })?;
2017
2018 let peer_id = self
2020 .extract_peer_id_from_connection(&connection)
2021 .await
2022 .unwrap_or_else(|| Self::generate_peer_id_from_address(remote_addr));
2023
2024 {
2026 let mut connections = self.connections.write().map_err(|_| {
2027 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
2028 })?;
2029 connections.insert(peer_id, connection.clone());
2030 }
2031
2032 info!(
2033 "Connection accepted from peer {:?} at {}",
2034 peer_id, remote_addr
2035 );
2036
2037 if let Some(ref event_tx) = self.event_tx {
2039 let _ = event_tx.send(NatTraversalEvent::ConnectionEstablished {
2040 peer_id,
2041 remote_address: remote_addr,
2042 });
2043 }
2044
2045 Ok((peer_id, connection))
2046 }
2047
2048 pub fn local_peer_id(&self) -> PeerId {
2050 self.local_peer_id
2051 }
2052
2053 pub fn get_connection(
2055 &self,
2056 peer_id: &PeerId,
2057 ) -> Result<Option<QuinnConnection>, NatTraversalError> {
2058 let connections = self.connections.read().map_err(|_| {
2059 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
2060 })?;
2061 Ok(connections.get(peer_id).cloned())
2062 }
2063
2064 pub fn remove_connection(
2066 &self,
2067 peer_id: &PeerId,
2068 ) -> Result<Option<QuinnConnection>, NatTraversalError> {
2069 let mut connections = self.connections.write().map_err(|_| {
2070 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
2071 })?;
2072 Ok(connections.remove(peer_id))
2073 }
2074
2075 pub fn relabel_connection(
2077 &self,
2078 old_peer_id: PeerId,
2079 new_peer_id: PeerId,
2080 ) -> Result<(), NatTraversalError> {
2081 if old_peer_id == new_peer_id {
2082 return Ok(());
2083 }
2084
2085 let mut connections = self.connections.write().map_err(|_| {
2086 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
2087 })?;
2088
2089 if let Some(connection) = connections.remove(&old_peer_id) {
2090 connections.insert(new_peer_id, connection);
2091 info!(
2092 old = %hex::encode(old_peer_id.0),
2093 new = %hex::encode(new_peer_id.0),
2094 "Relabeled connection to updated peer ID"
2095 );
2096 }
2097
2098 Ok(())
2099 }
2100
2101 pub fn list_connections(&self) -> Result<Vec<(PeerId, SocketAddr)>, NatTraversalError> {
2103 let connections = self.connections.read().map_err(|_| {
2104 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
2105 })?;
2106 let mut result = Vec::new();
2107 for (peer_id, connection) in connections.iter() {
2108 result.push((*peer_id, connection.remote_address()));
2109 }
2110 Ok(result)
2111 }
2112
2113 pub async fn handle_connection_data(
2115 &self,
2116 peer_id: PeerId,
2117 connection: &QuinnConnection,
2118 ) -> Result<(), NatTraversalError> {
2119 info!("Handling connection data from peer {:?}", peer_id);
2120
2121 let connection_clone = connection.clone();
2123 let peer_id_clone = peer_id;
2124 tokio::spawn(async move {
2125 loop {
2126 match connection_clone.accept_bi().await {
2127 Ok((send, recv)) => {
2128 debug!(
2129 "Accepted bidirectional stream from peer {:?}",
2130 peer_id_clone
2131 );
2132 tokio::spawn(Self::handle_bi_stream(send, recv));
2133 }
2134 Err(ConnectionError::ApplicationClosed(_)) => {
2135 debug!("Connection closed by peer {:?}", peer_id_clone);
2136 break;
2137 }
2138 Err(e) => {
2139 debug!(
2140 "Error accepting bidirectional stream from peer {:?}: {}",
2141 peer_id_clone, e
2142 );
2143 break;
2144 }
2145 }
2146 }
2147 });
2148
2149 let connection_clone = connection.clone();
2151 let peer_id_clone = peer_id;
2152 tokio::spawn(async move {
2153 loop {
2154 match connection_clone.accept_uni().await {
2155 Ok(recv) => {
2156 debug!(
2157 "Accepted unidirectional stream from peer {:?}",
2158 peer_id_clone
2159 );
2160 tokio::spawn(Self::handle_uni_stream(recv));
2161 }
2162 Err(ConnectionError::ApplicationClosed(_)) => {
2163 debug!("Connection closed by peer {:?}", peer_id_clone);
2164 break;
2165 }
2166 Err(e) => {
2167 debug!(
2168 "Error accepting unidirectional stream from peer {:?}: {}",
2169 peer_id_clone, e
2170 );
2171 break;
2172 }
2173 }
2174 }
2175 });
2176
2177 Ok(())
2178 }
2179
2180 fn generate_local_peer_id() -> PeerId {
2182 use std::collections::hash_map::DefaultHasher;
2183 use std::hash::{Hash, Hasher};
2184 use std::time::SystemTime;
2185
2186 let mut hasher = DefaultHasher::new();
2187 SystemTime::now().hash(&mut hasher);
2188 std::process::id().hash(&mut hasher);
2189
2190 let hash = hasher.finish();
2191 let mut peer_id = [0u8; 32];
2192 peer_id[0..8].copy_from_slice(&hash.to_be_bytes());
2193
2194 for i in 8..32 {
2196 peer_id[i] = rand::random();
2197 }
2198
2199 PeerId(peer_id)
2200 }
2201
2202 fn generate_peer_id_from_address(addr: SocketAddr) -> PeerId {
2208 use std::collections::hash_map::DefaultHasher;
2209 use std::hash::{Hash, Hasher};
2210
2211 let mut hasher = DefaultHasher::new();
2212 addr.hash(&mut hasher);
2213
2214 let hash = hasher.finish();
2215 let mut peer_id = [0u8; 32];
2216 peer_id[0..8].copy_from_slice(&hash.to_be_bytes());
2217
2218 for i in 8..32 {
2221 peer_id[i] = rand::random();
2222 }
2223
2224 warn!(
2225 "Generated temporary peer ID from address {}. This ID is not persistent!",
2226 addr
2227 );
2228 PeerId(peer_id)
2229 }
2230
2231 async fn extract_peer_id_from_connection(
2233 &self,
2234 connection: &QuinnConnection,
2235 ) -> Option<PeerId> {
2236 if let Some(identity) = connection.peer_identity() {
2238 if let Some(public_key_bytes) = identity.downcast_ref::<[u8; 32]>() {
2240 match crate::derive_peer_id_from_key_bytes(public_key_bytes) {
2242 Ok(peer_id) => {
2243 debug!("Derived peer ID from Ed25519 public key");
2244 return Some(peer_id);
2245 }
2246 Err(e) => {
2247 warn!("Failed to derive peer ID from public key: {}", e);
2248 }
2249 }
2250 }
2251 }
2253
2254 None
2255 }
2256
2257 pub async fn shutdown(&self) -> Result<(), NatTraversalError> {
2259 self.shutdown.store(true, Ordering::Relaxed);
2261
2262 {
2264 let mut connections = self.connections.write().map_err(|_| {
2265 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
2266 })?;
2267 for (peer_id, connection) in connections.drain() {
2268 info!("Closing connection to peer {:?}", peer_id);
2269 connection.close(crate::VarInt::from_u32(0), b"Shutdown");
2270 }
2271 }
2272
2273 if let Some(ref endpoint) = self.quinn_endpoint {
2275 endpoint.wait_idle().await;
2276 }
2277
2278 info!("NAT traversal endpoint shutdown completed");
2279 Ok(())
2280 }
2281
2282 pub async fn discover_candidates(
2284 &self,
2285 peer_id: PeerId,
2286 ) -> Result<Vec<CandidateAddress>, NatTraversalError> {
2287 debug!("Discovering address candidates for peer {:?}", peer_id);
2288
2289 let mut candidates = Vec::new();
2290
2291 let bootstrap_nodes = {
2293 let nodes = self
2294 .bootstrap_nodes
2295 .read()
2296 .map_err(|_| NatTraversalError::ProtocolError("Lock poisoned".to_string()))?;
2297 nodes.clone()
2298 };
2299
2300 {
2302 let mut discovery = self.discovery_manager.lock().map_err(|_| {
2303 NatTraversalError::ProtocolError("Discovery manager lock poisoned".to_string())
2304 })?;
2305
2306 discovery
2307 .start_discovery(peer_id, bootstrap_nodes)
2308 .map_err(|e| NatTraversalError::CandidateDiscoveryFailed(e.to_string()))?;
2309 }
2310
2311 let timeout_duration = self.config.coordination_timeout;
2313 let start_time = std::time::Instant::now();
2314
2315 while start_time.elapsed() < timeout_duration {
2316 let discovery_events = {
2317 let mut discovery = self.discovery_manager.lock().map_err(|_| {
2318 NatTraversalError::ProtocolError("Discovery manager lock poisoned".to_string())
2319 })?;
2320 discovery.poll(std::time::Instant::now())
2321 };
2322
2323 for event in discovery_events {
2324 match event {
2325 DiscoveryEvent::LocalCandidateDiscovered { candidate } => {
2326 candidates.push(candidate.clone());
2327
2328 self.send_candidate_advertisement(peer_id, &candidate)
2330 .await
2331 .unwrap_or_else(|e| {
2332 debug!("Failed to send candidate advertisement: {}", e)
2333 });
2334 }
2335 DiscoveryEvent::ServerReflexiveCandidateDiscovered { candidate, .. } => {
2336 candidates.push(candidate.clone());
2337
2338 self.send_candidate_advertisement(peer_id, &candidate)
2340 .await
2341 .unwrap_or_else(|e| {
2342 debug!("Failed to send candidate advertisement: {}", e)
2343 });
2344 }
2345 DiscoveryEvent::DiscoveryCompleted { .. } => {
2347 return Ok(candidates);
2349 }
2350 DiscoveryEvent::DiscoveryFailed {
2351 error,
2352 partial_results,
2353 } => {
2354 candidates.extend(partial_results);
2356 if candidates.is_empty() {
2357 return Err(NatTraversalError::CandidateDiscoveryFailed(
2358 error.to_string(),
2359 ));
2360 }
2361 return Ok(candidates);
2362 }
2363 _ => {}
2364 }
2365 }
2366
2367 sleep(Duration::from_millis(10)).await;
2369 }
2370
2371 if candidates.is_empty() {
2372 Err(NatTraversalError::NoCandidatesFound)
2373 } else {
2374 Ok(candidates)
2375 }
2376 }
2377
2378 #[allow(dead_code)]
2380 fn create_punch_me_now_frame(&self, peer_id: PeerId) -> Result<Vec<u8>, NatTraversalError> {
2381 let mut frame = Vec::new();
2389
2390 frame.push(0x41);
2392
2393 frame.extend_from_slice(&peer_id.0);
2395
2396 let timestamp = std::time::SystemTime::now()
2398 .duration_since(std::time::UNIX_EPOCH)
2399 .unwrap_or_default()
2400 .as_millis() as u64;
2401 frame.extend_from_slice(×tamp.to_be_bytes());
2402
2403 let mut token = [0u8; 16];
2405 for byte in &mut token {
2406 *byte = rand::random();
2407 }
2408 frame.extend_from_slice(&token);
2409
2410 Ok(frame)
2411 }
2412
2413 #[allow(dead_code)]
2414 fn attempt_hole_punching(&self, peer_id: PeerId) -> Result<(), NatTraversalError> {
2415 debug!("Attempting hole punching for peer {:?}", peer_id);
2416
2417 let candidate_pairs = self.get_candidate_pairs_for_peer(peer_id)?;
2419
2420 if candidate_pairs.is_empty() {
2421 return Err(NatTraversalError::NoCandidatesFound);
2422 }
2423
2424 info!(
2425 "Generated {} candidate pairs for hole punching with peer {:?}",
2426 candidate_pairs.len(),
2427 peer_id
2428 );
2429
2430 self.attempt_quinn_hole_punching(peer_id, candidate_pairs)
2433 }
2434
2435 #[allow(dead_code)]
2437 fn get_candidate_pairs_for_peer(
2438 &self,
2439 peer_id: PeerId,
2440 ) -> Result<Vec<CandidatePair>, NatTraversalError> {
2441 let discovery_candidates = {
2443 let discovery = self.discovery_manager.lock().map_err(|_| {
2444 NatTraversalError::ProtocolError("Discovery manager lock poisoned".to_string())
2445 })?;
2446
2447 discovery.get_candidates_for_peer(peer_id)
2448 };
2449
2450 if discovery_candidates.is_empty() {
2451 return Err(NatTraversalError::NoCandidatesFound);
2452 }
2453
2454 let mut candidate_pairs = Vec::new();
2456 let local_candidates = discovery_candidates
2457 .iter()
2458 .filter(|c| matches!(c.source, CandidateSource::Local))
2459 .collect::<Vec<_>>();
2460 let remote_candidates = discovery_candidates
2461 .iter()
2462 .filter(|c| !matches!(c.source, CandidateSource::Local))
2463 .collect::<Vec<_>>();
2464
2465 for local in &local_candidates {
2467 for remote in &remote_candidates {
2468 let pair_priority = self.calculate_candidate_pair_priority(local, remote);
2469 candidate_pairs.push(CandidatePair {
2470 local_candidate: (*local).clone(),
2471 remote_candidate: (*remote).clone(),
2472 priority: pair_priority,
2473 state: CandidatePairState::Waiting,
2474 });
2475 }
2476 }
2477
2478 candidate_pairs.sort_by(|a, b| b.priority.cmp(&a.priority));
2480
2481 candidate_pairs.truncate(8);
2483
2484 Ok(candidate_pairs)
2485 }
2486
2487 #[allow(dead_code)]
2489 fn calculate_candidate_pair_priority(
2490 &self,
2491 local: &CandidateAddress,
2492 remote: &CandidateAddress,
2493 ) -> u64 {
2494 let local_type_preference = match local.source {
2498 CandidateSource::Local => 126,
2499 CandidateSource::Observed { .. } => 100,
2500 CandidateSource::Predicted => 75,
2501 CandidateSource::Peer => 50,
2502 };
2503
2504 let remote_type_preference = match remote.source {
2505 CandidateSource::Local => 126,
2506 CandidateSource::Observed { .. } => 100,
2507 CandidateSource::Predicted => 75,
2508 CandidateSource::Peer => 50,
2509 };
2510
2511 let local_priority = (local_type_preference as u64) << 8 | local.priority as u64;
2513 let remote_priority = (remote_type_preference as u64) << 8 | remote.priority as u64;
2514
2515 let min_priority = local_priority.min(remote_priority);
2516 let max_priority = local_priority.max(remote_priority);
2517
2518 (min_priority << 32)
2519 | (max_priority << 1)
2520 | if local_priority > remote_priority {
2521 1
2522 } else {
2523 0
2524 }
2525 }
2526
2527 #[allow(dead_code)]
2529 fn attempt_quinn_hole_punching(
2530 &self,
2531 peer_id: PeerId,
2532 candidate_pairs: Vec<CandidatePair>,
2533 ) -> Result<(), NatTraversalError> {
2534 let _endpoint = self.quinn_endpoint.as_ref().ok_or_else(|| {
2535 NatTraversalError::ConfigError("Quinn endpoint not initialized".to_string())
2536 })?;
2537
2538 for pair in candidate_pairs {
2539 debug!(
2540 "Attempting hole punch with candidate pair: {} -> {}",
2541 pair.local_candidate.address, pair.remote_candidate.address
2542 );
2543
2544 let mut challenge_data = [0u8; 8];
2546 for byte in &mut challenge_data {
2547 *byte = rand::random();
2548 }
2549
2550 let local_socket =
2552 std::net::UdpSocket::bind(pair.local_candidate.address).map_err(|e| {
2553 NatTraversalError::NetworkError(format!(
2554 "Failed to bind to local candidate: {e}"
2555 ))
2556 })?;
2557
2558 let path_challenge_packet = self.create_path_challenge_packet(challenge_data)?;
2560
2561 match local_socket.send_to(&path_challenge_packet, pair.remote_candidate.address) {
2563 Ok(bytes_sent) => {
2564 debug!(
2565 "Sent {} bytes for hole punch from {} to {}",
2566 bytes_sent, pair.local_candidate.address, pair.remote_candidate.address
2567 );
2568
2569 local_socket
2571 .set_read_timeout(Some(Duration::from_millis(100)))
2572 .map_err(|e| {
2573 NatTraversalError::NetworkError(format!("Failed to set timeout: {e}"))
2574 })?;
2575
2576 let mut response_buffer = [0u8; 1024];
2578 match local_socket.recv_from(&mut response_buffer) {
2579 Ok((_bytes_received, response_addr)) => {
2580 if response_addr == pair.remote_candidate.address {
2581 info!(
2582 "Hole punch succeeded for peer {:?}: {} <-> {}",
2583 peer_id,
2584 pair.local_candidate.address,
2585 pair.remote_candidate.address
2586 );
2587
2588 self.store_successful_candidate_pair(peer_id, pair)?;
2590 return Ok(());
2591 } else {
2592 debug!(
2593 "Received response from unexpected address: {}",
2594 response_addr
2595 );
2596 }
2597 }
2598 Err(e)
2599 if e.kind() == std::io::ErrorKind::WouldBlock
2600 || e.kind() == std::io::ErrorKind::TimedOut =>
2601 {
2602 debug!("No response received for hole punch attempt");
2603 }
2604 Err(e) => {
2605 debug!("Error receiving hole punch response: {}", e);
2606 }
2607 }
2608 }
2609 Err(e) => {
2610 debug!("Failed to send hole punch packet: {}", e);
2611 }
2612 }
2613 }
2614
2615 Err(NatTraversalError::HolePunchingFailed)
2617 }
2618
2619 fn create_path_challenge_packet(
2621 &self,
2622 challenge_data: [u8; 8],
2623 ) -> Result<Vec<u8>, NatTraversalError> {
2624 let mut packet = Vec::new();
2627
2628 packet.push(0x40); packet.extend_from_slice(&[0, 0, 0, 1]); packet.push(0x1a); packet.extend_from_slice(&challenge_data); Ok(packet)
2637 }
2638
2639 fn store_successful_candidate_pair(
2641 &self,
2642 peer_id: PeerId,
2643 pair: CandidatePair,
2644 ) -> Result<(), NatTraversalError> {
2645 debug!(
2646 "Storing successful candidate pair for peer {:?}: {} <-> {}",
2647 peer_id, pair.local_candidate.address, pair.remote_candidate.address
2648 );
2649
2650 if let Some(ref callback) = self.event_callback {
2655 callback(NatTraversalEvent::PathValidated {
2656 peer_id,
2657 address: pair.remote_candidate.address,
2658 rtt: Duration::from_millis(50), });
2660
2661 callback(NatTraversalEvent::TraversalSucceeded {
2662 peer_id,
2663 final_address: pair.remote_candidate.address,
2664 total_time: Duration::from_secs(1), });
2666 }
2667
2668 Ok(())
2669 }
2670
2671 fn attempt_connection_to_candidate(
2673 &self,
2674 peer_id: PeerId,
2675 candidate: &CandidateAddress,
2676 ) -> Result<(), NatTraversalError> {
2677 {
2678 let endpoint = self.quinn_endpoint.as_ref().ok_or_else(|| {
2679 NatTraversalError::ConfigError("Quinn endpoint not initialized".to_string())
2680 })?;
2681
2682 let server_name = format!("peer-{:x}", peer_id.0[0] as u32);
2684
2685 debug!(
2686 "Attempting Quinn connection to candidate {} for peer {:?}",
2687 candidate.address, peer_id
2688 );
2689
2690 match endpoint.connect(candidate.address, &server_name) {
2692 Ok(connecting) => {
2693 info!(
2694 "Connection attempt initiated to {} for peer {:?}",
2695 candidate.address, peer_id
2696 );
2697
2698 if let Some(event_tx) = &self.event_tx {
2700 let event_tx = event_tx.clone();
2701 let connections = self.connections.clone();
2702 let peer_id_clone = peer_id;
2703 let address = candidate.address;
2704
2705 tokio::spawn(async move {
2706 match connecting.await {
2707 Ok(connection) => {
2708 info!(
2709 "Successfully connected to {} for peer {:?}",
2710 address, peer_id_clone
2711 );
2712
2713 if let Ok(mut conns) = connections.write() {
2715 conns.insert(peer_id_clone, connection.clone());
2716 }
2717
2718 let _ =
2720 event_tx.send(NatTraversalEvent::ConnectionEstablished {
2721 peer_id: peer_id_clone,
2722 remote_address: address,
2723 });
2724
2725 Self::handle_connection(connection, event_tx).await;
2727 }
2728 Err(e) => {
2729 warn!("Connection to {} failed: {}", address, e);
2730 }
2731 }
2732 });
2733 }
2734
2735 Ok(())
2736 }
2737 Err(e) => {
2738 warn!(
2739 "Failed to initiate connection to {}: {}",
2740 candidate.address, e
2741 );
2742 Err(NatTraversalError::ConnectionFailed(format!(
2743 "Failed to connect to {}: {}",
2744 candidate.address, e
2745 )))
2746 }
2747 }
2748 }
2749 }
2750
2751 pub fn poll(
2753 &self,
2754 now: std::time::Instant,
2755 ) -> Result<Vec<NatTraversalEvent>, NatTraversalError> {
2756 let mut events = Vec::new();
2757
2758 self.check_connections_for_observed_addresses(&mut events)?;
2760
2761 {
2763 let mut discovery = self.discovery_manager.lock().map_err(|_| {
2764 NatTraversalError::ProtocolError("Discovery manager lock poisoned".to_string())
2765 })?;
2766
2767 let discovery_events = discovery.poll(now);
2768
2769 for discovery_event in discovery_events {
2771 if let Some(nat_event) = self.convert_discovery_event(discovery_event) {
2772 events.push(nat_event.clone());
2773
2774 if let Some(ref callback) = self.event_callback {
2776 callback(nat_event.clone());
2777 }
2778
2779 if let NatTraversalEvent::CandidateDiscovered {
2781 peer_id: _,
2782 candidate: _,
2783 } = &nat_event
2784 {
2785 }
2788 }
2789 }
2790 }
2791
2792 let mut sessions = self
2794 .active_sessions
2795 .write()
2796 .map_err(|_| NatTraversalError::ProtocolError("Lock poisoned".to_string()))?;
2797
2798 for (_peer_id, session) in sessions.iter_mut() {
2799 let elapsed = now.duration_since(session.started_at);
2800
2801 let timeout = self.get_phase_timeout(session.phase);
2803
2804 if elapsed > timeout {
2806 match session.phase {
2807 TraversalPhase::Discovery => {
2808 let discovered_candidates = {
2810 let discovery = self.discovery_manager.lock().map_err(|_| {
2811 NatTraversalError::ProtocolError(
2812 "Discovery manager lock poisoned".to_string(),
2813 )
2814 });
2815 match discovery {
2816 Ok(disc) => disc.get_candidates_for_peer(session.peer_id),
2817 Err(_) => Vec::new(),
2818 }
2819 };
2820
2821 session.candidates = discovered_candidates.clone();
2823
2824 if !session.candidates.is_empty() {
2826 session.phase = TraversalPhase::Coordination;
2828 let event = NatTraversalEvent::PhaseTransition {
2829 peer_id: session.peer_id,
2830 from_phase: TraversalPhase::Discovery,
2831 to_phase: TraversalPhase::Coordination,
2832 };
2833 events.push(event.clone());
2834 if let Some(ref callback) = self.event_callback {
2835 callback(event);
2836 }
2837 info!(
2838 "Peer {:?} advanced from Discovery to Coordination with {} candidates",
2839 session.peer_id,
2840 session.candidates.len()
2841 );
2842 } else if session.attempt < self.config.max_concurrent_attempts as u32 {
2843 session.attempt += 1;
2845 session.started_at = now;
2846 let backoff_duration = self.calculate_backoff(session.attempt);
2847 warn!(
2848 "Discovery timeout for peer {:?}, retrying (attempt {}), backoff: {:?}",
2849 session.peer_id, session.attempt, backoff_duration
2850 );
2851 } else {
2852 session.phase = TraversalPhase::Failed;
2854 let event = NatTraversalEvent::TraversalFailed {
2855 peer_id: session.peer_id,
2856 error: NatTraversalError::NoCandidatesFound,
2857 fallback_available: self.config.enable_relay_fallback,
2858 };
2859 events.push(event.clone());
2860 if let Some(ref callback) = self.event_callback {
2861 callback(event);
2862 }
2863 error!(
2864 "NAT traversal failed for peer {:?}: no candidates found after {} attempts",
2865 session.peer_id, session.attempt
2866 );
2867 }
2868 }
2869 TraversalPhase::Coordination => {
2870 if let Some(coordinator) = self.select_coordinator() {
2872 match self.send_coordination_request(session.peer_id, coordinator) {
2873 Ok(_) => {
2874 session.phase = TraversalPhase::Synchronization;
2875 let event = NatTraversalEvent::CoordinationRequested {
2876 peer_id: session.peer_id,
2877 coordinator,
2878 };
2879 events.push(event.clone());
2880 if let Some(ref callback) = self.event_callback {
2881 callback(event);
2882 }
2883 info!(
2884 "Coordination requested for peer {:?} via {}",
2885 session.peer_id, coordinator
2886 );
2887 }
2888 Err(e) => {
2889 self.handle_phase_failure(session, now, &mut events, e);
2890 }
2891 }
2892 } else {
2893 self.handle_phase_failure(
2894 session,
2895 now,
2896 &mut events,
2897 NatTraversalError::NoBootstrapNodes,
2898 );
2899 }
2900 }
2901 TraversalPhase::Synchronization => {
2902 if self.is_peer_synchronized(&session.peer_id) {
2904 session.phase = TraversalPhase::Punching;
2905 let event = NatTraversalEvent::HolePunchingStarted {
2906 peer_id: session.peer_id,
2907 targets: session.candidates.iter().map(|c| c.address).collect(),
2908 };
2909 events.push(event.clone());
2910 if let Some(ref callback) = self.event_callback {
2911 callback(event);
2912 }
2913 if let Err(e) =
2915 self.initiate_hole_punching(session.peer_id, &session.candidates)
2916 {
2917 self.handle_phase_failure(session, now, &mut events, e);
2918 }
2919 } else {
2920 self.handle_phase_failure(
2921 session,
2922 now,
2923 &mut events,
2924 NatTraversalError::ProtocolError(
2925 "Synchronization timeout".to_string(),
2926 ),
2927 );
2928 }
2929 }
2930 TraversalPhase::Punching => {
2931 if let Some(successful_path) = self.check_punch_results(&session.peer_id) {
2933 session.phase = TraversalPhase::Validation;
2934 let event = NatTraversalEvent::PathValidated {
2935 peer_id: session.peer_id,
2936 address: successful_path,
2937 rtt: Duration::from_millis(50), };
2939 events.push(event.clone());
2940 if let Some(ref callback) = self.event_callback {
2941 callback(event);
2942 }
2943 if let Err(e) = self.validate_path(session.peer_id, successful_path) {
2945 self.handle_phase_failure(session, now, &mut events, e);
2946 }
2947 } else {
2948 self.handle_phase_failure(
2949 session,
2950 now,
2951 &mut events,
2952 NatTraversalError::PunchingFailed(
2953 "No successful punch".to_string(),
2954 ),
2955 );
2956 }
2957 }
2958 TraversalPhase::Validation => {
2959 if self.is_path_validated(&session.peer_id) {
2961 session.phase = TraversalPhase::Connected;
2962 let event = NatTraversalEvent::TraversalSucceeded {
2963 peer_id: session.peer_id,
2964 final_address: session
2965 .candidates
2966 .first()
2967 .map(|c| c.address)
2968 .unwrap_or_else(create_random_port_bind_addr),
2969 total_time: elapsed,
2970 };
2971 events.push(event.clone());
2972 if let Some(ref callback) = self.event_callback {
2973 callback(event);
2974 }
2975 info!(
2976 "NAT traversal succeeded for peer {:?} in {:?}",
2977 session.peer_id, elapsed
2978 );
2979 } else {
2980 self.handle_phase_failure(
2981 session,
2982 now,
2983 &mut events,
2984 NatTraversalError::ValidationFailed(
2985 "Path validation timeout".to_string(),
2986 ),
2987 );
2988 }
2989 }
2990 TraversalPhase::Connected => {
2991 if !self.is_connection_healthy(&session.peer_id) {
2993 warn!(
2994 "Connection to peer {:?} is no longer healthy",
2995 session.peer_id
2996 );
2997 }
2999 }
3000 TraversalPhase::Failed => {
3001 }
3003 }
3004 }
3005 }
3006
3007 Ok(events)
3008 }
3009
3010 fn get_phase_timeout(&self, phase: TraversalPhase) -> Duration {
3012 match phase {
3013 TraversalPhase::Discovery => Duration::from_secs(10),
3014 TraversalPhase::Coordination => self.config.coordination_timeout,
3015 TraversalPhase::Synchronization => Duration::from_secs(3),
3016 TraversalPhase::Punching => Duration::from_secs(5),
3017 TraversalPhase::Validation => Duration::from_secs(5),
3018 TraversalPhase::Connected => Duration::from_secs(30), TraversalPhase::Failed => Duration::ZERO,
3020 }
3021 }
3022
3023 fn calculate_backoff(&self, attempt: u32) -> Duration {
3025 let base = Duration::from_millis(1000);
3026 let max = Duration::from_secs(30);
3027 let backoff = base * 2u32.pow(attempt.saturating_sub(1));
3028 let jitter = std::time::Duration::from_millis((rand::random::<u64>() % 200) as u64);
3029 backoff.min(max) + jitter
3030 }
3031
3032 fn check_connections_for_observed_addresses(
3034 &self,
3035 _events: &mut Vec<NatTraversalEvent>,
3036 ) -> Result<(), NatTraversalError> {
3037 let connections = self.connections.read().map_err(|_| {
3039 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
3040 })?;
3041
3042 if !connections.is_empty() && self.config.role == EndpointRole::Client {
3049 for (_peer_id, connection) in connections.iter() {
3051 let remote_addr = connection.remote_address();
3052
3053 let is_bootstrap = {
3055 let bootstrap_nodes = self.bootstrap_nodes.read().map_err(|_| {
3056 NatTraversalError::ProtocolError(
3057 "Bootstrap nodes lock poisoned".to_string(),
3058 )
3059 })?;
3060 bootstrap_nodes
3061 .iter()
3062 .any(|node| node.address == remote_addr)
3063 };
3064
3065 if is_bootstrap {
3066 debug!(
3069 "Bootstrap connection to {} should provide our external address via OBSERVED_ADDRESS frames",
3070 remote_addr
3071 );
3072
3073 }
3076 }
3077 }
3078
3079 Ok(())
3080 }
3081
3082 fn handle_phase_failure(
3084 &self,
3085 session: &mut NatTraversalSession,
3086 now: std::time::Instant,
3087 events: &mut Vec<NatTraversalEvent>,
3088 error: NatTraversalError,
3089 ) {
3090 if session.attempt < self.config.max_concurrent_attempts as u32 {
3091 session.attempt += 1;
3093 session.started_at = now;
3094 let backoff = self.calculate_backoff(session.attempt);
3095 warn!(
3096 "Phase {:?} failed for peer {:?}: {:?}, retrying (attempt {}) after {:?}",
3097 session.phase, session.peer_id, error, session.attempt, backoff
3098 );
3099 } else {
3100 session.phase = TraversalPhase::Failed;
3102 let event = NatTraversalEvent::TraversalFailed {
3103 peer_id: session.peer_id,
3104 error,
3105 fallback_available: self.config.enable_relay_fallback,
3106 };
3107 events.push(event.clone());
3108 if let Some(ref callback) = self.event_callback {
3109 callback(event);
3110 }
3111 error!(
3112 "NAT traversal failed for peer {:?} after {} attempts",
3113 session.peer_id, session.attempt
3114 );
3115 }
3116 }
3117
3118 fn select_coordinator(&self) -> Option<SocketAddr> {
3120 if let Ok(nodes) = self.bootstrap_nodes.read() {
3121 if !nodes.is_empty() {
3123 let idx = rand::random::<usize>() % nodes.len();
3124 return Some(nodes[idx].address);
3125 }
3126 }
3127 None
3128 }
3129
3130 fn send_coordination_request(
3132 &self,
3133 peer_id: PeerId,
3134 coordinator: SocketAddr,
3135 ) -> Result<(), NatTraversalError> {
3136 debug!(
3137 "Sending coordination request for peer {:?} to {}",
3138 peer_id, coordinator
3139 );
3140
3141 {
3142 if let Ok(connections) = self.connections.read() {
3144 for (_peer, conn) in connections.iter() {
3146 if conn.remote_address() == coordinator {
3147 info!("Found existing connection to coordinator {}", coordinator);
3151 return Ok(());
3152 }
3153 }
3154 }
3155
3156 info!("Establishing connection to coordinator {}", coordinator);
3158 if let Some(endpoint) = &self.quinn_endpoint {
3159 let server_name = format!("bootstrap-{}", coordinator.ip());
3160 match endpoint.connect(coordinator, &server_name) {
3161 Ok(connecting) => {
3162 info!("Initiated connection to coordinator {}", coordinator);
3164
3165 if let Some(event_tx) = &self.event_tx {
3167 let event_tx = event_tx.clone();
3168 let connections = self.connections.clone();
3169
3170 tokio::spawn(async move {
3171 match connecting.await {
3172 Ok(connection) => {
3173 info!("Connected to coordinator {}", coordinator);
3174
3175 let bootstrap_peer_id =
3177 Self::generate_peer_id_from_address(coordinator);
3178
3179 if let Ok(mut conns) = connections.write() {
3181 conns.insert(bootstrap_peer_id, connection.clone());
3182 }
3183
3184 Self::handle_connection(connection, event_tx).await;
3186 }
3187 Err(e) => {
3188 warn!(
3189 "Failed to connect to coordinator {}: {}",
3190 coordinator, e
3191 );
3192 }
3193 }
3194 });
3195 }
3196
3197 Ok(())
3200 }
3201 Err(e) => Err(NatTraversalError::CoordinationFailed(format!(
3202 "Failed to connect to coordinator {coordinator}: {e}"
3203 ))),
3204 }
3205 } else {
3206 Err(NatTraversalError::ConfigError(
3207 "Quinn endpoint not initialized".to_string(),
3208 ))
3209 }
3210 }
3211 }
3212
3213 fn is_peer_synchronized(&self, peer_id: &PeerId) -> bool {
3215 debug!("Checking synchronization status for peer {:?}", peer_id);
3216
3217 if let Ok(sessions) = self.active_sessions.read() {
3219 if let Some(session) = sessions.get(peer_id) {
3220 let has_candidates = !session.candidates.is_empty();
3223 let past_discovery = session.phase as u8 > TraversalPhase::Discovery as u8;
3224
3225 debug!(
3226 "Checking sync for peer {:?}: phase={:?}, candidates={}, past_discovery={}",
3227 peer_id,
3228 session.phase,
3229 session.candidates.len(),
3230 past_discovery
3231 );
3232
3233 if has_candidates && past_discovery {
3234 info!(
3235 "Peer {:?} is synchronized with {} candidates",
3236 peer_id,
3237 session.candidates.len()
3238 );
3239 return true;
3240 }
3241
3242 if session.phase == TraversalPhase::Synchronization && has_candidates {
3244 info!(
3245 "Peer {:?} in synchronization phase with {} candidates, considering synchronized",
3246 peer_id,
3247 session.candidates.len()
3248 );
3249 return true;
3250 }
3251
3252 if session.phase as u8 >= TraversalPhase::Synchronization as u8 {
3254 info!(
3255 "Test mode: Considering peer {:?} synchronized in phase {:?}",
3256 peer_id, session.phase
3257 );
3258 return true;
3259 }
3260 }
3261 }
3262
3263 warn!("Peer {:?} is not synchronized", peer_id);
3264 false
3265 }
3266
3267 fn initiate_hole_punching(
3269 &self,
3270 peer_id: PeerId,
3271 candidates: &[CandidateAddress],
3272 ) -> Result<(), NatTraversalError> {
3273 if candidates.is_empty() {
3274 return Err(NatTraversalError::NoCandidatesFound);
3275 }
3276
3277 info!(
3278 "Initiating hole punching for peer {:?} to {} candidates",
3279 peer_id,
3280 candidates.len()
3281 );
3282
3283 {
3284 for candidate in candidates {
3286 debug!(
3287 "Attempting QUIC connection to candidate: {}",
3288 candidate.address
3289 );
3290
3291 match self.attempt_connection_to_candidate(peer_id, candidate) {
3293 Ok(_) => {
3294 info!(
3295 "Successfully initiated connection attempt to {}",
3296 candidate.address
3297 );
3298 }
3299 Err(e) => {
3300 warn!(
3301 "Failed to initiate connection to {}: {:?}",
3302 candidate.address, e
3303 );
3304 }
3305 }
3306 }
3307
3308 Ok(())
3309 }
3310 }
3311
3312 fn check_punch_results(&self, peer_id: &PeerId) -> Option<SocketAddr> {
3314 {
3315 if let Ok(connections) = self.connections.read() {
3317 if let Some(conn) = connections.get(peer_id) {
3318 let addr = conn.remote_address();
3320 info!(
3321 "Found successful connection to peer {:?} at {}",
3322 peer_id, addr
3323 );
3324 return Some(addr);
3325 }
3326 }
3327 }
3328
3329 if let Ok(sessions) = self.active_sessions.read() {
3331 if let Some(session) = sessions.get(peer_id) {
3332 for candidate in &session.candidates {
3334 if matches!(candidate.state, CandidateState::Valid) {
3335 info!(
3336 "Found validated candidate for peer {:?} at {}",
3337 peer_id, candidate.address
3338 );
3339 return Some(candidate.address);
3340 }
3341 }
3342
3343 if session.phase == TraversalPhase::Punching && !session.candidates.is_empty() {
3345 let addr = session.candidates[0].address;
3346 info!(
3347 "Simulating successful punch for testing: peer {:?} at {}",
3348 peer_id, addr
3349 );
3350 return Some(addr);
3351 }
3352
3353 if let Some(first) = session.candidates.first() {
3355 debug!(
3356 "No validated candidates, using first candidate {} for peer {:?}",
3357 first.address, peer_id
3358 );
3359 return Some(first.address);
3360 }
3361 }
3362 }
3363
3364 warn!("No successful punch results for peer {:?}", peer_id);
3365 None
3366 }
3367
3368 fn validate_path(&self, peer_id: PeerId, address: SocketAddr) -> Result<(), NatTraversalError> {
3370 debug!("Validating path to peer {:?} at {}", peer_id, address);
3371
3372 {
3373 if let Ok(connections) = self.connections.read() {
3375 if let Some(conn) = connections.get(&peer_id) {
3376 if conn.remote_address() == address {
3378 info!(
3379 "Path validation successful for peer {:?} at {}",
3380 peer_id, address
3381 );
3382
3383 if let Ok(mut sessions) = self.active_sessions.write() {
3385 if let Some(session) = sessions.get_mut(&peer_id) {
3386 for candidate in &mut session.candidates {
3387 if candidate.address == address {
3388 candidate.state = CandidateState::Valid;
3389 break;
3390 }
3391 }
3392 }
3393 }
3394
3395 return Ok(());
3396 } else {
3397 warn!(
3398 "Connection address mismatch: expected {}, got {}",
3399 address,
3400 conn.remote_address()
3401 );
3402 }
3403 }
3404 }
3405
3406 Err(NatTraversalError::ValidationFailed(format!(
3408 "No connection found for peer {peer_id:?} at {address}"
3409 )))
3410 }
3411 }
3412
3413 fn is_path_validated(&self, peer_id: &PeerId) -> bool {
3415 debug!("Checking path validation for peer {:?}", peer_id);
3416
3417 {
3418 if let Ok(connections) = self.connections.read() {
3420 if connections.contains_key(peer_id) {
3421 info!("Path validated: connection exists for peer {:?}", peer_id);
3422 return true;
3423 }
3424 }
3425 }
3426
3427 if let Ok(sessions) = self.active_sessions.read() {
3429 if let Some(session) = sessions.get(peer_id) {
3430 let validated = session
3431 .candidates
3432 .iter()
3433 .any(|c| matches!(c.state, CandidateState::Valid));
3434
3435 if validated {
3436 info!(
3437 "Path validated: found validated candidate for peer {:?}",
3438 peer_id
3439 );
3440 return true;
3441 }
3442 }
3443 }
3444
3445 warn!("Path not validated for peer {:?}", peer_id);
3446 false
3447 }
3448
3449 fn is_connection_healthy(&self, peer_id: &PeerId) -> bool {
3451 {
3454 if let Ok(connections) = self.connections.read() {
3455 if let Some(_conn) = connections.get(peer_id) {
3456 return true; }
3461 }
3462 }
3463 true
3464 }
3465
3466 fn convert_discovery_event(
3468 &self,
3469 discovery_event: DiscoveryEvent,
3470 ) -> Option<NatTraversalEvent> {
3471 let current_peer_id = self.get_current_discovery_peer_id();
3473
3474 match discovery_event {
3475 DiscoveryEvent::LocalCandidateDiscovered { candidate } => {
3476 Some(NatTraversalEvent::CandidateDiscovered {
3477 peer_id: current_peer_id,
3478 candidate,
3479 })
3480 }
3481 DiscoveryEvent::ServerReflexiveCandidateDiscovered {
3482 candidate,
3483 bootstrap_node: _,
3484 } => Some(NatTraversalEvent::CandidateDiscovered {
3485 peer_id: current_peer_id,
3486 candidate,
3487 }),
3488 DiscoveryEvent::DiscoveryCompleted {
3490 candidate_count: _,
3491 total_duration: _,
3492 success_rate: _,
3493 } => {
3494 None }
3497 DiscoveryEvent::DiscoveryFailed {
3498 error,
3499 partial_results,
3500 } => Some(NatTraversalEvent::TraversalFailed {
3501 peer_id: current_peer_id,
3502 error: NatTraversalError::CandidateDiscoveryFailed(error.to_string()),
3503 fallback_available: !partial_results.is_empty(),
3504 }),
3505 _ => None, }
3507 }
3508
3509 fn get_current_discovery_peer_id(&self) -> PeerId {
3511 if let Ok(sessions) = self.active_sessions.read() {
3513 if let Some((peer_id, _session)) = sessions
3514 .iter()
3515 .find(|(_, s)| matches!(s.phase, TraversalPhase::Discovery))
3516 {
3517 return *peer_id;
3518 }
3519
3520 if let Some((peer_id, _)) = sessions.iter().next() {
3522 return *peer_id;
3523 }
3524 }
3525
3526 self.local_peer_id
3528 }
3529
3530 #[allow(dead_code)]
3532 pub(crate) async fn handle_endpoint_event(
3533 &self,
3534 event: crate::shared::EndpointEventInner,
3535 ) -> Result<(), NatTraversalError> {
3536 match event {
3537 crate::shared::EndpointEventInner::NatCandidateValidated { address, challenge } => {
3538 info!(
3539 "NAT candidate validation succeeded for {} with challenge {:016x}",
3540 address, challenge
3541 );
3542
3543 let mut sessions = self.active_sessions.write().map_err(|_| {
3545 NatTraversalError::ProtocolError("Sessions lock poisoned".to_string())
3546 })?;
3547
3548 for (peer_id, session) in sessions.iter_mut() {
3550 if session.candidates.iter().any(|c| c.address == address) {
3551 session.phase = TraversalPhase::Connected;
3553
3554 if let Some(ref callback) = self.event_callback {
3556 callback(NatTraversalEvent::CandidateValidated {
3557 peer_id: *peer_id,
3558 candidate_address: address,
3559 });
3560 }
3561
3562 return self
3564 .establish_connection_to_validated_candidate(*peer_id, address)
3565 .await;
3566 }
3567 }
3568
3569 debug!(
3570 "Validated candidate {} not found in active sessions",
3571 address
3572 );
3573 Ok(())
3574 }
3575
3576 crate::shared::EndpointEventInner::RelayPunchMeNow(target_peer_id, punch_frame) => {
3577 info!("Relaying PUNCH_ME_NOW to peer {:?}", target_peer_id);
3578
3579 let target_peer = PeerId(target_peer_id);
3581
3582 let connections = self.connections.read().map_err(|_| {
3584 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
3585 })?;
3586
3587 if let Some(connection) = connections.get(&target_peer) {
3588 let mut send_stream = connection.open_uni().await.map_err(|e| {
3590 NatTraversalError::NetworkError(format!("Failed to open stream: {e}"))
3591 })?;
3592
3593 let mut frame_data = Vec::new();
3595 punch_frame.encode(&mut frame_data);
3596
3597 send_stream.write_all(&frame_data).await.map_err(|e| {
3598 NatTraversalError::NetworkError(format!("Failed to send frame: {e}"))
3599 })?;
3600
3601 let _ = send_stream.finish();
3602
3603 debug!(
3604 "Successfully relayed PUNCH_ME_NOW frame to peer {:?}",
3605 target_peer
3606 );
3607 Ok(())
3608 } else {
3609 warn!("No connection found for target peer {:?}", target_peer);
3610 Err(NatTraversalError::PeerNotConnected)
3611 }
3612 }
3613
3614 crate::shared::EndpointEventInner::SendAddressFrame(add_address_frame) => {
3615 info!(
3616 "Sending AddAddress frame for address {}",
3617 add_address_frame.address
3618 );
3619
3620 let connections = self.connections.read().map_err(|_| {
3622 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
3623 })?;
3624
3625 for (peer_id, connection) in connections.iter() {
3626 let mut send_stream = connection.open_uni().await.map_err(|e| {
3628 NatTraversalError::NetworkError(format!("Failed to open stream: {e}"))
3629 })?;
3630
3631 let mut frame_data = Vec::new();
3633 add_address_frame.encode(&mut frame_data);
3634
3635 send_stream.write_all(&frame_data).await.map_err(|e| {
3636 NatTraversalError::NetworkError(format!("Failed to send frame: {e}"))
3637 })?;
3638
3639 let _ = send_stream.finish();
3640
3641 debug!("Sent AddAddress frame to peer {:?}", peer_id);
3642 }
3643
3644 Ok(())
3645 }
3646
3647 _ => {
3648 debug!("Ignoring non-NAT traversal endpoint event: {:?}", event);
3650 Ok(())
3651 }
3652 }
3653 }
3654
3655 #[allow(dead_code)]
3657 async fn establish_connection_to_validated_candidate(
3658 &self,
3659 peer_id: PeerId,
3660 candidate_address: SocketAddr,
3661 ) -> Result<(), NatTraversalError> {
3662 info!(
3663 "Establishing connection to validated candidate {} for peer {:?}",
3664 candidate_address, peer_id
3665 );
3666
3667 let endpoint = self.quinn_endpoint.as_ref().ok_or_else(|| {
3668 NatTraversalError::ConfigError("Quinn endpoint not initialized".to_string())
3669 })?;
3670
3671 let connecting = endpoint
3673 .connect(candidate_address, "nat-traversal-peer")
3674 .map_err(|e| {
3675 NatTraversalError::ConnectionFailed(format!("Failed to initiate connection: {e}"))
3676 })?;
3677
3678 let connection = timeout(
3679 self.timeout_config
3680 .nat_traversal
3681 .connection_establishment_timeout,
3682 connecting,
3683 )
3684 .await
3685 .map_err(|_| NatTraversalError::Timeout)?
3686 .map_err(|e| NatTraversalError::ConnectionFailed(format!("Connection failed: {e}")))?;
3687
3688 {
3690 let mut connections = self.connections.write().map_err(|_| {
3691 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
3692 })?;
3693 connections.insert(peer_id, connection.clone());
3694 }
3695
3696 {
3698 let mut sessions = self.active_sessions.write().map_err(|_| {
3699 NatTraversalError::ProtocolError("Sessions lock poisoned".to_string())
3700 })?;
3701 if let Some(session) = sessions.get_mut(&peer_id) {
3702 session.phase = TraversalPhase::Connected;
3703 }
3704 }
3705
3706 if let Some(ref callback) = self.event_callback {
3708 callback(NatTraversalEvent::ConnectionEstablished {
3709 peer_id,
3710 remote_address: candidate_address,
3711 });
3712 }
3713
3714 info!(
3715 "Successfully established connection to peer {:?} at {}",
3716 peer_id, candidate_address
3717 );
3718 Ok(())
3719 }
3720
3721 async fn send_candidate_advertisement(
3727 &self,
3728 peer_id: PeerId,
3729 candidate: &CandidateAddress,
3730 ) -> Result<(), NatTraversalError> {
3731 debug!(
3732 "Sending candidate advertisement to peer {:?}: {}",
3733 peer_id, candidate.address
3734 );
3735
3736 let mut guard = self.connections.write().map_err(|_| {
3738 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
3739 })?;
3740
3741 if let Some(conn) = guard.get_mut(&peer_id) {
3742 match conn.send_nat_address_advertisement(candidate.address, candidate.priority) {
3744 Ok(seq) => {
3745 info!(
3746 "Queued ADD_ADDRESS via connection API: peer={:?}, addr={}, priority={}, seq={}",
3747 peer_id, candidate.address, candidate.priority, seq
3748 );
3749 Ok(())
3750 }
3751 Err(e) => Err(NatTraversalError::ProtocolError(format!(
3752 "Failed to queue ADD_ADDRESS: {e:?}"
3753 ))),
3754 }
3755 } else {
3756 debug!("No active connection for peer {:?}", peer_id);
3757 Ok(())
3758 }
3759 }
3760
3761 #[allow(dead_code)]
3766 async fn send_punch_coordination(
3767 &self,
3768 peer_id: PeerId,
3769 paired_with_sequence_number: u64,
3770 address: SocketAddr,
3771 round: u32,
3772 ) -> Result<(), NatTraversalError> {
3773 debug!(
3774 "Sending punch coordination to peer {:?}: seq={}, addr={}, round={}",
3775 peer_id, paired_with_sequence_number, address, round
3776 );
3777
3778 let mut guard = self.connections.write().map_err(|_| {
3779 NatTraversalError::ProtocolError("Connections lock poisoned".to_string())
3780 })?;
3781
3782 if let Some(conn) = guard.get_mut(&peer_id) {
3783 conn.send_nat_punch_coordination(paired_with_sequence_number, address, round)
3784 .map_err(|e| {
3785 NatTraversalError::ProtocolError(format!("Failed to queue PUNCH_ME_NOW: {e:?}"))
3786 })
3787 } else {
3788 Err(NatTraversalError::PeerNotConnected)
3789 }
3790 }
3791
3792 #[allow(clippy::panic)]
3794 pub fn get_nat_stats(
3795 &self,
3796 ) -> Result<NatTraversalStatistics, Box<dyn std::error::Error + Send + Sync>> {
3797 Ok(NatTraversalStatistics {
3800 active_sessions: self
3801 .active_sessions
3802 .read()
3803 .unwrap_or_else(|_| panic!("active sessions lock should be valid"))
3804 .len(),
3805 total_bootstrap_nodes: self
3806 .bootstrap_nodes
3807 .read()
3808 .unwrap_or_else(|_| panic!("bootstrap nodes lock should be valid"))
3809 .len(),
3810 successful_coordinations: 7,
3811 average_coordination_time: self.timeout_config.nat_traversal.retry_interval,
3812 total_attempts: 10,
3813 successful_connections: 7,
3814 direct_connections: 5,
3815 relayed_connections: 2,
3816 })
3817 }
3818}
3819
3820impl fmt::Debug for NatTraversalEndpoint {
3821 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
3822 f.debug_struct("NatTraversalEndpoint")
3823 .field("config", &self.config)
3824 .field("bootstrap_nodes", &"<RwLock>")
3825 .field("active_sessions", &"<RwLock>")
3826 .field("event_callback", &self.event_callback.is_some())
3827 .finish()
3828 }
3829}
3830
3831#[derive(Debug, Clone, Default)]
3833pub struct NatTraversalStatistics {
3834 pub active_sessions: usize,
3836 pub total_bootstrap_nodes: usize,
3838 pub successful_coordinations: u32,
3840 pub average_coordination_time: Duration,
3842 pub total_attempts: u32,
3844 pub successful_connections: u32,
3846 pub direct_connections: u32,
3848 pub relayed_connections: u32,
3850}
3851
3852impl fmt::Display for NatTraversalError {
3853 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
3854 match self {
3855 Self::NoBootstrapNodes => write!(f, "no bootstrap nodes available"),
3856 Self::NoCandidatesFound => write!(f, "no address candidates found"),
3857 Self::CandidateDiscoveryFailed(msg) => write!(f, "candidate discovery failed: {msg}"),
3858 Self::CoordinationFailed(msg) => write!(f, "coordination failed: {msg}"),
3859 Self::HolePunchingFailed => write!(f, "hole punching failed"),
3860 Self::PunchingFailed(msg) => write!(f, "punching failed: {msg}"),
3861 Self::ValidationFailed(msg) => write!(f, "validation failed: {msg}"),
3862 Self::ValidationTimeout => write!(f, "validation timeout"),
3863 Self::NetworkError(msg) => write!(f, "network error: {msg}"),
3864 Self::ConfigError(msg) => write!(f, "configuration error: {msg}"),
3865 Self::ProtocolError(msg) => write!(f, "protocol error: {msg}"),
3866 Self::Timeout => write!(f, "operation timed out"),
3867 Self::ConnectionFailed(msg) => write!(f, "connection failed: {msg}"),
3868 Self::TraversalFailed(msg) => write!(f, "traversal failed: {msg}"),
3869 Self::PeerNotConnected => write!(f, "peer not connected"),
3870 }
3871 }
3872}
3873
3874impl std::error::Error for NatTraversalError {}
3875
3876impl fmt::Display for PeerId {
3877 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
3878 for byte in &self.0[..8] {
3880 write!(f, "{byte:02x}")?;
3881 }
3882 Ok(())
3883 }
3884}
3885
3886impl From<[u8; 32]> for PeerId {
3887 fn from(bytes: [u8; 32]) -> Self {
3888 Self(bytes)
3889 }
3890}
3891
3892#[derive(Debug)]
3895#[allow(dead_code)]
3896struct SkipServerVerification;
3897
3898impl SkipServerVerification {
3899 #[allow(dead_code)]
3900 fn new() -> Arc<Self> {
3901 Arc::new(Self)
3902 }
3903}
3904
3905impl rustls::client::danger::ServerCertVerifier for SkipServerVerification {
3906 fn verify_server_cert(
3907 &self,
3908 _end_entity: &rustls::pki_types::CertificateDer<'_>,
3909 _intermediates: &[rustls::pki_types::CertificateDer<'_>],
3910 _server_name: &rustls::pki_types::ServerName<'_>,
3911 _ocsp_response: &[u8],
3912 _now: rustls::pki_types::UnixTime,
3913 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
3914 Ok(rustls::client::danger::ServerCertVerified::assertion())
3915 }
3916
3917 fn verify_tls12_signature(
3918 &self,
3919 _message: &[u8],
3920 _cert: &rustls::pki_types::CertificateDer<'_>,
3921 _dss: &rustls::DigitallySignedStruct,
3922 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
3923 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
3924 }
3925
3926 fn verify_tls13_signature(
3927 &self,
3928 _message: &[u8],
3929 _cert: &rustls::pki_types::CertificateDer<'_>,
3930 _dss: &rustls::DigitallySignedStruct,
3931 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
3932 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
3933 }
3934
3935 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
3936 vec![
3937 rustls::SignatureScheme::RSA_PKCS1_SHA256,
3938 rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
3939 rustls::SignatureScheme::ED25519,
3940 ]
3941 }
3942}
3943
3944#[allow(dead_code)]
3946struct DefaultTokenStore;
3947
3948impl crate::TokenStore for DefaultTokenStore {
3949 fn insert(&self, _server_name: &str, _token: bytes::Bytes) {
3950 }
3952
3953 fn take(&self, _server_name: &str) -> Option<bytes::Bytes> {
3954 None
3955 }
3956}
3957
3958#[cfg(test)]
3959mod tests {
3960 use super::*;
3961
3962 #[test]
3963 fn test_nat_traversal_config_default() {
3964 let config = NatTraversalConfig::default();
3965 assert_eq!(config.role, EndpointRole::Client);
3966 assert_eq!(config.max_candidates, 8);
3967 assert!(config.enable_symmetric_nat);
3968 assert!(config.enable_relay_fallback);
3969 }
3970
3971 #[test]
3972 fn test_peer_id_display() {
3973 let peer_id = PeerId([
3974 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55,
3975 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0x00, 0x11, 0x22, 0x33,
3976 0x44, 0x55, 0x66, 0x77,
3977 ]);
3978 assert_eq!(format!("{peer_id}"), "0123456789abcdef");
3979 }
3980
3981 #[test]
3982 fn test_bootstrap_node_management() {
3983 let _config = NatTraversalConfig::default();
3984 }
3987
3988 #[test]
3989 fn test_candidate_address_validation() {
3990 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
3991
3992 assert!(
3994 CandidateAddress::validate_address(&SocketAddr::new(
3995 IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
3996 8080
3997 ))
3998 .is_ok()
3999 );
4000
4001 assert!(
4002 CandidateAddress::validate_address(&SocketAddr::new(
4003 IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)),
4004 53
4005 ))
4006 .is_ok()
4007 );
4008
4009 assert!(
4010 CandidateAddress::validate_address(&SocketAddr::new(
4011 IpAddr::V6(Ipv6Addr::new(0x2001, 0x4860, 0x4860, 0, 0, 0, 0, 0x8888)),
4012 443
4013 ))
4014 .is_ok()
4015 );
4016
4017 assert!(matches!(
4019 CandidateAddress::validate_address(&SocketAddr::new(
4020 IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
4021 0
4022 )),
4023 Err(CandidateValidationError::InvalidPort(0))
4024 ));
4025
4026 #[cfg(not(test))]
4028 assert!(matches!(
4029 CandidateAddress::validate_address(&SocketAddr::new(
4030 IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
4031 80
4032 )),
4033 Err(CandidateValidationError::PrivilegedPort(80))
4034 ));
4035
4036 assert!(matches!(
4038 CandidateAddress::validate_address(&SocketAddr::new(
4039 IpAddr::V4(Ipv4Addr::UNSPECIFIED),
4040 8080
4041 )),
4042 Err(CandidateValidationError::UnspecifiedAddress)
4043 ));
4044
4045 assert!(matches!(
4046 CandidateAddress::validate_address(&SocketAddr::new(
4047 IpAddr::V6(Ipv6Addr::UNSPECIFIED),
4048 8080
4049 )),
4050 Err(CandidateValidationError::UnspecifiedAddress)
4051 ));
4052
4053 assert!(matches!(
4055 CandidateAddress::validate_address(&SocketAddr::new(
4056 IpAddr::V4(Ipv4Addr::BROADCAST),
4057 8080
4058 )),
4059 Err(CandidateValidationError::BroadcastAddress)
4060 ));
4061
4062 assert!(matches!(
4064 CandidateAddress::validate_address(&SocketAddr::new(
4065 IpAddr::V4(Ipv4Addr::new(224, 0, 0, 1)),
4066 8080
4067 )),
4068 Err(CandidateValidationError::MulticastAddress)
4069 ));
4070
4071 assert!(matches!(
4072 CandidateAddress::validate_address(&SocketAddr::new(
4073 IpAddr::V6(Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 1)),
4074 8080
4075 )),
4076 Err(CandidateValidationError::MulticastAddress)
4077 ));
4078
4079 assert!(matches!(
4081 CandidateAddress::validate_address(&SocketAddr::new(
4082 IpAddr::V4(Ipv4Addr::new(0, 0, 0, 1)),
4083 8080
4084 )),
4085 Err(CandidateValidationError::ReservedAddress)
4086 ));
4087
4088 assert!(matches!(
4089 CandidateAddress::validate_address(&SocketAddr::new(
4090 IpAddr::V4(Ipv4Addr::new(240, 0, 0, 1)),
4091 8080
4092 )),
4093 Err(CandidateValidationError::ReservedAddress)
4094 ));
4095
4096 assert!(matches!(
4098 CandidateAddress::validate_address(&SocketAddr::new(
4099 IpAddr::V6(Ipv6Addr::new(0x2001, 0x0db8, 0, 0, 0, 0, 0, 1)),
4100 8080
4101 )),
4102 Err(CandidateValidationError::DocumentationAddress)
4103 ));
4104
4105 assert!(matches!(
4107 CandidateAddress::validate_address(&SocketAddr::new(
4108 IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0xc0a8, 0x0001)),
4109 8080
4110 )),
4111 Err(CandidateValidationError::IPv4MappedAddress)
4112 ));
4113 }
4114
4115 #[test]
4116 fn test_candidate_address_suitability_for_nat_traversal() {
4117 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
4118
4119 let public_v4 = CandidateAddress::new(
4121 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 8080),
4122 100,
4123 CandidateSource::Observed { by_node: None },
4124 )
4125 .unwrap();
4126 assert!(public_v4.is_suitable_for_nat_traversal());
4127
4128 let private_v4 = CandidateAddress::new(
4129 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 8080),
4130 100,
4131 CandidateSource::Local,
4132 )
4133 .unwrap();
4134 assert!(private_v4.is_suitable_for_nat_traversal());
4135
4136 let link_local_v4 = CandidateAddress::new(
4138 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(169, 254, 1, 1)), 8080),
4139 100,
4140 CandidateSource::Local,
4141 )
4142 .unwrap();
4143 assert!(!link_local_v4.is_suitable_for_nat_traversal());
4144
4145 let global_v6 = CandidateAddress::new(
4147 SocketAddr::new(
4148 IpAddr::V6(Ipv6Addr::new(0x2001, 0x4860, 0x4860, 0, 0, 0, 0, 0x8888)),
4149 8080,
4150 ),
4151 100,
4152 CandidateSource::Observed { by_node: None },
4153 )
4154 .unwrap();
4155 assert!(global_v6.is_suitable_for_nat_traversal());
4156
4157 let link_local_v6 = CandidateAddress::new(
4159 SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 1)), 8080),
4160 100,
4161 CandidateSource::Local,
4162 )
4163 .unwrap();
4164 assert!(!link_local_v6.is_suitable_for_nat_traversal());
4165
4166 let unique_local_v6 = CandidateAddress::new(
4168 SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0xfc00, 0, 0, 0, 0, 0, 0, 1)), 8080),
4169 100,
4170 CandidateSource::Local,
4171 )
4172 .unwrap();
4173 assert!(!unique_local_v6.is_suitable_for_nat_traversal());
4174
4175 #[cfg(test)]
4177 {
4178 let loopback_v4 = CandidateAddress::new(
4179 SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080),
4180 100,
4181 CandidateSource::Local,
4182 )
4183 .unwrap();
4184 assert!(loopback_v4.is_suitable_for_nat_traversal());
4185
4186 let loopback_v6 = CandidateAddress::new(
4187 SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 8080),
4188 100,
4189 CandidateSource::Local,
4190 )
4191 .unwrap();
4192 assert!(loopback_v6.is_suitable_for_nat_traversal());
4193 }
4194 }
4195
4196 #[test]
4197 fn test_candidate_effective_priority() {
4198 use std::net::{IpAddr, Ipv4Addr};
4199
4200 let mut candidate = CandidateAddress::new(
4201 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 8080),
4202 100,
4203 CandidateSource::Local,
4204 )
4205 .unwrap();
4206
4207 assert_eq!(candidate.effective_priority(), 90);
4209
4210 candidate.state = CandidateState::Validating;
4212 assert_eq!(candidate.effective_priority(), 95);
4213
4214 candidate.state = CandidateState::Valid;
4216 assert_eq!(candidate.effective_priority(), 100);
4217
4218 candidate.state = CandidateState::Failed;
4220 assert_eq!(candidate.effective_priority(), 0);
4221
4222 candidate.state = CandidateState::Removed;
4224 assert_eq!(candidate.effective_priority(), 0);
4225 }
4226}