1use std::{
7 collections::HashMap,
8 net::SocketAddr,
9 sync::Arc,
10 time::{Duration, Instant},
11};
12
13use tracing::{debug, info, warn};
14
15use crate::{
16 candidate_discovery::{CandidateDiscoveryManager, DiscoveryError, DiscoveryEvent},
17 high_level::{Connection as QuinnConnection, Endpoint as QuinnEndpoint},
18 nat_traversal_api::{BootstrapNode, CandidateAddress, PeerId},
19};
20
21pub struct SimpleConnectionEstablishmentManager {
23 config: SimpleEstablishmentConfig,
25 active_attempts: HashMap<PeerId, SimpleConnectionAttempt>,
27 discovery_manager: Arc<std::sync::Mutex<CandidateDiscoveryManager>>,
29 bootstrap_nodes: Vec<BootstrapNode>,
31 event_callback: Option<Box<dyn Fn(SimpleConnectionEvent) + Send + Sync>>,
33 quinn_endpoint: Option<Arc<QuinnEndpoint>>,
35}
36
37#[derive(Debug, Clone)]
39pub struct SimpleEstablishmentConfig {
40 pub direct_timeout: Duration,
42 pub nat_traversal_timeout: Duration,
44 pub enable_nat_traversal: bool,
46 pub max_retries: u32,
48}
49
50#[derive(Debug)]
52struct SimpleConnectionAttempt {
53 peer_id: PeerId,
54 state: SimpleAttemptState,
55 started_at: Instant,
56 attempt_number: u32,
57 known_addresses: Vec<SocketAddr>,
58 discovered_candidates: Vec<CandidateAddress>,
59 last_error: Option<String>,
60 connection_handles: Vec<tokio::task::JoinHandle<Result<QuinnConnection, String>>>,
62 target_addresses: Vec<SocketAddr>,
64 established_connection: Option<QuinnConnection>,
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70enum SimpleAttemptState {
71 DirectConnection,
72 CandidateDiscovery,
73 NatTraversal,
74 Connected,
75 Failed,
76}
77
78#[derive(Debug, Clone)]
80pub enum SimpleConnectionEvent {
81 AttemptStarted {
82 peer_id: PeerId,
83 },
84 DirectConnectionTried {
85 peer_id: PeerId,
86 address: SocketAddr,
87 },
88 CandidateDiscoveryStarted {
89 peer_id: PeerId,
90 },
91 NatTraversalStarted {
92 peer_id: PeerId,
93 },
94 DirectConnectionSucceeded {
95 peer_id: PeerId,
96 address: SocketAddr,
97 },
98 DirectConnectionFailed {
99 peer_id: PeerId,
100 address: SocketAddr,
101 error: String,
102 },
103 ConnectionEstablished {
104 peer_id: PeerId,
105 },
106 ConnectionFailed {
107 peer_id: PeerId,
108 error: String,
109 },
110}
111
112impl Default for SimpleEstablishmentConfig {
113 fn default() -> Self {
114 Self {
115 direct_timeout: Duration::from_secs(5),
116 nat_traversal_timeout: Duration::from_secs(30),
117 enable_nat_traversal: true,
118 max_retries: 3,
119 }
120 }
121}
122
123impl SimpleConnectionEstablishmentManager {
124 pub fn new(
126 config: SimpleEstablishmentConfig,
127 discovery_manager: Arc<std::sync::Mutex<CandidateDiscoveryManager>>,
128 bootstrap_nodes: Vec<BootstrapNode>,
129 event_callback: Option<Box<dyn Fn(SimpleConnectionEvent) + Send + Sync>>,
130 quinn_endpoint: Option<Arc<QuinnEndpoint>>,
131 ) -> Self {
132 Self {
133 config,
134 active_attempts: HashMap::new(),
135 discovery_manager,
136 bootstrap_nodes,
137 event_callback,
138 quinn_endpoint,
139 }
140 }
141
142 pub fn connect_to_peer(
144 &mut self,
145 peer_id: PeerId,
146 known_addresses: Vec<SocketAddr>,
147 ) -> Result<(), String> {
148 if self.active_attempts.contains_key(&peer_id) {
150 return Err("Connection attempt already in progress".to_string());
151 }
152
153 let attempt = SimpleConnectionAttempt {
155 peer_id,
156 state: SimpleAttemptState::DirectConnection,
157 started_at: Instant::now(),
158 attempt_number: 1,
159 known_addresses: known_addresses.clone(),
160 discovered_candidates: Vec::new(),
161 last_error: None,
162 connection_handles: Vec::new(),
163 target_addresses: Vec::new(),
164 established_connection: None,
165 };
166
167 self.active_attempts.insert(peer_id, attempt);
168
169 self.emit_event(SimpleConnectionEvent::AttemptStarted { peer_id });
171
172 if !known_addresses.is_empty() {
174 info!("Starting direct connection attempt to peer {:?}", peer_id);
175
176 if let Some(ref quinn_endpoint) = self.quinn_endpoint {
178 self.start_direct_connections(peer_id, &known_addresses, quinn_endpoint.clone())?;
179 } else {
180 for address in &known_addresses {
182 self.emit_event(SimpleConnectionEvent::DirectConnectionTried {
183 peer_id,
184 address: *address,
185 });
186 }
187 }
188 } else if self.config.enable_nat_traversal {
189 self.start_candidate_discovery(peer_id)?;
191 } else {
192 return Err("No known addresses and NAT traversal disabled".to_string());
193 }
194
195 Ok(())
196 }
197
198 fn start_direct_connections(
200 &mut self,
201 peer_id: PeerId,
202 addresses: &[SocketAddr],
203 endpoint: Arc<QuinnEndpoint>,
204 ) -> Result<(), String> {
205 let attempt = self
206 .active_attempts
207 .get_mut(&peer_id)
208 .ok_or("Attempt not found")?;
209
210 let mut events_to_emit = Vec::new();
212
213 for &address in addresses {
214 let server_name = format!("peer-{:x}", peer_id.0[0] as u32);
215 let endpoint_clone = endpoint.clone();
216
217 let handle = tokio::spawn(async move {
219 let connecting = endpoint_clone
220 .connect(address, &server_name)
221 .map_err(|e| format!("Failed to start connection: {}", e))?;
222
223 match tokio::time::timeout(Duration::from_secs(10), connecting).await {
225 Ok(connection_result) => {
226 connection_result.map_err(|e| format!("Connection failed: {}", e))
227 }
228 Err(_) => Err("Connection timed out".to_string()),
229 }
230 });
231
232 attempt.connection_handles.push(handle);
233 attempt.target_addresses.push(address);
234
235 events_to_emit.push(SimpleConnectionEvent::DirectConnectionTried { peer_id, address });
237 }
238
239 for event in events_to_emit {
241 self.emit_event(event);
242 }
243
244 debug!(
245 "Started {} direct connections for peer {:?}",
246 addresses.len(),
247 peer_id
248 );
249 Ok(())
250 }
251
252 pub fn poll(&mut self, now: Instant) -> Vec<SimpleConnectionEvent> {
254 let mut events = Vec::new();
255
256 let discovery_events = match self.discovery_manager.lock() {
258 Ok(mut discovery) => discovery.poll(now),
259 _ => Vec::new(),
260 };
261
262 for discovery_event in discovery_events {
263 self.handle_discovery_event(discovery_event, &mut events);
264 }
265
266 let peer_ids: Vec<_> = self.active_attempts.keys().copied().collect();
268 let mut completed = Vec::new();
269
270 for peer_id in peer_ids {
271 if self.poll_attempt(peer_id, now, &mut events) {
272 completed.push(peer_id);
273 }
274 }
275
276 for peer_id in completed {
278 self.active_attempts.remove(&peer_id);
279 }
280
281 events
282 }
283
284 pub fn cancel_connection(&mut self, peer_id: PeerId) -> bool {
286 self.active_attempts.remove(&peer_id).is_some()
287 }
288
289 fn start_candidate_discovery(&mut self, peer_id: PeerId) -> Result<(), String> {
292 if let Some(attempt) = self.active_attempts.get_mut(&peer_id) {
293 attempt.state = SimpleAttemptState::CandidateDiscovery;
294
295 match self.discovery_manager.lock() {
296 Ok(mut discovery) => {
297 discovery
298 .start_discovery(peer_id, self.bootstrap_nodes.clone())
299 .map_err(|e| format!("Discovery failed: {e:?}"))?;
300 }
301 _ => {
302 return Err("Failed to lock discovery manager".to_string());
303 }
304 }
305
306 self.emit_event(SimpleConnectionEvent::CandidateDiscoveryStarted { peer_id });
307 }
308
309 Ok(())
310 }
311
312 fn poll_attempt(
313 &mut self,
314 peer_id: PeerId,
315 now: Instant,
316 events: &mut Vec<SimpleConnectionEvent>,
317 ) -> bool {
318 let attempt = match self.active_attempts.get_mut(&peer_id) {
319 Some(a) => a,
320 None => return true,
321 };
322
323 let elapsed = now.duration_since(attempt.started_at);
324 let timeout = match attempt.state {
325 SimpleAttemptState::DirectConnection => self.config.direct_timeout,
326 _ => self.config.nat_traversal_timeout,
327 };
328
329 if elapsed > timeout {
331 match attempt.state {
332 SimpleAttemptState::DirectConnection if self.config.enable_nat_traversal => {
333 info!(
335 "Direct connection timed out for peer {:?}, starting NAT traversal",
336 peer_id
337 );
338 attempt.state = SimpleAttemptState::CandidateDiscovery;
339
340 let discovery_result = match self.discovery_manager.lock() {
342 Ok(mut discovery) => {
343 discovery.start_discovery(peer_id, self.bootstrap_nodes.clone())
344 }
345 _ => Err(DiscoveryError::InternalError(
346 "Failed to lock discovery manager".to_string(),
347 )),
348 };
349
350 if let Err(e) = discovery_result {
351 attempt.state = SimpleAttemptState::Failed;
352 attempt.last_error = Some(format!("Discovery failed: {e:?}"));
353 events.push(SimpleConnectionEvent::ConnectionFailed {
354 peer_id,
355 error: format!("Discovery failed: {e:?}"),
356 });
357 return true;
358 }
359
360 events.push(SimpleConnectionEvent::CandidateDiscoveryStarted { peer_id });
361 return false;
362 }
363 _ => {
364 attempt.state = SimpleAttemptState::Failed;
366 attempt.last_error = Some("Timeout exceeded".to_string());
367 events.push(SimpleConnectionEvent::ConnectionFailed {
368 peer_id,
369 error: "Timeout exceeded".to_string(),
370 });
371 return true;
372 }
373 }
374 }
375
376 let has_connection_handles = !attempt.connection_handles.is_empty();
378
379 if has_connection_handles {
380 let mut connection_handles = std::mem::take(&mut attempt.connection_handles);
382 let mut target_addresses = std::mem::take(&mut attempt.target_addresses);
383 let mut established_connection = attempt.established_connection.take();
384
385 Self::poll_connection_handles_extracted(
386 peer_id,
387 &mut connection_handles,
388 &mut target_addresses,
389 &mut established_connection,
390 events,
391 );
392
393 attempt.connection_handles = connection_handles;
395 attempt.target_addresses = target_addresses;
396 attempt.established_connection = established_connection;
397
398 if attempt.established_connection.is_some() {
400 attempt.state = SimpleAttemptState::Connected;
401 events.push(SimpleConnectionEvent::ConnectionEstablished { peer_id });
402 return true;
403 }
404 }
405
406 match attempt.state {
407 SimpleAttemptState::DirectConnection => {
408 if attempt.connection_handles.is_empty()
410 || attempt.connection_handles.iter().all(|h| h.is_finished())
411 {
412 if attempt.established_connection.is_none() && self.config.enable_nat_traversal
414 {
415 debug!(
417 "Direct connections failed for peer {:?}, trying NAT traversal",
418 peer_id
419 );
420 attempt.state = SimpleAttemptState::CandidateDiscovery;
421 }
423 }
424 }
425 SimpleAttemptState::CandidateDiscovery => {
426 }
428 SimpleAttemptState::NatTraversal => {
429 debug!("Polling NAT traversal attempts for peer {:?}", peer_id);
431 }
432 SimpleAttemptState::Connected | SimpleAttemptState::Failed => {
433 return true;
434 }
435 }
436
437 false
438 }
439
440 fn poll_connection_handles_extracted(
442 peer_id: PeerId,
443 connection_handles: &mut Vec<tokio::task::JoinHandle<Result<QuinnConnection, String>>>,
444 target_addresses: &mut Vec<SocketAddr>,
445 established_connection: &mut Option<QuinnConnection>,
446 events: &mut Vec<SimpleConnectionEvent>,
447 ) -> bool {
448 let mut completed_indices = Vec::new();
449
450 for (index, handle) in connection_handles.iter_mut().enumerate() {
451 if handle.is_finished() {
452 completed_indices.push(index);
453 }
454 }
455
456 for &index in completed_indices.iter().rev() {
458 let handle = connection_handles.remove(index);
459 let target_address = target_addresses.remove(index);
460
461 match tokio::runtime::Handle::try_current() {
462 Ok(runtime_handle) => {
463 match runtime_handle.block_on(handle) {
464 Ok(Ok(connection)) => {
465 info!(
467 "QUIC connection established to {} for peer {:?}",
468 target_address, peer_id
469 );
470 *established_connection = Some(connection);
471
472 events.push(SimpleConnectionEvent::DirectConnectionSucceeded {
473 peer_id,
474 address: target_address,
475 });
476
477 for remaining_handle in connection_handles.drain(..) {
479 remaining_handle.abort();
480 }
481 target_addresses.clear();
482
483 return true; }
485 Ok(Err(e)) => {
486 warn!("QUIC connection to {} failed: {}", target_address, e);
488
489 events.push(SimpleConnectionEvent::DirectConnectionFailed {
490 peer_id,
491 address: target_address,
492 error: e,
493 });
494 }
495 Err(join_error) => {
496 warn!("QUIC connection task failed: {}", join_error);
498
499 events.push(SimpleConnectionEvent::DirectConnectionFailed {
500 peer_id,
501 address: target_address,
502 error: format!("Task failed: {}", join_error),
503 });
504 }
505 }
506 }
507 Err(_) => {
508 warn!("Unable to check connection result without tokio runtime");
510 }
511 }
512 }
513
514 false
515 }
516
517 fn handle_discovery_event(
518 &mut self,
519 discovery_event: DiscoveryEvent,
520 events: &mut Vec<SimpleConnectionEvent>,
521 ) {
522 match discovery_event {
523 DiscoveryEvent::LocalCandidateDiscovered { candidate }
524 | DiscoveryEvent::ServerReflexiveCandidateDiscovered { candidate, .. }
525 | DiscoveryEvent::PredictedCandidateGenerated { candidate, .. } => {
526 for attempt in self.active_attempts.values_mut() {
528 if attempt.state == SimpleAttemptState::CandidateDiscovery {
529 attempt.discovered_candidates.push(candidate.clone());
530 }
531 }
532 }
533 DiscoveryEvent::DiscoveryCompleted { .. } => {
534 let peer_ids: Vec<_> = self
536 .active_attempts
537 .iter()
538 .filter(|(_, a)| a.state == SimpleAttemptState::CandidateDiscovery)
539 .map(|(peer_id, _)| *peer_id)
540 .collect();
541
542 for peer_id in peer_ids {
543 if let Some(attempt) = self.active_attempts.get_mut(&peer_id) {
544 attempt.state = SimpleAttemptState::NatTraversal;
545 events.push(SimpleConnectionEvent::NatTraversalStarted { peer_id });
546 }
547 }
548 }
549 DiscoveryEvent::DiscoveryFailed { error, .. } => {
550 warn!("Discovery failed: {:?}", error);
551 let peer_ids: Vec<_> = self
553 .active_attempts
554 .iter()
555 .filter(|(_, a)| a.state == SimpleAttemptState::CandidateDiscovery)
556 .map(|(peer_id, _)| *peer_id)
557 .collect();
558
559 for peer_id in peer_ids {
560 if let Some(attempt) = self.active_attempts.get_mut(&peer_id) {
561 attempt.state = SimpleAttemptState::Failed;
562 attempt.last_error = Some(format!("Discovery failed: {error:?}"));
563 events.push(SimpleConnectionEvent::ConnectionFailed {
564 peer_id,
565 error: format!("Discovery failed: {error:?}"),
566 });
567 }
568 }
569 }
570 _ => {
571 }
573 }
574 }
575
576 fn emit_event(&self, event: SimpleConnectionEvent) {
577 if let Some(ref callback) = self.event_callback {
578 callback(event);
579 }
580 }
581}