1use std::{
7 collections::HashMap,
8 net::SocketAddr,
9 sync::Arc,
10 time::{Duration, Instant},
11};
12
13use tracing::{debug, info, warn};
14
15use crate::{
16 candidate_discovery::{CandidateDiscoveryManager, DiscoveryError, DiscoveryEvent},
17 nat_traversal_api::{BootstrapNode, CandidateAddress, PeerId},
18 high_level::{Endpoint as QuinnEndpoint, Connection as QuinnConnection},
19};
20
21pub struct SimpleConnectionEstablishmentManager {
23 config: SimpleEstablishmentConfig,
25 active_attempts: HashMap<PeerId, SimpleConnectionAttempt>,
27 discovery_manager: Arc<std::sync::Mutex<CandidateDiscoveryManager>>,
29 bootstrap_nodes: Vec<BootstrapNode>,
31 event_callback: Option<Box<dyn Fn(SimpleConnectionEvent) + Send + Sync>>,
33 quinn_endpoint: Option<Arc<QuinnEndpoint>>,
35}
36
37#[derive(Debug, Clone)]
39pub struct SimpleEstablishmentConfig {
40 pub direct_timeout: Duration,
42 pub nat_traversal_timeout: Duration,
44 pub enable_nat_traversal: bool,
46 pub max_retries: u32,
48}
49
50#[derive(Debug)]
52struct SimpleConnectionAttempt {
53 peer_id: PeerId,
54 state: SimpleAttemptState,
55 started_at: Instant,
56 attempt_number: u32,
57 known_addresses: Vec<SocketAddr>,
58 discovered_candidates: Vec<CandidateAddress>,
59 last_error: Option<String>,
60 connection_handles: Vec<tokio::task::JoinHandle<Result<QuinnConnection, String>>>,
62 target_addresses: Vec<SocketAddr>,
64 established_connection: Option<QuinnConnection>,
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70enum SimpleAttemptState {
71 DirectConnection,
72 CandidateDiscovery,
73 NatTraversal,
74 Connected,
75 Failed,
76}
77
78#[derive(Debug, Clone)]
80pub enum SimpleConnectionEvent {
81 AttemptStarted {
82 peer_id: PeerId,
83 },
84 DirectConnectionTried {
85 peer_id: PeerId,
86 address: SocketAddr,
87 },
88 CandidateDiscoveryStarted {
89 peer_id: PeerId,
90 },
91 NatTraversalStarted {
92 peer_id: PeerId,
93 },
94 DirectConnectionSucceeded {
95 peer_id: PeerId,
96 address: SocketAddr,
97 },
98 DirectConnectionFailed {
99 peer_id: PeerId,
100 address: SocketAddr,
101 error: String,
102 },
103 ConnectionEstablished {
104 peer_id: PeerId,
105 },
106 ConnectionFailed {
107 peer_id: PeerId,
108 error: String,
109 },
110}
111
112impl Default for SimpleEstablishmentConfig {
113 fn default() -> Self {
114 Self {
115 direct_timeout: Duration::from_secs(5),
116 nat_traversal_timeout: Duration::from_secs(30),
117 enable_nat_traversal: true,
118 max_retries: 3,
119 }
120 }
121}
122
123impl SimpleConnectionEstablishmentManager {
124 pub fn new(
126 config: SimpleEstablishmentConfig,
127 discovery_manager: Arc<std::sync::Mutex<CandidateDiscoveryManager>>,
128 bootstrap_nodes: Vec<BootstrapNode>,
129 event_callback: Option<Box<dyn Fn(SimpleConnectionEvent) + Send + Sync>>,
130 quinn_endpoint: Option<Arc<QuinnEndpoint>>,
131 ) -> Self {
132 Self {
133 config,
134 active_attempts: HashMap::new(),
135 discovery_manager,
136 bootstrap_nodes,
137 event_callback,
138 quinn_endpoint,
139 }
140 }
141
142 pub fn connect_to_peer(
144 &mut self,
145 peer_id: PeerId,
146 known_addresses: Vec<SocketAddr>,
147 ) -> Result<(), String> {
148 if self.active_attempts.contains_key(&peer_id) {
150 return Err("Connection attempt already in progress".to_string());
151 }
152
153 let attempt = SimpleConnectionAttempt {
155 peer_id,
156 state: SimpleAttemptState::DirectConnection,
157 started_at: Instant::now(),
158 attempt_number: 1,
159 known_addresses: known_addresses.clone(),
160 discovered_candidates: Vec::new(),
161 last_error: None,
162 connection_handles: Vec::new(),
163 target_addresses: Vec::new(),
164 established_connection: None,
165 };
166
167 self.active_attempts.insert(peer_id, attempt);
168
169 self.emit_event(SimpleConnectionEvent::AttemptStarted { peer_id });
171
172 if !known_addresses.is_empty() {
174 info!("Starting direct connection attempt to peer {:?}", peer_id);
175
176 if let Some(ref quinn_endpoint) = self.quinn_endpoint {
178 self.start_direct_connections(peer_id, &known_addresses, quinn_endpoint.clone())?;
179 } else {
180 for address in &known_addresses {
182 self.emit_event(SimpleConnectionEvent::DirectConnectionTried {
183 peer_id,
184 address: *address,
185 });
186 }
187 }
188 } else if self.config.enable_nat_traversal {
189 self.start_candidate_discovery(peer_id)?;
191 } else {
192 return Err("No known addresses and NAT traversal disabled".to_string());
193 }
194
195 Ok(())
196 }
197
198 fn start_direct_connections(
200 &mut self,
201 peer_id: PeerId,
202 addresses: &[SocketAddr],
203 endpoint: Arc<QuinnEndpoint>,
204 ) -> Result<(), String> {
205 let attempt = self.active_attempts.get_mut(&peer_id)
206 .ok_or("Attempt not found")?;
207
208 let mut events_to_emit = Vec::new();
210
211 for &address in addresses {
212 let server_name = format!("peer-{:x}", peer_id.0[0] as u32);
213 let endpoint_clone = endpoint.clone();
214
215 let handle = tokio::spawn(async move {
217 let connecting = endpoint_clone.connect(address, &server_name)
218 .map_err(|e| format!("Failed to start connection: {}", e))?;
219
220 match tokio::time::timeout(Duration::from_secs(10), connecting).await {
222 Ok(connection_result) => connection_result
223 .map_err(|e| format!("Connection failed: {}", e)),
224 Err(_) => Err("Connection timed out".to_string()),
225 }
226 });
227
228 attempt.connection_handles.push(handle);
229 attempt.target_addresses.push(address);
230
231 events_to_emit.push(SimpleConnectionEvent::DirectConnectionTried {
233 peer_id,
234 address,
235 });
236 }
237
238 for event in events_to_emit {
240 self.emit_event(event);
241 }
242
243 debug!("Started {} direct connections for peer {:?}", addresses.len(), peer_id);
244 Ok(())
245 }
246
247 pub fn poll(&mut self, now: Instant) -> Vec<SimpleConnectionEvent> {
249 let mut events = Vec::new();
250
251 let discovery_events = match self.discovery_manager.lock() {
253 Ok(mut discovery) => discovery.poll(now),
254 _ => Vec::new(),
255 };
256
257 for discovery_event in discovery_events {
258 self.handle_discovery_event(discovery_event, &mut events);
259 }
260
261 let peer_ids: Vec<_> = self.active_attempts.keys().copied().collect();
263 let mut completed = Vec::new();
264
265 for peer_id in peer_ids {
266 if self.poll_attempt(peer_id, now, &mut events) {
267 completed.push(peer_id);
268 }
269 }
270
271 for peer_id in completed {
273 self.active_attempts.remove(&peer_id);
274 }
275
276 events
277 }
278
279 pub fn cancel_connection(&mut self, peer_id: PeerId) -> bool {
281 self.active_attempts.remove(&peer_id).is_some()
282 }
283
284 fn start_candidate_discovery(&mut self, peer_id: PeerId) -> Result<(), String> {
287 if let Some(attempt) = self.active_attempts.get_mut(&peer_id) {
288 attempt.state = SimpleAttemptState::CandidateDiscovery;
289
290 match self.discovery_manager.lock() {
291 Ok(mut discovery) => {
292 discovery
293 .start_discovery(peer_id, self.bootstrap_nodes.clone())
294 .map_err(|e| format!("Discovery failed: {e:?}"))?;
295 }
296 _ => {
297 return Err("Failed to lock discovery manager".to_string());
298 }
299 }
300
301 self.emit_event(SimpleConnectionEvent::CandidateDiscoveryStarted { peer_id });
302 }
303
304 Ok(())
305 }
306
307 fn poll_attempt(
308 &mut self,
309 peer_id: PeerId,
310 now: Instant,
311 events: &mut Vec<SimpleConnectionEvent>,
312 ) -> bool {
313 let attempt = match self.active_attempts.get_mut(&peer_id) {
314 Some(a) => a,
315 None => return true,
316 };
317
318 let elapsed = now.duration_since(attempt.started_at);
319 let timeout = match attempt.state {
320 SimpleAttemptState::DirectConnection => self.config.direct_timeout,
321 _ => self.config.nat_traversal_timeout,
322 };
323
324 if elapsed > timeout {
326 match attempt.state {
327 SimpleAttemptState::DirectConnection if self.config.enable_nat_traversal => {
328 info!(
330 "Direct connection timed out for peer {:?}, starting NAT traversal",
331 peer_id
332 );
333 attempt.state = SimpleAttemptState::CandidateDiscovery;
334
335 let discovery_result = match self.discovery_manager.lock() {
337 Ok(mut discovery) => {
338 discovery.start_discovery(peer_id, self.bootstrap_nodes.clone())
339 }
340 _ => Err(DiscoveryError::InternalError(
341 "Failed to lock discovery manager".to_string(),
342 )),
343 };
344
345 if let Err(e) = discovery_result {
346 attempt.state = SimpleAttemptState::Failed;
347 attempt.last_error = Some(format!("Discovery failed: {e:?}"));
348 events.push(SimpleConnectionEvent::ConnectionFailed {
349 peer_id,
350 error: format!("Discovery failed: {e:?}"),
351 });
352 return true;
353 }
354
355 events.push(SimpleConnectionEvent::CandidateDiscoveryStarted { peer_id });
356 return false;
357 }
358 _ => {
359 attempt.state = SimpleAttemptState::Failed;
361 attempt.last_error = Some("Timeout exceeded".to_string());
362 events.push(SimpleConnectionEvent::ConnectionFailed {
363 peer_id,
364 error: "Timeout exceeded".to_string(),
365 });
366 return true;
367 }
368 }
369 }
370
371 let has_connection_handles = !attempt.connection_handles.is_empty();
373
374 if has_connection_handles {
375 let mut connection_handles = std::mem::take(&mut attempt.connection_handles);
377 let mut target_addresses = std::mem::take(&mut attempt.target_addresses);
378 let mut established_connection = attempt.established_connection.take();
379
380 Self::poll_connection_handles_extracted(
381 peer_id,
382 &mut connection_handles,
383 &mut target_addresses,
384 &mut established_connection,
385 events,
386 );
387
388 attempt.connection_handles = connection_handles;
390 attempt.target_addresses = target_addresses;
391 attempt.established_connection = established_connection;
392
393 if attempt.established_connection.is_some() {
395 attempt.state = SimpleAttemptState::Connected;
396 events.push(SimpleConnectionEvent::ConnectionEstablished { peer_id });
397 return true;
398 }
399 }
400
401 match attempt.state {
402 SimpleAttemptState::DirectConnection => {
403 if attempt.connection_handles.is_empty() ||
405 attempt.connection_handles.iter().all(|h| h.is_finished()) {
406 if attempt.established_connection.is_none() && self.config.enable_nat_traversal {
408 debug!("Direct connections failed for peer {:?}, trying NAT traversal", peer_id);
410 attempt.state = SimpleAttemptState::CandidateDiscovery;
411 }
413 }
414 }
415 SimpleAttemptState::CandidateDiscovery => {
416 }
418 SimpleAttemptState::NatTraversal => {
419 debug!("Polling NAT traversal attempts for peer {:?}", peer_id);
421 }
422 SimpleAttemptState::Connected | SimpleAttemptState::Failed => {
423 return true;
424 }
425 }
426
427 false
428 }
429
430 fn poll_connection_handles_extracted(
432 peer_id: PeerId,
433 connection_handles: &mut Vec<tokio::task::JoinHandle<Result<QuinnConnection, String>>>,
434 target_addresses: &mut Vec<SocketAddr>,
435 established_connection: &mut Option<QuinnConnection>,
436 events: &mut Vec<SimpleConnectionEvent>,
437 ) -> bool {
438 let mut completed_indices = Vec::new();
439
440 for (index, handle) in connection_handles.iter_mut().enumerate() {
441 if handle.is_finished() {
442 completed_indices.push(index);
443 }
444 }
445
446 for &index in completed_indices.iter().rev() {
448 let handle = connection_handles.remove(index);
449 let target_address = target_addresses.remove(index);
450
451 match tokio::runtime::Handle::try_current() {
452 Ok(runtime_handle) => {
453 match runtime_handle.block_on(handle) {
454 Ok(Ok(connection)) => {
455 info!("QUIC connection established to {} for peer {:?}", target_address, peer_id);
457 *established_connection = Some(connection);
458
459 events.push(SimpleConnectionEvent::DirectConnectionSucceeded {
460 peer_id,
461 address: target_address,
462 });
463
464 for remaining_handle in connection_handles.drain(..) {
466 remaining_handle.abort();
467 }
468 target_addresses.clear();
469
470 return true; }
472 Ok(Err(e)) => {
473 warn!("QUIC connection to {} failed: {}", target_address, e);
475
476 events.push(SimpleConnectionEvent::DirectConnectionFailed {
477 peer_id,
478 address: target_address,
479 error: e,
480 });
481 }
482 Err(join_error) => {
483 warn!("QUIC connection task failed: {}", join_error);
485
486 events.push(SimpleConnectionEvent::DirectConnectionFailed {
487 peer_id,
488 address: target_address,
489 error: format!("Task failed: {}", join_error),
490 });
491 }
492 }
493 }
494 Err(_) => {
495 warn!("Unable to check connection result without tokio runtime");
497 }
498 }
499 }
500
501 false
502 }
503
504 fn handle_discovery_event(
505 &mut self,
506 discovery_event: DiscoveryEvent,
507 events: &mut Vec<SimpleConnectionEvent>,
508 ) {
509 match discovery_event {
510 DiscoveryEvent::LocalCandidateDiscovered { candidate }
511 | DiscoveryEvent::ServerReflexiveCandidateDiscovered { candidate, .. }
512 | DiscoveryEvent::PredictedCandidateGenerated { candidate, .. } => {
513 for attempt in self.active_attempts.values_mut() {
515 if attempt.state == SimpleAttemptState::CandidateDiscovery {
516 attempt.discovered_candidates.push(candidate.clone());
517 }
518 }
519 }
520 DiscoveryEvent::DiscoveryCompleted { .. } => {
521 let peer_ids: Vec<_> = self
523 .active_attempts
524 .iter()
525 .filter(|(_, a)| a.state == SimpleAttemptState::CandidateDiscovery)
526 .map(|(peer_id, _)| *peer_id)
527 .collect();
528
529 for peer_id in peer_ids {
530 if let Some(attempt) = self.active_attempts.get_mut(&peer_id) {
531 attempt.state = SimpleAttemptState::NatTraversal;
532 events.push(SimpleConnectionEvent::NatTraversalStarted { peer_id });
533 }
534 }
535 }
536 DiscoveryEvent::DiscoveryFailed { error, .. } => {
537 warn!("Discovery failed: {:?}", error);
538 let peer_ids: Vec<_> = self
540 .active_attempts
541 .iter()
542 .filter(|(_, a)| a.state == SimpleAttemptState::CandidateDiscovery)
543 .map(|(peer_id, _)| *peer_id)
544 .collect();
545
546 for peer_id in peer_ids {
547 if let Some(attempt) = self.active_attempts.get_mut(&peer_id) {
548 attempt.state = SimpleAttemptState::Failed;
549 attempt.last_error = Some(format!("Discovery failed: {error:?}"));
550 events.push(SimpleConnectionEvent::ConnectionFailed {
551 peer_id,
552 error: format!("Discovery failed: {error:?}"),
553 });
554 }
555 }
556 }
557 _ => {
558 }
560 }
561 }
562
563 fn emit_event(&self, event: SimpleConnectionEvent) {
564 if let Some(ref callback) = self.event_callback {
565 callback(event);
566 }
567 }
568}