ant_quic/
connection_establishment_simple.rs

1//! Simplified Connection Establishment with Automatic NAT Traversal
2//!
3//! This module provides a simplified but complete connection establishment
4//! system that automatically handles NAT traversal with fallback mechanisms.
5
6use std::{
7    collections::HashMap,
8    net::SocketAddr,
9    sync::Arc,
10    time::{Duration, Instant},
11};
12
13use tracing::{info, warn};
14
15use crate::{
16    candidate_discovery::{CandidateDiscoveryManager, DiscoveryEvent, DiscoveryError},
17    nat_traversal_api::{BootstrapNode, CandidateAddress, PeerId},
18};
19
20/// Simplified connection establishment manager
21pub struct SimpleConnectionEstablishmentManager {
22    /// Configuration for connection establishment
23    config: SimpleEstablishmentConfig,
24    /// Active connection attempts  
25    active_attempts: HashMap<PeerId, SimpleConnectionAttempt>,
26    /// Candidate discovery manager
27    discovery_manager: Arc<std::sync::Mutex<CandidateDiscoveryManager>>,
28    /// Known bootstrap nodes
29    bootstrap_nodes: Vec<BootstrapNode>,
30    /// Event callback
31    event_callback: Option<Box<dyn Fn(SimpleConnectionEvent) + Send + Sync>>,
32}
33
34/// Simplified configuration
35#[derive(Debug, Clone)]
36pub struct SimpleEstablishmentConfig {
37    /// Timeout for direct connection attempts
38    pub direct_timeout: Duration,
39    /// Timeout for NAT traversal
40    pub nat_traversal_timeout: Duration,
41    /// Enable automatic NAT traversal
42    pub enable_nat_traversal: bool,
43    /// Maximum retry attempts
44    pub max_retries: u32,
45}
46
47/// Simplified connection attempt state
48#[derive(Debug)]
49struct SimpleConnectionAttempt {
50    peer_id: PeerId,
51    state: SimpleAttemptState,
52    started_at: Instant,
53    attempt_number: u32,
54    known_addresses: Vec<SocketAddr>,
55    discovered_candidates: Vec<CandidateAddress>,
56    last_error: Option<String>,
57}
58
59/// Simplified state machine
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61enum SimpleAttemptState {
62    DirectConnection,
63    CandidateDiscovery,
64    NatTraversal,
65    Connected,
66    Failed,
67}
68
69/// Simplified events
70#[derive(Debug, Clone)]
71pub enum SimpleConnectionEvent {
72    AttemptStarted { peer_id: PeerId },
73    DirectConnectionTried { peer_id: PeerId, address: SocketAddr },
74    CandidateDiscoveryStarted { peer_id: PeerId },
75    NatTraversalStarted { peer_id: PeerId },
76    ConnectionEstablished { peer_id: PeerId, address: SocketAddr },
77    ConnectionFailed { peer_id: PeerId, error: String },
78}
79
80impl Default for SimpleEstablishmentConfig {
81    fn default() -> Self {
82        Self {
83            direct_timeout: Duration::from_secs(5),
84            nat_traversal_timeout: Duration::from_secs(30),
85            enable_nat_traversal: true,
86            max_retries: 3,
87        }
88    }
89}
90
91impl SimpleConnectionEstablishmentManager {
92    /// Create a new simplified connection establishment manager
93    pub fn new(
94        config: SimpleEstablishmentConfig,
95        discovery_manager: Arc<std::sync::Mutex<CandidateDiscoveryManager>>,
96        bootstrap_nodes: Vec<BootstrapNode>,
97        event_callback: Option<Box<dyn Fn(SimpleConnectionEvent) + Send + Sync>>,
98    ) -> Self {
99        Self {
100            config,
101            active_attempts: HashMap::new(),
102            discovery_manager,
103            bootstrap_nodes,
104            event_callback,
105        }
106    }
107
108    /// Start connection to peer
109    pub fn connect_to_peer(
110        &mut self,
111        peer_id: PeerId,
112        known_addresses: Vec<SocketAddr>,
113    ) -> Result<(), String> {
114        // Check if already attempting
115        if self.active_attempts.contains_key(&peer_id) {
116            return Err("Connection attempt already in progress".to_string());
117        }
118
119        // Create new attempt
120        let attempt = SimpleConnectionAttempt {
121            peer_id,
122            state: SimpleAttemptState::DirectConnection,
123            started_at: Instant::now(),
124            attempt_number: 1,
125            known_addresses: known_addresses.clone(),
126            discovered_candidates: Vec::new(),
127            last_error: None,
128        };
129
130        self.active_attempts.insert(peer_id, attempt);
131
132        // Emit event
133        self.emit_event(SimpleConnectionEvent::AttemptStarted { peer_id });
134
135        // Try direct connection first if we have addresses
136        if !known_addresses.is_empty() {
137            info!("Starting direct connection attempt to peer {:?}", peer_id);
138            for address in &known_addresses {
139                self.emit_event(SimpleConnectionEvent::DirectConnectionTried {
140                    peer_id,
141                    address: *address,
142                });
143            }
144        } else if self.config.enable_nat_traversal {
145            // Start candidate discovery immediately
146            self.start_candidate_discovery(peer_id)?;
147        } else {
148            return Err("No known addresses and NAT traversal disabled".to_string());
149        }
150
151        Ok(())
152    }
153
154    /// Poll for progress
155    pub fn poll(&mut self, now: Instant) -> Vec<SimpleConnectionEvent> {
156        let mut events = Vec::new();
157
158        // Process discovery events
159        let discovery_events = if let Ok(mut discovery) = self.discovery_manager.lock() {
160            discovery.poll(now)
161        } else {
162            Vec::new()
163        };
164
165        for discovery_event in discovery_events {
166            self.handle_discovery_event(discovery_event, &mut events);
167        }
168
169        // Process active attempts
170        let peer_ids: Vec<_> = self.active_attempts.keys().copied().collect();
171        let mut completed = Vec::new();
172
173        for peer_id in peer_ids {
174            if self.poll_attempt(peer_id, now, &mut events) {
175                completed.push(peer_id);
176            }
177        }
178
179        // Remove completed attempts
180        for peer_id in completed {
181            self.active_attempts.remove(&peer_id);
182        }
183
184        events
185    }
186
187    /// Cancel connection attempt
188    pub fn cancel_connection(&mut self, peer_id: PeerId) -> bool {
189        self.active_attempts.remove(&peer_id).is_some()
190    }
191
192    // Private methods
193
194    fn start_candidate_discovery(&mut self, peer_id: PeerId) -> Result<(), String> {
195        if let Some(attempt) = self.active_attempts.get_mut(&peer_id) {
196            attempt.state = SimpleAttemptState::CandidateDiscovery;
197
198            if let Ok(mut discovery) = self.discovery_manager.lock() {
199                discovery.start_discovery(peer_id, self.bootstrap_nodes.clone())
200                    .map_err(|e| format!("Discovery failed: {:?}", e))?;
201            } else {
202                return Err("Failed to lock discovery manager".to_string());
203            }
204
205            self.emit_event(SimpleConnectionEvent::CandidateDiscoveryStarted { peer_id });
206        }
207
208        Ok(())
209    }
210
211    fn poll_attempt(
212        &mut self,
213        peer_id: PeerId,
214        now: Instant,
215        events: &mut Vec<SimpleConnectionEvent>,
216    ) -> bool {
217        let should_complete = {
218            let attempt = match self.active_attempts.get_mut(&peer_id) {
219                Some(a) => a,
220                None => return true,
221            };
222
223            let elapsed = now.duration_since(attempt.started_at);
224            let timeout = match attempt.state {
225                SimpleAttemptState::DirectConnection => self.config.direct_timeout,
226                _ => self.config.nat_traversal_timeout,
227            };
228
229            // Check timeout
230            if elapsed > timeout {
231                match attempt.state {
232                    SimpleAttemptState::DirectConnection if self.config.enable_nat_traversal => {
233                        // Fallback to NAT traversal
234                        info!("Direct connection timed out for peer {:?}, starting NAT traversal", peer_id);
235                        attempt.state = SimpleAttemptState::CandidateDiscovery;
236                        
237                        // Start discovery outside of the borrow
238                        let discovery_result = if let Ok(mut discovery) = self.discovery_manager.lock() {
239                            discovery.start_discovery(peer_id, self.bootstrap_nodes.clone())
240                        } else {
241                            Err(DiscoveryError::InternalError("Failed to lock discovery manager".to_string()))
242                        };
243                        
244                        if let Err(e) = discovery_result {
245                            attempt.state = SimpleAttemptState::Failed;
246                            attempt.last_error = Some(format!("Discovery failed: {:?}", e));
247                            events.push(SimpleConnectionEvent::ConnectionFailed {
248                                peer_id,
249                                error: format!("Discovery failed: {:?}", e),
250                            });
251                            return true;
252                        }
253                        
254                        events.push(SimpleConnectionEvent::CandidateDiscoveryStarted { peer_id });
255                        return false;
256                    }
257                    _ => {
258                        // Timeout, mark as failed
259                        attempt.state = SimpleAttemptState::Failed;
260                        attempt.last_error = Some("Timeout exceeded".to_string());
261                        events.push(SimpleConnectionEvent::ConnectionFailed {
262                            peer_id,
263                            error: "Timeout exceeded".to_string(),
264                        });
265                        return true;
266                    }
267                }
268            }
269
270            // Simulate connection success for testing
271            // In real implementation, this would check actual connection status
272            match attempt.state {
273                SimpleAttemptState::DirectConnection => {
274                    if elapsed > Duration::from_secs(2) {
275                        // Simulate some direct connections succeeding
276                        if !attempt.known_addresses.is_empty() {
277                            attempt.state = SimpleAttemptState::Connected;
278                            events.push(SimpleConnectionEvent::ConnectionEstablished {
279                                peer_id,
280                                address: attempt.known_addresses[0],
281                            });
282                            return true;
283                        }
284                    }
285                }
286                SimpleAttemptState::CandidateDiscovery => {
287                    // Wait for discovery events
288                }
289                SimpleAttemptState::NatTraversal => {
290                    if elapsed > Duration::from_secs(5) {
291                        // Simulate NAT traversal success
292                        if !attempt.discovered_candidates.is_empty() {
293                            attempt.state = SimpleAttemptState::Connected;
294                            events.push(SimpleConnectionEvent::ConnectionEstablished {
295                                peer_id,
296                                address: attempt.discovered_candidates[0].address,
297                            });
298                            return true;
299                        }
300                    }
301                }
302                SimpleAttemptState::Connected | SimpleAttemptState::Failed => {
303                    return true;
304                }
305            }
306
307            false
308        };
309
310        should_complete
311    }
312
313    fn handle_discovery_event(
314        &mut self,
315        discovery_event: DiscoveryEvent,
316        events: &mut Vec<SimpleConnectionEvent>,
317    ) {
318        match discovery_event {
319            DiscoveryEvent::LocalCandidateDiscovered { candidate } |
320            DiscoveryEvent::ServerReflexiveCandidateDiscovered { candidate, .. } |
321            DiscoveryEvent::PredictedCandidateGenerated { candidate, .. } => {
322                // Add candidate to relevant attempts
323                for attempt in self.active_attempts.values_mut() {
324                    if attempt.state == SimpleAttemptState::CandidateDiscovery {
325                        attempt.discovered_candidates.push(candidate.clone());
326                    }
327                }
328            }
329            DiscoveryEvent::DiscoveryCompleted { .. } => {
330                // Transition attempts to NAT traversal
331                let peer_ids: Vec<_> = self.active_attempts.iter()
332                    .filter(|(_, a)| a.state == SimpleAttemptState::CandidateDiscovery)
333                    .map(|(peer_id, _)| *peer_id)
334                    .collect();
335
336                for peer_id in peer_ids {
337                    if let Some(attempt) = self.active_attempts.get_mut(&peer_id) {
338                        attempt.state = SimpleAttemptState::NatTraversal;
339                        events.push(SimpleConnectionEvent::NatTraversalStarted { peer_id });
340                    }
341                }
342            }
343            DiscoveryEvent::DiscoveryFailed { error, .. } => {
344                warn!("Discovery failed: {:?}", error);
345                // Mark relevant attempts as failed
346                let peer_ids: Vec<_> = self.active_attempts.iter()
347                    .filter(|(_, a)| a.state == SimpleAttemptState::CandidateDiscovery)
348                    .map(|(peer_id, _)| *peer_id)
349                    .collect();
350
351                for peer_id in peer_ids {
352                    if let Some(attempt) = self.active_attempts.get_mut(&peer_id) {
353                        attempt.state = SimpleAttemptState::Failed;
354                        attempt.last_error = Some(format!("Discovery failed: {:?}", error));
355                        events.push(SimpleConnectionEvent::ConnectionFailed {
356                            peer_id,
357                            error: format!("Discovery failed: {:?}", error),
358                        });
359                    }
360                }
361            }
362            _ => {
363                // Handle other events as needed
364            }
365        }
366    }
367
368    fn emit_event(&self, event: SimpleConnectionEvent) {
369        if let Some(ref callback) = self.event_callback {
370            callback(event);
371        }
372    }
373
374    /// Get current status
375    pub fn get_status(&self) -> SimpleConnectionStatus {
376        let mut direct_attempts = 0;
377        let mut nat_traversal_attempts = 0;
378        let mut connected = 0;
379        let mut failed = 0;
380
381        for attempt in self.active_attempts.values() {
382            match attempt.state {
383                SimpleAttemptState::DirectConnection => direct_attempts += 1,
384                SimpleAttemptState::CandidateDiscovery | SimpleAttemptState::NatTraversal => {
385                    nat_traversal_attempts += 1
386                }
387                SimpleAttemptState::Connected => connected += 1,
388                SimpleAttemptState::Failed => failed += 1,
389            }
390        }
391
392        SimpleConnectionStatus {
393            total_attempts: self.active_attempts.len(),
394            direct_attempts,
395            nat_traversal_attempts,
396            connected,
397            failed,
398        }
399    }
400}
401
402/// Status information
403#[derive(Debug, Clone)]
404pub struct SimpleConnectionStatus {
405    pub total_attempts: usize,
406    pub direct_attempts: usize,
407    pub nat_traversal_attempts: usize,
408    pub connected: usize,
409    pub failed: usize,
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415
416    #[test]
417    fn test_simple_connection_manager_creation() {
418        let config = SimpleEstablishmentConfig::default();
419        let discovery_manager = Arc::new(std::sync::Mutex::new(
420            crate::candidate_discovery::CandidateDiscoveryManager::new(
421                crate::candidate_discovery::DiscoveryConfig::default(),
422                crate::connection::nat_traversal::NatTraversalRole::Client,
423            )
424        ));
425
426        let _manager = SimpleConnectionEstablishmentManager::new(
427            config,
428            discovery_manager,
429            Vec::new(),
430            None,
431        );
432    }
433
434    #[test]
435    fn test_connect_to_peer() {
436        let config = SimpleEstablishmentConfig::default();
437        let discovery_manager = Arc::new(std::sync::Mutex::new(
438            crate::candidate_discovery::CandidateDiscoveryManager::new(
439                crate::candidate_discovery::DiscoveryConfig::default(),
440                crate::connection::nat_traversal::NatTraversalRole::Client,
441            )
442        ));
443
444        let mut manager = SimpleConnectionEstablishmentManager::new(
445            config,
446            discovery_manager,
447            Vec::new(),
448            None,
449        );
450
451        let peer_id = PeerId([1; 32]);
452        let addresses = vec![SocketAddr::from(([127, 0, 0, 1], 8080))];
453
454        assert!(manager.connect_to_peer(peer_id, addresses).is_ok());
455        
456        // Try to connect again - should fail
457        assert!(manager.connect_to_peer(peer_id, Vec::new()).is_err());
458    }
459}