ant_quic/
connection_establishment_simple.rs

1//! Simplified Connection Establishment with Automatic NAT Traversal
2//!
3//! This module provides a simplified but complete connection establishment
4//! system that automatically handles NAT traversal with fallback mechanisms.
5
6use std::{
7    collections::HashMap,
8    net::SocketAddr,
9    sync::Arc,
10    time::{Duration, Instant},
11};
12
13use tracing::{debug, info, warn};
14
15use crate::{
16    candidate_discovery::{CandidateDiscoveryManager, DiscoveryError, DiscoveryEvent},
17    nat_traversal_api::{BootstrapNode, CandidateAddress, PeerId},
18};
19
20/// Simplified connection establishment manager
21pub struct SimpleConnectionEstablishmentManager {
22    /// Configuration for connection establishment
23    config: SimpleEstablishmentConfig,
24    /// Active connection attempts  
25    active_attempts: HashMap<PeerId, SimpleConnectionAttempt>,
26    /// Candidate discovery manager
27    discovery_manager: Arc<std::sync::Mutex<CandidateDiscoveryManager>>,
28    /// Known bootstrap nodes
29    bootstrap_nodes: Vec<BootstrapNode>,
30    /// Event callback
31    event_callback: Option<Box<dyn Fn(SimpleConnectionEvent) + Send + Sync>>,
32}
33
34/// Simplified configuration
35#[derive(Debug, Clone)]
36pub struct SimpleEstablishmentConfig {
37    /// Timeout for direct connection attempts
38    pub direct_timeout: Duration,
39    /// Timeout for NAT traversal
40    pub nat_traversal_timeout: Duration,
41    /// Enable automatic NAT traversal
42    pub enable_nat_traversal: bool,
43    /// Maximum retry attempts
44    pub max_retries: u32,
45}
46
47/// Simplified connection attempt state
48#[derive(Debug)]
49struct SimpleConnectionAttempt {
50    peer_id: PeerId,
51    state: SimpleAttemptState,
52    started_at: Instant,
53    attempt_number: u32,
54    known_addresses: Vec<SocketAddr>,
55    discovered_candidates: Vec<CandidateAddress>,
56    last_error: Option<String>,
57}
58
59/// Simplified state machine
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61enum SimpleAttemptState {
62    DirectConnection,
63    CandidateDiscovery,
64    NatTraversal,
65    Connected,
66    Failed,
67}
68
69/// Simplified events
70#[derive(Debug, Clone)]
71pub enum SimpleConnectionEvent {
72    AttemptStarted {
73        peer_id: PeerId,
74    },
75    DirectConnectionTried {
76        peer_id: PeerId,
77        address: SocketAddr,
78    },
79    CandidateDiscoveryStarted {
80        peer_id: PeerId,
81    },
82    NatTraversalStarted {
83        peer_id: PeerId,
84    },
85    ConnectionEstablished {
86        peer_id: PeerId,
87        address: SocketAddr,
88    },
89    ConnectionFailed {
90        peer_id: PeerId,
91        error: String,
92    },
93}
94
95impl Default for SimpleEstablishmentConfig {
96    fn default() -> Self {
97        Self {
98            direct_timeout: Duration::from_secs(5),
99            nat_traversal_timeout: Duration::from_secs(30),
100            enable_nat_traversal: true,
101            max_retries: 3,
102        }
103    }
104}
105
106impl SimpleConnectionEstablishmentManager {
107    /// Create a new simplified connection establishment manager
108    pub fn new(
109        config: SimpleEstablishmentConfig,
110        discovery_manager: Arc<std::sync::Mutex<CandidateDiscoveryManager>>,
111        bootstrap_nodes: Vec<BootstrapNode>,
112        event_callback: Option<Box<dyn Fn(SimpleConnectionEvent) + Send + Sync>>,
113    ) -> Self {
114        Self {
115            config,
116            active_attempts: HashMap::new(),
117            discovery_manager,
118            bootstrap_nodes,
119            event_callback,
120        }
121    }
122
123    /// Start connection to peer
124    pub fn connect_to_peer(
125        &mut self,
126        peer_id: PeerId,
127        known_addresses: Vec<SocketAddr>,
128    ) -> Result<(), String> {
129        // Check if already attempting
130        if self.active_attempts.contains_key(&peer_id) {
131            return Err("Connection attempt already in progress".to_string());
132        }
133
134        // Create new attempt
135        let attempt = SimpleConnectionAttempt {
136            peer_id,
137            state: SimpleAttemptState::DirectConnection,
138            started_at: Instant::now(),
139            attempt_number: 1,
140            known_addresses: known_addresses.clone(),
141            discovered_candidates: Vec::new(),
142            last_error: None,
143        };
144
145        self.active_attempts.insert(peer_id, attempt);
146
147        // Emit event
148        self.emit_event(SimpleConnectionEvent::AttemptStarted { peer_id });
149
150        // Try direct connection first if we have addresses
151        if !known_addresses.is_empty() {
152            info!("Starting direct connection attempt to peer {:?}", peer_id);
153            for address in &known_addresses {
154                self.emit_event(SimpleConnectionEvent::DirectConnectionTried {
155                    peer_id,
156                    address: *address,
157                });
158            }
159        } else if self.config.enable_nat_traversal {
160            // Start candidate discovery immediately
161            self.start_candidate_discovery(peer_id)?;
162        } else {
163            return Err("No known addresses and NAT traversal disabled".to_string());
164        }
165
166        Ok(())
167    }
168
169    /// Poll for progress
170    pub fn poll(&mut self, now: Instant) -> Vec<SimpleConnectionEvent> {
171        let mut events = Vec::new();
172
173        // Process discovery events
174        let discovery_events = if let Ok(mut discovery) = self.discovery_manager.lock() {
175            discovery.poll(now)
176        } else {
177            Vec::new()
178        };
179
180        for discovery_event in discovery_events {
181            self.handle_discovery_event(discovery_event, &mut events);
182        }
183
184        // Process active attempts
185        let peer_ids: Vec<_> = self.active_attempts.keys().copied().collect();
186        let mut completed = Vec::new();
187
188        for peer_id in peer_ids {
189            if self.poll_attempt(peer_id, now, &mut events) {
190                completed.push(peer_id);
191            }
192        }
193
194        // Remove completed attempts
195        for peer_id in completed {
196            self.active_attempts.remove(&peer_id);
197        }
198
199        events
200    }
201
202    /// Cancel connection attempt
203    pub fn cancel_connection(&mut self, peer_id: PeerId) -> bool {
204        self.active_attempts.remove(&peer_id).is_some()
205    }
206
207    // Private methods
208
209    fn start_candidate_discovery(&mut self, peer_id: PeerId) -> Result<(), String> {
210        if let Some(attempt) = self.active_attempts.get_mut(&peer_id) {
211            attempt.state = SimpleAttemptState::CandidateDiscovery;
212
213            if let Ok(mut discovery) = self.discovery_manager.lock() {
214                discovery
215                    .start_discovery(peer_id, self.bootstrap_nodes.clone())
216                    .map_err(|e| format!("Discovery failed: {:?}", e))?;
217            } else {
218                return Err("Failed to lock discovery manager".to_string());
219            }
220
221            self.emit_event(SimpleConnectionEvent::CandidateDiscoveryStarted { peer_id });
222        }
223
224        Ok(())
225    }
226
227    fn poll_attempt(
228        &mut self,
229        peer_id: PeerId,
230        now: Instant,
231        events: &mut Vec<SimpleConnectionEvent>,
232    ) -> bool {
233        let attempt = match self.active_attempts.get_mut(&peer_id) {
234            Some(a) => a,
235            None => return true,
236        };
237
238        let elapsed = now.duration_since(attempt.started_at);
239        let timeout = match attempt.state {
240            SimpleAttemptState::DirectConnection => self.config.direct_timeout,
241            _ => self.config.nat_traversal_timeout,
242        };
243
244        // Check timeout
245        if elapsed > timeout {
246            match attempt.state {
247                SimpleAttemptState::DirectConnection if self.config.enable_nat_traversal => {
248                    // Fallback to NAT traversal
249                    info!(
250                        "Direct connection timed out for peer {:?}, starting NAT traversal",
251                        peer_id
252                    );
253                    attempt.state = SimpleAttemptState::CandidateDiscovery;
254
255                    // Start discovery outside of the borrow
256                    let discovery_result = if let Ok(mut discovery) = self.discovery_manager.lock()
257                    {
258                        discovery.start_discovery(peer_id, self.bootstrap_nodes.clone())
259                    } else {
260                        Err(DiscoveryError::InternalError(
261                            "Failed to lock discovery manager".to_string(),
262                        ))
263                    };
264
265                    if let Err(e) = discovery_result {
266                        attempt.state = SimpleAttemptState::Failed;
267                        attempt.last_error = Some(format!("Discovery failed: {:?}", e));
268                        events.push(SimpleConnectionEvent::ConnectionFailed {
269                            peer_id,
270                            error: format!("Discovery failed: {:?}", e),
271                        });
272                        return true;
273                    }
274
275                    events.push(SimpleConnectionEvent::CandidateDiscoveryStarted { peer_id });
276                    return false;
277                }
278                _ => {
279                    // Timeout, mark as failed
280                    attempt.state = SimpleAttemptState::Failed;
281                    attempt.last_error = Some("Timeout exceeded".to_string());
282                    events.push(SimpleConnectionEvent::ConnectionFailed {
283                        peer_id,
284                        error: "Timeout exceeded".to_string(),
285                    });
286                    return true;
287                }
288            }
289        }
290
291        // Simulate connection establishment for testing
292        match attempt.state {
293            SimpleAttemptState::DirectConnection => {
294                // Simulate direct connection attempt
295                debug!("Simulating direct connection attempt to peer {:?}", peer_id);
296            }
297            SimpleAttemptState::CandidateDiscovery => {
298                // Wait for discovery events
299            }
300            SimpleAttemptState::NatTraversal => {
301                // Simulate NAT traversal attempt
302                debug!("Simulating NAT traversal attempt to peer {:?}", peer_id);
303            }
304            SimpleAttemptState::Connected | SimpleAttemptState::Failed => {
305                return true;
306            }
307        }
308
309        false
310    }
311
312    fn handle_discovery_event(
313        &mut self,
314        discovery_event: DiscoveryEvent,
315        events: &mut Vec<SimpleConnectionEvent>,
316    ) {
317        match discovery_event {
318            DiscoveryEvent::LocalCandidateDiscovered { candidate }
319            | DiscoveryEvent::ServerReflexiveCandidateDiscovered { candidate, .. }
320            | DiscoveryEvent::PredictedCandidateGenerated { candidate, .. } => {
321                // Add candidate to relevant attempts
322                for attempt in self.active_attempts.values_mut() {
323                    if attempt.state == SimpleAttemptState::CandidateDiscovery {
324                        attempt.discovered_candidates.push(candidate.clone());
325                    }
326                }
327            }
328            DiscoveryEvent::DiscoveryCompleted { .. } => {
329                // Transition attempts to NAT traversal
330                let peer_ids: Vec<_> = self
331                    .active_attempts
332                    .iter()
333                    .filter(|(_, a)| a.state == SimpleAttemptState::CandidateDiscovery)
334                    .map(|(peer_id, _)| *peer_id)
335                    .collect();
336
337                for peer_id in peer_ids {
338                    if let Some(attempt) = self.active_attempts.get_mut(&peer_id) {
339                        attempt.state = SimpleAttemptState::NatTraversal;
340                        events.push(SimpleConnectionEvent::NatTraversalStarted { peer_id });
341                    }
342                }
343            }
344            DiscoveryEvent::DiscoveryFailed { error, .. } => {
345                warn!("Discovery failed: {:?}", error);
346                // Mark relevant attempts as failed
347                let peer_ids: Vec<_> = self
348                    .active_attempts
349                    .iter()
350                    .filter(|(_, a)| a.state == SimpleAttemptState::CandidateDiscovery)
351                    .map(|(peer_id, _)| *peer_id)
352                    .collect();
353
354                for peer_id in peer_ids {
355                    if let Some(attempt) = self.active_attempts.get_mut(&peer_id) {
356                        attempt.state = SimpleAttemptState::Failed;
357                        attempt.last_error = Some(format!("Discovery failed: {:?}", error));
358                        events.push(SimpleConnectionEvent::ConnectionFailed {
359                            peer_id,
360                            error: format!("Discovery failed: {:?}", error),
361                        });
362                    }
363                }
364            }
365            _ => {
366                // Handle other events as needed
367            }
368        }
369    }
370
371    fn emit_event(&self, event: SimpleConnectionEvent) {
372        if let Some(ref callback) = self.event_callback {
373            callback(event);
374        }
375    }
376}