1use std::{
7 collections::HashMap,
8 net::SocketAddr,
9 sync::Arc,
10 time::{Duration, Instant},
11};
12
13use tracing::{info, warn, debug};
14
15use crate::{
16 candidate_discovery::{CandidateDiscoveryManager, DiscoveryEvent, DiscoveryError},
17 nat_traversal_api::{BootstrapNode, CandidateAddress, PeerId},
18};
19
20pub struct SimpleConnectionEstablishmentManager {
22 config: SimpleEstablishmentConfig,
24 active_attempts: HashMap<PeerId, SimpleConnectionAttempt>,
26 discovery_manager: Arc<std::sync::Mutex<CandidateDiscoveryManager>>,
28 bootstrap_nodes: Vec<BootstrapNode>,
30 event_callback: Option<Box<dyn Fn(SimpleConnectionEvent) + Send + Sync>>,
32}
33
34#[derive(Debug, Clone)]
36pub struct SimpleEstablishmentConfig {
37 pub direct_timeout: Duration,
39 pub nat_traversal_timeout: Duration,
41 pub enable_nat_traversal: bool,
43 pub max_retries: u32,
45}
46
47#[derive(Debug)]
49struct SimpleConnectionAttempt {
50 peer_id: PeerId,
51 state: SimpleAttemptState,
52 started_at: Instant,
53 attempt_number: u32,
54 known_addresses: Vec<SocketAddr>,
55 discovered_candidates: Vec<CandidateAddress>,
56 last_error: Option<String>,
57}
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61enum SimpleAttemptState {
62 DirectConnection,
63 CandidateDiscovery,
64 NatTraversal,
65 Connected,
66 Failed,
67}
68
69#[derive(Debug, Clone)]
71pub enum SimpleConnectionEvent {
72 AttemptStarted { peer_id: PeerId },
73 DirectConnectionTried { peer_id: PeerId, address: SocketAddr },
74 CandidateDiscoveryStarted { peer_id: PeerId },
75 NatTraversalStarted { peer_id: PeerId },
76 ConnectionEstablished { peer_id: PeerId, address: SocketAddr },
77 ConnectionFailed { peer_id: PeerId, error: String },
78}
79
80impl Default for SimpleEstablishmentConfig {
81 fn default() -> Self {
82 Self {
83 direct_timeout: Duration::from_secs(5),
84 nat_traversal_timeout: Duration::from_secs(30),
85 enable_nat_traversal: true,
86 max_retries: 3,
87 }
88 }
89}
90
91impl SimpleConnectionEstablishmentManager {
92 pub fn new(
94 config: SimpleEstablishmentConfig,
95 discovery_manager: Arc<std::sync::Mutex<CandidateDiscoveryManager>>,
96 bootstrap_nodes: Vec<BootstrapNode>,
97 event_callback: Option<Box<dyn Fn(SimpleConnectionEvent) + Send + Sync>>,
98 ) -> Self {
99 Self {
100 config,
101 active_attempts: HashMap::new(),
102 discovery_manager,
103 bootstrap_nodes,
104 event_callback,
105 }
106 }
107
108 pub fn connect_to_peer(
110 &mut self,
111 peer_id: PeerId,
112 known_addresses: Vec<SocketAddr>,
113 ) -> Result<(), String> {
114 if self.active_attempts.contains_key(&peer_id) {
116 return Err("Connection attempt already in progress".to_string());
117 }
118
119 let attempt = SimpleConnectionAttempt {
121 peer_id,
122 state: SimpleAttemptState::DirectConnection,
123 started_at: Instant::now(),
124 attempt_number: 1,
125 known_addresses: known_addresses.clone(),
126 discovered_candidates: Vec::new(),
127 last_error: None,
128 };
129
130 self.active_attempts.insert(peer_id, attempt);
131
132 self.emit_event(SimpleConnectionEvent::AttemptStarted { peer_id });
134
135 if !known_addresses.is_empty() {
137 info!("Starting direct connection attempt to peer {:?}", peer_id);
138 for address in &known_addresses {
139 self.emit_event(SimpleConnectionEvent::DirectConnectionTried {
140 peer_id,
141 address: *address,
142 });
143 }
144 } else if self.config.enable_nat_traversal {
145 self.start_candidate_discovery(peer_id)?;
147 } else {
148 return Err("No known addresses and NAT traversal disabled".to_string());
149 }
150
151 Ok(())
152 }
153
154 pub fn poll(&mut self, now: Instant) -> Vec<SimpleConnectionEvent> {
156 let mut events = Vec::new();
157
158 let discovery_events = if let Ok(mut discovery) = self.discovery_manager.lock() {
160 discovery.poll(now)
161 } else {
162 Vec::new()
163 };
164
165 for discovery_event in discovery_events {
166 self.handle_discovery_event(discovery_event, &mut events);
167 }
168
169 let peer_ids: Vec<_> = self.active_attempts.keys().copied().collect();
171 let mut completed = Vec::new();
172
173 for peer_id in peer_ids {
174 if self.poll_attempt(peer_id, now, &mut events) {
175 completed.push(peer_id);
176 }
177 }
178
179 for peer_id in completed {
181 self.active_attempts.remove(&peer_id);
182 }
183
184 events
185 }
186
187 pub fn cancel_connection(&mut self, peer_id: PeerId) -> bool {
189 self.active_attempts.remove(&peer_id).is_some()
190 }
191
192 fn start_candidate_discovery(&mut self, peer_id: PeerId) -> Result<(), String> {
195 if let Some(attempt) = self.active_attempts.get_mut(&peer_id) {
196 attempt.state = SimpleAttemptState::CandidateDiscovery;
197
198 if let Ok(mut discovery) = self.discovery_manager.lock() {
199 discovery.start_discovery(peer_id, self.bootstrap_nodes.clone())
200 .map_err(|e| format!("Discovery failed: {:?}", e))?;
201 } else {
202 return Err("Failed to lock discovery manager".to_string());
203 }
204
205 self.emit_event(SimpleConnectionEvent::CandidateDiscoveryStarted { peer_id });
206 }
207
208 Ok(())
209 }
210
211 fn poll_attempt(
212 &mut self,
213 peer_id: PeerId,
214 now: Instant,
215 events: &mut Vec<SimpleConnectionEvent>,
216 ) -> bool {
217 let attempt = match self.active_attempts.get_mut(&peer_id) {
218 Some(a) => a,
219 None => return true,
220 };
221
222 let elapsed = now.duration_since(attempt.started_at);
223 let timeout = match attempt.state {
224 SimpleAttemptState::DirectConnection => self.config.direct_timeout,
225 _ => self.config.nat_traversal_timeout,
226 };
227
228 if elapsed > timeout {
230 match attempt.state {
231 SimpleAttemptState::DirectConnection if self.config.enable_nat_traversal => {
232 info!("Direct connection timed out for peer {:?}, starting NAT traversal", peer_id);
234 attempt.state = SimpleAttemptState::CandidateDiscovery;
235
236 let discovery_result = if let Ok(mut discovery) = self.discovery_manager.lock() {
238 discovery.start_discovery(peer_id, self.bootstrap_nodes.clone())
239 } else {
240 Err(DiscoveryError::InternalError("Failed to lock discovery manager".to_string()))
241 };
242
243 if let Err(e) = discovery_result {
244 attempt.state = SimpleAttemptState::Failed;
245 attempt.last_error = Some(format!("Discovery failed: {:?}", e));
246 events.push(SimpleConnectionEvent::ConnectionFailed {
247 peer_id,
248 error: format!("Discovery failed: {:?}", e),
249 });
250 return true;
251 }
252
253 events.push(SimpleConnectionEvent::CandidateDiscoveryStarted { peer_id });
254 return false;
255 }
256 _ => {
257 attempt.state = SimpleAttemptState::Failed;
259 attempt.last_error = Some("Timeout exceeded".to_string());
260 events.push(SimpleConnectionEvent::ConnectionFailed {
261 peer_id,
262 error: "Timeout exceeded".to_string(),
263 });
264 return true;
265 }
266 }
267 }
268
269 match attempt.state {
271 SimpleAttemptState::DirectConnection => {
272 debug!("Simulating direct connection attempt to peer {:?}", peer_id);
274 }
275 SimpleAttemptState::CandidateDiscovery => {
276 }
278 SimpleAttemptState::NatTraversal => {
279 debug!("Simulating NAT traversal attempt to peer {:?}", peer_id);
281 }
282 SimpleAttemptState::Connected | SimpleAttemptState::Failed => {
283 return true;
284 }
285 }
286
287 false
288 }
289
290 fn handle_discovery_event(
291 &mut self,
292 discovery_event: DiscoveryEvent,
293 events: &mut Vec<SimpleConnectionEvent>,
294 ) {
295 match discovery_event {
296 DiscoveryEvent::LocalCandidateDiscovered { candidate } |
297 DiscoveryEvent::ServerReflexiveCandidateDiscovered { candidate, .. } |
298 DiscoveryEvent::PredictedCandidateGenerated { candidate, .. } => {
299 for attempt in self.active_attempts.values_mut() {
301 if attempt.state == SimpleAttemptState::CandidateDiscovery {
302 attempt.discovered_candidates.push(candidate.clone());
303 }
304 }
305 }
306 DiscoveryEvent::DiscoveryCompleted { .. } => {
307 let peer_ids: Vec<_> = self.active_attempts.iter()
309 .filter(|(_, a)| a.state == SimpleAttemptState::CandidateDiscovery)
310 .map(|(peer_id, _)| *peer_id)
311 .collect();
312
313 for peer_id in peer_ids {
314 if let Some(attempt) = self.active_attempts.get_mut(&peer_id) {
315 attempt.state = SimpleAttemptState::NatTraversal;
316 events.push(SimpleConnectionEvent::NatTraversalStarted { peer_id });
317 }
318 }
319 }
320 DiscoveryEvent::DiscoveryFailed { error, .. } => {
321 warn!("Discovery failed: {:?}", error);
322 let peer_ids: Vec<_> = self.active_attempts.iter()
324 .filter(|(_, a)| a.state == SimpleAttemptState::CandidateDiscovery)
325 .map(|(peer_id, _)| *peer_id)
326 .collect();
327
328 for peer_id in peer_ids {
329 if let Some(attempt) = self.active_attempts.get_mut(&peer_id) {
330 attempt.state = SimpleAttemptState::Failed;
331 attempt.last_error = Some(format!("Discovery failed: {:?}", error));
332 events.push(SimpleConnectionEvent::ConnectionFailed {
333 peer_id,
334 error: format!("Discovery failed: {:?}", error),
335 });
336 }
337 }
338 }
339 _ => {
340 }
342 }
343 }
344
345 fn emit_event(&self, event: SimpleConnectionEvent) {
346 if let Some(ref callback) = self.event_callback {
347 callback(event);
348 }
349 }
350}