ant_quic/
connection_establishment_simple.rs

1//! Simplified Connection Establishment with Automatic NAT Traversal
2//!
3//! This module provides a simplified but complete connection establishment
4//! system that automatically handles NAT traversal with fallback mechanisms.
5
6use std::{
7    collections::HashMap,
8    net::SocketAddr,
9    sync::Arc,
10    time::{Duration, Instant},
11};
12
13use tracing::{debug, info, warn};
14
15use crate::{
16    candidate_discovery::{CandidateDiscoveryManager, DiscoveryError, DiscoveryEvent},
17    nat_traversal_api::{BootstrapNode, CandidateAddress, PeerId},
18    high_level::{Endpoint as QuinnEndpoint, Connection as QuinnConnection},
19};
20
21/// Simplified connection establishment manager
22pub struct SimpleConnectionEstablishmentManager {
23    /// Configuration for connection establishment
24    config: SimpleEstablishmentConfig,
25    /// Active connection attempts  
26    active_attempts: HashMap<PeerId, SimpleConnectionAttempt>,
27    /// Candidate discovery manager
28    discovery_manager: Arc<std::sync::Mutex<CandidateDiscoveryManager>>,
29    /// Known bootstrap nodes
30    bootstrap_nodes: Vec<BootstrapNode>,
31    /// Event callback
32    event_callback: Option<Box<dyn Fn(SimpleConnectionEvent) + Send + Sync>>,
33    /// Quinn endpoint for making QUIC connections
34    quinn_endpoint: Option<Arc<QuinnEndpoint>>,
35}
36
37/// Simplified configuration
38#[derive(Debug, Clone)]
39pub struct SimpleEstablishmentConfig {
40    /// Timeout for direct connection attempts
41    pub direct_timeout: Duration,
42    /// Timeout for NAT traversal
43    pub nat_traversal_timeout: Duration,
44    /// Enable automatic NAT traversal
45    pub enable_nat_traversal: bool,
46    /// Maximum retry attempts
47    pub max_retries: u32,
48}
49
50/// Simplified connection attempt state
51#[derive(Debug)]
52struct SimpleConnectionAttempt {
53    peer_id: PeerId,
54    state: SimpleAttemptState,
55    started_at: Instant,
56    attempt_number: u32,
57    known_addresses: Vec<SocketAddr>,
58    discovered_candidates: Vec<CandidateAddress>,
59    last_error: Option<String>,
60    /// Active QUIC connection attempts
61    connection_handles: Vec<tokio::task::JoinHandle<Result<QuinnConnection, String>>>,
62    /// Current target addresses being attempted
63    target_addresses: Vec<SocketAddr>,
64    /// Established connection (if successful)
65    established_connection: Option<QuinnConnection>,
66}
67
68/// Simplified state machine
69#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70enum SimpleAttemptState {
71    DirectConnection,
72    CandidateDiscovery,
73    NatTraversal,
74    Connected,
75    Failed,
76}
77
78/// Simplified events
79#[derive(Debug, Clone)]
80pub enum SimpleConnectionEvent {
81    AttemptStarted {
82        peer_id: PeerId,
83    },
84    DirectConnectionTried {
85        peer_id: PeerId,
86        address: SocketAddr,
87    },
88    CandidateDiscoveryStarted {
89        peer_id: PeerId,
90    },
91    NatTraversalStarted {
92        peer_id: PeerId,
93    },
94    DirectConnectionSucceeded {
95        peer_id: PeerId,
96        address: SocketAddr,
97    },
98    DirectConnectionFailed {
99        peer_id: PeerId,
100        address: SocketAddr,
101        error: String,
102    },
103    ConnectionEstablished {
104        peer_id: PeerId,
105    },
106    ConnectionFailed {
107        peer_id: PeerId,
108        error: String,
109    },
110}
111
112impl Default for SimpleEstablishmentConfig {
113    fn default() -> Self {
114        Self {
115            direct_timeout: Duration::from_secs(5),
116            nat_traversal_timeout: Duration::from_secs(30),
117            enable_nat_traversal: true,
118            max_retries: 3,
119        }
120    }
121}
122
123impl SimpleConnectionEstablishmentManager {
124    /// Create a new simplified connection establishment manager
125    pub fn new(
126        config: SimpleEstablishmentConfig,
127        discovery_manager: Arc<std::sync::Mutex<CandidateDiscoveryManager>>,
128        bootstrap_nodes: Vec<BootstrapNode>,
129        event_callback: Option<Box<dyn Fn(SimpleConnectionEvent) + Send + Sync>>,
130        quinn_endpoint: Option<Arc<QuinnEndpoint>>,
131    ) -> Self {
132        Self {
133            config,
134            active_attempts: HashMap::new(),
135            discovery_manager,
136            bootstrap_nodes,
137            event_callback,
138            quinn_endpoint,
139        }
140    }
141
142    /// Start connection to peer
143    pub fn connect_to_peer(
144        &mut self,
145        peer_id: PeerId,
146        known_addresses: Vec<SocketAddr>,
147    ) -> Result<(), String> {
148        // Check if already attempting
149        if self.active_attempts.contains_key(&peer_id) {
150            return Err("Connection attempt already in progress".to_string());
151        }
152
153        // Create new attempt
154        let attempt = SimpleConnectionAttempt {
155            peer_id,
156            state: SimpleAttemptState::DirectConnection,
157            started_at: Instant::now(),
158            attempt_number: 1,
159            known_addresses: known_addresses.clone(),
160            discovered_candidates: Vec::new(),
161            last_error: None,
162            connection_handles: Vec::new(),
163            target_addresses: Vec::new(),
164            established_connection: None,
165        };
166
167        self.active_attempts.insert(peer_id, attempt);
168
169        // Emit event
170        self.emit_event(SimpleConnectionEvent::AttemptStarted { peer_id });
171
172        // Try direct connection first if we have addresses
173        if !known_addresses.is_empty() {
174            info!("Starting direct connection attempt to peer {:?}", peer_id);
175            
176            // Start direct connections if we have a Quinn endpoint
177            if let Some(ref quinn_endpoint) = self.quinn_endpoint {
178                self.start_direct_connections(peer_id, &known_addresses, quinn_endpoint.clone())?;
179            } else {
180                // Just emit events if no real endpoint (for testing)
181                for address in &known_addresses {
182                    self.emit_event(SimpleConnectionEvent::DirectConnectionTried {
183                        peer_id,
184                        address: *address,
185                    });
186                }
187            }
188        } else if self.config.enable_nat_traversal {
189            // Start candidate discovery immediately
190            self.start_candidate_discovery(peer_id)?;
191        } else {
192            return Err("No known addresses and NAT traversal disabled".to_string());
193        }
194
195        Ok(())
196    }
197
198    /// Start direct QUIC connections to known addresses
199    fn start_direct_connections(
200        &mut self,
201        peer_id: PeerId,
202        addresses: &[SocketAddr],
203        endpoint: Arc<QuinnEndpoint>,
204    ) -> Result<(), String> {
205        let attempt = self.active_attempts.get_mut(&peer_id)
206            .ok_or("Attempt not found")?;
207        
208        // Collect events to emit after the loop
209        let mut events_to_emit = Vec::new();
210        
211        for &address in addresses {
212            let server_name = format!("peer-{:x}", peer_id.0[0] as u32);
213            let endpoint_clone = endpoint.clone();
214            
215            // Spawn a task to handle the connection attempt
216            let handle = tokio::spawn(async move {
217                let connecting = endpoint_clone.connect(address, &server_name)
218                    .map_err(|e| format!("Failed to start connection: {}", e))?;
219                    
220                // Apply a timeout to the connection attempt
221                match tokio::time::timeout(Duration::from_secs(10), connecting).await {
222                    Ok(connection_result) => connection_result
223                        .map_err(|e| format!("Connection failed: {}", e)),
224                    Err(_) => Err("Connection timed out".to_string()),
225                }
226            });
227            
228            attempt.connection_handles.push(handle);
229            attempt.target_addresses.push(address);
230            
231            // Collect event to emit later
232            events_to_emit.push(SimpleConnectionEvent::DirectConnectionTried {
233                peer_id,
234                address,
235            });
236        }
237        
238        // Emit events after borrowing is done
239        for event in events_to_emit {
240            self.emit_event(event);
241        }
242        
243        debug!("Started {} direct connections for peer {:?}", addresses.len(), peer_id);
244        Ok(())
245    }
246
247    /// Poll for progress
248    pub fn poll(&mut self, now: Instant) -> Vec<SimpleConnectionEvent> {
249        let mut events = Vec::new();
250
251        // Process discovery events
252        let discovery_events = match self.discovery_manager.lock() {
253            Ok(mut discovery) => discovery.poll(now),
254            _ => Vec::new(),
255        };
256
257        for discovery_event in discovery_events {
258            self.handle_discovery_event(discovery_event, &mut events);
259        }
260
261        // Process active attempts
262        let peer_ids: Vec<_> = self.active_attempts.keys().copied().collect();
263        let mut completed = Vec::new();
264
265        for peer_id in peer_ids {
266            if self.poll_attempt(peer_id, now, &mut events) {
267                completed.push(peer_id);
268            }
269        }
270
271        // Remove completed attempts
272        for peer_id in completed {
273            self.active_attempts.remove(&peer_id);
274        }
275
276        events
277    }
278
279    /// Cancel connection attempt
280    pub fn cancel_connection(&mut self, peer_id: PeerId) -> bool {
281        self.active_attempts.remove(&peer_id).is_some()
282    }
283
284    // Private methods
285
286    fn start_candidate_discovery(&mut self, peer_id: PeerId) -> Result<(), String> {
287        if let Some(attempt) = self.active_attempts.get_mut(&peer_id) {
288            attempt.state = SimpleAttemptState::CandidateDiscovery;
289
290            match self.discovery_manager.lock() {
291                Ok(mut discovery) => {
292                    discovery
293                        .start_discovery(peer_id, self.bootstrap_nodes.clone())
294                        .map_err(|e| format!("Discovery failed: {e:?}"))?;
295                }
296                _ => {
297                    return Err("Failed to lock discovery manager".to_string());
298                }
299            }
300
301            self.emit_event(SimpleConnectionEvent::CandidateDiscoveryStarted { peer_id });
302        }
303
304        Ok(())
305    }
306
307    fn poll_attempt(
308        &mut self,
309        peer_id: PeerId,
310        now: Instant,
311        events: &mut Vec<SimpleConnectionEvent>,
312    ) -> bool {
313        let attempt = match self.active_attempts.get_mut(&peer_id) {
314            Some(a) => a,
315            None => return true,
316        };
317
318        let elapsed = now.duration_since(attempt.started_at);
319        let timeout = match attempt.state {
320            SimpleAttemptState::DirectConnection => self.config.direct_timeout,
321            _ => self.config.nat_traversal_timeout,
322        };
323
324        // Check timeout
325        if elapsed > timeout {
326            match attempt.state {
327                SimpleAttemptState::DirectConnection if self.config.enable_nat_traversal => {
328                    // Fallback to NAT traversal
329                    info!(
330                        "Direct connection timed out for peer {:?}, starting NAT traversal",
331                        peer_id
332                    );
333                    attempt.state = SimpleAttemptState::CandidateDiscovery;
334
335                    // Start discovery outside of the borrow
336                    let discovery_result = match self.discovery_manager.lock() {
337                        Ok(mut discovery) => {
338                            discovery.start_discovery(peer_id, self.bootstrap_nodes.clone())
339                        }
340                        _ => Err(DiscoveryError::InternalError(
341                            "Failed to lock discovery manager".to_string(),
342                        )),
343                    };
344
345                    if let Err(e) = discovery_result {
346                        attempt.state = SimpleAttemptState::Failed;
347                        attempt.last_error = Some(format!("Discovery failed: {e:?}"));
348                        events.push(SimpleConnectionEvent::ConnectionFailed {
349                            peer_id,
350                            error: format!("Discovery failed: {e:?}"),
351                        });
352                        return true;
353                    }
354
355                    events.push(SimpleConnectionEvent::CandidateDiscoveryStarted { peer_id });
356                    return false;
357                }
358                _ => {
359                    // Timeout, mark as failed
360                    attempt.state = SimpleAttemptState::Failed;
361                    attempt.last_error = Some("Timeout exceeded".to_string());
362                    events.push(SimpleConnectionEvent::ConnectionFailed {
363                        peer_id,
364                        error: "Timeout exceeded".to_string(),
365                    });
366                    return true;
367                }
368            }
369        }
370
371        // Check real QUIC connection attempts
372        let has_connection_handles = !attempt.connection_handles.is_empty();
373        
374        if has_connection_handles {
375            // Extract data needed for polling to avoid double mutable borrow
376            let mut connection_handles = std::mem::take(&mut attempt.connection_handles);
377            let mut target_addresses = std::mem::take(&mut attempt.target_addresses);
378            let mut established_connection = attempt.established_connection.take();
379            
380            Self::poll_connection_handles_extracted(
381                peer_id,
382                &mut connection_handles,
383                &mut target_addresses,
384                &mut established_connection,
385                events,
386            );
387            
388            // Put data back
389            attempt.connection_handles = connection_handles;
390            attempt.target_addresses = target_addresses;
391            attempt.established_connection = established_connection;
392            
393            // Check if we have a successful connection
394            if attempt.established_connection.is_some() {
395                attempt.state = SimpleAttemptState::Connected;
396                events.push(SimpleConnectionEvent::ConnectionEstablished { peer_id });
397                return true;
398            }
399        }
400        
401        match attempt.state {
402            SimpleAttemptState::DirectConnection => {
403                // Check if all direct connections failed
404                if attempt.connection_handles.is_empty() || 
405                   attempt.connection_handles.iter().all(|h| h.is_finished()) {
406                    // All direct attempts finished, check for success
407                    if attempt.established_connection.is_none() && self.config.enable_nat_traversal {
408                        // No connection established, try NAT traversal
409                        debug!("Direct connections failed for peer {:?}, trying NAT traversal", peer_id);
410                        attempt.state = SimpleAttemptState::CandidateDiscovery;
411                        // Start discovery will be handled in next poll cycle
412                    }
413                }
414            }
415            SimpleAttemptState::CandidateDiscovery => {
416                // Wait for discovery events
417            }
418            SimpleAttemptState::NatTraversal => {
419                // Poll NAT traversal connection attempts
420                debug!("Polling NAT traversal attempts for peer {:?}", peer_id);
421            }
422            SimpleAttemptState::Connected | SimpleAttemptState::Failed => {
423                return true;
424            }
425        }
426
427        false
428    }
429
430    /// Poll connection handles to check for completed connections (extracted version)
431    fn poll_connection_handles_extracted(
432        peer_id: PeerId,
433        connection_handles: &mut Vec<tokio::task::JoinHandle<Result<QuinnConnection, String>>>,
434        target_addresses: &mut Vec<SocketAddr>,
435        established_connection: &mut Option<QuinnConnection>,
436        events: &mut Vec<SimpleConnectionEvent>,
437    ) -> bool {
438        let mut completed_indices = Vec::new();
439        
440        for (index, handle) in connection_handles.iter_mut().enumerate() {
441            if handle.is_finished() {
442                completed_indices.push(index);
443            }
444        }
445        
446        // Process completed handles
447        for &index in completed_indices.iter().rev() {
448            let handle = connection_handles.remove(index);
449            let target_address = target_addresses.remove(index);
450            
451            match tokio::runtime::Handle::try_current() {
452                Ok(runtime_handle) => {
453                    match runtime_handle.block_on(handle) {
454                        Ok(Ok(connection)) => {
455                            // Connection succeeded
456                            info!("QUIC connection established to {} for peer {:?}", target_address, peer_id);
457                            *established_connection = Some(connection);
458                            
459                            events.push(SimpleConnectionEvent::DirectConnectionSucceeded {
460                                peer_id,
461                                address: target_address,
462                            });
463                            
464                            // Cancel remaining connection attempts
465                            for remaining_handle in connection_handles.drain(..) {
466                                remaining_handle.abort();
467                            }
468                            target_addresses.clear();
469                            
470                            return true; // Exit early on success
471                        }
472                        Ok(Err(e)) => {
473                            // Connection failed
474                            warn!("QUIC connection to {} failed: {}", target_address, e);
475                            
476                            events.push(SimpleConnectionEvent::DirectConnectionFailed {
477                                peer_id,
478                                address: target_address,
479                                error: e,
480                            });
481                        }
482                        Err(join_error) => {
483                            // Task panic or cancellation
484                            warn!("QUIC connection task failed: {}", join_error);
485                            
486                            events.push(SimpleConnectionEvent::DirectConnectionFailed {
487                                peer_id,
488                                address: target_address,
489                                error: format!("Task failed: {}", join_error),
490                            });
491                        }
492                    }
493                }
494                Err(_) => {
495                    // No tokio runtime available, can't check result
496                    warn!("Unable to check connection result without tokio runtime");
497                }
498            }
499        }
500        
501        false
502    }
503
504    fn handle_discovery_event(
505        &mut self,
506        discovery_event: DiscoveryEvent,
507        events: &mut Vec<SimpleConnectionEvent>,
508    ) {
509        match discovery_event {
510            DiscoveryEvent::LocalCandidateDiscovered { candidate }
511            | DiscoveryEvent::ServerReflexiveCandidateDiscovered { candidate, .. }
512            | DiscoveryEvent::PredictedCandidateGenerated { candidate, .. } => {
513                // Add candidate to relevant attempts
514                for attempt in self.active_attempts.values_mut() {
515                    if attempt.state == SimpleAttemptState::CandidateDiscovery {
516                        attempt.discovered_candidates.push(candidate.clone());
517                    }
518                }
519            }
520            DiscoveryEvent::DiscoveryCompleted { .. } => {
521                // Transition attempts to NAT traversal
522                let peer_ids: Vec<_> = self
523                    .active_attempts
524                    .iter()
525                    .filter(|(_, a)| a.state == SimpleAttemptState::CandidateDiscovery)
526                    .map(|(peer_id, _)| *peer_id)
527                    .collect();
528
529                for peer_id in peer_ids {
530                    if let Some(attempt) = self.active_attempts.get_mut(&peer_id) {
531                        attempt.state = SimpleAttemptState::NatTraversal;
532                        events.push(SimpleConnectionEvent::NatTraversalStarted { peer_id });
533                    }
534                }
535            }
536            DiscoveryEvent::DiscoveryFailed { error, .. } => {
537                warn!("Discovery failed: {:?}", error);
538                // Mark relevant attempts as failed
539                let peer_ids: Vec<_> = self
540                    .active_attempts
541                    .iter()
542                    .filter(|(_, a)| a.state == SimpleAttemptState::CandidateDiscovery)
543                    .map(|(peer_id, _)| *peer_id)
544                    .collect();
545
546                for peer_id in peer_ids {
547                    if let Some(attempt) = self.active_attempts.get_mut(&peer_id) {
548                        attempt.state = SimpleAttemptState::Failed;
549                        attempt.last_error = Some(format!("Discovery failed: {error:?}"));
550                        events.push(SimpleConnectionEvent::ConnectionFailed {
551                            peer_id,
552                            error: format!("Discovery failed: {error:?}"),
553                        });
554                    }
555                }
556            }
557            _ => {
558                // Handle other events as needed
559            }
560        }
561    }
562
563    fn emit_event(&self, event: SimpleConnectionEvent) {
564        if let Some(ref callback) = self.event_callback {
565            callback(event);
566        }
567    }
568}