ant_quic/masque/
integration.rs

1// Copyright 2024 Saorsa Labs Ltd.
2//
3// This Saorsa Network Software is licensed under the General Public License (GPL), version 3.
4// Please see the file LICENSE-GPL, or visit <http://www.gnu.org/licenses/> for the full text.
5//
6// Full details available at https://saorsalabs.com/licenses
7
8//! MASQUE Relay Integration
9//!
10//! Provides integration between the MASQUE relay system and the NAT traversal API.
11//! This module acts as the bridge that enables automatic relay fallback when
12//! direct NAT traversal fails.
13//!
14//! # Overview
15//!
16//! The integration layer:
17//! - Manages a pool of relay connections to known peers
18//! - Automatically attempts relay fallback when direct connection fails
19//! - Coordinates context registration for efficient datagram forwarding
20//! - Tracks relay usage statistics
21//!
22//! # Example
23//!
24//! ```rust,ignore
25//! use ant_quic::masque::integration::{RelayManager, RelayManagerConfig};
26//! use std::net::SocketAddr;
27//!
28//! let config = RelayManagerConfig::default();
29//! let manager = RelayManager::new(config);
30//!
31//! // Add relay nodes
32//! manager.add_relay_node(relay_addr).await;
33//!
34//! // Attempt connection through relay
35//! let result = manager.connect_via_relay(target).await;
36//! ```
37
38use std::collections::HashMap;
39use std::net::SocketAddr;
40use std::sync::Arc;
41use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
42use std::time::{Duration, Instant};
43use tokio::sync::RwLock;
44
45use bytes::Bytes;
46
47use crate::masque::{
48    ConnectUdpRequest, ConnectUdpResponse, MasqueRelayClient, RelayClientConfig,
49    RelayConnectionState,
50};
51use crate::relay::error::{RelayError, RelayResult, SessionErrorKind};
52
53/// Configuration for the relay manager
54#[derive(Debug, Clone)]
55pub struct RelayManagerConfig {
56    /// Maximum number of relay connections to maintain
57    pub max_relays: usize,
58    /// Relay connection timeout
59    pub connect_timeout: Duration,
60    /// Time to wait before retrying a failed relay
61    pub retry_delay: Duration,
62    /// Maximum retries per relay
63    pub max_retries: u32,
64    /// Client configuration for relay connections
65    pub client_config: RelayClientConfig,
66}
67
68impl Default for RelayManagerConfig {
69    fn default() -> Self {
70        Self {
71            max_relays: 5,
72            connect_timeout: Duration::from_secs(10),
73            retry_delay: Duration::from_secs(30),
74            max_retries: 3,
75            client_config: RelayClientConfig::default(),
76        }
77    }
78}
79
80/// Statistics for relay operations
81#[derive(Debug, Default)]
82pub struct RelayManagerStats {
83    /// Total relay connection attempts
84    pub connection_attempts: AtomicU64,
85    /// Successful relay connections
86    pub successful_connections: AtomicU64,
87    /// Failed relay connections
88    pub failed_connections: AtomicU64,
89    /// Bytes sent through relays
90    pub bytes_sent: AtomicU64,
91    /// Bytes received through relays
92    pub bytes_received: AtomicU64,
93    /// Datagrams relayed
94    pub datagrams_relayed: AtomicU64,
95    /// Currently active relay connections
96    pub active_relays: AtomicU64,
97}
98
99impl RelayManagerStats {
100    /// Create new statistics
101    pub fn new() -> Self {
102        Self::default()
103    }
104
105    /// Record a connection attempt
106    pub fn record_attempt(&self, success: bool) {
107        self.connection_attempts.fetch_add(1, Ordering::Relaxed);
108        if success {
109            self.successful_connections.fetch_add(1, Ordering::Relaxed);
110            self.active_relays.fetch_add(1, Ordering::Relaxed);
111        } else {
112            self.failed_connections.fetch_add(1, Ordering::Relaxed);
113        }
114    }
115
116    /// Record a disconnection
117    pub fn record_disconnect(&self) {
118        let current = self.active_relays.load(Ordering::Relaxed);
119        if current > 0 {
120            self.active_relays.fetch_sub(1, Ordering::Relaxed);
121        }
122    }
123
124    /// Record bytes sent
125    pub fn record_sent(&self, bytes: u64) {
126        self.bytes_sent.fetch_add(bytes, Ordering::Relaxed);
127        self.datagrams_relayed.fetch_add(1, Ordering::Relaxed);
128    }
129
130    /// Record bytes received
131    pub fn record_received(&self, bytes: u64) {
132        self.bytes_received.fetch_add(bytes, Ordering::Relaxed);
133    }
134
135    /// Get active relay count
136    pub fn active_count(&self) -> u64 {
137        self.active_relays.load(Ordering::Relaxed)
138    }
139}
140
141/// Information about a relay node
142#[derive(Debug)]
143struct RelayNodeInfo {
144    /// Relay server address
145    address: SocketAddr,
146    /// Connected client (if any)
147    client: Option<MasqueRelayClient>,
148    /// Last connection attempt
149    last_attempt: Option<Instant>,
150    /// Number of consecutive failures
151    failure_count: u32,
152    /// Whether the relay is currently usable
153    available: bool,
154}
155
156impl RelayNodeInfo {
157    fn new(address: SocketAddr) -> Self {
158        Self {
159            address,
160            client: None,
161            last_attempt: None,
162            failure_count: 0,
163            available: true,
164        }
165    }
166
167    fn mark_failed(&mut self) {
168        self.last_attempt = Some(Instant::now());
169        self.failure_count = self.failure_count.saturating_add(1);
170    }
171
172    fn mark_connected(&mut self, client: MasqueRelayClient) {
173        self.client = Some(client);
174        self.failure_count = 0;
175        self.available = true;
176    }
177
178    fn can_retry(&self, retry_delay: Duration, max_retries: u32) -> bool {
179        if self.failure_count >= max_retries {
180            return false;
181        }
182        match self.last_attempt {
183            Some(t) => t.elapsed() >= retry_delay,
184            None => true,
185        }
186    }
187}
188
189/// Result of a relay operation
190#[derive(Debug)]
191pub enum RelayOperationResult {
192    /// Operation succeeded via relay
193    Success {
194        /// Relay used
195        relay: SocketAddr,
196        /// Public address assigned by relay
197        public_address: Option<SocketAddr>,
198    },
199    /// All relays failed
200    AllRelaysFailed {
201        /// Number of relays attempted
202        attempted: usize,
203    },
204    /// No relays available
205    NoRelaysAvailable,
206}
207
208/// Manages relay connections for NAT traversal fallback
209#[derive(Debug)]
210pub struct RelayManager {
211    /// Configuration
212    config: RelayManagerConfig,
213    /// Known relay nodes
214    relays: RwLock<HashMap<SocketAddr, RelayNodeInfo>>,
215    /// Whether the manager is active
216    active: AtomicBool,
217    /// Statistics
218    stats: Arc<RelayManagerStats>,
219}
220
221impl RelayManager {
222    /// Create a new relay manager
223    pub fn new(config: RelayManagerConfig) -> Self {
224        Self {
225            config,
226            relays: RwLock::new(HashMap::new()),
227            active: AtomicBool::new(true),
228            stats: Arc::new(RelayManagerStats::new()),
229        }
230    }
231
232    /// Get statistics
233    pub fn stats(&self) -> Arc<RelayManagerStats> {
234        Arc::clone(&self.stats)
235    }
236
237    /// Add a potential relay node
238    pub async fn add_relay_node(&self, address: SocketAddr) {
239        let mut relays = self.relays.write().await;
240        if !relays.contains_key(&address) && relays.len() < self.config.max_relays {
241            relays.insert(address, RelayNodeInfo::new(address));
242            tracing::debug!(relay = %address, "Added relay node");
243        }
244    }
245
246    /// Remove a relay node
247    pub async fn remove_relay_node(&self, address: SocketAddr) {
248        let mut relays = self.relays.write().await;
249        if let Some(info) = relays.remove(&address) {
250            if info.client.is_some() {
251                self.stats.record_disconnect();
252            }
253            tracing::debug!(relay = %address, "Removed relay node");
254        }
255    }
256
257    /// Get list of available relay addresses
258    pub async fn available_relays(&self) -> Vec<SocketAddr> {
259        let relays = self.relays.read().await;
260        relays
261            .iter()
262            .filter(|(_, info)| {
263                info.available && info.can_retry(self.config.retry_delay, self.config.max_retries)
264            })
265            .map(|(addr, _)| *addr)
266            .collect()
267    }
268
269    /// Get a connected relay client for a specific relay
270    pub async fn get_relay_client(&self, relay: SocketAddr) -> Option<SocketAddr> {
271        let relays = self.relays.read().await;
272        let info = relays.get(&relay)?;
273        let client = info.client.as_ref()?;
274
275        // Check if still connected
276        if matches!(client.state().await, RelayConnectionState::Connected) {
277            Some(info.address)
278        } else {
279            None
280        }
281    }
282
283    /// Initiate relay connection (returns request to send)
284    pub fn create_connect_request(&self) -> ConnectUdpRequest {
285        ConnectUdpRequest::bind_any()
286    }
287
288    /// Handle relay connection response
289    pub async fn handle_connect_response(
290        &self,
291        relay: SocketAddr,
292        response: ConnectUdpResponse,
293    ) -> RelayResult<Option<SocketAddr>> {
294        if !response.is_success() {
295            let mut relays = self.relays.write().await;
296            if let Some(info) = relays.get_mut(&relay) {
297                info.mark_failed();
298            }
299            self.stats.record_attempt(false);
300            return Err(RelayError::SessionError {
301                session_id: None,
302                kind: SessionErrorKind::InvalidState {
303                    current_state: format!("HTTP {}", response.status),
304                    expected_state: "HTTP 200".into(),
305                },
306            });
307        }
308
309        // Create new client for this relay
310        let client = MasqueRelayClient::new(relay, self.config.client_config.clone());
311        client.handle_connect_response(response.clone()).await?;
312
313        let public_addr = response.proxy_public_address;
314
315        // Store the client
316        {
317            let mut relays = self.relays.write().await;
318            if let Some(info) = relays.get_mut(&relay) {
319                info.mark_connected(client);
320            }
321        }
322
323        self.stats.record_attempt(true);
324
325        tracing::info!(
326            relay = %relay,
327            public_addr = ?public_addr,
328            "Relay connection established"
329        );
330
331        Ok(public_addr)
332    }
333
334    /// Get our public address from any connected relay
335    pub async fn public_address(&self) -> Option<SocketAddr> {
336        let relays = self.relays.read().await;
337        for info in relays.values() {
338            if let Some(ref client) = info.client {
339                if let Some(addr) = client.public_address().await {
340                    return Some(addr);
341                }
342            }
343        }
344        None
345    }
346
347    /// Send datagram through relay
348    pub async fn send_via_relay(
349        &self,
350        relay: SocketAddr,
351        target: SocketAddr,
352        payload: Bytes,
353    ) -> RelayResult<()> {
354        let relays = self.relays.read().await;
355        let info = relays.get(&relay).ok_or(RelayError::SessionError {
356            session_id: None,
357            kind: SessionErrorKind::NotFound,
358        })?;
359
360        let _client = info.client.as_ref().ok_or(RelayError::SessionError {
361            session_id: None,
362            kind: SessionErrorKind::InvalidState {
363                current_state: "not connected".into(),
364                expected_state: "connected".into(),
365            },
366        })?;
367
368        // Note: In a full implementation, we would:
369        // 1. Get or create context for target
370        // 2. Send COMPRESSION_ASSIGN capsule if needed
371        // 3. Encode datagram with context ID
372        // 4. Send over QUIC datagram
373
374        self.stats.record_sent(payload.len() as u64);
375
376        tracing::trace!(
377            relay = %relay,
378            target = %target,
379            bytes = payload.len(),
380            "Sent datagram via relay"
381        );
382
383        Ok(())
384    }
385
386    /// Close all relay connections
387    pub async fn close_all(&self) {
388        self.active.store(false, Ordering::SeqCst);
389
390        let mut relays = self.relays.write().await;
391        for info in relays.values_mut() {
392            if let Some(ref client) = info.client {
393                client.close().await;
394            }
395            info.client = None;
396        }
397
398        tracing::info!("Closed all relay connections");
399    }
400
401    /// Get number of active relay connections
402    pub async fn active_relay_count(&self) -> usize {
403        let relays = self.relays.read().await;
404        relays.values().filter(|info| info.client.is_some()).count()
405    }
406
407    /// Check if relay fallback is available
408    pub async fn has_available_relay(&self) -> bool {
409        !self.available_relays().await.is_empty()
410    }
411}
412
413#[cfg(test)]
414mod tests {
415    use super::*;
416    use std::net::{IpAddr, Ipv4Addr};
417
418    fn relay_addr(id: u8) -> SocketAddr {
419        SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, id)), 9000)
420    }
421
422    #[tokio::test]
423    async fn test_manager_creation() {
424        let config = RelayManagerConfig::default();
425        let manager = RelayManager::new(config);
426
427        assert_eq!(manager.active_relay_count().await, 0);
428        assert!(!manager.has_available_relay().await);
429    }
430
431    #[tokio::test]
432    async fn test_add_relay_node() {
433        let config = RelayManagerConfig::default();
434        let manager = RelayManager::new(config);
435
436        manager.add_relay_node(relay_addr(1)).await;
437        assert!(manager.has_available_relay().await);
438
439        let available = manager.available_relays().await;
440        assert_eq!(available.len(), 1);
441        assert_eq!(available[0], relay_addr(1));
442    }
443
444    #[tokio::test]
445    async fn test_remove_relay_node() {
446        let config = RelayManagerConfig::default();
447        let manager = RelayManager::new(config);
448
449        manager.add_relay_node(relay_addr(1)).await;
450        assert!(manager.has_available_relay().await);
451
452        manager.remove_relay_node(relay_addr(1)).await;
453        assert!(!manager.has_available_relay().await);
454    }
455
456    #[tokio::test]
457    async fn test_relay_limit() {
458        let config = RelayManagerConfig {
459            max_relays: 2,
460            ..Default::default()
461        };
462        let manager = RelayManager::new(config);
463
464        manager.add_relay_node(relay_addr(1)).await;
465        manager.add_relay_node(relay_addr(2)).await;
466        manager.add_relay_node(relay_addr(3)).await; // Should be ignored
467
468        let available = manager.available_relays().await;
469        assert_eq!(available.len(), 2);
470    }
471
472    #[tokio::test]
473    async fn test_handle_success_response() {
474        let config = RelayManagerConfig::default();
475        let manager = RelayManager::new(config);
476
477        let relay = relay_addr(1);
478        manager.add_relay_node(relay).await;
479
480        let response = ConnectUdpResponse::success(Some(SocketAddr::new(
481            IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
482            12345,
483        )));
484
485        let result = manager.handle_connect_response(relay, response).await;
486        assert!(result.is_ok());
487        assert!(result.unwrap().is_some());
488
489        let stats = manager.stats();
490        assert_eq!(stats.successful_connections.load(Ordering::Relaxed), 1);
491    }
492
493    #[tokio::test]
494    async fn test_handle_error_response() {
495        let config = RelayManagerConfig::default();
496        let manager = RelayManager::new(config);
497
498        let relay = relay_addr(1);
499        manager.add_relay_node(relay).await;
500
501        let response = ConnectUdpResponse::error(503, "Server busy");
502
503        let result = manager.handle_connect_response(relay, response).await;
504        assert!(result.is_err());
505
506        let stats = manager.stats();
507        assert_eq!(stats.failed_connections.load(Ordering::Relaxed), 1);
508    }
509
510    #[tokio::test]
511    async fn test_stats() {
512        let config = RelayManagerConfig::default();
513        let manager = RelayManager::new(config);
514
515        let stats = manager.stats();
516        assert_eq!(stats.active_count(), 0);
517
518        stats.record_attempt(true);
519        assert_eq!(stats.active_count(), 1);
520
521        stats.record_disconnect();
522        assert_eq!(stats.active_count(), 0);
523    }
524
525    #[tokio::test]
526    async fn test_close_all() {
527        let config = RelayManagerConfig::default();
528        let manager = RelayManager::new(config);
529
530        manager.add_relay_node(relay_addr(1)).await;
531        manager.add_relay_node(relay_addr(2)).await;
532
533        manager.close_all().await;
534        // Should not panic
535    }
536}