ant_quic/
connection_establishment_simple.rs

1//! Simplified Connection Establishment with Automatic NAT Traversal
2//!
3//! This module provides a simplified but complete connection establishment
4//! system that automatically handles NAT traversal with fallback mechanisms.
5
6use std::{
7    collections::HashMap,
8    net::SocketAddr,
9    sync::Arc,
10    time::{Duration, Instant},
11};
12
13use tracing::{debug, info, warn};
14
15use crate::{
16    candidate_discovery::{CandidateDiscoveryManager, DiscoveryError, DiscoveryEvent},
17    high_level::{Connection as QuinnConnection, Endpoint as QuinnEndpoint},
18    nat_traversal_api::{BootstrapNode, CandidateAddress, PeerId},
19};
20
21/// Simplified connection establishment manager
22pub struct SimpleConnectionEstablishmentManager {
23    /// Configuration for connection establishment
24    config: SimpleEstablishmentConfig,
25    /// Active connection attempts  
26    active_attempts: HashMap<PeerId, SimpleConnectionAttempt>,
27    /// Candidate discovery manager
28    discovery_manager: Arc<std::sync::Mutex<CandidateDiscoveryManager>>,
29    /// Known bootstrap nodes
30    bootstrap_nodes: Vec<BootstrapNode>,
31    /// Event callback
32    event_callback: Option<Box<dyn Fn(SimpleConnectionEvent) + Send + Sync>>,
33    /// Quinn endpoint for making QUIC connections
34    quinn_endpoint: Option<Arc<QuinnEndpoint>>,
35}
36
37/// Simplified configuration
38#[derive(Debug, Clone)]
39pub struct SimpleEstablishmentConfig {
40    /// Timeout for direct connection attempts
41    pub direct_timeout: Duration,
42    /// Timeout for NAT traversal
43    pub nat_traversal_timeout: Duration,
44    /// Enable automatic NAT traversal
45    pub enable_nat_traversal: bool,
46    /// Maximum retry attempts
47    pub max_retries: u32,
48}
49
50/// Simplified connection attempt state
51#[derive(Debug)]
52struct SimpleConnectionAttempt {
53    peer_id: PeerId,
54    state: SimpleAttemptState,
55    started_at: Instant,
56    attempt_number: u32,
57    known_addresses: Vec<SocketAddr>,
58    discovered_candidates: Vec<CandidateAddress>,
59    last_error: Option<String>,
60    /// Active QUIC connection attempts
61    connection_handles: Vec<tokio::task::JoinHandle<Result<QuinnConnection, String>>>,
62    /// Current target addresses being attempted
63    target_addresses: Vec<SocketAddr>,
64    /// Established connection (if successful)
65    established_connection: Option<QuinnConnection>,
66}
67
68/// Simplified state machine
69#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70enum SimpleAttemptState {
71    DirectConnection,
72    CandidateDiscovery,
73    NatTraversal,
74    Connected,
75    Failed,
76}
77
78/// Simplified events
79#[derive(Debug, Clone)]
80pub enum SimpleConnectionEvent {
81    AttemptStarted {
82        peer_id: PeerId,
83    },
84    DirectConnectionTried {
85        peer_id: PeerId,
86        address: SocketAddr,
87    },
88    CandidateDiscoveryStarted {
89        peer_id: PeerId,
90    },
91    NatTraversalStarted {
92        peer_id: PeerId,
93    },
94    DirectConnectionSucceeded {
95        peer_id: PeerId,
96        address: SocketAddr,
97    },
98    DirectConnectionFailed {
99        peer_id: PeerId,
100        address: SocketAddr,
101        error: String,
102    },
103    ConnectionEstablished {
104        peer_id: PeerId,
105    },
106    ConnectionFailed {
107        peer_id: PeerId,
108        error: String,
109    },
110}
111
112impl Default for SimpleEstablishmentConfig {
113    fn default() -> Self {
114        Self {
115            direct_timeout: Duration::from_secs(5),
116            nat_traversal_timeout: Duration::from_secs(30),
117            enable_nat_traversal: true,
118            max_retries: 3,
119        }
120    }
121}
122
123impl SimpleConnectionEstablishmentManager {
124    /// Create a new simplified connection establishment manager
125    pub fn new(
126        config: SimpleEstablishmentConfig,
127        discovery_manager: Arc<std::sync::Mutex<CandidateDiscoveryManager>>,
128        bootstrap_nodes: Vec<BootstrapNode>,
129        event_callback: Option<Box<dyn Fn(SimpleConnectionEvent) + Send + Sync>>,
130        quinn_endpoint: Option<Arc<QuinnEndpoint>>,
131    ) -> Self {
132        Self {
133            config,
134            active_attempts: HashMap::new(),
135            discovery_manager,
136            bootstrap_nodes,
137            event_callback,
138            quinn_endpoint,
139        }
140    }
141
142    /// Start connection to peer
143    pub fn connect_to_peer(
144        &mut self,
145        peer_id: PeerId,
146        known_addresses: Vec<SocketAddr>,
147    ) -> Result<(), String> {
148        // Check if already attempting
149        if self.active_attempts.contains_key(&peer_id) {
150            return Err("Connection attempt already in progress".to_string());
151        }
152
153        // Create new attempt
154        let attempt = SimpleConnectionAttempt {
155            peer_id,
156            state: SimpleAttemptState::DirectConnection,
157            started_at: Instant::now(),
158            attempt_number: 1,
159            known_addresses: known_addresses.clone(),
160            discovered_candidates: Vec::new(),
161            last_error: None,
162            connection_handles: Vec::new(),
163            target_addresses: Vec::new(),
164            established_connection: None,
165        };
166
167        self.active_attempts.insert(peer_id, attempt);
168
169        // Emit event
170        self.emit_event(SimpleConnectionEvent::AttemptStarted { peer_id });
171
172        // Try direct connection first if we have addresses
173        if !known_addresses.is_empty() {
174            info!("Starting direct connection attempt to peer {:?}", peer_id);
175
176            // Start direct connections if we have a Quinn endpoint
177            if let Some(ref quinn_endpoint) = self.quinn_endpoint {
178                self.start_direct_connections(peer_id, &known_addresses, quinn_endpoint.clone())?;
179            } else {
180                // Just emit events if no real endpoint (for testing)
181                for address in &known_addresses {
182                    self.emit_event(SimpleConnectionEvent::DirectConnectionTried {
183                        peer_id,
184                        address: *address,
185                    });
186                }
187            }
188        } else if self.config.enable_nat_traversal {
189            // Start candidate discovery immediately
190            self.start_candidate_discovery(peer_id)?;
191        } else {
192            return Err("No known addresses and NAT traversal disabled".to_string());
193        }
194
195        Ok(())
196    }
197
198    /// Start direct QUIC connections to known addresses
199    fn start_direct_connections(
200        &mut self,
201        peer_id: PeerId,
202        addresses: &[SocketAddr],
203        endpoint: Arc<QuinnEndpoint>,
204    ) -> Result<(), String> {
205        let attempt = self
206            .active_attempts
207            .get_mut(&peer_id)
208            .ok_or("Attempt not found")?;
209
210        // Collect events to emit after the loop
211        let mut events_to_emit = Vec::new();
212
213        for &address in addresses {
214            let server_name = format!("peer-{:x}", peer_id.0[0] as u32);
215            let endpoint_clone = endpoint.clone();
216
217            // Spawn a task to handle the connection attempt
218            let handle = tokio::spawn(async move {
219                let connecting = endpoint_clone
220                    .connect(address, &server_name)
221                    .map_err(|e| format!("Failed to start connection: {}", e))?;
222
223                // Apply a timeout to the connection attempt
224                match tokio::time::timeout(Duration::from_secs(10), connecting).await {
225                    Ok(connection_result) => {
226                        connection_result.map_err(|e| format!("Connection failed: {}", e))
227                    }
228                    Err(_) => Err("Connection timed out".to_string()),
229                }
230            });
231
232            attempt.connection_handles.push(handle);
233            attempt.target_addresses.push(address);
234
235            // Collect event to emit later
236            events_to_emit.push(SimpleConnectionEvent::DirectConnectionTried { peer_id, address });
237        }
238
239        // Emit events after borrowing is done
240        for event in events_to_emit {
241            self.emit_event(event);
242        }
243
244        debug!(
245            "Started {} direct connections for peer {:?}",
246            addresses.len(),
247            peer_id
248        );
249        Ok(())
250    }
251
252    /// Poll for progress
253    pub fn poll(&mut self, now: Instant) -> Vec<SimpleConnectionEvent> {
254        let mut events = Vec::new();
255
256        // Process discovery events
257        let discovery_events = match self.discovery_manager.lock() {
258            Ok(mut discovery) => discovery.poll(now),
259            _ => Vec::new(),
260        };
261
262        for discovery_event in discovery_events {
263            self.handle_discovery_event(discovery_event, &mut events);
264        }
265
266        // Process active attempts
267        let peer_ids: Vec<_> = self.active_attempts.keys().copied().collect();
268        let mut completed = Vec::new();
269
270        for peer_id in peer_ids {
271            if self.poll_attempt(peer_id, now, &mut events) {
272                completed.push(peer_id);
273            }
274        }
275
276        // Remove completed attempts
277        for peer_id in completed {
278            self.active_attempts.remove(&peer_id);
279        }
280
281        events
282    }
283
284    /// Cancel connection attempt
285    pub fn cancel_connection(&mut self, peer_id: PeerId) -> bool {
286        self.active_attempts.remove(&peer_id).is_some()
287    }
288
289    // Private methods
290
291    fn start_candidate_discovery(&mut self, peer_id: PeerId) -> Result<(), String> {
292        if let Some(attempt) = self.active_attempts.get_mut(&peer_id) {
293            attempt.state = SimpleAttemptState::CandidateDiscovery;
294
295            match self.discovery_manager.lock() {
296                Ok(mut discovery) => {
297                    discovery
298                        .start_discovery(peer_id, self.bootstrap_nodes.clone())
299                        .map_err(|e| format!("Discovery failed: {e:?}"))?;
300                }
301                _ => {
302                    return Err("Failed to lock discovery manager".to_string());
303                }
304            }
305
306            self.emit_event(SimpleConnectionEvent::CandidateDiscoveryStarted { peer_id });
307        }
308
309        Ok(())
310    }
311
312    fn poll_attempt(
313        &mut self,
314        peer_id: PeerId,
315        now: Instant,
316        events: &mut Vec<SimpleConnectionEvent>,
317    ) -> bool {
318        let attempt = match self.active_attempts.get_mut(&peer_id) {
319            Some(a) => a,
320            None => return true,
321        };
322
323        let elapsed = now.duration_since(attempt.started_at);
324        let timeout = match attempt.state {
325            SimpleAttemptState::DirectConnection => self.config.direct_timeout,
326            _ => self.config.nat_traversal_timeout,
327        };
328
329        // Check timeout
330        if elapsed > timeout {
331            match attempt.state {
332                SimpleAttemptState::DirectConnection if self.config.enable_nat_traversal => {
333                    // Fallback to NAT traversal
334                    info!(
335                        "Direct connection timed out for peer {:?}, starting NAT traversal",
336                        peer_id
337                    );
338                    attempt.state = SimpleAttemptState::CandidateDiscovery;
339
340                    // Start discovery outside of the borrow
341                    let discovery_result = match self.discovery_manager.lock() {
342                        Ok(mut discovery) => {
343                            discovery.start_discovery(peer_id, self.bootstrap_nodes.clone())
344                        }
345                        _ => Err(DiscoveryError::InternalError(
346                            "Failed to lock discovery manager".to_string(),
347                        )),
348                    };
349
350                    if let Err(e) = discovery_result {
351                        attempt.state = SimpleAttemptState::Failed;
352                        attempt.last_error = Some(format!("Discovery failed: {e:?}"));
353                        events.push(SimpleConnectionEvent::ConnectionFailed {
354                            peer_id,
355                            error: format!("Discovery failed: {e:?}"),
356                        });
357                        return true;
358                    }
359
360                    events.push(SimpleConnectionEvent::CandidateDiscoveryStarted { peer_id });
361                    return false;
362                }
363                _ => {
364                    // Timeout, mark as failed
365                    attempt.state = SimpleAttemptState::Failed;
366                    attempt.last_error = Some("Timeout exceeded".to_string());
367                    events.push(SimpleConnectionEvent::ConnectionFailed {
368                        peer_id,
369                        error: "Timeout exceeded".to_string(),
370                    });
371                    return true;
372                }
373            }
374        }
375
376        // Check real QUIC connection attempts
377        let has_connection_handles = !attempt.connection_handles.is_empty();
378
379        if has_connection_handles {
380            // Extract data needed for polling to avoid double mutable borrow
381            let mut connection_handles = std::mem::take(&mut attempt.connection_handles);
382            let mut target_addresses = std::mem::take(&mut attempt.target_addresses);
383            let mut established_connection = attempt.established_connection.take();
384
385            Self::poll_connection_handles_extracted(
386                peer_id,
387                &mut connection_handles,
388                &mut target_addresses,
389                &mut established_connection,
390                events,
391            );
392
393            // Put data back
394            attempt.connection_handles = connection_handles;
395            attempt.target_addresses = target_addresses;
396            attempt.established_connection = established_connection;
397
398            // Check if we have a successful connection
399            if attempt.established_connection.is_some() {
400                attempt.state = SimpleAttemptState::Connected;
401                events.push(SimpleConnectionEvent::ConnectionEstablished { peer_id });
402                return true;
403            }
404        }
405
406        match attempt.state {
407            SimpleAttemptState::DirectConnection => {
408                // Check if all direct connections failed
409                if attempt.connection_handles.is_empty()
410                    || attempt.connection_handles.iter().all(|h| h.is_finished())
411                {
412                    // All direct attempts finished, check for success
413                    if attempt.established_connection.is_none() && self.config.enable_nat_traversal
414                    {
415                        // No connection established, try NAT traversal
416                        debug!(
417                            "Direct connections failed for peer {:?}, trying NAT traversal",
418                            peer_id
419                        );
420                        attempt.state = SimpleAttemptState::CandidateDiscovery;
421                        // Start discovery will be handled in next poll cycle
422                    }
423                }
424            }
425            SimpleAttemptState::CandidateDiscovery => {
426                // Wait for discovery events
427            }
428            SimpleAttemptState::NatTraversal => {
429                // Poll NAT traversal connection attempts
430                debug!("Polling NAT traversal attempts for peer {:?}", peer_id);
431            }
432            SimpleAttemptState::Connected | SimpleAttemptState::Failed => {
433                return true;
434            }
435        }
436
437        false
438    }
439
440    /// Poll connection handles to check for completed connections (extracted version)
441    fn poll_connection_handles_extracted(
442        peer_id: PeerId,
443        connection_handles: &mut Vec<tokio::task::JoinHandle<Result<QuinnConnection, String>>>,
444        target_addresses: &mut Vec<SocketAddr>,
445        established_connection: &mut Option<QuinnConnection>,
446        events: &mut Vec<SimpleConnectionEvent>,
447    ) -> bool {
448        let mut completed_indices = Vec::new();
449
450        for (index, handle) in connection_handles.iter_mut().enumerate() {
451            if handle.is_finished() {
452                completed_indices.push(index);
453            }
454        }
455
456        // Process completed handles
457        for &index in completed_indices.iter().rev() {
458            let handle = connection_handles.remove(index);
459            let target_address = target_addresses.remove(index);
460
461            match tokio::runtime::Handle::try_current() {
462                Ok(runtime_handle) => {
463                    match runtime_handle.block_on(handle) {
464                        Ok(Ok(connection)) => {
465                            // Connection succeeded
466                            info!(
467                                "QUIC connection established to {} for peer {:?}",
468                                target_address, peer_id
469                            );
470                            *established_connection = Some(connection);
471
472                            events.push(SimpleConnectionEvent::DirectConnectionSucceeded {
473                                peer_id,
474                                address: target_address,
475                            });
476
477                            // Cancel remaining connection attempts
478                            for remaining_handle in connection_handles.drain(..) {
479                                remaining_handle.abort();
480                            }
481                            target_addresses.clear();
482
483                            return true; // Exit early on success
484                        }
485                        Ok(Err(e)) => {
486                            // Connection failed
487                            warn!("QUIC connection to {} failed: {}", target_address, e);
488
489                            events.push(SimpleConnectionEvent::DirectConnectionFailed {
490                                peer_id,
491                                address: target_address,
492                                error: e,
493                            });
494                        }
495                        Err(join_error) => {
496                            // Task panic or cancellation
497                            warn!("QUIC connection task failed: {}", join_error);
498
499                            events.push(SimpleConnectionEvent::DirectConnectionFailed {
500                                peer_id,
501                                address: target_address,
502                                error: format!("Task failed: {}", join_error),
503                            });
504                        }
505                    }
506                }
507                Err(_) => {
508                    // No tokio runtime available, can't check result
509                    warn!("Unable to check connection result without tokio runtime");
510                }
511            }
512        }
513
514        false
515    }
516
517    fn handle_discovery_event(
518        &mut self,
519        discovery_event: DiscoveryEvent,
520        events: &mut Vec<SimpleConnectionEvent>,
521    ) {
522        match discovery_event {
523            DiscoveryEvent::LocalCandidateDiscovered { candidate }
524            | DiscoveryEvent::ServerReflexiveCandidateDiscovered { candidate, .. }
525            | DiscoveryEvent::PredictedCandidateGenerated { candidate, .. } => {
526                // Add candidate to relevant attempts
527                for attempt in self.active_attempts.values_mut() {
528                    if attempt.state == SimpleAttemptState::CandidateDiscovery {
529                        attempt.discovered_candidates.push(candidate.clone());
530                    }
531                }
532            }
533            DiscoveryEvent::DiscoveryCompleted { .. } => {
534                // Transition attempts to NAT traversal
535                let peer_ids: Vec<_> = self
536                    .active_attempts
537                    .iter()
538                    .filter(|(_, a)| a.state == SimpleAttemptState::CandidateDiscovery)
539                    .map(|(peer_id, _)| *peer_id)
540                    .collect();
541
542                for peer_id in peer_ids {
543                    if let Some(attempt) = self.active_attempts.get_mut(&peer_id) {
544                        attempt.state = SimpleAttemptState::NatTraversal;
545                        events.push(SimpleConnectionEvent::NatTraversalStarted { peer_id });
546                    }
547                }
548            }
549            DiscoveryEvent::DiscoveryFailed { error, .. } => {
550                warn!("Discovery failed: {:?}", error);
551                // Mark relevant attempts as failed
552                let peer_ids: Vec<_> = self
553                    .active_attempts
554                    .iter()
555                    .filter(|(_, a)| a.state == SimpleAttemptState::CandidateDiscovery)
556                    .map(|(peer_id, _)| *peer_id)
557                    .collect();
558
559                for peer_id in peer_ids {
560                    if let Some(attempt) = self.active_attempts.get_mut(&peer_id) {
561                        attempt.state = SimpleAttemptState::Failed;
562                        attempt.last_error = Some(format!("Discovery failed: {error:?}"));
563                        events.push(SimpleConnectionEvent::ConnectionFailed {
564                            peer_id,
565                            error: format!("Discovery failed: {error:?}"),
566                        });
567                    }
568                }
569            }
570            _ => {
571                // Handle other events as needed
572            }
573        }
574    }
575
576    fn emit_event(&self, event: SimpleConnectionEvent) {
577        if let Some(ref callback) = self.event_callback {
578            callback(event);
579        }
580    }
581}