1use std::{
7 collections::HashMap,
8 net::SocketAddr,
9 sync::Arc,
10 time::{Duration, Instant},
11};
12
13use tracing::{debug, info, warn};
14
15use crate::{
16 candidate_discovery::{CandidateDiscoveryManager, DiscoveryError, DiscoveryEvent},
17 nat_traversal_api::{BootstrapNode, CandidateAddress, PeerId},
18};
19
20pub struct SimpleConnectionEstablishmentManager {
22 config: SimpleEstablishmentConfig,
24 active_attempts: HashMap<PeerId, SimpleConnectionAttempt>,
26 discovery_manager: Arc<std::sync::Mutex<CandidateDiscoveryManager>>,
28 bootstrap_nodes: Vec<BootstrapNode>,
30 event_callback: Option<Box<dyn Fn(SimpleConnectionEvent) + Send + Sync>>,
32}
33
34#[derive(Debug, Clone)]
36pub struct SimpleEstablishmentConfig {
37 pub direct_timeout: Duration,
39 pub nat_traversal_timeout: Duration,
41 pub enable_nat_traversal: bool,
43 pub max_retries: u32,
45}
46
47#[derive(Debug)]
49struct SimpleConnectionAttempt {
50 peer_id: PeerId,
51 state: SimpleAttemptState,
52 started_at: Instant,
53 attempt_number: u32,
54 known_addresses: Vec<SocketAddr>,
55 discovered_candidates: Vec<CandidateAddress>,
56 last_error: Option<String>,
57}
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61enum SimpleAttemptState {
62 DirectConnection,
63 CandidateDiscovery,
64 NatTraversal,
65 Connected,
66 Failed,
67}
68
69#[derive(Debug, Clone)]
71pub enum SimpleConnectionEvent {
72 AttemptStarted {
73 peer_id: PeerId,
74 },
75 DirectConnectionTried {
76 peer_id: PeerId,
77 address: SocketAddr,
78 },
79 CandidateDiscoveryStarted {
80 peer_id: PeerId,
81 },
82 NatTraversalStarted {
83 peer_id: PeerId,
84 },
85 ConnectionEstablished {
86 peer_id: PeerId,
87 address: SocketAddr,
88 },
89 ConnectionFailed {
90 peer_id: PeerId,
91 error: String,
92 },
93}
94
95impl Default for SimpleEstablishmentConfig {
96 fn default() -> Self {
97 Self {
98 direct_timeout: Duration::from_secs(5),
99 nat_traversal_timeout: Duration::from_secs(30),
100 enable_nat_traversal: true,
101 max_retries: 3,
102 }
103 }
104}
105
106impl SimpleConnectionEstablishmentManager {
107 pub fn new(
109 config: SimpleEstablishmentConfig,
110 discovery_manager: Arc<std::sync::Mutex<CandidateDiscoveryManager>>,
111 bootstrap_nodes: Vec<BootstrapNode>,
112 event_callback: Option<Box<dyn Fn(SimpleConnectionEvent) + Send + Sync>>,
113 ) -> Self {
114 Self {
115 config,
116 active_attempts: HashMap::new(),
117 discovery_manager,
118 bootstrap_nodes,
119 event_callback,
120 }
121 }
122
123 pub fn connect_to_peer(
125 &mut self,
126 peer_id: PeerId,
127 known_addresses: Vec<SocketAddr>,
128 ) -> Result<(), String> {
129 if self.active_attempts.contains_key(&peer_id) {
131 return Err("Connection attempt already in progress".to_string());
132 }
133
134 let attempt = SimpleConnectionAttempt {
136 peer_id,
137 state: SimpleAttemptState::DirectConnection,
138 started_at: Instant::now(),
139 attempt_number: 1,
140 known_addresses: known_addresses.clone(),
141 discovered_candidates: Vec::new(),
142 last_error: None,
143 };
144
145 self.active_attempts.insert(peer_id, attempt);
146
147 self.emit_event(SimpleConnectionEvent::AttemptStarted { peer_id });
149
150 if !known_addresses.is_empty() {
152 info!("Starting direct connection attempt to peer {:?}", peer_id);
153 for address in &known_addresses {
154 self.emit_event(SimpleConnectionEvent::DirectConnectionTried {
155 peer_id,
156 address: *address,
157 });
158 }
159 } else if self.config.enable_nat_traversal {
160 self.start_candidate_discovery(peer_id)?;
162 } else {
163 return Err("No known addresses and NAT traversal disabled".to_string());
164 }
165
166 Ok(())
167 }
168
169 pub fn poll(&mut self, now: Instant) -> Vec<SimpleConnectionEvent> {
171 let mut events = Vec::new();
172
173 let discovery_events = if let Ok(mut discovery) = self.discovery_manager.lock() {
175 discovery.poll(now)
176 } else {
177 Vec::new()
178 };
179
180 for discovery_event in discovery_events {
181 self.handle_discovery_event(discovery_event, &mut events);
182 }
183
184 let peer_ids: Vec<_> = self.active_attempts.keys().copied().collect();
186 let mut completed = Vec::new();
187
188 for peer_id in peer_ids {
189 if self.poll_attempt(peer_id, now, &mut events) {
190 completed.push(peer_id);
191 }
192 }
193
194 for peer_id in completed {
196 self.active_attempts.remove(&peer_id);
197 }
198
199 events
200 }
201
202 pub fn cancel_connection(&mut self, peer_id: PeerId) -> bool {
204 self.active_attempts.remove(&peer_id).is_some()
205 }
206
207 fn start_candidate_discovery(&mut self, peer_id: PeerId) -> Result<(), String> {
210 if let Some(attempt) = self.active_attempts.get_mut(&peer_id) {
211 attempt.state = SimpleAttemptState::CandidateDiscovery;
212
213 if let Ok(mut discovery) = self.discovery_manager.lock() {
214 discovery
215 .start_discovery(peer_id, self.bootstrap_nodes.clone())
216 .map_err(|e| format!("Discovery failed: {:?}", e))?;
217 } else {
218 return Err("Failed to lock discovery manager".to_string());
219 }
220
221 self.emit_event(SimpleConnectionEvent::CandidateDiscoveryStarted { peer_id });
222 }
223
224 Ok(())
225 }
226
227 fn poll_attempt(
228 &mut self,
229 peer_id: PeerId,
230 now: Instant,
231 events: &mut Vec<SimpleConnectionEvent>,
232 ) -> bool {
233 let attempt = match self.active_attempts.get_mut(&peer_id) {
234 Some(a) => a,
235 None => return true,
236 };
237
238 let elapsed = now.duration_since(attempt.started_at);
239 let timeout = match attempt.state {
240 SimpleAttemptState::DirectConnection => self.config.direct_timeout,
241 _ => self.config.nat_traversal_timeout,
242 };
243
244 if elapsed > timeout {
246 match attempt.state {
247 SimpleAttemptState::DirectConnection if self.config.enable_nat_traversal => {
248 info!(
250 "Direct connection timed out for peer {:?}, starting NAT traversal",
251 peer_id
252 );
253 attempt.state = SimpleAttemptState::CandidateDiscovery;
254
255 let discovery_result = if let Ok(mut discovery) = self.discovery_manager.lock()
257 {
258 discovery.start_discovery(peer_id, self.bootstrap_nodes.clone())
259 } else {
260 Err(DiscoveryError::InternalError(
261 "Failed to lock discovery manager".to_string(),
262 ))
263 };
264
265 if let Err(e) = discovery_result {
266 attempt.state = SimpleAttemptState::Failed;
267 attempt.last_error = Some(format!("Discovery failed: {:?}", e));
268 events.push(SimpleConnectionEvent::ConnectionFailed {
269 peer_id,
270 error: format!("Discovery failed: {:?}", e),
271 });
272 return true;
273 }
274
275 events.push(SimpleConnectionEvent::CandidateDiscoveryStarted { peer_id });
276 return false;
277 }
278 _ => {
279 attempt.state = SimpleAttemptState::Failed;
281 attempt.last_error = Some("Timeout exceeded".to_string());
282 events.push(SimpleConnectionEvent::ConnectionFailed {
283 peer_id,
284 error: "Timeout exceeded".to_string(),
285 });
286 return true;
287 }
288 }
289 }
290
291 match attempt.state {
293 SimpleAttemptState::DirectConnection => {
294 debug!("Simulating direct connection attempt to peer {:?}", peer_id);
296 }
297 SimpleAttemptState::CandidateDiscovery => {
298 }
300 SimpleAttemptState::NatTraversal => {
301 debug!("Simulating NAT traversal attempt to peer {:?}", peer_id);
303 }
304 SimpleAttemptState::Connected | SimpleAttemptState::Failed => {
305 return true;
306 }
307 }
308
309 false
310 }
311
312 fn handle_discovery_event(
313 &mut self,
314 discovery_event: DiscoveryEvent,
315 events: &mut Vec<SimpleConnectionEvent>,
316 ) {
317 match discovery_event {
318 DiscoveryEvent::LocalCandidateDiscovered { candidate }
319 | DiscoveryEvent::ServerReflexiveCandidateDiscovered { candidate, .. }
320 | DiscoveryEvent::PredictedCandidateGenerated { candidate, .. } => {
321 for attempt in self.active_attempts.values_mut() {
323 if attempt.state == SimpleAttemptState::CandidateDiscovery {
324 attempt.discovered_candidates.push(candidate.clone());
325 }
326 }
327 }
328 DiscoveryEvent::DiscoveryCompleted { .. } => {
329 let peer_ids: Vec<_> = self
331 .active_attempts
332 .iter()
333 .filter(|(_, a)| a.state == SimpleAttemptState::CandidateDiscovery)
334 .map(|(peer_id, _)| *peer_id)
335 .collect();
336
337 for peer_id in peer_ids {
338 if let Some(attempt) = self.active_attempts.get_mut(&peer_id) {
339 attempt.state = SimpleAttemptState::NatTraversal;
340 events.push(SimpleConnectionEvent::NatTraversalStarted { peer_id });
341 }
342 }
343 }
344 DiscoveryEvent::DiscoveryFailed { error, .. } => {
345 warn!("Discovery failed: {:?}", error);
346 let peer_ids: Vec<_> = self
348 .active_attempts
349 .iter()
350 .filter(|(_, a)| a.state == SimpleAttemptState::CandidateDiscovery)
351 .map(|(peer_id, _)| *peer_id)
352 .collect();
353
354 for peer_id in peer_ids {
355 if let Some(attempt) = self.active_attempts.get_mut(&peer_id) {
356 attempt.state = SimpleAttemptState::Failed;
357 attempt.last_error = Some(format!("Discovery failed: {:?}", error));
358 events.push(SimpleConnectionEvent::ConnectionFailed {
359 peer_id,
360 error: format!("Discovery failed: {:?}", error),
361 });
362 }
363 }
364 }
365 _ => {
366 }
368 }
369 }
370
371 fn emit_event(&self, event: SimpleConnectionEvent) {
372 if let Some(ref callback) = self.event_callback {
373 callback(event);
374 }
375 }
376}