1use std::{
8 collections::HashMap,
9 fmt,
10 net::SocketAddr,
11 sync::Arc,
12 time::Duration,
13};
14
15use tracing::{debug, info};
16
17use crate::{
18 candidate_discovery::{CandidateDiscoveryManager, DiscoveryConfig, DiscoveryEvent},
19 connection::nat_traversal::{CandidateSource, CandidateState, NatTraversalRole},
20 Endpoint, VarInt,
21};
22
23pub struct NatTraversalEndpoint {
25 endpoint: Endpoint,
27 config: NatTraversalConfig,
29 bootstrap_nodes: Arc<std::sync::RwLock<Vec<BootstrapNode>>>,
31 active_sessions: Arc<std::sync::RwLock<HashMap<PeerId, NatTraversalSession>>>,
33 discovery_manager: Arc<std::sync::Mutex<CandidateDiscoveryManager>>,
35 event_callback: Option<Box<dyn Fn(NatTraversalEvent) + Send + Sync>>,
37}
38
39#[derive(Debug, Clone)]
41pub struct NatTraversalConfig {
42 pub role: EndpointRole,
44 pub bootstrap_nodes: Vec<SocketAddr>,
46 pub max_candidates: usize,
48 pub coordination_timeout: Duration,
50 pub enable_symmetric_nat: bool,
52 pub enable_relay_fallback: bool,
54 pub max_concurrent_attempts: usize,
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq)]
60pub enum EndpointRole {
61 Client,
63 Server { can_coordinate: bool },
65 Bootstrap,
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
71pub struct PeerId(pub [u8; 32]);
72
73#[derive(Debug, Clone)]
75pub struct BootstrapNode {
76 pub address: SocketAddr,
78 pub last_seen: std::time::Instant,
80 pub can_coordinate: bool,
82 pub rtt: Option<Duration>,
84 pub coordination_count: u32,
86}
87
88#[derive(Debug)]
90struct NatTraversalSession {
91 peer_id: PeerId,
93 coordinator: SocketAddr,
95 attempt: u32,
97 started_at: std::time::Instant,
99 phase: TraversalPhase,
101 candidates: Vec<CandidateAddress>,
103}
104
105#[derive(Debug, Clone, Copy, PartialEq, Eq)]
107enum TraversalPhase {
108 Discovery,
110 Coordination,
112 Synchronization,
114 Punching,
116 Validation,
118 Connected,
120 Failed,
122}
123
124#[derive(Debug, Clone)]
126pub struct CandidateAddress {
127 pub address: SocketAddr,
129 pub priority: u32,
131 pub source: CandidateSource,
133 pub state: CandidateState,
135}
136
137
138#[derive(Debug, Clone)]
140pub enum NatTraversalEvent {
141 CandidateDiscovered {
143 peer_id: PeerId,
144 candidate: CandidateAddress,
145 },
146 CoordinationRequested {
148 peer_id: PeerId,
149 coordinator: SocketAddr,
150 },
151 CoordinationSynchronized {
153 peer_id: PeerId,
154 round_id: VarInt,
155 },
156 HolePunchingStarted {
158 peer_id: PeerId,
159 targets: Vec<SocketAddr>,
160 },
161 PathValidated {
163 peer_id: PeerId,
164 address: SocketAddr,
165 rtt: Duration,
166 },
167 TraversalSucceeded {
169 peer_id: PeerId,
170 final_address: SocketAddr,
171 total_time: Duration,
172 },
173 TraversalFailed {
175 peer_id: PeerId,
176 error: NatTraversalError,
177 fallback_available: bool,
178 },
179}
180
181#[derive(Debug, Clone)]
183pub enum NatTraversalError {
184 NoBootstrapNodes,
186 NoCandidatesFound,
188 CandidateDiscoveryFailed(String),
190 CoordinationFailed(String),
192 HolePunchingFailed,
194 ValidationTimeout,
196 NetworkError(String),
198 ConfigError(String),
200 ProtocolError(String),
202}
203
204impl Default for NatTraversalConfig {
205 fn default() -> Self {
206 Self {
207 role: EndpointRole::Client,
208 bootstrap_nodes: Vec::new(),
209 max_candidates: 8,
210 coordination_timeout: Duration::from_secs(10),
211 enable_symmetric_nat: true,
212 enable_relay_fallback: true,
213 max_concurrent_attempts: 3,
214 }
215 }
216}
217
218impl NatTraversalEndpoint {
219 pub fn new(
221 config: NatTraversalConfig,
222 event_callback: Option<Box<dyn Fn(NatTraversalEvent) + Send + Sync>>,
223 ) -> Result<Self, NatTraversalError> {
224 if config.bootstrap_nodes.is_empty() && config.role != EndpointRole::Bootstrap {
226 return Err(NatTraversalError::ConfigError(
227 "At least one bootstrap node required for non-bootstrap endpoints".to_string(),
228 ));
229 }
230
231 let bootstrap_nodes = Arc::new(std::sync::RwLock::new(
233 config
234 .bootstrap_nodes
235 .iter()
236 .map(|&address| BootstrapNode {
237 address,
238 last_seen: std::time::Instant::now(),
239 can_coordinate: true, rtt: None,
241 coordination_count: 0,
242 })
243 .collect(),
244 ));
245
246 let discovery_config = DiscoveryConfig {
248 total_timeout: config.coordination_timeout,
249 max_candidates: config.max_candidates,
250 enable_symmetric_prediction: config.enable_symmetric_nat,
251 ..DiscoveryConfig::default()
252 };
253
254 let nat_traversal_role = match config.role {
255 EndpointRole::Client => NatTraversalRole::Client,
256 EndpointRole::Server { can_coordinate } => NatTraversalRole::Server { can_relay: can_coordinate },
257 EndpointRole::Bootstrap => NatTraversalRole::Bootstrap,
258 };
259
260 let discovery_manager = Arc::new(std::sync::Mutex::new(
261 CandidateDiscoveryManager::new(discovery_config, nat_traversal_role)
262 ));
263
264 let endpoint = unsafe { std::mem::zeroed() }; Ok(Self {
271 endpoint,
272 config,
273 bootstrap_nodes,
274 active_sessions: Arc::new(std::sync::RwLock::new(HashMap::new())),
275 discovery_manager,
276 event_callback,
277 })
278 }
279
280 pub fn initiate_nat_traversal(
282 &self,
283 peer_id: PeerId,
284 coordinator: SocketAddr,
285 ) -> Result<(), NatTraversalError> {
286 info!("Starting NAT traversal to peer {:?} via coordinator {}", peer_id, coordinator);
287
288 let session = NatTraversalSession {
290 peer_id,
291 coordinator,
292 attempt: 1,
293 started_at: std::time::Instant::now(),
294 phase: TraversalPhase::Discovery,
295 candidates: Vec::new(),
296 };
297
298 {
300 let mut sessions = self.active_sessions.write()
301 .map_err(|_| NatTraversalError::ProtocolError("Lock poisoned".to_string()))?;
302 sessions.insert(peer_id, session);
303 }
304
305 let bootstrap_nodes_vec = {
307 let bootstrap_nodes = self.bootstrap_nodes.read()
308 .map_err(|_| NatTraversalError::ProtocolError("Lock poisoned".to_string()))?;
309 bootstrap_nodes.clone()
310 };
311
312 {
313 let mut discovery = self.discovery_manager.lock()
314 .map_err(|_| NatTraversalError::ProtocolError("Discovery manager lock poisoned".to_string()))?;
315
316 discovery.start_discovery(peer_id, bootstrap_nodes_vec)
317 .map_err(|e| NatTraversalError::CandidateDiscoveryFailed(e.to_string()))?;
318 }
319
320 if let Some(ref callback) = self.event_callback {
322 callback(NatTraversalEvent::CoordinationRequested {
323 peer_id,
324 coordinator,
325 });
326 }
327
328 Ok(())
330 }
331
332 pub fn get_statistics(&self) -> Result<NatTraversalStatistics, NatTraversalError> {
334 let sessions = self.active_sessions.read()
335 .map_err(|_| NatTraversalError::ProtocolError("Lock poisoned".to_string()))?;
336 let bootstrap_nodes = self.bootstrap_nodes.read()
337 .map_err(|_| NatTraversalError::ProtocolError("Lock poisoned".to_string()))?;
338
339 Ok(NatTraversalStatistics {
340 active_sessions: sessions.len(),
341 total_bootstrap_nodes: bootstrap_nodes.len(),
342 successful_coordinations: bootstrap_nodes.iter().map(|b| b.coordination_count).sum(),
343 average_coordination_time: Duration::from_millis(500), })
345 }
346
347 pub fn add_bootstrap_node(&self, address: SocketAddr) -> Result<(), NatTraversalError> {
349 let mut bootstrap_nodes = self.bootstrap_nodes.write()
350 .map_err(|_| NatTraversalError::ProtocolError("Lock poisoned".to_string()))?;
351
352 if !bootstrap_nodes.iter().any(|b| b.address == address) {
354 bootstrap_nodes.push(BootstrapNode {
355 address,
356 last_seen: std::time::Instant::now(),
357 can_coordinate: true,
358 rtt: None,
359 coordination_count: 0,
360 });
361 info!("Added bootstrap node: {}", address);
362 }
363 Ok(())
364 }
365
366 pub fn remove_bootstrap_node(&self, address: SocketAddr) -> Result<(), NatTraversalError> {
368 let mut bootstrap_nodes = self.bootstrap_nodes.write()
369 .map_err(|_| NatTraversalError::ProtocolError("Lock poisoned".to_string()))?;
370 bootstrap_nodes.retain(|b| b.address != address);
371 info!("Removed bootstrap node: {}", address);
372 Ok(())
373 }
374
375 fn discover_candidates(&self, peer_id: PeerId) -> Result<(), NatTraversalError> {
378 debug!("Discovering address candidates for peer {:?}", peer_id);
379
380 Ok(())
386 }
387
388 fn coordinate_with_bootstrap(
389 &self,
390 peer_id: PeerId,
391 coordinator: SocketAddr,
392 ) -> Result<(), NatTraversalError> {
393 debug!("Coordinating with bootstrap {} for peer {:?}", coordinator, peer_id);
394
395 Ok(())
401 }
402
403 fn attempt_hole_punching(&self, peer_id: PeerId) -> Result<(), NatTraversalError> {
404 debug!("Attempting hole punching for peer {:?}", peer_id);
405
406 Ok(())
412 }
413
414 pub fn poll(&self, now: std::time::Instant) -> Result<Vec<NatTraversalEvent>, NatTraversalError> {
416 let mut events = Vec::new();
417
418 {
420 let mut discovery = self.discovery_manager.lock()
421 .map_err(|_| NatTraversalError::ProtocolError("Discovery manager lock poisoned".to_string()))?;
422
423 let discovery_events = discovery.poll(now);
424
425 for discovery_event in discovery_events {
427 if let Some(nat_event) = self.convert_discovery_event(discovery_event) {
428 events.push(nat_event.clone());
429
430 if let Some(ref callback) = self.event_callback {
432 callback(nat_event);
433 }
434 }
435 }
436 }
437
438 let mut sessions = self.active_sessions.write()
440 .map_err(|_| NatTraversalError::ProtocolError("Lock poisoned".to_string()))?;
441
442 for (_peer_id, session) in sessions.iter_mut() {
443 let _elapsed = now.duration_since(session.started_at);
448 }
449
450 Ok(events)
451 }
452
453 fn convert_discovery_event(&self, discovery_event: DiscoveryEvent) -> Option<NatTraversalEvent> {
455 match discovery_event {
456 DiscoveryEvent::LocalCandidateDiscovered { candidate } => {
457 Some(NatTraversalEvent::CandidateDiscovered {
458 peer_id: PeerId([0; 32]), candidate,
460 })
461 },
462 DiscoveryEvent::ServerReflexiveCandidateDiscovered { candidate, bootstrap_node: _ } => {
463 Some(NatTraversalEvent::CandidateDiscovered {
464 peer_id: PeerId([0; 32]), candidate,
466 })
467 },
468 DiscoveryEvent::PredictedCandidateGenerated { candidate, confidence: _ } => {
469 Some(NatTraversalEvent::CandidateDiscovered {
470 peer_id: PeerId([0; 32]), candidate,
472 })
473 },
474 DiscoveryEvent::DiscoveryCompleted { candidate_count: _, total_duration: _, success_rate: _ } => {
475 None },
478 DiscoveryEvent::DiscoveryFailed { error, partial_results } => {
479 Some(NatTraversalEvent::TraversalFailed {
480 peer_id: PeerId([0; 32]), error: NatTraversalError::CandidateDiscoveryFailed(error.to_string()),
482 fallback_available: !partial_results.is_empty(),
483 })
484 },
485 _ => None, }
487 }
488}
489
490impl fmt::Debug for NatTraversalEndpoint {
491 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
492 f.debug_struct("NatTraversalEndpoint")
493 .field("config", &self.config)
494 .field("bootstrap_nodes", &"<RwLock>")
495 .field("active_sessions", &"<RwLock>")
496 .field("event_callback", &self.event_callback.is_some())
497 .finish()
498 }
499}
500
501#[derive(Debug, Clone)]
503pub struct NatTraversalStatistics {
504 pub active_sessions: usize,
506 pub total_bootstrap_nodes: usize,
508 pub successful_coordinations: u32,
510 pub average_coordination_time: Duration,
512}
513
514impl fmt::Display for NatTraversalError {
515 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
516 match self {
517 Self::NoBootstrapNodes => write!(f, "no bootstrap nodes available"),
518 Self::NoCandidatesFound => write!(f, "no address candidates found"),
519 Self::CandidateDiscoveryFailed(msg) => write!(f, "candidate discovery failed: {}", msg),
520 Self::CoordinationFailed(msg) => write!(f, "coordination failed: {}", msg),
521 Self::HolePunchingFailed => write!(f, "hole punching failed"),
522 Self::ValidationTimeout => write!(f, "validation timeout"),
523 Self::NetworkError(msg) => write!(f, "network error: {}", msg),
524 Self::ConfigError(msg) => write!(f, "configuration error: {}", msg),
525 Self::ProtocolError(msg) => write!(f, "protocol error: {}", msg),
526 }
527 }
528}
529
530impl std::error::Error for NatTraversalError {}
531
532impl fmt::Display for PeerId {
533 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
534 for byte in &self.0[..8] {
536 write!(f, "{:02x}", byte)?;
537 }
538 Ok(())
539 }
540}
541
542impl From<[u8; 32]> for PeerId {
543 fn from(bytes: [u8; 32]) -> Self {
544 Self(bytes)
545 }
546}
547
548#[cfg(test)]
549mod tests {
550 use super::*;
551
552 #[test]
553 fn test_nat_traversal_config_default() {
554 let config = NatTraversalConfig::default();
555 assert_eq!(config.role, EndpointRole::Client);
556 assert_eq!(config.max_candidates, 8);
557 assert!(config.enable_symmetric_nat);
558 assert!(config.enable_relay_fallback);
559 }
560
561 #[test]
562 fn test_peer_id_display() {
563 let peer_id = PeerId([0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77]);
564 assert_eq!(format!("{}", peer_id), "0123456789abcdef");
565 }
566
567 #[test]
568 fn test_bootstrap_node_management() {
569 let config = NatTraversalConfig::default();
570 }
573}