1use std::{
7 collections::HashMap,
8 net::SocketAddr,
9 sync::Arc,
10 time::{Duration, Instant},
11};
12
13use tracing::{debug, info, warn};
14
15use crate::{
16 candidate_discovery::{CandidateDiscoveryManager, DiscoveryError, DiscoveryEvent},
17 nat_traversal_api::{BootstrapNode, CandidateAddress, PeerId},
18};
19
20pub struct SimpleConnectionEstablishmentManager {
22 config: SimpleEstablishmentConfig,
24 active_attempts: HashMap<PeerId, SimpleConnectionAttempt>,
26 discovery_manager: Arc<std::sync::Mutex<CandidateDiscoveryManager>>,
28 bootstrap_nodes: Vec<BootstrapNode>,
30 event_callback: Option<Box<dyn Fn(SimpleConnectionEvent) + Send + Sync>>,
32}
33
34#[derive(Debug, Clone)]
36pub struct SimpleEstablishmentConfig {
37 pub direct_timeout: Duration,
39 pub nat_traversal_timeout: Duration,
41 pub enable_nat_traversal: bool,
43 pub max_retries: u32,
45}
46
47#[derive(Debug)]
49struct SimpleConnectionAttempt {
50 peer_id: PeerId,
51 state: SimpleAttemptState,
52 started_at: Instant,
53 attempt_number: u32,
54 known_addresses: Vec<SocketAddr>,
55 discovered_candidates: Vec<CandidateAddress>,
56 last_error: Option<String>,
57}
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61enum SimpleAttemptState {
62 DirectConnection,
63 CandidateDiscovery,
64 NatTraversal,
65 Connected,
66 Failed,
67}
68
69#[derive(Debug, Clone)]
71pub enum SimpleConnectionEvent {
72 AttemptStarted {
73 peer_id: PeerId,
74 },
75 DirectConnectionTried {
76 peer_id: PeerId,
77 address: SocketAddr,
78 },
79 CandidateDiscoveryStarted {
80 peer_id: PeerId,
81 },
82 NatTraversalStarted {
83 peer_id: PeerId,
84 },
85 ConnectionEstablished {
86 peer_id: PeerId,
87 address: SocketAddr,
88 },
89 ConnectionFailed {
90 peer_id: PeerId,
91 error: String,
92 },
93}
94
95impl Default for SimpleEstablishmentConfig {
96 fn default() -> Self {
97 Self {
98 direct_timeout: Duration::from_secs(5),
99 nat_traversal_timeout: Duration::from_secs(30),
100 enable_nat_traversal: true,
101 max_retries: 3,
102 }
103 }
104}
105
106impl SimpleConnectionEstablishmentManager {
107 pub fn new(
109 config: SimpleEstablishmentConfig,
110 discovery_manager: Arc<std::sync::Mutex<CandidateDiscoveryManager>>,
111 bootstrap_nodes: Vec<BootstrapNode>,
112 event_callback: Option<Box<dyn Fn(SimpleConnectionEvent) + Send + Sync>>,
113 ) -> Self {
114 Self {
115 config,
116 active_attempts: HashMap::new(),
117 discovery_manager,
118 bootstrap_nodes,
119 event_callback,
120 }
121 }
122
123 pub fn connect_to_peer(
125 &mut self,
126 peer_id: PeerId,
127 known_addresses: Vec<SocketAddr>,
128 ) -> Result<(), String> {
129 if self.active_attempts.contains_key(&peer_id) {
131 return Err("Connection attempt already in progress".to_string());
132 }
133
134 let attempt = SimpleConnectionAttempt {
136 peer_id,
137 state: SimpleAttemptState::DirectConnection,
138 started_at: Instant::now(),
139 attempt_number: 1,
140 known_addresses: known_addresses.clone(),
141 discovered_candidates: Vec::new(),
142 last_error: None,
143 };
144
145 self.active_attempts.insert(peer_id, attempt);
146
147 self.emit_event(SimpleConnectionEvent::AttemptStarted { peer_id });
149
150 if !known_addresses.is_empty() {
152 info!("Starting direct connection attempt to peer {:?}", peer_id);
153 for address in &known_addresses {
154 self.emit_event(SimpleConnectionEvent::DirectConnectionTried {
155 peer_id,
156 address: *address,
157 });
158 }
159 } else if self.config.enable_nat_traversal {
160 self.start_candidate_discovery(peer_id)?;
162 } else {
163 return Err("No known addresses and NAT traversal disabled".to_string());
164 }
165
166 Ok(())
167 }
168
169 pub fn poll(&mut self, now: Instant) -> Vec<SimpleConnectionEvent> {
171 let mut events = Vec::new();
172
173 let discovery_events = match self.discovery_manager.lock() {
175 Ok(mut discovery) => discovery.poll(now),
176 _ => Vec::new(),
177 };
178
179 for discovery_event in discovery_events {
180 self.handle_discovery_event(discovery_event, &mut events);
181 }
182
183 let peer_ids: Vec<_> = self.active_attempts.keys().copied().collect();
185 let mut completed = Vec::new();
186
187 for peer_id in peer_ids {
188 if self.poll_attempt(peer_id, now, &mut events) {
189 completed.push(peer_id);
190 }
191 }
192
193 for peer_id in completed {
195 self.active_attempts.remove(&peer_id);
196 }
197
198 events
199 }
200
201 pub fn cancel_connection(&mut self, peer_id: PeerId) -> bool {
203 self.active_attempts.remove(&peer_id).is_some()
204 }
205
206 fn start_candidate_discovery(&mut self, peer_id: PeerId) -> Result<(), String> {
209 if let Some(attempt) = self.active_attempts.get_mut(&peer_id) {
210 attempt.state = SimpleAttemptState::CandidateDiscovery;
211
212 match self.discovery_manager.lock() {
213 Ok(mut discovery) => {
214 discovery
215 .start_discovery(peer_id, self.bootstrap_nodes.clone())
216 .map_err(|e| format!("Discovery failed: {e:?}"))?;
217 }
218 _ => {
219 return Err("Failed to lock discovery manager".to_string());
220 }
221 }
222
223 self.emit_event(SimpleConnectionEvent::CandidateDiscoveryStarted { peer_id });
224 }
225
226 Ok(())
227 }
228
229 fn poll_attempt(
230 &mut self,
231 peer_id: PeerId,
232 now: Instant,
233 events: &mut Vec<SimpleConnectionEvent>,
234 ) -> bool {
235 let attempt = match self.active_attempts.get_mut(&peer_id) {
236 Some(a) => a,
237 None => return true,
238 };
239
240 let elapsed = now.duration_since(attempt.started_at);
241 let timeout = match attempt.state {
242 SimpleAttemptState::DirectConnection => self.config.direct_timeout,
243 _ => self.config.nat_traversal_timeout,
244 };
245
246 if elapsed > timeout {
248 match attempt.state {
249 SimpleAttemptState::DirectConnection if self.config.enable_nat_traversal => {
250 info!(
252 "Direct connection timed out for peer {:?}, starting NAT traversal",
253 peer_id
254 );
255 attempt.state = SimpleAttemptState::CandidateDiscovery;
256
257 let discovery_result = match self.discovery_manager.lock() {
259 Ok(mut discovery) => {
260 discovery.start_discovery(peer_id, self.bootstrap_nodes.clone())
261 }
262 _ => Err(DiscoveryError::InternalError(
263 "Failed to lock discovery manager".to_string(),
264 )),
265 };
266
267 if let Err(e) = discovery_result {
268 attempt.state = SimpleAttemptState::Failed;
269 attempt.last_error = Some(format!("Discovery failed: {e:?}"));
270 events.push(SimpleConnectionEvent::ConnectionFailed {
271 peer_id,
272 error: format!("Discovery failed: {e:?}"),
273 });
274 return true;
275 }
276
277 events.push(SimpleConnectionEvent::CandidateDiscoveryStarted { peer_id });
278 return false;
279 }
280 _ => {
281 attempt.state = SimpleAttemptState::Failed;
283 attempt.last_error = Some("Timeout exceeded".to_string());
284 events.push(SimpleConnectionEvent::ConnectionFailed {
285 peer_id,
286 error: "Timeout exceeded".to_string(),
287 });
288 return true;
289 }
290 }
291 }
292
293 match attempt.state {
295 SimpleAttemptState::DirectConnection => {
296 debug!("Simulating direct connection attempt to peer {:?}", peer_id);
298 }
299 SimpleAttemptState::CandidateDiscovery => {
300 }
302 SimpleAttemptState::NatTraversal => {
303 debug!("Simulating NAT traversal attempt to peer {:?}", peer_id);
305 }
306 SimpleAttemptState::Connected | SimpleAttemptState::Failed => {
307 return true;
308 }
309 }
310
311 false
312 }
313
314 fn handle_discovery_event(
315 &mut self,
316 discovery_event: DiscoveryEvent,
317 events: &mut Vec<SimpleConnectionEvent>,
318 ) {
319 match discovery_event {
320 DiscoveryEvent::LocalCandidateDiscovered { candidate }
321 | DiscoveryEvent::ServerReflexiveCandidateDiscovered { candidate, .. }
322 | DiscoveryEvent::PredictedCandidateGenerated { candidate, .. } => {
323 for attempt in self.active_attempts.values_mut() {
325 if attempt.state == SimpleAttemptState::CandidateDiscovery {
326 attempt.discovered_candidates.push(candidate.clone());
327 }
328 }
329 }
330 DiscoveryEvent::DiscoveryCompleted { .. } => {
331 let peer_ids: Vec<_> = self
333 .active_attempts
334 .iter()
335 .filter(|(_, a)| a.state == SimpleAttemptState::CandidateDiscovery)
336 .map(|(peer_id, _)| *peer_id)
337 .collect();
338
339 for peer_id in peer_ids {
340 if let Some(attempt) = self.active_attempts.get_mut(&peer_id) {
341 attempt.state = SimpleAttemptState::NatTraversal;
342 events.push(SimpleConnectionEvent::NatTraversalStarted { peer_id });
343 }
344 }
345 }
346 DiscoveryEvent::DiscoveryFailed { error, .. } => {
347 warn!("Discovery failed: {:?}", error);
348 let peer_ids: Vec<_> = self
350 .active_attempts
351 .iter()
352 .filter(|(_, a)| a.state == SimpleAttemptState::CandidateDiscovery)
353 .map(|(peer_id, _)| *peer_id)
354 .collect();
355
356 for peer_id in peer_ids {
357 if let Some(attempt) = self.active_attempts.get_mut(&peer_id) {
358 attempt.state = SimpleAttemptState::Failed;
359 attempt.last_error = Some(format!("Discovery failed: {error:?}"));
360 events.push(SimpleConnectionEvent::ConnectionFailed {
361 peer_id,
362 error: format!("Discovery failed: {error:?}"),
363 });
364 }
365 }
366 }
367 _ => {
368 }
370 }
371 }
372
373 fn emit_event(&self, event: SimpleConnectionEvent) {
374 if let Some(ref callback) = self.event_callback {
375 callback(event);
376 }
377 }
378}