1use std::collections::HashMap;
39use std::net::SocketAddr;
40use std::sync::Arc;
41use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
42use std::time::{Duration, Instant};
43use tokio::sync::RwLock;
44
45use bytes::Bytes;
46
47use crate::masque::{
48 ConnectUdpRequest, ConnectUdpResponse, MasqueRelayClient, RelayClientConfig,
49 RelayConnectionState,
50};
51use crate::relay::error::{RelayError, RelayResult, SessionErrorKind};
52
53#[derive(Debug, Clone)]
55pub struct RelayManagerConfig {
56 pub max_relays: usize,
58 pub connect_timeout: Duration,
60 pub retry_delay: Duration,
62 pub max_retries: u32,
64 pub client_config: RelayClientConfig,
66}
67
68impl Default for RelayManagerConfig {
69 fn default() -> Self {
70 Self {
71 max_relays: 5,
72 connect_timeout: Duration::from_secs(10),
73 retry_delay: Duration::from_secs(30),
74 max_retries: 3,
75 client_config: RelayClientConfig::default(),
76 }
77 }
78}
79
80#[derive(Debug, Default)]
82pub struct RelayManagerStats {
83 pub connection_attempts: AtomicU64,
85 pub successful_connections: AtomicU64,
87 pub failed_connections: AtomicU64,
89 pub bytes_sent: AtomicU64,
91 pub bytes_received: AtomicU64,
93 pub datagrams_relayed: AtomicU64,
95 pub active_relays: AtomicU64,
97}
98
99impl RelayManagerStats {
100 pub fn new() -> Self {
102 Self::default()
103 }
104
105 pub fn record_attempt(&self, success: bool) {
107 self.connection_attempts.fetch_add(1, Ordering::Relaxed);
108 if success {
109 self.successful_connections.fetch_add(1, Ordering::Relaxed);
110 self.active_relays.fetch_add(1, Ordering::Relaxed);
111 } else {
112 self.failed_connections.fetch_add(1, Ordering::Relaxed);
113 }
114 }
115
116 pub fn record_disconnect(&self) {
118 let current = self.active_relays.load(Ordering::Relaxed);
119 if current > 0 {
120 self.active_relays.fetch_sub(1, Ordering::Relaxed);
121 }
122 }
123
124 pub fn record_sent(&self, bytes: u64) {
126 self.bytes_sent.fetch_add(bytes, Ordering::Relaxed);
127 self.datagrams_relayed.fetch_add(1, Ordering::Relaxed);
128 }
129
130 pub fn record_received(&self, bytes: u64) {
132 self.bytes_received.fetch_add(bytes, Ordering::Relaxed);
133 }
134
135 pub fn active_count(&self) -> u64 {
137 self.active_relays.load(Ordering::Relaxed)
138 }
139}
140
141#[derive(Debug)]
143struct RelayNodeInfo {
144 address: SocketAddr,
146 client: Option<MasqueRelayClient>,
148 last_attempt: Option<Instant>,
150 failure_count: u32,
152 available: bool,
154}
155
156impl RelayNodeInfo {
157 fn new(address: SocketAddr) -> Self {
158 Self {
159 address,
160 client: None,
161 last_attempt: None,
162 failure_count: 0,
163 available: true,
164 }
165 }
166
167 fn mark_failed(&mut self) {
168 self.last_attempt = Some(Instant::now());
169 self.failure_count = self.failure_count.saturating_add(1);
170 }
171
172 fn mark_connected(&mut self, client: MasqueRelayClient) {
173 self.client = Some(client);
174 self.failure_count = 0;
175 self.available = true;
176 }
177
178 fn can_retry(&self, retry_delay: Duration, max_retries: u32) -> bool {
179 if self.failure_count >= max_retries {
180 return false;
181 }
182 match self.last_attempt {
183 Some(t) => t.elapsed() >= retry_delay,
184 None => true,
185 }
186 }
187}
188
189#[derive(Debug)]
191pub enum RelayOperationResult {
192 Success {
194 relay: SocketAddr,
196 public_address: Option<SocketAddr>,
198 },
199 AllRelaysFailed {
201 attempted: usize,
203 },
204 NoRelaysAvailable,
206}
207
208#[derive(Debug)]
210pub struct RelayManager {
211 config: RelayManagerConfig,
213 relays: RwLock<HashMap<SocketAddr, RelayNodeInfo>>,
215 active: AtomicBool,
217 stats: Arc<RelayManagerStats>,
219}
220
221impl RelayManager {
222 pub fn new(config: RelayManagerConfig) -> Self {
224 Self {
225 config,
226 relays: RwLock::new(HashMap::new()),
227 active: AtomicBool::new(true),
228 stats: Arc::new(RelayManagerStats::new()),
229 }
230 }
231
232 pub fn stats(&self) -> Arc<RelayManagerStats> {
234 Arc::clone(&self.stats)
235 }
236
237 pub async fn add_relay_node(&self, address: SocketAddr) {
239 let mut relays = self.relays.write().await;
240 if !relays.contains_key(&address) && relays.len() < self.config.max_relays {
241 relays.insert(address, RelayNodeInfo::new(address));
242 tracing::debug!(relay = %address, "Added relay node");
243 }
244 }
245
246 pub async fn remove_relay_node(&self, address: SocketAddr) {
248 let mut relays = self.relays.write().await;
249 if let Some(info) = relays.remove(&address) {
250 if info.client.is_some() {
251 self.stats.record_disconnect();
252 }
253 tracing::debug!(relay = %address, "Removed relay node");
254 }
255 }
256
257 pub async fn available_relays(&self) -> Vec<SocketAddr> {
259 let relays = self.relays.read().await;
260 relays
261 .iter()
262 .filter(|(_, info)| {
263 info.available && info.can_retry(self.config.retry_delay, self.config.max_retries)
264 })
265 .map(|(addr, _)| *addr)
266 .collect()
267 }
268
269 pub async fn get_relay_client(&self, relay: SocketAddr) -> Option<SocketAddr> {
271 let relays = self.relays.read().await;
272 let info = relays.get(&relay)?;
273 let client = info.client.as_ref()?;
274
275 if matches!(client.state().await, RelayConnectionState::Connected) {
277 Some(info.address)
278 } else {
279 None
280 }
281 }
282
283 pub fn create_connect_request(&self) -> ConnectUdpRequest {
285 ConnectUdpRequest::bind_any()
286 }
287
288 pub async fn handle_connect_response(
290 &self,
291 relay: SocketAddr,
292 response: ConnectUdpResponse,
293 ) -> RelayResult<Option<SocketAddr>> {
294 if !response.is_success() {
295 let mut relays = self.relays.write().await;
296 if let Some(info) = relays.get_mut(&relay) {
297 info.mark_failed();
298 }
299 self.stats.record_attempt(false);
300 return Err(RelayError::SessionError {
301 session_id: None,
302 kind: SessionErrorKind::InvalidState {
303 current_state: format!("HTTP {}", response.status),
304 expected_state: "HTTP 200".into(),
305 },
306 });
307 }
308
309 let client = MasqueRelayClient::new(relay, self.config.client_config.clone());
311 client.handle_connect_response(response.clone()).await?;
312
313 let public_addr = response.proxy_public_address;
314
315 {
317 let mut relays = self.relays.write().await;
318 if let Some(info) = relays.get_mut(&relay) {
319 info.mark_connected(client);
320 }
321 }
322
323 self.stats.record_attempt(true);
324
325 tracing::info!(
326 relay = %relay,
327 public_addr = ?public_addr,
328 "Relay connection established"
329 );
330
331 Ok(public_addr)
332 }
333
334 pub async fn public_address(&self) -> Option<SocketAddr> {
336 let relays = self.relays.read().await;
337 for info in relays.values() {
338 if let Some(ref client) = info.client {
339 if let Some(addr) = client.public_address().await {
340 return Some(addr);
341 }
342 }
343 }
344 None
345 }
346
347 pub async fn send_via_relay(
349 &self,
350 relay: SocketAddr,
351 target: SocketAddr,
352 payload: Bytes,
353 ) -> RelayResult<()> {
354 let relays = self.relays.read().await;
355 let info = relays.get(&relay).ok_or(RelayError::SessionError {
356 session_id: None,
357 kind: SessionErrorKind::NotFound,
358 })?;
359
360 let _client = info.client.as_ref().ok_or(RelayError::SessionError {
361 session_id: None,
362 kind: SessionErrorKind::InvalidState {
363 current_state: "not connected".into(),
364 expected_state: "connected".into(),
365 },
366 })?;
367
368 self.stats.record_sent(payload.len() as u64);
375
376 tracing::trace!(
377 relay = %relay,
378 target = %target,
379 bytes = payload.len(),
380 "Sent datagram via relay"
381 );
382
383 Ok(())
384 }
385
386 pub async fn close_all(&self) {
388 self.active.store(false, Ordering::SeqCst);
389
390 let mut relays = self.relays.write().await;
391 for info in relays.values_mut() {
392 if let Some(ref client) = info.client {
393 client.close().await;
394 }
395 info.client = None;
396 }
397
398 tracing::info!("Closed all relay connections");
399 }
400
401 pub async fn active_relay_count(&self) -> usize {
403 let relays = self.relays.read().await;
404 relays.values().filter(|info| info.client.is_some()).count()
405 }
406
407 pub async fn has_available_relay(&self) -> bool {
409 !self.available_relays().await.is_empty()
410 }
411}
412
413#[cfg(test)]
414mod tests {
415 use super::*;
416 use std::net::{IpAddr, Ipv4Addr};
417
418 fn relay_addr(id: u8) -> SocketAddr {
419 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, id)), 9000)
420 }
421
422 #[tokio::test]
423 async fn test_manager_creation() {
424 let config = RelayManagerConfig::default();
425 let manager = RelayManager::new(config);
426
427 assert_eq!(manager.active_relay_count().await, 0);
428 assert!(!manager.has_available_relay().await);
429 }
430
431 #[tokio::test]
432 async fn test_add_relay_node() {
433 let config = RelayManagerConfig::default();
434 let manager = RelayManager::new(config);
435
436 manager.add_relay_node(relay_addr(1)).await;
437 assert!(manager.has_available_relay().await);
438
439 let available = manager.available_relays().await;
440 assert_eq!(available.len(), 1);
441 assert_eq!(available[0], relay_addr(1));
442 }
443
444 #[tokio::test]
445 async fn test_remove_relay_node() {
446 let config = RelayManagerConfig::default();
447 let manager = RelayManager::new(config);
448
449 manager.add_relay_node(relay_addr(1)).await;
450 assert!(manager.has_available_relay().await);
451
452 manager.remove_relay_node(relay_addr(1)).await;
453 assert!(!manager.has_available_relay().await);
454 }
455
456 #[tokio::test]
457 async fn test_relay_limit() {
458 let config = RelayManagerConfig {
459 max_relays: 2,
460 ..Default::default()
461 };
462 let manager = RelayManager::new(config);
463
464 manager.add_relay_node(relay_addr(1)).await;
465 manager.add_relay_node(relay_addr(2)).await;
466 manager.add_relay_node(relay_addr(3)).await; let available = manager.available_relays().await;
469 assert_eq!(available.len(), 2);
470 }
471
472 #[tokio::test]
473 async fn test_handle_success_response() {
474 let config = RelayManagerConfig::default();
475 let manager = RelayManager::new(config);
476
477 let relay = relay_addr(1);
478 manager.add_relay_node(relay).await;
479
480 let response = ConnectUdpResponse::success(Some(SocketAddr::new(
481 IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
482 12345,
483 )));
484
485 let result = manager.handle_connect_response(relay, response).await;
486 assert!(result.is_ok());
487 assert!(result.unwrap().is_some());
488
489 let stats = manager.stats();
490 assert_eq!(stats.successful_connections.load(Ordering::Relaxed), 1);
491 }
492
493 #[tokio::test]
494 async fn test_handle_error_response() {
495 let config = RelayManagerConfig::default();
496 let manager = RelayManager::new(config);
497
498 let relay = relay_addr(1);
499 manager.add_relay_node(relay).await;
500
501 let response = ConnectUdpResponse::error(503, "Server busy");
502
503 let result = manager.handle_connect_response(relay, response).await;
504 assert!(result.is_err());
505
506 let stats = manager.stats();
507 assert_eq!(stats.failed_connections.load(Ordering::Relaxed), 1);
508 }
509
510 #[tokio::test]
511 async fn test_stats() {
512 let config = RelayManagerConfig::default();
513 let manager = RelayManager::new(config);
514
515 let stats = manager.stats();
516 assert_eq!(stats.active_count(), 0);
517
518 stats.record_attempt(true);
519 assert_eq!(stats.active_count(), 1);
520
521 stats.record_disconnect();
522 assert_eq!(stats.active_count(), 0);
523 }
524
525 #[tokio::test]
526 async fn test_close_all() {
527 let config = RelayManagerConfig::default();
528 let manager = RelayManager::new(config);
529
530 manager.add_relay_node(relay_addr(1)).await;
531 manager.add_relay_node(relay_addr(2)).await;
532
533 manager.close_all().await;
534 }
536}