1use std::{
7 collections::HashMap,
8 net::SocketAddr,
9 sync::Arc,
10 time::{Duration, Instant},
11};
12
13use tracing::{info, warn};
14
15use crate::{
16 candidate_discovery::{CandidateDiscoveryManager, DiscoveryEvent, DiscoveryError},
17 nat_traversal_api::{BootstrapNode, CandidateAddress, PeerId},
18};
19
20pub struct SimpleConnectionEstablishmentManager {
22 config: SimpleEstablishmentConfig,
24 active_attempts: HashMap<PeerId, SimpleConnectionAttempt>,
26 discovery_manager: Arc<std::sync::Mutex<CandidateDiscoveryManager>>,
28 bootstrap_nodes: Vec<BootstrapNode>,
30 event_callback: Option<Box<dyn Fn(SimpleConnectionEvent) + Send + Sync>>,
32}
33
34#[derive(Debug, Clone)]
36pub struct SimpleEstablishmentConfig {
37 pub direct_timeout: Duration,
39 pub nat_traversal_timeout: Duration,
41 pub enable_nat_traversal: bool,
43 pub max_retries: u32,
45}
46
47#[derive(Debug)]
49struct SimpleConnectionAttempt {
50 peer_id: PeerId,
51 state: SimpleAttemptState,
52 started_at: Instant,
53 attempt_number: u32,
54 known_addresses: Vec<SocketAddr>,
55 discovered_candidates: Vec<CandidateAddress>,
56 last_error: Option<String>,
57}
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61enum SimpleAttemptState {
62 DirectConnection,
63 CandidateDiscovery,
64 NatTraversal,
65 Connected,
66 Failed,
67}
68
69#[derive(Debug, Clone)]
71pub enum SimpleConnectionEvent {
72 AttemptStarted { peer_id: PeerId },
73 DirectConnectionTried { peer_id: PeerId, address: SocketAddr },
74 CandidateDiscoveryStarted { peer_id: PeerId },
75 NatTraversalStarted { peer_id: PeerId },
76 ConnectionEstablished { peer_id: PeerId, address: SocketAddr },
77 ConnectionFailed { peer_id: PeerId, error: String },
78}
79
80impl Default for SimpleEstablishmentConfig {
81 fn default() -> Self {
82 Self {
83 direct_timeout: Duration::from_secs(5),
84 nat_traversal_timeout: Duration::from_secs(30),
85 enable_nat_traversal: true,
86 max_retries: 3,
87 }
88 }
89}
90
91impl SimpleConnectionEstablishmentManager {
92 pub fn new(
94 config: SimpleEstablishmentConfig,
95 discovery_manager: Arc<std::sync::Mutex<CandidateDiscoveryManager>>,
96 bootstrap_nodes: Vec<BootstrapNode>,
97 event_callback: Option<Box<dyn Fn(SimpleConnectionEvent) + Send + Sync>>,
98 ) -> Self {
99 Self {
100 config,
101 active_attempts: HashMap::new(),
102 discovery_manager,
103 bootstrap_nodes,
104 event_callback,
105 }
106 }
107
108 pub fn connect_to_peer(
110 &mut self,
111 peer_id: PeerId,
112 known_addresses: Vec<SocketAddr>,
113 ) -> Result<(), String> {
114 if self.active_attempts.contains_key(&peer_id) {
116 return Err("Connection attempt already in progress".to_string());
117 }
118
119 let attempt = SimpleConnectionAttempt {
121 peer_id,
122 state: SimpleAttemptState::DirectConnection,
123 started_at: Instant::now(),
124 attempt_number: 1,
125 known_addresses: known_addresses.clone(),
126 discovered_candidates: Vec::new(),
127 last_error: None,
128 };
129
130 self.active_attempts.insert(peer_id, attempt);
131
132 self.emit_event(SimpleConnectionEvent::AttemptStarted { peer_id });
134
135 if !known_addresses.is_empty() {
137 info!("Starting direct connection attempt to peer {:?}", peer_id);
138 for address in &known_addresses {
139 self.emit_event(SimpleConnectionEvent::DirectConnectionTried {
140 peer_id,
141 address: *address,
142 });
143 }
144 } else if self.config.enable_nat_traversal {
145 self.start_candidate_discovery(peer_id)?;
147 } else {
148 return Err("No known addresses and NAT traversal disabled".to_string());
149 }
150
151 Ok(())
152 }
153
154 pub fn poll(&mut self, now: Instant) -> Vec<SimpleConnectionEvent> {
156 let mut events = Vec::new();
157
158 let discovery_events = if let Ok(mut discovery) = self.discovery_manager.lock() {
160 discovery.poll(now)
161 } else {
162 Vec::new()
163 };
164
165 for discovery_event in discovery_events {
166 self.handle_discovery_event(discovery_event, &mut events);
167 }
168
169 let peer_ids: Vec<_> = self.active_attempts.keys().copied().collect();
171 let mut completed = Vec::new();
172
173 for peer_id in peer_ids {
174 if self.poll_attempt(peer_id, now, &mut events) {
175 completed.push(peer_id);
176 }
177 }
178
179 for peer_id in completed {
181 self.active_attempts.remove(&peer_id);
182 }
183
184 events
185 }
186
187 pub fn cancel_connection(&mut self, peer_id: PeerId) -> bool {
189 self.active_attempts.remove(&peer_id).is_some()
190 }
191
192 fn start_candidate_discovery(&mut self, peer_id: PeerId) -> Result<(), String> {
195 if let Some(attempt) = self.active_attempts.get_mut(&peer_id) {
196 attempt.state = SimpleAttemptState::CandidateDiscovery;
197
198 if let Ok(mut discovery) = self.discovery_manager.lock() {
199 discovery.start_discovery(peer_id, self.bootstrap_nodes.clone())
200 .map_err(|e| format!("Discovery failed: {:?}", e))?;
201 } else {
202 return Err("Failed to lock discovery manager".to_string());
203 }
204
205 self.emit_event(SimpleConnectionEvent::CandidateDiscoveryStarted { peer_id });
206 }
207
208 Ok(())
209 }
210
211 fn poll_attempt(
212 &mut self,
213 peer_id: PeerId,
214 now: Instant,
215 events: &mut Vec<SimpleConnectionEvent>,
216 ) -> bool {
217 let should_complete = {
218 let attempt = match self.active_attempts.get_mut(&peer_id) {
219 Some(a) => a,
220 None => return true,
221 };
222
223 let elapsed = now.duration_since(attempt.started_at);
224 let timeout = match attempt.state {
225 SimpleAttemptState::DirectConnection => self.config.direct_timeout,
226 _ => self.config.nat_traversal_timeout,
227 };
228
229 if elapsed > timeout {
231 match attempt.state {
232 SimpleAttemptState::DirectConnection if self.config.enable_nat_traversal => {
233 info!("Direct connection timed out for peer {:?}, starting NAT traversal", peer_id);
235 attempt.state = SimpleAttemptState::CandidateDiscovery;
236
237 let discovery_result = if let Ok(mut discovery) = self.discovery_manager.lock() {
239 discovery.start_discovery(peer_id, self.bootstrap_nodes.clone())
240 } else {
241 Err(DiscoveryError::InternalError("Failed to lock discovery manager".to_string()))
242 };
243
244 if let Err(e) = discovery_result {
245 attempt.state = SimpleAttemptState::Failed;
246 attempt.last_error = Some(format!("Discovery failed: {:?}", e));
247 events.push(SimpleConnectionEvent::ConnectionFailed {
248 peer_id,
249 error: format!("Discovery failed: {:?}", e),
250 });
251 return true;
252 }
253
254 events.push(SimpleConnectionEvent::CandidateDiscoveryStarted { peer_id });
255 return false;
256 }
257 _ => {
258 attempt.state = SimpleAttemptState::Failed;
260 attempt.last_error = Some("Timeout exceeded".to_string());
261 events.push(SimpleConnectionEvent::ConnectionFailed {
262 peer_id,
263 error: "Timeout exceeded".to_string(),
264 });
265 return true;
266 }
267 }
268 }
269
270 match attempt.state {
273 SimpleAttemptState::DirectConnection => {
274 if elapsed > Duration::from_secs(2) {
275 if !attempt.known_addresses.is_empty() {
277 attempt.state = SimpleAttemptState::Connected;
278 events.push(SimpleConnectionEvent::ConnectionEstablished {
279 peer_id,
280 address: attempt.known_addresses[0],
281 });
282 return true;
283 }
284 }
285 }
286 SimpleAttemptState::CandidateDiscovery => {
287 }
289 SimpleAttemptState::NatTraversal => {
290 if elapsed > Duration::from_secs(5) {
291 if !attempt.discovered_candidates.is_empty() {
293 attempt.state = SimpleAttemptState::Connected;
294 events.push(SimpleConnectionEvent::ConnectionEstablished {
295 peer_id,
296 address: attempt.discovered_candidates[0].address,
297 });
298 return true;
299 }
300 }
301 }
302 SimpleAttemptState::Connected | SimpleAttemptState::Failed => {
303 return true;
304 }
305 }
306
307 false
308 };
309
310 should_complete
311 }
312
313 fn handle_discovery_event(
314 &mut self,
315 discovery_event: DiscoveryEvent,
316 events: &mut Vec<SimpleConnectionEvent>,
317 ) {
318 match discovery_event {
319 DiscoveryEvent::LocalCandidateDiscovered { candidate } |
320 DiscoveryEvent::ServerReflexiveCandidateDiscovered { candidate, .. } |
321 DiscoveryEvent::PredictedCandidateGenerated { candidate, .. } => {
322 for attempt in self.active_attempts.values_mut() {
324 if attempt.state == SimpleAttemptState::CandidateDiscovery {
325 attempt.discovered_candidates.push(candidate.clone());
326 }
327 }
328 }
329 DiscoveryEvent::DiscoveryCompleted { .. } => {
330 let peer_ids: Vec<_> = self.active_attempts.iter()
332 .filter(|(_, a)| a.state == SimpleAttemptState::CandidateDiscovery)
333 .map(|(peer_id, _)| *peer_id)
334 .collect();
335
336 for peer_id in peer_ids {
337 if let Some(attempt) = self.active_attempts.get_mut(&peer_id) {
338 attempt.state = SimpleAttemptState::NatTraversal;
339 events.push(SimpleConnectionEvent::NatTraversalStarted { peer_id });
340 }
341 }
342 }
343 DiscoveryEvent::DiscoveryFailed { error, .. } => {
344 warn!("Discovery failed: {:?}", error);
345 let peer_ids: Vec<_> = self.active_attempts.iter()
347 .filter(|(_, a)| a.state == SimpleAttemptState::CandidateDiscovery)
348 .map(|(peer_id, _)| *peer_id)
349 .collect();
350
351 for peer_id in peer_ids {
352 if let Some(attempt) = self.active_attempts.get_mut(&peer_id) {
353 attempt.state = SimpleAttemptState::Failed;
354 attempt.last_error = Some(format!("Discovery failed: {:?}", error));
355 events.push(SimpleConnectionEvent::ConnectionFailed {
356 peer_id,
357 error: format!("Discovery failed: {:?}", error),
358 });
359 }
360 }
361 }
362 _ => {
363 }
365 }
366 }
367
368 fn emit_event(&self, event: SimpleConnectionEvent) {
369 if let Some(ref callback) = self.event_callback {
370 callback(event);
371 }
372 }
373
374 pub fn get_status(&self) -> SimpleConnectionStatus {
376 let mut direct_attempts = 0;
377 let mut nat_traversal_attempts = 0;
378 let mut connected = 0;
379 let mut failed = 0;
380
381 for attempt in self.active_attempts.values() {
382 match attempt.state {
383 SimpleAttemptState::DirectConnection => direct_attempts += 1,
384 SimpleAttemptState::CandidateDiscovery | SimpleAttemptState::NatTraversal => {
385 nat_traversal_attempts += 1
386 }
387 SimpleAttemptState::Connected => connected += 1,
388 SimpleAttemptState::Failed => failed += 1,
389 }
390 }
391
392 SimpleConnectionStatus {
393 total_attempts: self.active_attempts.len(),
394 direct_attempts,
395 nat_traversal_attempts,
396 connected,
397 failed,
398 }
399 }
400}
401
402#[derive(Debug, Clone)]
404pub struct SimpleConnectionStatus {
405 pub total_attempts: usize,
406 pub direct_attempts: usize,
407 pub nat_traversal_attempts: usize,
408 pub connected: usize,
409 pub failed: usize,
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415
416 #[test]
417 fn test_simple_connection_manager_creation() {
418 let config = SimpleEstablishmentConfig::default();
419 let discovery_manager = Arc::new(std::sync::Mutex::new(
420 crate::candidate_discovery::CandidateDiscoveryManager::new(
421 crate::candidate_discovery::DiscoveryConfig::default(),
422 crate::connection::nat_traversal::NatTraversalRole::Client,
423 )
424 ));
425
426 let _manager = SimpleConnectionEstablishmentManager::new(
427 config,
428 discovery_manager,
429 Vec::new(),
430 None,
431 );
432 }
433
434 #[test]
435 fn test_connect_to_peer() {
436 let config = SimpleEstablishmentConfig::default();
437 let discovery_manager = Arc::new(std::sync::Mutex::new(
438 crate::candidate_discovery::CandidateDiscoveryManager::new(
439 crate::candidate_discovery::DiscoveryConfig::default(),
440 crate::connection::nat_traversal::NatTraversalRole::Client,
441 )
442 ));
443
444 let mut manager = SimpleConnectionEstablishmentManager::new(
445 config,
446 discovery_manager,
447 Vec::new(),
448 None,
449 );
450
451 let peer_id = PeerId([1; 32]);
452 let addresses = vec![SocketAddr::from(([127, 0, 0, 1], 8080))];
453
454 assert!(manager.connect_to_peer(peer_id, addresses).is_ok());
455
456 assert!(manager.connect_to_peer(peer_id, Vec::new()).is_err());
458 }
459}